From c10d449e433d32551250e3862b773cb79ca5c21d Mon Sep 17 00:00:00 2001 From: Christopher Coverdale <18324680+ccdle12@users.noreply.github.com> Date: Mon, 27 Oct 2025 12:16:14 +0000 Subject: [PATCH] Add local TcpSocketPair (#2526) * Add extern fn socketpair() to posix * Add extern fn getsockname() for local socketpair loopback in windows * Add local TcpSocketPair * Add unit test for TcpSocketPair * Add implicit wsa startup --------- Co-authored-by: Christoffer Lerno --- lib/std/net/os/posix.c3 | 1 + lib/std/net/os/win32.c3 | 33 +++++++++++++- lib/std/net/socket_private.c3 | 16 +++++++ lib/std/net/tcp.c3 | 62 ++++++++++++++++++++++++++ releasenotes.md | 3 ++ test/unit/stdlib/net/tcp_socketpair.c3 | 19 ++++++++ 6 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 test/unit/stdlib/net/tcp_socketpair.c3 diff --git a/lib/std/net/os/posix.c3 b/lib/std/net/os/posix.c3 index 8cb86daf4..10085ab5f 100644 --- a/lib/std/net/os/posix.c3 +++ b/lib/std/net/os/posix.c3 @@ -22,6 +22,7 @@ extern fn CInt bind(NativeSocket socket, SockAddrPtr address, Socklen_t address_ extern fn CInt listen(NativeSocket socket, CInt backlog); extern fn NativeSocket accept(NativeSocket socket, SockAddrPtr address, Socklen_t* address_len); extern fn CInt poll(Posix_pollfd* fds, Posix_nfds_t nfds, CInt timeout); +extern fn CInt socketpair(AIFamily domain, AISockType type, CInt protocol, NativeSocket[2]* sv); const CUShort POLLIN = 0x0001; const CUShort POLLPRI = 0x0002; diff --git a/lib/std/net/os/win32.c3 b/lib/std/net/os/win32.c3 index c44942eaf..600511f61 100644 --- a/lib/std/net/os/win32.c3 +++ b/lib/std/net/os/win32.c3 @@ -1,5 +1,9 @@ module std::net::os @if(env::WIN32); -import std::os, std::io, libc; +import std::os, std::io, libc, std::thread; +import std::core::mem; +import std::os::win32; + + const AIFamily PLATFORM_AF_IPX = 6; const AIFamily PLATFORM_AF_APPLETALK = 16; @@ -21,6 +25,33 @@ extern fn int connect(NativeSocket, SockAddrPtr address, Socklen_t address_len); extern fn int bind(NativeSocket, SockAddrPtr address, Socklen_t address_len); extern fn int listen(NativeSocket, int backlog); extern fn NativeSocket accept(NativeSocket, SockAddrPtr address, Socklen_t* address_len); +extern fn CInt getsockname(NativeSocket socket, SockAddrPtr address, Socklen_t* address_len); + +char[408] wsa_data @local; +int wsa_init @local; + +macro void? start_wsa() +{ + if (mem::compare_exchange(&wsa_init, 0, 1) == 0) + { + Win32_WORD version = 0x0202; + CInt wsa_error = win32::wsaStartup(version, &wsa_data); + if (wsa_error > 0) + { + mem::@atomic_store(wsa_init, 0); + return os::socket_error()?; + } + } +} + +fn void close_wsa() @local @finalizer +{ + if (mem::compare_exchange(&wsa_init, 1, 0) == 1) + { + win32::wsaCleanup(); + mem::@atomic_store(wsa_init, 0); + } +} macro bool NativeSocket.is_valid(self) { diff --git a/lib/std/net/socket_private.c3 b/lib/std/net/socket_private.c3 index c663a349a..8da26130d 100644 --- a/lib/std/net/socket_private.c3 +++ b/lib/std/net/socket_private.c3 @@ -1,5 +1,9 @@ module std::net @if(os::SUPPORTS_INET); import std::time, libc, std::os; +import std::core::env; +import std::net::os; + + macro void? apply_sockoptions(sockfd, options) @private { @@ -9,6 +13,9 @@ macro void? apply_sockoptions(sockfd, options) @private fn Socket? connect_from_addrinfo(AddrInfo* addrinfo, SocketOption[] options) @private { + $if env::WIN32: + os::start_wsa()!; + $endif @loop_over_ai(addrinfo; NativeSocket sockfd, AddrInfo* ai) { apply_sockoptions(sockfd, options)!; @@ -37,6 +44,9 @@ fn bool last_error_is_delayed_connect() fn Socket? connect_with_timeout_from_addrinfo(AddrInfo* addrinfo, SocketOption[] options, Duration timeout) @private { + $if env::WIN32: + os::start_wsa()!; + $endif Clock c = 0; @loop_over_ai(addrinfo; NativeSocket sockfd, AddrInfo* ai) { @@ -82,6 +92,9 @@ fn Socket? connect_with_timeout_from_addrinfo(AddrInfo* addrinfo, SocketOption[] fn Socket? connect_async_from_addrinfo(AddrInfo* addrinfo, SocketOption[] options) @private { + $if env::WIN32: + os::start_wsa()!; + $endif @loop_over_ai(addrinfo; NativeSocket sockfd, AddrInfo* ai) { apply_sockoptions(sockfd, options)!; @@ -98,6 +111,9 @@ fn Socket? connect_async_from_addrinfo(AddrInfo* addrinfo, SocketOption[] option macro void @network_loop_over_ai(network, host, port; @body(fd, ai)) @private { + $if env::WIN32: + os::start_wsa()!; + $endif AddrInfo* ai = network.addrinfo(host, port)!; AddrInfo* first = ai; defer os::freeaddrinfo(first); diff --git a/lib/std/net/tcp.c3 b/lib/std/net/tcp.c3 index 33ce6c60c..644fcb594 100644 --- a/lib/std/net/tcp.c3 +++ b/lib/std/net/tcp.c3 @@ -1,6 +1,11 @@ module std::net::tcp @if(os::SUPPORTS_INET); import std::net @public; import std::time, libc; +import std::os::win32; +import std::core::env; +import std::net::os; + + typedef TcpSocket = inline Socket; typedef TcpServerSocket = inline Socket; @@ -44,6 +49,9 @@ fn TcpSocket? accept(TcpServerSocket* server_socket) { TcpSocket socket; socket.ai_addrlen = socket.ai_addr_storage.len; + $if env::WIN32: + os::start_wsa()!; + $endif socket.sock = os::accept(server_socket.sock, (SockAddrPtr)&socket.ai_addr_storage, &socket.ai_addrlen); if (!socket.sock.is_valid()) return net::ACCEPT_FAILED?; return socket; @@ -51,6 +59,9 @@ fn TcpSocket? accept(TcpServerSocket* server_socket) fn TcpServerSocket? listen_to(AddrInfo* ai, uint backlog, SocketOption... options) { + $if env::WIN32: + os::start_wsa()!; + $endif net::@loop_over_ai(ai; NativeSocket sockfd, AddrInfo* ai_candidate) { net::apply_sockoptions(sockfd, options)!; @@ -60,4 +71,55 @@ fn TcpServerSocket? listen_to(AddrInfo* ai, uint backlog, SocketOption... option return os::socket_error()?; } +struct TcpSocketPair +{ + TcpSocket send; + TcpSocket recv; +} +fn TcpSocketPair*? TcpSocketPair.init(&self) +{ + $if env::WIN32: + os::start_wsa()!; + + TcpServerSocket listen_sock = tcp::listen("127.0.0.1", 0, 0)!; + + TcpSocket listen_sock_info; + listen_sock_info.ai_addrlen = listen_sock.ai_addr_storage.len; + + int sock_result = os::getsockname(listen_sock.sock, (SockAddrPtr) &listen_sock_info.ai_addr_storage, &listen_sock_info.ai_addrlen); + if (sock_result < 0) return os::socket_error()?; + + char[] listen_port_bytes = listen_sock_info.ai_addr_storage[2:2]; + char msb = listen_port_bytes[0]; + char lsb = listen_port_bytes[1]; + int listen_port = (msb << 8) | lsb; + + defer (void)listen_sock.close(); + TcpSocket tcp_send_sock = tcp::connect_async("127.0.0.1", listen_port)!; + TcpSocket tcp_recv_sock = tcp::accept(&listen_sock)!; + + $else + NativeSocket[2] sockets; + isz sockpair_result = os::socketpair(os::AF_UNIX, os::SOCK_STREAM, 0, &sockets); + if (sockpair_result < 0) return os::socket_error()?; + + Socket send_sock = { .sock = sockets[0] }; + TcpSocket tcp_send_sock = (TcpSocket) send_sock; + + Socket recv_sock = { .sock = sockets[1] }; + TcpSocket tcp_recv_sock = (TcpSocket) recv_sock; + $endif + + *self = { .send = tcp_send_sock, .recv = tcp_recv_sock }; + return self; +} + +fn void? TcpSocketPair.destroy(&self) +{ + { + defer catch (void)self.recv.close(); + self.send.close()!; + } + self.recv.close()!; +} diff --git a/releasenotes.md b/releasenotes.md index 375165cd6..f6523f62f 100644 --- a/releasenotes.md +++ b/releasenotes.md @@ -36,6 +36,9 @@ ### Stdlib changes - Sorting functions correctly took slices by value, but also other types by value. Now, only slices are accepted by value, other containers are always by ref. - Added `@str_snakecase`, `@str_replace` and `@str_pascalcase` builtin compile time macros based on the `$$` builtins. +- Add TcpSocketPair to create a bidirectional local socket pair. +- Add `extern fn CInt socketpair(AIFamily domain, AISockType type, CInt protocol, NativeSocket[2]* sv)` binding to posix. +- Add `extern fn getsockname(NativeSocket socket, SockAddrPtr address, Socklen_t* address_len)` binding to win32. ## 0.7.6 Change list diff --git a/test/unit/stdlib/net/tcp_socketpair.c3 b/test/unit/stdlib/net/tcp_socketpair.c3 new file mode 100644 index 000000000..ca73714be --- /dev/null +++ b/test/unit/stdlib/net/tcp_socketpair.c3 @@ -0,0 +1,19 @@ +module tcpsockpairtest @test; +import std::net; + +fn void test_tcp_sock_pair() +{ + tcp::TcpSocketPair sockets; + sockets.init()!!; + defer sockets.destroy()!!; + + String expected = "hello, world!"; + sockets.send.write(expected)!!; + + char[100] recv_buf; + sockets.recv.read(&recv_buf)!!; + + String result = (String) recv_buf[0:expected.len]; + + assert(result.trim() == expected); +}