Source code for piwheels.transport

# The piwheels project
#   Copyright (c) 2017 Ben Nuttall <https://github.com/bennuttall>
#   Copyright (c) 2017 Dave Jones <dave@waveform.org.uk>
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#     * Neither the name of the copyright holder nor the
#       names of its contributors may be used to endorse or promote products
#       derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
This module augments the classes provided by pyzmq (the 0MQ Python bindings)
to use CBOR encoding, and voluptuous for message validation. It also tweaks a
few minor things like using seconds for timeouts.

.. autoclass:: Context
    :members:

.. autoclass:: Socket
    :members:

.. autoclass:: Poller
    :members:
"""

import logging
import ipaddress as ip
import datetime as dt
from binascii import hexlify

import zmq
from voluptuous import Invalid

import cbor2
from .protocols import Protocol, NoData


PUSH = zmq.PUSH
PULL = zmq.PULL
REQ = zmq.REQ
REP = zmq.REP
PUB = zmq.PUB
SUB = zmq.SUB
ROUTER = zmq.ROUTER
DEALER = zmq.DEALER

NOBLOCK = zmq.NOBLOCK
POLLIN = zmq.POLLIN
POLLOUT = zmq.POLLOUT

SUBSCRIBE = zmq.SUBSCRIBE
UNSUBSCRIBE = zmq.UNSUBSCRIBE

Error = zmq.ZMQError
Again = zmq.error.Again


def default_encoder(encoder, value):
    if isinstance(value, dt.timedelta):
        encoder.encode(
            cbor2.CBORTag(2001, (
                value.days, value.seconds, value.microseconds)))
    elif value is NoData:
        encoder.encode(cbor2.CBORTag(2002, None))
    else:
        raise cbor2.CBOREncodeError(
            'cannot serialize type %s' % value.__class__.__name__)


def default_decoder(decoder, tag):
    if tag.tag == 2001:
        days, seconds, microseconds = tag.value
        return dt.timedelta(
            days=days, seconds=seconds, microseconds=microseconds)
    elif tag.tag == 2002:
        return NoData
    return tag


[docs]class Context: """ Wrapper for 0MQ :class:`zmq.Context`. This extends the :meth:`socket` method to include parameters for the socket's protocol and logger. """ def __init__(self): self._context = zmq.Context.instance() def socket(self, sock_type, *, protocol=None, logger=None): return Socket(self._context.socket(sock_type), protocol, logger) def close(self, linger=1): self._context.destroy(linger=linger * 1000) self._context.term()
[docs]class Socket: """ Wrapper for :class:`zmq.Socket`. This extends 0MQ's sockets to include a protocol which will be used to validate messages that are sent and received (via a voluptuous schema), and a logger which can be used to debug socket behaviour. """ def __init__(self, socket, protocol=None, logger=None): if logger is None: logger = logging.getLogger() if protocol is None: protocol = Protocol() self._logger = logger self._socket = socket self._protocol = protocol self._socket.ipv6 = True def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_tb): self.close() def _dump_msg(self, msg, data=NoData): try: schema = self._protocol.send[msg] except KeyError: raise IOError('unknown message: %s' % msg) if data is NoData: if schema is not NoData: raise IOError('data must be specified for %s' % msg) return cbor2.dumps(msg, default=default_encoder) else: if schema is NoData: raise IOError('no data expected for %s' % msg) try: data = schema(data) except Invalid as e: raise IOError('invalid data for %s: %r' % (msg, data)) try: return cbor2.dumps((msg, data), default=default_encoder) except cbor2.CBOREncodeError as e: raise IOError('unable to serialize data') def _load_msg(self, buf): try: msg = cbor2.loads(buf, tag_hook=default_decoder) except cbor2.CBORDecodeError as e: raise IOError('unable to deserialize data') if isinstance(msg, str): try: schema = self._protocol.recv[msg] except KeyError: raise IOError('unknown message: %s' % msg) if schema is NoData: return msg, None raise IOError('missing data for: %s' % msg) else: try: msg, data = msg except (TypeError, ValueError): raise IOError('invalid message structure received') try: schema = self._protocol.recv[msg] except KeyError: raise IOError('unknown message: %s' % msg) if schema is NoData: raise IOError('data not expected for: %s' % msg) try: return msg, schema(data) except Invalid as e: raise IOError('invalid data for %s: %r' % (msg, data)) @property def hwm(self): """ The high-water mark of the socket, i.e. the number of messages that can be queued before the socket blocks (or drops, depending on the socket type) messages. """ return self._socket.hwm @hwm.setter def hwm(self, value): self._socket.hwm = value
[docs] def bind(self, address): """ Binds the socket to listen on the specified *address*. """ return self._socket.bind(address)
[docs] def connect(self, address): """ Connects the socket to the listening socket at *address*. """ return self._socket.connect(address)
[docs] def close(self, linger=None): """ Closes the socket. If *linger* is specified, it is the number of seconds to wait for pending messages to be flushed. """ return self._socket.close( linger=linger if linger is None else linger * 1000)
[docs] def subscribe(self, topic): """ Subscribes SUB type sockets to the specified *topic* (a string prefix). """ self._socket.setsockopt_string(SUBSCRIBE, topic)
[docs] def unsubscribe(self, topic): """ Unsubscribes SUB type sockets from the specified *topic* (a string prefix). """ self._socket.setsockopt_string(UNSUBSCRIBE, topic)
[docs] def poll(self, timeout=None, flags=POLLIN): """ Polls the socket for pending data (by default, when *flags* is POLLIN). If no data is available after *timeout* seconds, returns False. Otherwise returns True. If *flags* is POLLOUT instead, tests whether the socket has available slots for queueing new messages. """ return self._socket.poll( timeout if timeout is None else timeout * 1000, flags)
[docs] def send(self, buf, flags=0): """ Send *buf* (a :class:`bytes` string). """ self._logger.debug('>> %s', buf) return self._socket.send(buf, flags)
[docs] def recv(self, flags=0): """ Receives the next message as a :class:`bytes` string. """ buf = self._socket.recv(flags) self._logger.debug('<< %s', buf) return buf
[docs] def drain(self): """ Receives all pending messages in the queue and discards them. This is typically useful during shutdown routines or for testing. """ while self.poll(0): self.recv()
[docs] def send_multipart(self, msg_parts, flags=0): """ Send *msg_parts*, a list of :class:`bytes` strings as a multi-part message which can be received intact with :meth:`recv_multipart`. """ self._logger.debug('>>' + (' %s' * len(msg_parts)), *msg_parts) return self._socket.send_multipart(msg_parts, flags)
[docs] def recv_multipart(self, flags=0): """ Receives a multi-part message, returning its content as a list of :class:`bytes` strings. """ msg_parts = self._socket.recv_multipart(flags) self._logger.debug('<<' + (' %s' * len(msg_parts)), *msg_parts) return msg_parts
[docs] def send_msg(self, msg, data=NoData, flags=0): """ Send the unicode string *msg* with its associated *data* as a CBOR-encoded message. This is the primary method used in piwheels for sending information between tasks. The message, and its associated data, must validate against the :attr:`protocol` associated with the socket on construction. """ self._logger.debug('>> %s %r', msg, data) return self._socket.send(self._dump_msg(msg, data), flags)
[docs] def recv_msg(self, flags=0): """ Receive a CBOR-encoded message, returning a tuple of the unicode message string and its associated data. This is the primary method used in piwheels for receving information into a task. The message, and its associated data, will be validated agains the :attr:`protocol` associated with the socket on construction. """ msg, data = self._load_msg(self._socket.recv(flags)) self._logger.debug('<< %s %r', msg, data) return msg, data
[docs] def send_addr_msg(self, addr, msg, data=NoData, flags=0): """ Send a CBOR-encoded message (and associated data) to *addr*, a :class:`bytes` string. """ self._logger.debug('>> %s %s %r', hexlify(addr).decode('ascii'), msg, data) self._socket.send_multipart([addr, b'', self._dump_msg(msg, data)], flags)
[docs] def recv_addr_msg(self, flags=0): """ Receive a CBOR-encoded message (and associated data) along with the address it came from (represented as a :class:`bytes` string). """ try: addr, empty, buf = self._socket.recv_multipart(flags) except ValueError: raise IOError('invalid message structure received') msg, data = self._load_msg(buf) self._logger.debug('<< %s %s %r', hexlify(addr).decode('ascii'), msg, data) return addr, msg, data
[docs]class Poller: """ Wrapper for 0MQ :class:`zmq.Poller`. This simply tweaks 0MQ's poller to use seconds for timeouts, and to return a :class:`dict` by default from :meth:`poll`. """ def __init__(self): self._poller = zmq.Poller() self._map = {}
[docs] def register(self, sock, flags=POLLIN | POLLOUT): """ Register *sock* with the poller, watching for events as specified by *flags* (which defaults to POLLIN and POLLOUT events). """ if isinstance(sock, Socket): self._map[sock._socket] = sock return self._poller.register(sock._socket, flags) else: return self._poller.register(sock, flags)
[docs] def unregister(self, sock): """ Unregister *sock* from the poller. After this, calls to :meth:`poll` will never return references to *sock*. """ if isinstance(sock, Socket): self._poller.unregister(sock._socket) del self._map[sock._socket] else: self._poller.unregister(sock)
[docs] def poll(self, timeout=None): """ Poll all registered sockets for the events they were registered with, for *timeout* seconds. Returns a dictionary mapping sockets to events or an empty dictinoary if the *timeout* elapsed with no events occurring. """ return { self._map.get(sock, sock): event for sock, event in self._poller.poll( timeout if timeout is None else timeout * 1000) }