Source code for radical.utils.zmq.queue

# pylint: disable=protected-access, abstract-class-instantiated

import sys
import zmq
import time

import threading as mt

from typing      import Optional

from ..atfork    import atfork
from ..config    import Config
from ..ids       import generate_id, ID_CUSTOM
from ..misc      import as_string, as_bytes, as_list, noop
from ..logger    import Logger
from ..profile   import Profiler
from ..debug     import print_exception_trace
from ..serialize import to_msgpack, from_msgpack

from .bridge     import Bridge
from .utils      import zmq_bind, no_intr
from .utils      import log_bulk, LOG_ENABLED
# from .utils    import prof_bulk


# --------------------------------------------------------------------------
#
_LINGER_TIMEOUT    =  250  # ms to linger after close
_HIGH_WATER_MARK   =    0  # number of messages to buffer before dropping
_DEFAULT_BULK_SIZE = 1024  # number of messages to put in a bulk


# ------------------------------------------------------------------------------
#
def _atfork_child():
    Getter._callbacks = dict()                                            # noqa


atfork(noop, noop, _atfork_child)


# ------------------------------------------------------------------------------
#
# Communication between components is done via queues.  Queues are
# uni-directional, ie. Queues have an input-end for which one can call 'put()',
# and and output-end, for which one can call 'get()'.
#
# The semantics we expect (and which is what is matched by the native Python
# `Queue.Queue`), is:
#
#   - multiple upstream   components put messages onto the same queue (input)
#   - multiple downstream components get messages from the same queue (output)
#   - local order of messages is maintained: order of messages pushed onto the
#     *same* input is preserved when pulled on any output
#   - message routing is fair: whatever downstream component calls 'get' first
#     will get the next message
#
# We implement the interface of Queue.Queue:
#
#   put(msg)
#   get()
#   get_nowait()
#
# Not implemented is, at the moment:
#
#   qsize
#   empty
#   full
#   put(msg, block, timeout)
#   put_nowait
#   get(block, timeout)
#   task_done
#
# Our Queue additionally takes 'name', 'role' and 'address' parameter on the
# constructor.  'role' can be 'input', 'bridge' or 'output', where 'input' is
# the end of a queue one can 'put()' messages into, and 'output' the end of the
# queue where one can 'get()' messages from. A 'bridge' acts as as a message
# forwarder.  'address' denominates a connection endpoint, and 'name' is
# a unique identifier: if multiple instances in the current process space use
# the same identifier, they will get the same queue instance (are connected to
# the same bridge).
#
[docs] class Queue(Bridge): def __init__(self, channel: str, cfg: Optional[dict] = None, log=None): ''' This Queue type sets up an zmq channel of this kind: input \\ // output == bridge == input // \\ output ie. any number of inputs can 'zmq.push()' to a bridge (which 'zmq.pull()'s), and any number of outputs can 'zmq.request()' messages from the bridge (which 'zmq.response()'s). The bridge is the entity which 'bind()'s network interfaces, both input and output type endpoints 'connect()' to it. It is the callees responsibility to ensure that only one bridge of a given type exists. Addresses are of the form 'tcp://host:port'. Both 'host' and 'port' can be wildcards for BRIDGE roles -- the bridge will report the in and out addresses as obj.addr_put and obj.addr_get. ''' if cfg: # create deep copy cfg = Config(cfg=cfg) else: cfg = Config() # ensure channel is set in config if cfg.channel: assert cfg.channel == channel else: cfg.channel = channel if not cfg.uid: cfg.uid = generate_id('%s.bridge.%%(counter)04d' % cfg.channel, ID_CUSTOM) super().__init__(cfg, log=log) self._bulk_size = self._cfg.get('bulk_size', 0) if self._bulk_size <= 0: self._bulk_size = _DEFAULT_BULK_SIZE # -------------------------------------------------------------------------- # # protocol independent addr query @property def type_in(self): return 'put' @property def type_out(self): return 'get' @property def addr_in(self): return self._addr_put @property def addr_out(self): return self._addr_get # protocol dependent addr query @property def addr_put(self): return self._addr_put @property def addr_get(self): return self._addr_get # -------------------------------------------------------------------------- # def _bridge_initialize(self): self._log.info('start bridge %s', self._uid) self._lock = mt.Lock() self._ctx = zmq.Context() # rely on GC for destruction self._put = self._ctx.socket(zmq.PULL) self._put.linger = _LINGER_TIMEOUT self._put.hwm = _HIGH_WATER_MARK self._addr_put = zmq_bind(self._put) self._get = self._ctx.socket(zmq.REP) self._get.linger = _LINGER_TIMEOUT self._get.hwm = _HIGH_WATER_MARK self._addr_get = zmq_bind(self._get) self._log.info('bridge in %s: %s', self._uid, self._addr_put) self._log.info('bridge out %s: %s', self._uid, self._addr_get) # start polling senders self._poll_put = zmq.Poller() self._poll_put.register(self._put, zmq.POLLIN) # start polling receivers self._poll_get = zmq.Poller() self._poll_get.register(self._get, zmq.POLLIN) # -------------------------------------------------------------------------- # def _bridge_work(self): # TODO: *always* pull for messages and buffer them. Serve requests from # that buffer. try: self.nin = 0 self.nout = 0 self.last = 0 buf = dict() while not self._term.is_set(): active = False # check for incoming messages, and buffer them ev_put = dict(no_intr(self._poll_put.poll, timeout=10)) # self._prof.prof('poll_put', msg=len(ev_put)) self._log.debug_9('polled put: %s', ev_put) if self._put in ev_put: with self._lock: data = list(no_intr(self._put.recv_multipart)) self._log.debug_9('recvd put: %s', data) if len(data) != 2: raise RuntimeError('%d frames unsupported' % len(data)) qname = as_string(from_msgpack(data[0])) msgs = from_msgpack(data[1]) # prof_bulk(self._prof, 'poll_put_recv', msgs) log_bulk(self._log, '<> %s' % qname, msgs) self._log.debug_9('put %s: %s ! ', qname, len(msgs)) if qname not in buf: buf[qname] = list() buf[qname] += msgs self.nin += len(msgs) active = True # check if somebody wants our messages ev_get = dict(no_intr(self._poll_get.poll, timeout=10)) # self._prof.prof('poll_get', msg=len(ev_get)) self._log.debug_9('polled get: %s [%s]', ev_get, self._get) if self._get in ev_get: # send up to `bulk_size` messages from the buffer # NOTE: this sends partial bulks on buffer underrun with self._lock: # the actual req message is ignored - we only care # about who sent it qname = as_string(no_intr(self._get.recv)) if not qname: qname = 'default' if qname in buf: msgs = buf[qname][:self._bulk_size] else: self._log.debug_9('get: %s not in %s', qname, list(buf.keys())) msgs = list() log_bulk(self._log, '>< %s' % qname, msgs) data = [to_msgpack(qname), to_msgpack(msgs)] active = True # self._log.debug_9('==== get %s: %s', qname, list(buf.keys())) # self._log.debug_9('==== get %s: %s', qname, list(buf.values())) # self._log.debug_9('==== get %s: %s ! [%s]', qname, len(msgs), # [[x, len(y)] for x,y in buf.items()]) no_intr(self._get.send_multipart, data) # prof_bulk(self._prof, 'poll_get_send', msgs=msgs, msg=req) self.nout += len(msgs) self.last = time.time() # remove sent messages from buffer if msgs: del buf[qname][:self._bulk_size] if not active: # self._prof.prof('sleep', msg=len(buf)) # let CPU sleep a bit when there is nothing to do # We don't want to use poll timouts since we use two # competing polls and don't want the idle channel slow down # the busy one. time.sleep(0.01) except Exception: self._log.exception('bridge failed')
[docs] def stop(self): Bridge.stop(self)
# ------------------------------------------------------------------------------ #
[docs] class Putter(object): # -------------------------------------------------------------------------- # def __init__(self, channel, url=None, log=None, prof=None, path=None): self._channel = channel self._url = as_string(url) self._log = log self._prof = prof self._lock = mt.Lock() self._uid = generate_id('%s.put.%%(counter)04d' % self._channel, ID_CUSTOM) if not self._url: self._url = Bridge.get_config(channel, path).get('put') if not self._url: raise ValueError('no contact url specified, no config found') if not self._log: if LOG_ENABLED: level = 'DEBUG_9' else : level = 'ERROR' self._log = Logger(name=self._uid, ns='radical.utils.zmq', level=level, path=path) if not self._prof: self._prof = Profiler(name=self._uid, ns='radical.utils', path=path) self._prof.disable() if 'hb' in self._uid or 'heartbeat' in self._uid: self._prof.disable() self._log.info('connect put to %s: %s', self._channel, self._url) self._ctx = zmq.Context() # rely on GC for destruction self._q = self._ctx.socket(zmq.PUSH) self._q.linger = _LINGER_TIMEOUT self._q.hwm = _HIGH_WATER_MARK self._q.connect(self._url) # -------------------------------------------------------------------------- # def __str__(self): return 'Putter(%s @ %s)' % (self.channel, self._url) @property def name(self): return self._uid @property def uid(self): return self._uid @property def channel(self): return self._channel # -------------------------------------------------------------------------- #
[docs] def put(self, msgs, qname=None): msgs = as_list(msgs) if not qname: qname = 'default' log_bulk(self._log, '-> %s[%s]' % (self._channel, qname), msgs) data = [to_msgpack(qname), to_msgpack(msgs)] with self._lock: no_intr(self._q.send_multipart, data)
# prof_bulk(self._prof, 'put', msgs) # ------------------------------------------------------------------------------ #
[docs] class Getter(object): # instead of creating a new listener thread for each endpoint which then, on # incoming messages, calls a getter callback, we only create *one* # listening thread per ZMQ endpoint address and call *all* registered # callbacks in that thread. We hold those endpoints in a class dict, so # that all class instances share that information _callbacks = dict() # -------------------------------------------------------------------------- # @staticmethod def _get_nowait(url, qname=None, timeout=None, uid=None): # timeout in ms info = Getter._callbacks[url] if not qname: qname = 'default' with info['lock']: if LOG_ENABLED: level = 'DEBUG_9' else : level = 'ERROR' logger = Logger(name=qname, ns='radical.utils.zmq', level=level) if not info['requested']: # send the request *once* per recieval (got lock above) # FIXME: why is this sent repeatedly? logger.debug_9('=> from %s[%s]', uid, qname) no_intr(info['socket'].send, as_bytes(qname)) info['requested'] = True if no_intr(info['socket'].poll, flags=zmq.POLLIN, timeout=timeout): data = list(no_intr(info['socket'].recv_multipart)) info['requested'] = False qname = as_string(from_msgpack(data[0])) msgs = as_string(from_msgpack(data[1])) log_bulk(logger, '<-1 %s [%s]' % (uid, qname), msgs) return msgs else: return None # -------------------------------------------------------------------------- # @staticmethod def _listener(url, qname=None, uid=None): ''' other than the pubsub listener, the queue listener will not deliver an incoming message to all subscribers, but only to exactly *one* subscriber. We this perform a round-robin over all known callbacks ''' if not qname: qname = 'default' assert url in Getter._callbacks time.sleep(0.01) try: term = Getter._callbacks.get(url, {}).get('term') idx = 0 # round-robin cb index while not term.is_set(): # this list is dynamic callbacks = Getter._callbacks[url]['callbacks'] if not callbacks: time.sleep(0.01) continue msgs = Getter._get_nowait(url, qname=qname, timeout=500, uid=uid) BULK = True if msgs: if BULK: idx += 1 if idx >= len(callbacks): idx = 0 cb, _lock = callbacks[idx] if _lock: with _lock: cb(as_string(msgs)) else: cb(as_string(msgs)) else: for m in as_list(msgs): idx += 1 if idx >= len(callbacks): idx = 0 cb, _lock = callbacks[idx] if _lock: with _lock: cb(as_string(m)) else: cb(as_string(m)) except Exception as e: print_exception_trace() sys.stderr.write('listener died: %s : %s : %s\n' % (qname, url, repr(e))) sys.stderr.flush() # -------------------------------------------------------------------------- # def _start_listener(self, qname=None): if not qname: qname = 'default' # only start if needed if Getter._callbacks[self._url]['thread']: return t = mt.Thread(target=Getter._listener, args=[self._url, qname, self._uid]) t.daemon = True t.start() Getter._callbacks[self._url]['thread'] = t # -------------------------------------------------------------------------- # def _stop_listener(self, force=False): # only stop listener if no callbacks remain registered (unless forced) if force or not Getter._callbacks[self._url]['callbacks']: if Getter._callbacks[self._url]['thread']: Getter._callbacks[self._url]['term' ].set() Getter._callbacks[self._url]['thread'].join() Getter._callbacks[self._url]['term' ].unset() Getter._callbacks[self._url]['thread'] = None # -------------------------------------------------------------------------- # def __init__(self, channel, url=None, cb=None, log=None, prof=None, path=None): ''' When a callback `cb` is specified, then the Getter c'tor will spawn a separate thread which continues to listen on the channel, and the cb is invoked on any incoming message. The message will be the only argument to the cb. ''' self._channel = channel self._url = as_string(url) self._lock = mt.Lock() self._log = log self._prof = prof self._uid = generate_id('%s.get.%%(counter)04d' % self._channel, ID_CUSTOM) if not self._url: self._url = Bridge.get_config(channel, path).get('get') if not self._url: raise ValueError('no contact url specified, no config found') if not self._log: if LOG_ENABLED: level = 'DEBUG_9' else : level = 'ERROR' self._log = Logger(name=self._uid, ns='radical.utils.zmq', level=level, path=path) if not self._prof: self._prof = Profiler(name=self._uid, ns='radical.utils', path=path) self._prof.disable() if 'hb' in self._uid or 'heartbeat' in self._uid: self._prof.disable() self._log.info('connect get to %s: %s', self._channel, self._url) self._requested = False # send/recv sync self._ctx = zmq.Context() # rely on GC for destruction self._q = self._ctx.socket(zmq.REQ) self._q.linger = _LINGER_TIMEOUT self._q.hwm = _HIGH_WATER_MARK self._q.connect(self._url) if url not in Getter._callbacks: Getter._callbacks[url] = {'uid' : self._uid, 'socket' : self._q, 'channel' : self._channel, 'lock' : mt.Lock(), 'term' : mt.Event(), 'requested': self._requested, 'thread' : None, 'callbacks': list()} if cb: self.subscribe(cb) else: self._interactive = True # -------------------------------------------------------------------------- # def __str__(self): return 'Getter(%s @ %s)' % (self.channel, self._url) @property def name(self): return self._uid @property def uid(self): return self._uid @property def channel(self): return self._channel # -------------------------------------------------------------------------- #
[docs] def subscribe(self, cb, lock=None): # if we need to serve callbacks, then open a thread to watch the socket # and register the callbacks. If a thread is already runnning on that # channel, just register the callback. # # Note that once a thread is watching a socket, we cannot allow to use # `get()` and `get_nowait()` anymore, as those will interfere with the # thread consuming the messages, # # The given lock (if any) is used to shield concurrent cb invokations. # # FIXME: clean up lock usage - see self._lock if self._url not in Getter._callbacks: Getter._callbacks[self._url] = {'uid' : self._uid, 'socket' : self._q, 'channel' : self._channel, 'lock' : mt.Lock(), 'term' : mt.Event(), 'requested': self._requested, 'thread' : None, 'callbacks': list()} # we allow only one cb per queue getter process at the moment, until we # have more clarity on the RR behavior of concurrent callbacks. if Getter._callbacks[self._url]['callbacks']: raise RuntimeError('multiple callbacks not supported') Getter._callbacks[self._url]['callbacks'].append([cb, lock]) self._interactive = False self._start_listener()
# -------------------------------------------------------------------------- #
[docs] def unsubscribe(self, cb): if self._url in Getter._callbacks: for _cb, _lock in Getter._callbacks[self._url]['callbacks']: if cb == _cb: Getter._callbacks[self._url]['callbacks'].remove([_cb, _lock]) break self._stop_listener()
# -------------------------------------------------------------------------- #
[docs] def stop(self): self._stop_listener(force=True)
# -------------------------------------------------------------------------- #
[docs] def get(self, qname=None): if not self._interactive: raise RuntimeError('invalid get(): callbacks are registered') if not qname: qname = 'default' # double-check: minimize lock use which is only needed for a very # rare race anyway if not self._requested: with self._lock: if not self._requested: self._log.debug_9('=> from %s[%s]', self._channel, qname) no_intr(self._q.send, as_bytes(qname)) self._requested = True # self._prof.prof('requested') with self._lock: data = list(no_intr(self._q.recv_multipart)) self._requested = False qname = from_msgpack(data[0]) msgs = from_msgpack(data[1]) log_bulk(self._log, '<-2 %s [%s]' % (self._channel, qname), msgs) return as_string(msgs)
# -------------------------------------------------------------------------- #
[docs] def get_nowait(self, qname=None, timeout=None): # timeout in ms if not self._interactive: raise RuntimeError('invalid get(): callbacks are registered') # backward compatibility to `get_nowait(timeout=None)` if timeout is None and isinstance(qname, int): timeout = qname qname = None if not qname: qname = 'default' if not self._requested: with self._lock: # need to protect self._requested if not self._requested: self._log.debug_9('=> from %s[%s]', self._channel, qname) no_intr(self._q.send_multipart, [as_bytes(qname)]) self._requested = True if no_intr(self._q.poll, flags=zmq.POLLIN, timeout=timeout): with self._lock: data = list(no_intr(self._q.recv_multipart)) self._requested = False qname = from_msgpack(data[0]) msgs = from_msgpack(data[1]) log_bulk(self._log, '<-3 %s [%s]' % (self._channel, qname), msgs) return as_string(msgs) else: return None
# ------------------------------------------------------------------------------ #
[docs] def test_queue(channel, addr_pub, addr_sub): c_a = 200 c_b = 400 data = dict() for i in 'ABCD': data[i] = dict() for j in 'AB': data[i][j] = 0 def cb(uid, msg): if msg['idx'] is None: return False data[uid][msg['src']] += 1 cb_C = lambda t,m: cb('C', m) cb_D = lambda t,m: cb('D', m) Getter(channel=channel, url=addr_sub, cb=cb_C) Getter(channel=channel, url=addr_sub, cb=cb_D) # -------------------------------------------------------------------------- def work_pub(uid, n, delay): pub = Putter(channel=channel, url=addr_pub) idx = 0 while idx < n: time.sleep(delay) pub.put({'src': uid, 'idx': idx}) idx += 1 data[uid][uid] += 1 # send EOF pub.put({'src': uid, 'idx': None}) # -------------------------------------------------------------------------- t_a = mt.Thread(target=work_pub, args=['A', c_a, 0.001]) t_b = mt.Thread(target=work_pub, args=['B', c_b, 0.001]) t_a.start() t_b.start() t_a.join() t_b.join() time.sleep(0.01) import pprint pprint.pprint(data) assert data['A']['A'] == c_a assert data['B']['B'] == c_b assert data['C']['A'] + data['C']['B'] + \ data['D']['A'] + data['D']['B'] == 2 * (c_a + c_b) return data
# ------------------------------------------------------------------------------