jpayne@7: from __future__ import annotations jpayne@7: jpayne@7: import select jpayne@7: import socket jpayne@7: from functools import partial jpayne@7: jpayne@7: __all__ = ["wait_for_read", "wait_for_write"] jpayne@7: jpayne@7: jpayne@7: # How should we wait on sockets? jpayne@7: # jpayne@7: # There are two types of APIs you can use for waiting on sockets: the fancy jpayne@7: # modern stateful APIs like epoll/kqueue, and the older stateless APIs like jpayne@7: # select/poll. The stateful APIs are more efficient when you have a lots of jpayne@7: # sockets to keep track of, because you can set them up once and then use them jpayne@7: # lots of times. But we only ever want to wait on a single socket at a time jpayne@7: # and don't want to keep track of state, so the stateless APIs are actually jpayne@7: # more efficient. So we want to use select() or poll(). jpayne@7: # jpayne@7: # Now, how do we choose between select() and poll()? On traditional Unixes, jpayne@7: # select() has a strange calling convention that makes it slow, or fail jpayne@7: # altogether, for high-numbered file descriptors. The point of poll() is to fix jpayne@7: # that, so on Unixes, we prefer poll(). jpayne@7: # jpayne@7: # On Windows, there is no poll() (or at least Python doesn't provide a wrapper jpayne@7: # for it), but that's OK, because on Windows, select() doesn't have this jpayne@7: # strange calling convention; plain select() works fine. jpayne@7: # jpayne@7: # So: on Windows we use select(), and everywhere else we use poll(). We also jpayne@7: # fall back to select() in case poll() is somehow broken or missing. jpayne@7: jpayne@7: jpayne@7: def select_wait_for_socket( jpayne@7: sock: socket.socket, jpayne@7: read: bool = False, jpayne@7: write: bool = False, jpayne@7: timeout: float | None = None, jpayne@7: ) -> bool: jpayne@7: if not read and not write: jpayne@7: raise RuntimeError("must specify at least one of read=True, write=True") jpayne@7: rcheck = [] jpayne@7: wcheck = [] jpayne@7: if read: jpayne@7: rcheck.append(sock) jpayne@7: if write: jpayne@7: wcheck.append(sock) jpayne@7: # When doing a non-blocking connect, most systems signal success by jpayne@7: # marking the socket writable. Windows, though, signals success by marked jpayne@7: # it as "exceptional". We paper over the difference by checking the write jpayne@7: # sockets for both conditions. (The stdlib selectors module does the same jpayne@7: # thing.) jpayne@7: fn = partial(select.select, rcheck, wcheck, wcheck) jpayne@7: rready, wready, xready = fn(timeout) jpayne@7: return bool(rready or wready or xready) jpayne@7: jpayne@7: jpayne@7: def poll_wait_for_socket( jpayne@7: sock: socket.socket, jpayne@7: read: bool = False, jpayne@7: write: bool = False, jpayne@7: timeout: float | None = None, jpayne@7: ) -> bool: jpayne@7: if not read and not write: jpayne@7: raise RuntimeError("must specify at least one of read=True, write=True") jpayne@7: mask = 0 jpayne@7: if read: jpayne@7: mask |= select.POLLIN jpayne@7: if write: jpayne@7: mask |= select.POLLOUT jpayne@7: poll_obj = select.poll() jpayne@7: poll_obj.register(sock, mask) jpayne@7: jpayne@7: # For some reason, poll() takes timeout in milliseconds jpayne@7: def do_poll(t: float | None) -> list[tuple[int, int]]: jpayne@7: if t is not None: jpayne@7: t *= 1000 jpayne@7: return poll_obj.poll(t) jpayne@7: jpayne@7: return bool(do_poll(timeout)) jpayne@7: jpayne@7: jpayne@7: def _have_working_poll() -> bool: jpayne@7: # Apparently some systems have a select.poll that fails as soon as you try jpayne@7: # to use it, either due to strange configuration or broken monkeypatching jpayne@7: # from libraries like eventlet/greenlet. jpayne@7: try: jpayne@7: poll_obj = select.poll() jpayne@7: poll_obj.poll(0) jpayne@7: except (AttributeError, OSError): jpayne@7: return False jpayne@7: else: jpayne@7: return True jpayne@7: jpayne@7: jpayne@7: def wait_for_socket( jpayne@7: sock: socket.socket, jpayne@7: read: bool = False, jpayne@7: write: bool = False, jpayne@7: timeout: float | None = None, jpayne@7: ) -> bool: jpayne@7: # We delay choosing which implementation to use until the first time we're jpayne@7: # called. We could do it at import time, but then we might make the wrong jpayne@7: # decision if someone goes wild with monkeypatching select.poll after jpayne@7: # we're imported. jpayne@7: global wait_for_socket jpayne@7: if _have_working_poll(): jpayne@7: wait_for_socket = poll_wait_for_socket jpayne@7: elif hasattr(select, "select"): jpayne@7: wait_for_socket = select_wait_for_socket jpayne@7: return wait_for_socket(sock, read, write, timeout) jpayne@7: jpayne@7: jpayne@7: def wait_for_read(sock: socket.socket, timeout: float | None = None) -> bool: jpayne@7: """Waits for reading to be available on a given socket. jpayne@7: Returns True if the socket is readable, or False if the timeout expired. jpayne@7: """ jpayne@7: return wait_for_socket(sock, read=True, timeout=timeout) jpayne@7: jpayne@7: jpayne@7: def wait_for_write(sock: socket.socket, timeout: float | None = None) -> bool: jpayne@7: """Waits for writing to be available on a given socket. jpayne@7: Returns True if the socket is readable, or False if the timeout expired. jpayne@7: """ jpayne@7: return wait_for_socket(sock, write=True, timeout=timeout)