Source code for radical.utils.zmq.server


import zmq

import threading as mt

from typing      import Optional, Union, Iterator, Any, Dict

from ..ids       import generate_id
from ..url       import Url
from ..misc      import as_string
from ..host      import get_hostip
from ..logger    import Logger
from ..profile   import Profiler
from ..debug     import get_exception_trace
from ..serialize import to_msgpack, from_msgpack

from .utils      import no_intr


# --------------------------------------------------------------------------
#
_LINGER_TIMEOUT    =         250  # ms to linger after close
_HIGH_WATER_MARK   = 1024 * 1024  # number of messages to buffer before dropping
_DEFAULT_BULK_SIZE =        1024  # number of messages to put in a bulk


# ------------------------------------------------------------------------------
#
[docs] class Server(object): # -------------------------------------------------------------------------- # def __init__(self, url: Optional[str] = None, uid: Optional[str] = None, path: Optional[str] = None) -> None: # this server offers only synchronous communication: a request will be # worked upon and answered before the next request is received. self._url = url self._cbs = dict() self._path = path if not self._path: self._path = './' if uid: self._uid = uid else : self._uid = generate_id('server', ns='radical.utils') self._log = Logger(self._uid, path=self._path) self._prof = Profiler(self._uid, path=self._path) self._addr = None self._thread = None self._up = mt.Event() self._term = mt.Event() self.register_request('echo', self._request_echo) self.register_request('fail', self._request_fail) if not self._url: self._url = 'tcp://*:10000-11000' # URLs can specify port ranges to use - check if that is the case (see # default above) and initilize iterator. The URL is expected to have # the form: # # <proto>://<iface>:<ports>/ # # where # <proto>: any protocol accepted by zmq, defaults to `tcp` # <iface>: IP number of interface to bind to defaults to `*` # <ports>: port range to find port to bind to defaults to `*` # # The port range can be formed as: # # '*' : any port # '100+' : any port equal or larger than 100 # '100-' : any port equal or larger than 100 # '100-110': any port equal or larger than 100, up to 110 tmp = self._url.split(':', 2) assert len(tmp) == 3 self._proto = tmp[0] self._iface = tmp[1].lstrip('/') self._ports = tmp[2].replace('+', '-') tmp = self._ports.split('-') self._port_this : Union[int, str, None] = None self._port_start: Optional[int] self._port_stop : Optional[int] if len(tmp) == 0: self._port_start = 1 self._port_stop = None elif len(tmp) == 1: if tmp[0] == '*': self._port_this = '*' self._port_start = None self._port_stop = None else: self._port_start = int(tmp[0]) self._port_stop = int(tmp[0]) elif len(tmp) == 2: if tmp[0]: self._port_start = int(tmp[0]) else : self._port_start = 1 if tmp[1]: self._port_stop = int(tmp[1]) else : self._port_stop = None else: raise RuntimeError('cannot parse port spec %s' % self._ports) # -------------------------------------------------------------------------- # def _iterate_ports(self) -> Iterator[Union[int, str, None]]: if self._port_this == '*': # leave scanning to zmq yield self._port_this if self._port_this is None: # initialize range iterator self._port_this = self._port_start if self._port_stop is None: while True: yield self._port_this self._port_this += 1 else: # make type checker happy assert isinstance(self._port_this, int) assert isinstance(self._port_start, int) while self._port_this <= self._port_stop: yield self._port_this self._port_this += 1 # -------------------------------------------------------------------------- # def _iterate_urls(self) -> Iterator[str]: for port in self._iterate_ports(): yield '%s://%s:%s' % (self._proto, self._iface, port) # -------------------------------------------------------------------------- # @property def uid(self) -> str: return self._uid @property def addr(self) -> Optional[str]: return self._addr # -------------------------------------------------------------------------- #
[docs] def start(self) -> None: self._log.info('start bridge %s', self._uid) if self._thread: raise RuntimeError('`start()` can be called only once') self._thread = mt.Thread(target=self._work) self._thread.daemon = True self._thread.start() self._up.wait()
# -------------------------------------------------------------------------- #
[docs] def stop(self) -> None: self._log.info('stop bridge %s', self._uid) self._term.set()
# -------------------------------------------------------------------------- #
[docs] def wait(self) -> None: self._log.info('wait bridge %s', self._uid) if self._thread: self._thread.join() self._log.info('wait bridge %s', self._uid)
# -------------------------------------------------------------------------- #
[docs] def register_request(self, req, cb) -> None: self._log.info('add handler: %s: %s', req, cb) self._cbs[req] = cb
# -------------------------------------------------------------------------- # def _request_fail(self, arg) -> None: raise RuntimeError('task failed successfully') # -------------------------------------------------------------------------- # def _request_echo(self, arg: Any) -> Any: return arg # -------------------------------------------------------------------------- # def _success(self, res: Optional[str] = None) -> Dict[str, Optional[str]]: return {'err': None, 'exc': None, 'res': res} # -------------------------------------------------------------------------- # def _error(self, err: Optional[str] = None, exc: Optional[str] = None) -> Dict[str, Optional[str]]: if not err: err = 'invalid request' return {'err': err, 'exc': exc, 'res': None} # -------------------------------------------------------------------------- # def _work(self) -> None: self._ctx = zmq.Context() self._sock = self._ctx.socket(zmq.REP) self._sock.linger = _LINGER_TIMEOUT self._sock.hwm = _HIGH_WATER_MARK for url in self._iterate_urls(): try: self._log.debug('try url %s', url) self._sock.bind(url) self._log.debug('success') break except zmq.error.ZMQError as e: if 'Address already in use' in str(e): self._log.warn('port in use - try next (%s)' % url) else: raise addr = Url(as_string(self._sock.getsockopt(zmq.LAST_ENDPOINT))) addr.host = get_hostip() self._addr = str(addr) self._up.set() self._poll = zmq.Poller() self._poll.register(self._sock, zmq.POLLIN) while not self._term.is_set(): event = dict(no_intr(self._poll.poll, timeout=100)) if self._sock not in event: continue # default response rep = None req = None try: data = no_intr(self._sock.recv) req = as_string(from_msgpack(data)) self._log.debug('req: %s', str(req)[:128]) if not isinstance(req, dict): rep = self._error(err='invalid message type') else: cmd = req['cmd'] args = req['args'] kwargs = req['kwargs'] if not cmd: rep = self._error(err='no command in request') elif cmd not in self._cbs: rep = self._error(err='command [%s] unknown' % cmd) else: rep = self._success(self._cbs[cmd](*args, **kwargs)) except Exception as e: self._log.exception('command failed: %s', req) rep = self._error(err='command failed: %s' % str(e), exc='\n'.join(get_exception_trace())) finally: if not rep: rep = self._error('server error') no_intr(self._sock.send, to_msgpack(rep)) self._log.debug('rep: %s', str(rep)[:128]) self._sock.close() self._log.debug('term')
# ------------------------------------------------------------------------------