From 7441d4f3f33b4cc21912ff9a3b495f62abc3362e Mon Sep 17 00:00:00 2001
From: Alex Crichton <alex@alexcrichton.com>
Date: Tue, 15 Jul 2014 10:28:00 -0700
Subject: [PATCH] native: TCP close/close_accept for windows

This commit implements TcpAcceptor::{close, close_accept} for windows via
WSAEVENT types.
---
 src/libnative/io/c_windows.rs    |  33 ++++
 src/libnative/io/net.rs          | 255 ++++++++++++++++++++++---------
 src/libnative/io/pipe_windows.rs |  14 +-
 src/libnative/io/util.rs         |   4 +-
 4 files changed, 228 insertions(+), 78 deletions(-)

diff --git a/src/libnative/io/c_windows.rs b/src/libnative/io/c_windows.rs
index 80c9e91b48f..3bd850b5aac 100644
--- a/src/libnative/io/c_windows.rs
+++ b/src/libnative/io/c_windows.rs
@@ -26,6 +26,14 @@ pub static ENABLE_INSERT_MODE: libc::DWORD = 0x20;
 pub static ENABLE_LINE_INPUT: libc::DWORD = 0x2;
 pub static ENABLE_PROCESSED_INPUT: libc::DWORD = 0x1;
 pub static ENABLE_QUICK_EDIT_MODE: libc::DWORD = 0x40;
+pub static WSA_INVALID_EVENT: WSAEVENT = 0 as WSAEVENT;
+
+pub static FD_ACCEPT: libc::c_long = 0x08;
+pub static FD_MAX_EVENTS: uint = 10;
+pub static WSA_INFINITE: libc::DWORD = libc::INFINITE;
+pub static WSA_WAIT_TIMEOUT: libc::DWORD = libc::consts::os::extra::WAIT_TIMEOUT;
+pub static WSA_WAIT_EVENT_0: libc::DWORD = libc::consts::os::extra::WAIT_OBJECT_0;
+pub static WSA_WAIT_FAILED: libc::DWORD = libc::consts::os::extra::WAIT_FAILED;
 
 #[repr(C)]
 #[cfg(target_arch = "x86")]
