jpayne@69: import errno jpayne@69: import os jpayne@69: import selectors jpayne@69: import signal jpayne@69: import socket jpayne@69: import struct jpayne@69: import sys jpayne@69: import threading jpayne@69: import warnings jpayne@69: jpayne@69: from . import connection jpayne@69: from . import process jpayne@69: from .context import reduction jpayne@69: from . import resource_tracker jpayne@69: from . import spawn jpayne@69: from . import util jpayne@69: jpayne@69: __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', jpayne@69: 'set_forkserver_preload'] jpayne@69: jpayne@69: # jpayne@69: # jpayne@69: # jpayne@69: jpayne@69: MAXFDS_TO_SEND = 256 jpayne@69: SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t jpayne@69: jpayne@69: # jpayne@69: # Forkserver class jpayne@69: # jpayne@69: jpayne@69: class ForkServer(object): jpayne@69: jpayne@69: def __init__(self): jpayne@69: self._forkserver_address = None jpayne@69: self._forkserver_alive_fd = None jpayne@69: self._forkserver_pid = None jpayne@69: self._inherited_fds = None jpayne@69: self._lock = threading.Lock() jpayne@69: self._preload_modules = ['__main__'] jpayne@69: jpayne@69: def _stop(self): jpayne@69: # Method used by unit tests to stop the server jpayne@69: with self._lock: jpayne@69: self._stop_unlocked() jpayne@69: jpayne@69: def _stop_unlocked(self): jpayne@69: if self._forkserver_pid is None: jpayne@69: return jpayne@69: jpayne@69: # close the "alive" file descriptor asks the server to stop jpayne@69: os.close(self._forkserver_alive_fd) jpayne@69: self._forkserver_alive_fd = None jpayne@69: jpayne@69: os.waitpid(self._forkserver_pid, 0) jpayne@69: self._forkserver_pid = None jpayne@69: jpayne@69: os.unlink(self._forkserver_address) jpayne@69: self._forkserver_address = None jpayne@69: jpayne@69: def set_forkserver_preload(self, modules_names): jpayne@69: '''Set list of module names to try to load in forkserver process.''' jpayne@69: if not all(type(mod) is str for mod in self._preload_modules): jpayne@69: raise TypeError('module_names must be a list of strings') jpayne@69: self._preload_modules = modules_names jpayne@69: jpayne@69: def get_inherited_fds(self): jpayne@69: '''Return list of fds inherited from parent process. jpayne@69: jpayne@69: This returns None if the current process was not started by fork jpayne@69: server. jpayne@69: ''' jpayne@69: return self._inherited_fds jpayne@69: jpayne@69: def connect_to_new_process(self, fds): jpayne@69: '''Request forkserver to create a child process. jpayne@69: jpayne@69: Returns a pair of fds (status_r, data_w). The calling process can read jpayne@69: the child process's pid and (eventually) its returncode from status_r. jpayne@69: The calling process should write to data_w the pickled preparation and jpayne@69: process data. jpayne@69: ''' jpayne@69: self.ensure_running() jpayne@69: if len(fds) + 4 >= MAXFDS_TO_SEND: jpayne@69: raise ValueError('too many fds') jpayne@69: with socket.socket(socket.AF_UNIX) as client: jpayne@69: client.connect(self._forkserver_address) jpayne@69: parent_r, child_w = os.pipe() jpayne@69: child_r, parent_w = os.pipe() jpayne@69: allfds = [child_r, child_w, self._forkserver_alive_fd, jpayne@69: resource_tracker.getfd()] jpayne@69: allfds += fds jpayne@69: try: jpayne@69: reduction.sendfds(client, allfds) jpayne@69: return parent_r, parent_w jpayne@69: except: jpayne@69: os.close(parent_r) jpayne@69: os.close(parent_w) jpayne@69: raise jpayne@69: finally: jpayne@69: os.close(child_r) jpayne@69: os.close(child_w) jpayne@69: jpayne@69: def ensure_running(self): jpayne@69: '''Make sure that a fork server is running. jpayne@69: jpayne@69: This can be called from any process. Note that usually a child jpayne@69: process will just reuse the forkserver started by its parent, so jpayne@69: ensure_running() will do nothing. jpayne@69: ''' jpayne@69: with self._lock: jpayne@69: resource_tracker.ensure_running() jpayne@69: if self._forkserver_pid is not None: jpayne@69: # forkserver was launched before, is it still running? jpayne@69: pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG) jpayne@69: if not pid: jpayne@69: # still alive jpayne@69: return jpayne@69: # dead, launch it again jpayne@69: os.close(self._forkserver_alive_fd) jpayne@69: self._forkserver_address = None jpayne@69: self._forkserver_alive_fd = None jpayne@69: self._forkserver_pid = None jpayne@69: jpayne@69: cmd = ('from multiprocessing.forkserver import main; ' + jpayne@69: 'main(%d, %d, %r, **%r)') jpayne@69: jpayne@69: if self._preload_modules: jpayne@69: desired_keys = {'main_path', 'sys_path'} jpayne@69: data = spawn.get_preparation_data('ignore') jpayne@69: data = {x: y for x, y in data.items() if x in desired_keys} jpayne@69: else: jpayne@69: data = {} jpayne@69: jpayne@69: with socket.socket(socket.AF_UNIX) as listener: jpayne@69: address = connection.arbitrary_address('AF_UNIX') jpayne@69: listener.bind(address) jpayne@69: os.chmod(address, 0o600) jpayne@69: listener.listen() jpayne@69: jpayne@69: # all client processes own the write end of the "alive" pipe; jpayne@69: # when they all terminate the read end becomes ready. jpayne@69: alive_r, alive_w = os.pipe() jpayne@69: try: jpayne@69: fds_to_pass = [listener.fileno(), alive_r] jpayne@69: cmd %= (listener.fileno(), alive_r, self._preload_modules, jpayne@69: data) jpayne@69: exe = spawn.get_executable() jpayne@69: args = [exe] + util._args_from_interpreter_flags() jpayne@69: args += ['-c', cmd] jpayne@69: pid = util.spawnv_passfds(exe, args, fds_to_pass) jpayne@69: except: jpayne@69: os.close(alive_w) jpayne@69: raise jpayne@69: finally: jpayne@69: os.close(alive_r) jpayne@69: self._forkserver_address = address jpayne@69: self._forkserver_alive_fd = alive_w jpayne@69: self._forkserver_pid = pid jpayne@69: jpayne@69: # jpayne@69: # jpayne@69: # jpayne@69: jpayne@69: def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): jpayne@69: '''Run forkserver.''' jpayne@69: if preload: jpayne@69: if '__main__' in preload and main_path is not None: jpayne@69: process.current_process()._inheriting = True jpayne@69: try: jpayne@69: spawn.import_main_path(main_path) jpayne@69: finally: jpayne@69: del process.current_process()._inheriting jpayne@69: for modname in preload: jpayne@69: try: jpayne@69: __import__(modname) jpayne@69: except ImportError: jpayne@69: pass jpayne@69: jpayne@69: util._close_stdin() jpayne@69: jpayne@69: sig_r, sig_w = os.pipe() jpayne@69: os.set_blocking(sig_r, False) jpayne@69: os.set_blocking(sig_w, False) jpayne@69: jpayne@69: def sigchld_handler(*_unused): jpayne@69: # Dummy signal handler, doesn't do anything jpayne@69: pass jpayne@69: jpayne@69: handlers = { jpayne@69: # unblocking SIGCHLD allows the wakeup fd to notify our event loop jpayne@69: signal.SIGCHLD: sigchld_handler, jpayne@69: # protect the process from ^C jpayne@69: signal.SIGINT: signal.SIG_IGN, jpayne@69: } jpayne@69: old_handlers = {sig: signal.signal(sig, val) jpayne@69: for (sig, val) in handlers.items()} jpayne@69: jpayne@69: # calling os.write() in the Python signal handler is racy jpayne@69: signal.set_wakeup_fd(sig_w) jpayne@69: jpayne@69: # map child pids to client fds jpayne@69: pid_to_fd = {} jpayne@69: jpayne@69: with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ jpayne@69: selectors.DefaultSelector() as selector: jpayne@69: _forkserver._forkserver_address = listener.getsockname() jpayne@69: jpayne@69: selector.register(listener, selectors.EVENT_READ) jpayne@69: selector.register(alive_r, selectors.EVENT_READ) jpayne@69: selector.register(sig_r, selectors.EVENT_READ) jpayne@69: jpayne@69: while True: jpayne@69: try: jpayne@69: while True: jpayne@69: rfds = [key.fileobj for (key, events) in selector.select()] jpayne@69: if rfds: jpayne@69: break jpayne@69: jpayne@69: if alive_r in rfds: jpayne@69: # EOF because no more client processes left jpayne@69: assert os.read(alive_r, 1) == b'', "Not at EOF?" jpayne@69: raise SystemExit jpayne@69: jpayne@69: if sig_r in rfds: jpayne@69: # Got SIGCHLD jpayne@69: os.read(sig_r, 65536) # exhaust jpayne@69: while True: jpayne@69: # Scan for child processes jpayne@69: try: jpayne@69: pid, sts = os.waitpid(-1, os.WNOHANG) jpayne@69: except ChildProcessError: jpayne@69: break jpayne@69: if pid == 0: jpayne@69: break jpayne@69: child_w = pid_to_fd.pop(pid, None) jpayne@69: if child_w is not None: jpayne@69: if os.WIFSIGNALED(sts): jpayne@69: returncode = -os.WTERMSIG(sts) jpayne@69: else: jpayne@69: if not os.WIFEXITED(sts): jpayne@69: raise AssertionError( jpayne@69: "Child {0:n} status is {1:n}".format( jpayne@69: pid,sts)) jpayne@69: returncode = os.WEXITSTATUS(sts) jpayne@69: # Send exit code to client process jpayne@69: try: jpayne@69: write_signed(child_w, returncode) jpayne@69: except BrokenPipeError: jpayne@69: # client vanished jpayne@69: pass jpayne@69: os.close(child_w) jpayne@69: else: jpayne@69: # This shouldn't happen really jpayne@69: warnings.warn('forkserver: waitpid returned ' jpayne@69: 'unexpected pid %d' % pid) jpayne@69: jpayne@69: if listener in rfds: jpayne@69: # Incoming fork request jpayne@69: with listener.accept()[0] as s: jpayne@69: # Receive fds from client jpayne@69: fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) jpayne@69: if len(fds) > MAXFDS_TO_SEND: jpayne@69: raise RuntimeError( jpayne@69: "Too many ({0:n}) fds to send".format( jpayne@69: len(fds))) jpayne@69: child_r, child_w, *fds = fds jpayne@69: s.close() jpayne@69: pid = os.fork() jpayne@69: if pid == 0: jpayne@69: # Child jpayne@69: code = 1 jpayne@69: try: jpayne@69: listener.close() jpayne@69: selector.close() jpayne@69: unused_fds = [alive_r, child_w, sig_r, sig_w] jpayne@69: unused_fds.extend(pid_to_fd.values()) jpayne@69: code = _serve_one(child_r, fds, jpayne@69: unused_fds, jpayne@69: old_handlers) jpayne@69: except Exception: jpayne@69: sys.excepthook(*sys.exc_info()) jpayne@69: sys.stderr.flush() jpayne@69: finally: jpayne@69: os._exit(code) jpayne@69: else: jpayne@69: # Send pid to client process jpayne@69: try: jpayne@69: write_signed(child_w, pid) jpayne@69: except BrokenPipeError: jpayne@69: # client vanished jpayne@69: pass jpayne@69: pid_to_fd[pid] = child_w jpayne@69: os.close(child_r) jpayne@69: for fd in fds: jpayne@69: os.close(fd) jpayne@69: jpayne@69: except OSError as e: jpayne@69: if e.errno != errno.ECONNABORTED: jpayne@69: raise jpayne@69: jpayne@69: jpayne@69: def _serve_one(child_r, fds, unused_fds, handlers): jpayne@69: # close unnecessary stuff and reset signal handlers jpayne@69: signal.set_wakeup_fd(-1) jpayne@69: for sig, val in handlers.items(): jpayne@69: signal.signal(sig, val) jpayne@69: for fd in unused_fds: jpayne@69: os.close(fd) jpayne@69: jpayne@69: (_forkserver._forkserver_alive_fd, jpayne@69: resource_tracker._resource_tracker._fd, jpayne@69: *_forkserver._inherited_fds) = fds jpayne@69: jpayne@69: # Run process object received over pipe jpayne@69: parent_sentinel = os.dup(child_r) jpayne@69: code = spawn._main(child_r, parent_sentinel) jpayne@69: jpayne@69: return code jpayne@69: jpayne@69: jpayne@69: # jpayne@69: # Read and write signed numbers jpayne@69: # jpayne@69: jpayne@69: def read_signed(fd): jpayne@69: data = b'' jpayne@69: length = SIGNED_STRUCT.size jpayne@69: while len(data) < length: jpayne@69: s = os.read(fd, length - len(data)) jpayne@69: if not s: jpayne@69: raise EOFError('unexpected EOF') jpayne@69: data += s jpayne@69: return SIGNED_STRUCT.unpack(data)[0] jpayne@69: jpayne@69: def write_signed(fd, n): jpayne@69: msg = SIGNED_STRUCT.pack(n) jpayne@69: while msg: jpayne@69: nbytes = os.write(fd, msg) jpayne@69: if nbytes == 0: jpayne@69: raise RuntimeError('should not get here') jpayne@69: msg = msg[nbytes:] jpayne@69: jpayne@69: # jpayne@69: # jpayne@69: # jpayne@69: jpayne@69: _forkserver = ForkServer() jpayne@69: ensure_running = _forkserver.ensure_running jpayne@69: get_inherited_fds = _forkserver.get_inherited_fds jpayne@69: connect_to_new_process = _forkserver.connect_to_new_process jpayne@69: set_forkserver_preload = _forkserver.set_forkserver_preload