@@ -52,6 +60,16 @@ pub struct WSADATA {
 
 pub type LPWSADATA = *mut WSADATA;
 
+#[repr(C)]
+pub struct WSANETWORKEVENTS {
+    pub lNetworkEvents: libc::c_long,
+    pub iErrorCode: [libc::c_int, ..FD_MAX_EVENTS],
+}
+
+pub type LPWSANETWORKEVENTS = *mut WSANETWORKEVENTS;
+
+pub type WSAEVENT = libc::HANDLE;
+
 #[repr(C)]
 pub struct fd_set {
     fd_count: libc::c_uint,
@@ -68,6 +86,21 @@ extern "system" {
     pub fn WSAStartup(wVersionRequested: libc::WORD,
                       lpWSAData: LPWSADATA) -> libc::c_int;
     pub fn WSAGetLastError() -> libc::c_int;
+    pub fn WSACloseEvent(hEvent: WSAEVENT) -> libc::BOOL;
+    pub fn WSACreateEvent() -> WSAEVENT;
+    pub fn WSAEventSelect(s: libc::SOCKET,
+                          hEventObject: WSAEVENT,
+                          lNetworkEvents: libc::c_long) -> libc::c_int;
+    pub fn WSASetEvent(hEvent: WSAEVENT) -> libc::BOOL;
+    pub fn WSAWaitForMultipleEvents(cEvents: libc::DWORD,
+                                    lphEvents: *const WSAEVENT,
+                                    fWaitAll: libc::BOOL,
+                                    dwTimeout: libc::DWORD,
+                                    fAltertable: libc::BOOL) -> libc::DWORD;
+    pub fn WSAEnumNetworkEvents(s: libc::SOCKET,
+                                hEventObject: WSAEVENT,
+                                lpNetworkEvents: LPWSANETWORKEVENTS)
+                                -> libc::c_int;
 
     pub fn ioctlsocket(s: libc::SOCKET, cmd: libc::c_long,
                        argp: *mut libc::c_ulong) -> libc::c_int;
diff --git a/src/libnative/io/net.rs b/src/libnative/io/net.rs
index 7a8a363a0a3..daa1b25e407 100644
--- a/src/libnative/io/net.rs
+++ b/src/libnative/io/net.rs
@@ -11,6 +11,7 @@
 use alloc::arc::Arc;
 use libc;
 use std::mem;
+use std::ptr;
 use std::rt::mutex;
 use std::rt::rtio;
 use std::rt::rtio::{IoResult, IoError};
@@ -19,16 +20,16 @@ use std::sync::atomics;
 use super::{retry, keep_going};
 use super::c;
 use super::util;
-use super::file::FileDesc;
-use super::process;
+
+#[cfg(unix)] use super::process;
+#[cfg(unix)] use super::file::FileDesc;
+
+pub use self::os::{init, sock_t, last_error};
 
 ////////////////////////////////////////////////////////////////////////////////
 // sockaddr and misc bindings
 ////////////////////////////////////////////////////////////////////////////////
 
-#[cfg(windows)] pub type sock_t = libc::SOCKET;
-#[cfg(unix)]    pub type sock_t = super::file::fd_t;
-
 pub fn htons(u: u16) -> u16 {
     u.to_be()
 }
@@ -100,7 +101,7 @@ fn socket(addr: rtio::SocketAddr, ty: libc::c_int) -> IoResult<sock_t> {
             rtio::Ipv6Addr(..) => libc::AF_INET6,
         };
         match libc::socket(fam, ty, 0) {
-            -1 => Err(super::last_error()),
+            -1 => Err(os::last_error()),
             fd => Ok(fd),
         }
     }
@@ -114,7 +115,7 @@ fn setsockopt<T>(fd: sock_t, opt: libc::c_int, val: libc::c_int,
                                    payload,
                                    mem::size_of::<T>() as libc::socklen_t);
         if ret != 0 {
-            Err(last_error())
+            Err(os::last_error())
         } else {
             Ok(())
         }
@@ -130,7 +131,7 @@ pub fn getsockopt<T: Copy>(fd: sock_t, opt: libc::c_int,
                                 &mut slot as *mut _ as *mut _,
                                 &mut len);
         if ret != 0 {
-            Err(last_error())
+            Err(os::last_error())
         } else {
             assert!(len as uint == mem::size_of::<T>());
             Ok(slot)
@@ -138,25 +139,6 @@ pub fn getsockopt<T: Copy>(fd: sock_t, opt: libc::c_int,
     }
 }
 
-#[cfg(windows)]
-pub fn last_error() -> IoError {
-    use std::os;
-    let code = unsafe { c::WSAGetLastError() as uint };
-    IoError {
-        code: code,
-        extra: 0,
-        detail: Some(os::error_string(code)),
-    }
-}
-
-#[cfg(not(windows))]
-fn last_error() -> IoError {
-    super::last_error()
-}
-
-#[cfg(windows)] unsafe fn close(sock: sock_t) { let _ = libc::closesocket(sock); }
-#[cfg(unix)]    unsafe fn close(sock: sock_t) { let _ = libc::close(sock); }
-
 fn sockname(fd: sock_t,
             f: unsafe extern "system" fn(sock_t, *mut libc::sockaddr,
                                          *mut libc::socklen_t) -> libc::c_int)
@@ -170,7 +152,7 @@ fn sockname(fd: sock_t,
                     storage as *mut libc::sockaddr,
                     &mut len as *mut libc::socklen_t);
         if ret != 0 {
-            return Err(last_error())
+            return Err(os::last_error())
         }
     }
     return sockaddr_to_addr(&storage, len as uint);
@@ -224,28 +206,6 @@ pub fn sockaddr_to_addr(storage: &libc::sockaddr_storage,
     }
 }
 
-#[cfg(unix)]
-pub fn init() {}
-
-#[cfg(windows)]
-pub fn init() {
-
-    unsafe {
-        use std::rt::mutex::{StaticNativeMutex, NATIVE_MUTEX_INIT};
-        static mut INITIALIZED: bool = false;
-        static mut LOCK: StaticNativeMutex = NATIVE_MUTEX_INIT;
-
-        let _guard = LOCK.lock();
-        if !INITIALIZED {
-            let mut data: c::WSADATA = mem::zeroed();
-            let ret = c::WSAStartup(0x202,      // version 2.2
-                                    &mut data);
-            assert_eq!(ret, 0);
-            INITIALIZED = true;
-        }
-    }
-}
-
 ////////////////////////////////////////////////////////////////////////////////
 // TCP streams
 ////////////////////////////////////////////////////////////////////////////////
@@ -292,7 +252,7 @@ impl TcpStream {
             },
             None => {
                 match retry(|| unsafe { libc::connect(fd, addrp, len) }) {
-                    -1 => Err(last_error()),
+                    -1 => Err(os::last_error()),
                     _ => Ok(ret),
                 }
             }
@@ -438,7 +398,7 @@ impl rtio::RtioSocket for TcpStream {
 }
 
 impl Drop for Inner {
-    fn drop(&mut self) { unsafe { close(self.fd); } }
+    fn drop(&mut self) { unsafe { os::close(self.fd); } }
 }
 
 #[unsafe_destructor]
@@ -474,7 +434,7 @@ impl TcpListener {
         }
 
         match unsafe { libc::bind(fd, addrp, len) } {
-            -1 => Err(last_error()),
+            -1 => Err(os::last_error()),
             _ => Ok(ret),
         }
     }
@@ -482,9 +442,8 @@ impl TcpListener {
     pub fn fd(&self) -> sock_t { self.inner.fd }
 
     pub fn native_listen(self, backlog: int) -> IoResult<TcpAcceptor> {
-        try!(util::set_nonblocking(self.fd(), true));
         match unsafe { libc::listen(self.fd(), backlog as libc::c_int) } {
-            -1 => Err(last_error()),
+            -1 => Err(os::last_error()),
 
             #[cfg(unix)]
             _ => {
@@ -502,6 +461,26 @@ impl TcpListener {
                     deadline: 0,
                 })
             }
+
+            #[cfg(windows)]
+            _ => {
+                let accept = try!(os::Event::new());
+                let ret = unsafe {
+                    c::WSAEventSelect(self.fd(), accept.handle(), c::FD_ACCEPT)
+                };
+                if ret != 0 {
+                    return Err(os::last_error())
+                }
+                Ok(TcpAcceptor {
+                    inner: Arc::new(AcceptorInner {
+                        listener: self,
+                        abort: try!(os::Event::new()),
+                        accept: accept,
+                        closed: atomics::AtomicBool::new(false),
+                    }),
+                    deadline: 0,
+                })
+            }
         }
     }
 }
@@ -534,6 +513,14 @@ struct AcceptorInner {
     closed: atomics::AtomicBool,
 }
 
+#[cfg(windows)]
+struct AcceptorInner {
+    listener: TcpListener,
+    abort: os::Event,
+    accept: os::Event,
+    closed: atomics::AtomicBool,
+}
+
 impl TcpAcceptor {
     pub fn fd(&self) -> sock_t { self.inner.listener.fd() }
 
@@ -542,20 +529,12 @@ impl TcpAcceptor {
         let deadline = if self.deadline == 0 {None} else {Some(self.deadline)};
 
         while !self.inner.closed.load(atomics::SeqCst) {
-            unsafe {
-                let mut storage: libc::sockaddr_storage = mem::zeroed();
-                let storagep = &mut storage as *mut libc::sockaddr_storage;
-                let size = mem::size_of::<libc::sockaddr_storage>();
-                let mut size = size as libc::socklen_t;
-                match retry(|| {
-                    libc::accept(self.fd(),
-                                 storagep as *mut libc::sockaddr,
-                                 &mut size as *mut libc::socklen_t) as libc::c_int
-                }) as sock_t {
-                    -1 if util::wouldblock() => {}
-                    -1 => return Err(last_error()),
-                    fd => return Ok(TcpStream::new(Inner::new(fd))),
-                }
+            match retry(|| unsafe {
+                libc::accept(self.fd(), ptr::mut_null(), ptr::mut_null())
+            }) {
+                -1 if util::wouldblock() => {}
+                -1 => return Err(os::last_error()),
+                fd => return Ok(TcpStream::new(Inner::new(fd as sock_t))),
             }
             try!(util::await([self.fd(), self.inner.reader.fd()],
                              deadline, util::Readable));
@@ -563,6 +542,50 @@ impl TcpAcceptor {
 
         Err(util::eof())
     }
+
+    #[cfg(windows)]
+    pub fn native_accept(&mut self) -> IoResult<TcpStream> {
+        let events = [self.inner.abort.handle(), self.inner.accept.handle()];
+
+        while !self.inner.closed.load(atomics::SeqCst) {
+            let ms = if self.deadline == 0 {
+                c::WSA_INFINITE as u64
+            } else {
+                let now = ::io::timer::now();
+                if self.deadline < now {0} else {now - self.deadline}
+            };
+            let ret = unsafe {
+                c::WSAWaitForMultipleEvents(2, events.as_ptr(), libc::FALSE,
+                                            ms as libc::DWORD, libc::FALSE)
+            };
+            match ret {
+                c::WSA_WAIT_TIMEOUT => {
+                    return Err(util::timeout("accept timed out"))
+                }
+                c::WSA_WAIT_FAILED => return Err(os::last_error()),
+                c::WSA_WAIT_EVENT_0 => break,
+                n => assert_eq!(n, c::WSA_WAIT_EVENT_0 + 1),
+            }
+            println!("woke up");
+
+            let mut wsaevents: c::WSANETWORKEVENTS = unsafe { mem::zeroed() };
+            let ret = unsafe {
+                c::WSAEnumNetworkEvents(self.fd(), events[1], &mut wsaevents)
+            };
+            if ret != 0 { return Err(os::last_error()) }
+
+            if wsaevents.lNetworkEvents & c::FD_ACCEPT == 0 { continue }
+            match unsafe {
+                libc::accept(self.fd(), ptr::mut_null(), ptr::mut_null())
+            } {
+                -1 if util::wouldblock() => {}
+                -1 => return Err(os::last_error()),
+                fd => return Ok(TcpStream::new(Inner::new(fd))),
+            }
+        }
+
+        Err(util::eof())
+    }
 }
 
 impl rtio::RtioSocket for TcpAcceptor {
@@ -599,6 +622,17 @@ impl rtio::RtioTcpAcceptor for TcpAcceptor {
             Err(e) => Err(e),
         }
     }
+
+    #[cfg(windows)]
+    fn close_accept(&mut self) -> IoResult<()> {
+        self.inner.closed.store(true, atomics::SeqCst);
+        let ret = unsafe { c::WSASetEvent(self.inner.abort.handle()) };
+        if ret == libc::TRUE {
+            Ok(())
+        } else {
+            Err(os::last_error())
+        }
+    }
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -625,7 +659,7 @@ impl UdpSocket {
         let addrp = &storage as *const _ as *const libc::sockaddr;
 
         match unsafe { libc::bind(fd, addrp, len) } {
-            -1 => Err(last_error()),
+            -1 => Err(os::last_error()),
             _ => Ok(ret),
         }
     }
@@ -881,7 +915,7 @@ pub fn read<T>(fd: sock_t,
             let _guard = lock();
             match retry(|| read(deadline.is_some())) {
                 -1 if util::wouldblock() => { assert!(deadline.is_some()); }
-                -1 => return Err(last_error()),
+                -1 => return Err(os::last_error()),
                n => { ret = n; break }
             }
         }
@@ -889,7 +923,7 @@ pub fn read<T>(fd: sock_t,
 
     match ret {
         0 => Err(util::eof()),
-        n if n < 0 => Err(last_error()),
+        n if n < 0 => Err(os::last_error()),
         n => Ok(n as uint)
     }
 }
@@ -940,15 +974,88 @@ pub fn write<T>(fd: sock_t,
             let len = buf.len() - written;
             match retry(|| write(deadline.is_some(), ptr, len) as libc::c_int) {
                 -1 if util::wouldblock() => {}
-                -1 => return Err(last_error()),
+                -1 => return Err(os::last_error()),
                 n => { written += n as uint; }
             }
         }
         ret = 0;
     }
     if ret < 0 {
-        Err(last_error())
+        Err(os::last_error())
     } else {
         Ok(written)
     }
 }
+
+#[cfg(windows)]
+mod os {
+    use libc;
+    use std::mem;
+    use std::rt::rtio::{IoError, IoResult};
+
+    use io::c;
+
+    pub type sock_t = libc::SOCKET;
+    pub struct Event(c::WSAEVENT);
+
+    impl Event {
+        pub fn new() -> IoResult<Event> {
+            let event = unsafe { c::WSACreateEvent() };
+            if event == c::WSA_INVALID_EVENT {
+                Err(last_error())
+            } else {
+                Ok(Event(event))
+            }
+        }
+
+        pub fn handle(&self) -> c::WSAEVENT { let Event(handle) = *self; handle }
+    }
+
+    impl Drop for Event {
+        fn drop(&mut self) {
+            unsafe { let _ = c::WSACloseEvent(self.handle()); }
+        }
+    }
+
+    pub fn init() {
+        unsafe {
+            use std::rt::mutex::{StaticNativeMutex, NATIVE_MUTEX_INIT};
+            static mut INITIALIZED: bool = false;
+            static mut LOCK: StaticNativeMutex = NATIVE_MUTEX_INIT;
+
+            let _guard = LOCK.lock();
+            if !INITIALIZED {
+                let mut data: c::WSADATA = mem::zeroed();
+                let ret = c::WSAStartup(0x202,      // version 2.2
+                                        &mut data);
+                assert_eq!(ret, 0);
+                INITIALIZED = true;
+            }
+        }
+    }
+
+    pub fn last_error() -> IoError {
+        use std::os;
+        let code = unsafe { c::WSAGetLastError() as uint };
+        IoError {
+            code: code,
+            extra: 0,
+            detail: Some(os::error_string(code)),
+        }
+    }
+
+    pub unsafe fn close(sock: sock_t) { let _ = libc::closesocket(sock); }
+}
+
+#[cfg(unix)]
+mod os {
+    use libc;
+    use std::rt::rtio::IoError;
+    use io;
+
+    pub type sock_t = io::file::fd_t;
+
+    pub fn init() {}
+    pub fn last_error() -> IoError { io::last_error() }
+    pub unsafe fn close(sock: sock_t) { let _ = libc::close(sock); }
+}
diff --git a/src/libnative/io/pipe_windows.rs b/src/libnative/io/pipe_windows.rs
index 6ad51ee586f..4d01230cbd9 100644
--- a/src/libnative/io/pipe_windows.rs
+++ b/src/libnative/io/pipe_windows.rs
@@ -99,10 +99,10 @@ use super::c;
 use super::util;
 use super::file::to_utf16;
 
-pub struct Event(libc::HANDLE);
+struct Event(libc::HANDLE);
 
 impl Event {
-    pub fn new(manual_reset: bool, initial_state: bool) -> IoResult<Event> {
+    fn new(manual_reset: bool, initial_state: bool) -> IoResult<Event> {
         let event = unsafe {
             libc::CreateEventW(ptr::mut_null(),
                                manual_reset as libc::BOOL,
@@ -116,7 +116,7 @@ impl Event {
         }
     }
 
-    pub fn handle(&self) -> libc::HANDLE { let Event(handle) = *self; handle }
+    fn handle(&self) -> libc::HANDLE { let Event(handle) = *self; handle }
 }
 
 impl Drop for Event {
@@ -709,5 +709,13 @@ impl rtio::RtioUnixAcceptor for UnixAcceptor {
     fn set_timeout(&mut self, timeout: Option<u64>) {
         self.deadline = timeout.map(|i| i + ::io::timer::now()).unwrap_or(0);
     }
+
+    fn clone(&self) -> Box<rtio::RtioUnixAcceptor + Send> {
+        fail!()
+    }
+
+    fn close_accept(&mut self) -> IoResult<()> {
+        fail!()
+    }
 }
 
diff --git a/src/libnative/io/util.rs b/src/libnative/io/util.rs
index aec29bc2d03..c5b1bbec4f1 100644
--- a/src/libnative/io/util.rs
+++ b/src/libnative/io/util.rs
@@ -194,7 +194,9 @@ pub fn await(fds: &[net::sock_t], deadline: Option<u64>,
                 &mut tv as *mut _
             }
         };
-        let r = unsafe { c::select(max, read, write, ptr::mut_null(), tvp) };
+        let r = unsafe {
+            c::select(max as libc::c_int, read, write, ptr::mut_null(), tvp)
+        };
         r
     }) {
         -1 => Err(last_error()),