diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..95f8a12 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,13 @@ +[report] + +source = + structparse +# deadserver + +exclude_lines = + pragma: no cover + def __repr__ + raise NotImplementedError + if 0: + +show_missing = True diff --git a/.gitignore b/.gitignore index 077a157..600d8ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.py[co] -__pycache__/ venv/ +config.py +.coverage diff --git a/README.md b/README.md index c64e5c2..9a413a0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +**Note: This README is out of date. TODO.** + server: the DB + "manager" ========================== @@ -15,5 +17,30 @@ Setup - fish: `. venv/bin/activate.fish` - csh, tcsh: `source venv/bin/activate.csh` 3. Install dependencies if necessary: `pip install -r requirements.txt` +4. configure: `cp config.py{.example,}; $EDITOR config.py` +5. create DB tables: `psql -U -f dbinit.sql` +6. run with `./runserver.py`; + run the HTTP API server with `./runhttp.py` + +Running tests +------------- + +1. Edit configuration: `cp tests/config.py{.example,}; $EDITOR tests/config.py` + A real Postgres DB is used, you need to specify the connection string. +2. create DB tables: `psql -U -f dbinit.sql` +3. run with `py.test tests/` + or `py.test --cov gateserver/ --cov-report term-missing tests/` for coverage report + +Next to do: +----------- + +- split DB table to controller and door +- on request arrival check if client IP matches the one in DB for this ID +- HTTP: rewrite to use Werkzeug instead of CherryPy +- fix DB singleton (who wants a singleton?!) +- CI + +Style Guide & such +------------------ -The rest doesn't exist yet. +[PEP-8](https://www.python.org/dev/peps/pep-0008/), `import this`. Also: code and design reviews. diff --git a/config.py.example b/config.py.example new file mode 100644 index 0000000..884cf31 --- /dev/null +++ b/config.py.example @@ -0,0 +1,9 @@ +"""The server configuration.""" + +http_host = '0.0.0.0' +http_port = 5047 + +udp_host = '0.0.0.0' # Use the actual IP address for UDP! +udp_port = 5042 + +db_url = 'postgresql://user:password@localhost/deadlock' diff --git a/controller_client.py b/controller_client.py new file mode 100644 index 0000000..c042671 --- /dev/null +++ b/controller_client.py @@ -0,0 +1,42 @@ +"""Quick & dirty client (i.e. the controller end), used for manual testing of the server.""" + +import socket +import os +import sys + +import records + +import config +from deadserver.api import * +from deadserver.protocol import * + +api = API(config=config, db=records.Database(config.db_url)) + +def msg(buf): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.sendto(buf, (config.udp_host, config.udp_port)) + return sock.recv(1024) + +def send(id, msgtype, data): + nonce = os.urandom(18) + req = Request(msgtype.value, data) + req_packet = make_packet(id, nonce, req, get_key=api.get_key) + res_packet = msg(req_packet.pack()) + return parse_packet(Response, res_packet, get_key=api.get_key) + +if __name__ == '__main__': + mac, msgtype = sys.argv[1:] + try: + t = MsgType[msgtype.upper()] + except KeyError: + sys.exit('No such message type: '+msgtype) + + indata = sys.stdin.buffer.read() + + hdr, res = send(str2id(mac), t, indata) + + print(' * * * as {} sent request: {}'.format(mac, str(t))) + print(indata) + print(' * * * received response: {}'.format(str(t))) + print(str(res.status)) + print(res.data) diff --git a/dbinit.sql b/dbinit.sql new file mode 100644 index 0000000..71a12cc --- /dev/null +++ b/dbinit.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS controller ( + id macaddr PRIMARY KEY, + ip inet UNIQUE NOT NULL, + key bytea NOT NULL, + name text +); +CREATE TABLE IF NOT EXISTS log ( + time timestamp NOT NULL, + ctrl_id macaddr REFERENCES controller, + message text +); diff --git a/deadserver/__init__.py b/deadserver/__init__.py new file mode 100644 index 0000000..8ca01a4 --- /dev/null +++ b/deadserver/__init__.py @@ -0,0 +1 @@ +"""The Deadlock server -- communicates with controllers.""" diff --git a/deadserver/api.py b/deadserver/api.py new file mode 100644 index 0000000..296edef --- /dev/null +++ b/deadserver/api.py @@ -0,0 +1,43 @@ +"""The controller ↔ server API -- the business logic. + +This knows what should happen for a given request. See `controller_protocol` for +the message format details. +""" + +from . import handlers +from . import protocol + +class API: + def __init__(self, config, db): + self.config = config + self.db = db + + def handle_packet(self, in_buf): + try: + request_header, request = protocol.parse_packet(protocol.Request, in_buf, self.get_key) + handler = handlers.get_handler_for(request.msg_type) + status, response = handler(request_header.controller_id, request.data.val, api=self) + self.log_message(request_header.controller_id, request, status) + response_packet = protocol.make_response_packet_for(request_header, request.msg_type, + status, response, get_key=self.get_key) + return response_packet.pack() + except protocol.BadMessageError as e: + self.log_bad_message(in_buf, e) + + # TODO if protocol crypto and insides were better separated, this could just create a + # {de,en}cryption black box and thereby avoid telling the key to anyone else. + def get_key(self, id): + """Loads the key for this controller from the DB.""" + rows = self.db.query('SELECT key FROM controller WHERE id = :id', + id=protocol.id2str(id)).all() + protocol.check(len(rows) == 1, 'unknown controller ID') + return bytes(rows[0]['key']) + + def log_message(self, controller_id, request, status): + """TODO""" + # print(utils.bytes2mac(controller_id), mtype.name, indata, '->', status.name) + print(protocol.id2str(controller_id), request, '->', status.name) + + def log_bad_message(self, buf, e): + """TODO""" + raise e diff --git a/deadserver/handlers/__init__.py b/deadserver/handlers/__init__.py new file mode 100644 index 0000000..49aa513 --- /dev/null +++ b/deadserver/handlers/__init__.py @@ -0,0 +1,35 @@ +"""Collects request handlers for the various message types. + +How to write a request handler: + +Your module must define `function(controller_id, request_data) -> status, response_data`. This must +be registered as a handler for a request type using the `@handles(deadserver.protocol.MsgType)` +decorator. See below for note on importing. + +Hello world example: + +```python +from deadserver.protocol import MsgType, ResponseStatus + +@handles(deadserver.protocol.MsgType.HELLO) # actually, this doesn't exist, but if it did... +def handle_hello(controller_id, data): + return ResponseStatus.OK, b'Hello ' + controller_id + b'! You sent: ' + data +``` + +See `./open.py` for a real-world example. + +---------------------------------------------------------------------------------------------------- + +In order to be executed (and therefore registered), your handler module must be imported somewhere. +This file is a good place for that, as it is imported by `deadserver.api`. Unless you have a reason +to do this differently, add your handlers below. +""" + +### LIST OF ALL STANDARD HANDLER IMPORTS ########################################################### + +from . import echotest +from . import open + +#################################################################################################### + +from .defs import get_handler_for # for more convenient access diff --git a/deadserver/handlers/defs.py b/deadserver/handlers/defs.py new file mode 100644 index 0000000..4360224 --- /dev/null +++ b/deadserver/handlers/defs.py @@ -0,0 +1,13 @@ +"""Provides functions for defining request handlers, such as the `handles(msg_type)` decorator.""" + +_all_handlers = {} + +def handles(msg_type): + def decorator(fn): + _all_handlers[msg_type] = fn + fn.handles = msg_type + return fn + return decorator + +def get_handler_for(msg_type): + return _all_handlers[msg_type] diff --git a/deadserver/handlers/echotest.py b/deadserver/handlers/echotest.py new file mode 100644 index 0000000..3555229 --- /dev/null +++ b/deadserver/handlers/echotest.py @@ -0,0 +1,9 @@ +"""Handler for ECHOTEST requests.""" + +from ..protocol import MsgType, ResponseStatus + +from .defs import handles + +@handles(MsgType.ECHOTEST) +def handle_hello(controller_id, data, api): + return ResponseStatus.OK, data diff --git a/deadserver/handlers/open.py b/deadserver/handlers/open.py new file mode 100644 index 0000000..a97666c --- /dev/null +++ b/deadserver/handlers/open.py @@ -0,0 +1,17 @@ +"""Handler for OPEN requests.""" + +from .defs import handles +from . import utils + +from structparse import struct, types +from deadserver.protocol import MsgType, ResponseStatus + +CardId = types.PascalStr(12) + +OpenRequest = struct('OpenRequest', (CardId, 'card_id')) + +@handles(MsgType.OPEN) +@utils.unpack_indata_as(OpenRequest) +def handle(controller_id, data, api): + status = ResponseStatus.OK if data.card_id == CardId('hello') else ResponseStatus.ERR + return status, None diff --git a/deadserver/handlers/utils.py b/deadserver/handlers/utils.py new file mode 100644 index 0000000..0cf024f --- /dev/null +++ b/deadserver/handlers/utils.py @@ -0,0 +1,26 @@ +"""Provides helpers for conveniently defining request handlers.""" + +import functools + +from .. import protocol + +def unpack_indata_as(struct): + """Decorator to first unpack the request data buffer into the given struct.""" + def decorator(fn): + @functools.wraps(fn) + def decorated(controller_id, indata, api): + try: + unpacked = struct.unpack(indata) + except ValueError as e: + raise protocol.BadMessageError('parsing data failed') from e + return fn(controller_id, unpacked, api) + return decorated + return decorator + +def pack_outdata(fn): + """Decorator to pack the structured response data into bytes.""" + @functools.wraps(fn) + def decorated(controller_id, indata, api): + status, outdata = fn(controller_id, indata, api) + return status, outdata.pack() + return decorated diff --git a/deadserver/protocol.py b/deadserver/protocol.py new file mode 100644 index 0000000..3ded85e --- /dev/null +++ b/deadserver/protocol.py @@ -0,0 +1,104 @@ +"""The controller ↔ server protocol message structure. + +This knows the data format for the various structures in the protocol. See +`controller_api` for the behavior / business logic. +""" + +# TODO: consider [CBOR](http://cbor.io/). +# TODO: With the faster processor, we probably can afford assymetric crypto. Switch if possible. +# TODO: separate the 2 layers of the protocol +# TODO: ... and then blackboxes instead of secret keys + +from structparse import struct, types +import nacl.secret +import enum + + +class BadMessageError(Exception): pass + +def check(expression, errmsg): + if not expression: raise BadMessageError(errmsg) + + +PROTOCOL_VERSION = types.Bytes(2)([0,1]) + +class MsgType(types.Uint8, enum.Enum): + OPEN = 1 + ECHOTEST = 254 + +class ResponseStatus(types.Uint8, enum.Enum): + OK = 0x01 + ERR = 0x10 + TRY_AGAIN = 0x11 + +Request = struct('Request', + (MsgType, 'msg_type'), + (types.Tail, 'data' )) + +Response = struct('Response', + (MsgType, 'msg_type'), + (ResponseStatus, 'status' ), + (types.Tail, 'data' )) + +PacketHeader = struct('PacketHeader', + (types.Bytes(2), 'protocol_version'), + (types.Bytes(6), 'controller_id' ), + (types.Bytes(18), 'nonce' )) + +Packet = struct('Packet', + (PacketHeader, 'header'), + (types.Tail, 'payload')) + + +def id2str(id): + return ':'.join('{:02x}'.format(x) for x in id.val) + +def str2id(s): + return types.Bytes(6)(bytes.fromhex(s.replace(':', ''))) + + +def crypto_unwrap_payload(nonce, payload, key): + return nacl.secret.SecretBox(key).decrypt(payload, nonce) + +def crypto_wrap_payload(nonce, payload, key): + # Note: encrypt returns the ciphertext prepended by nonce. We don't want this, so strip it. + return nacl.secret.SecretBox(key).encrypt(payload, nonce)[nacl.secret.SecretBox.NONCE_SIZE:] + +def parse_packet(struct, buf, get_key): + """Parses the buffer into PacketHeader and `struct`, which must be Request or Response. + + Decrypts the payload with the key returned by the `get_key` function. + """ + assert struct in [Request, Response] + + try: + hdr, tail = PacketHeader.unpack_from(buf) + check(hdr.protocol_version == PROTOCOL_VERSION, 'Invalid protocol version') + payload_buf = crypto_unwrap_payload(hdr.controller_id.val + hdr.nonce.val, + tail, get_key(hdr.controller_id)) + payload = struct.unpack(payload_buf) + except ValueError as e: + raise BadMessageError('parsing packet failed') from e + + return hdr, payload + +def make_packet(controller_id, nonce, payload, get_key): + """Packs and encrypts the packet headers, request/response headers and data.""" + encrypted = crypto_wrap_payload(controller_id.val + nonce, payload.pack(), get_key(controller_id)) + return Packet(PacketHeader(protocol_version=PROTOCOL_VERSION, + controller_id=controller_id, + nonce=nonce), + encrypted) + +def make_response_packet_for(request_header, msg_type, status, response_data, get_key): + """Creates a response packet from the status and data for the given request. + + Packs status and data into a response, encrypting according to the key returned by `get_key`. + Requires `request_packet` to be valid. + """ + if not response_data: response_data = b'' + response = Response(msg_type=msg_type, + status=status, + data=response_data) + response_nonce = bytearray(request_header.nonce.val); response_nonce[-1] ^= 0x1 + return make_packet(request_header.controller_id, response_nonce, response, get_key) diff --git a/deadserver/server.py b/deadserver/server.py new file mode 100644 index 0000000..9caec51 --- /dev/null +++ b/deadserver/server.py @@ -0,0 +1,26 @@ +"""The UDP server that handles controller messages. + +This only knows how to receive requests and send responses -- it is just a thin wrapper around +`api.API` and it is not concerned with the protocol details. + +""" + +import functools +import socketserver + +import records + +from . import api + +def serve(config): + app = api.API(config=config, db=records.Database(config.db_url)) + + class MessageHandler(socketserver.BaseRequestHandler): + def handle(self): + """Handles a request from the controller.""" + in_packet, socket = self.request + out_packet = app.handle_packet(in_packet) + if out_packet: socket.sendto(out_packet, self.client_address) + + server = socketserver.ThreadingUDPServer((config.udp_host, config.udp_port), MessageHandler) + server.serve_forever() diff --git a/fun_stuff.txt b/fun_stuff.txt new file mode 100644 index 0000000..a8d7204 --- /dev/null +++ b/fun_stuff.txt @@ -0,0 +1,7 @@ +Fun problems that I've overcome +=============================== + +- stateless is awesome +- avoid running around with secret keys by passing just a crypto black box (TODO but do it :D), and actually storing / loading the keys encrypted and decrypting only when really needed in order to avoid e.g. accidentally logging it (TODO also do it :D) +- extensibility + runtime configuration for packet types +- extensibility for "batch jobs" (e.g. for local dbs creation) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d3a5838 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +norecursedirs = venv +testpaths = . diff --git a/requirements-fresh.txt b/requirements-fresh.txt new file mode 100644 index 0000000..6207840 --- /dev/null +++ b/requirements-fresh.txt @@ -0,0 +1,5 @@ +psycopg2 +pynacl +pytest +pytest-cov +records diff --git a/requirements.txt b/requirements.txt index 24ed765..90c35ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,13 @@ -psycopg2==2.5.4 -https://github.com/warner/python-tweetnacl/tarball/b48a25a33f +PyNaCl==1.0.1 +SQLAlchemy==1.0.12 +cffi==1.5.2 +coverage==4.0.3 +docopt==0.6.2 +psycopg2==2.6.1 +py==1.4.31 +pycparser==2.14 +pytest==2.9.1 +pytest-cov==2.2.1 +records==0.4.3 +six==1.10.0 +tablib==0.11.2 diff --git a/runserver.py b/runserver.py new file mode 100755 index 0000000..6aa028f --- /dev/null +++ b/runserver.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +"""Deadlock server runner.""" + +from deadserver import server +import config + +if __name__ == '__main__': + server.serve(config) diff --git a/structparse/__init__.py b/structparse/__init__.py new file mode 100644 index 0000000..c53097d --- /dev/null +++ b/structparse/__init__.py @@ -0,0 +1,3 @@ +"""Simplifies parsing "C structs" -- structured types serialized as byte arrays.""" + +from .structdef import struct # make this easily accessible diff --git a/structparse/structdef.py b/structparse/structdef.py new file mode 100644 index 0000000..08a47b7 --- /dev/null +++ b/structparse/structdef.py @@ -0,0 +1,68 @@ +"""Allows defining serializable simple types and structs.""" + +from collections import namedtuple + +def unzip(x): return zip(*x) + +class Type: + """Defines a type that can be serialized and unserialized.""" + + @staticmethod + def unpack_from(buf): + """Constructs a Python value from the given buffer. + + Signature: buffer -> (parsed data, rest of the buffer) + """ + raise NotImplementedError + + def pack(self): + """Packs itself into `bytes`, returns that buffer. + + Signature: data -> buffer + """ + raise NotImplementedError + + @classmethod + def unpack(cls, buf): + """Unpacks the whole buffer into this Type. + + Raises ValueError if len(buf) doesn't exactly match the type size. + """ + x, rest = cls.unpack_from(buf) + if rest: raise ValueError('buffer size != struct size') + return x + + +class _StructMixin(Type): + """Mixin providing the `pack`, `unpack` and `unpack_from` methods for a struct.""" + def __new__(cls, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and args[0].__class__ is cls: + return args[0] # assumes immutability + if len(args) + len(kwargs) != len(cls._fields): + raise TypeError('Must be initialized with exactly {} arguments'.format(len(cls._fields))) + field_types = dict(zip(cls._fields, cls._fieldtypes)) + to_convert = dict(zip(cls._fields, args), **kwargs) + converted = { n: field_types[n](v) for (n, v) in to_convert.items() } + return super().__new__(cls, **converted) + + @classmethod + def unpack_from(cls, buf): + """Constructs a new instance by unpacking the given buffer.""" + vals = [] + for t in cls._fieldtypes: + val, buf = t.unpack_from(buf) + vals.append(val) + return cls(*vals), buf + + def pack(self): + """Returns itself packed as `bytes`.""" + return b''.join([ t.pack(x) for t,x in zip(self._fieldtypes, self) ]) + + +def struct(name, *fields): + """Creates a "C struct" -- a namedtuple that can be packed and unpacked.""" + fieldtypes, fieldnames = unzip(fields) + class Cls(_StructMixin, namedtuple(name, fieldnames)): pass + Cls.__name__ = name + Cls._fieldtypes = fieldtypes + return Cls diff --git a/structparse/tests/__init__.py b/structparse/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/structparse/tests/test_struct.py b/structparse/tests/test_struct.py new file mode 100644 index 0000000..e4ca7f6 --- /dev/null +++ b/structparse/tests/test_struct.py @@ -0,0 +1,54 @@ +import pytest + +from .. import struct, types + + +def test_new(): + TestStruct = struct('TestStruct', + (types.Uint8, 'john'), + (types.Uint8, 'paul'), + (types.Uint8, 'george'), + (types.Uint8, 'ringo')) + def check(instance): + assert (instance.john == types.Uint8(12) and + instance.paul == types.Uint8(4) and + instance.george == types.Uint8(6) and + instance.ringo == types.Uint8(0)) + check(TestStruct(12, 4, 6, 0)) + check(TestStruct(paul=4, george=6, john=12, ringo=0)) + check(TestStruct(12, 4, ringo=0, george=6)) + with pytest.raises(TypeError) as err: + TestStruct(1, 2, 47) + assert 'exactly 4' in str(err.value) + with pytest.raises(TypeError) as err: + TestStruct(12, 4, john=42, paul=47) + assert 'arguments' in str(err.value) + + +@pytest.fixture +def Sample(): + return struct('Sample', + (types.Bytes(4), 'x'), + (types.Uint8, 'y'), + (types.PascalStr(5), 'z')) + +@pytest.fixture +def Nested(Sample): + return struct('Nested', + (Sample, 'a'), + (types.Tail, 'b')) + +@pytest.fixture +def strct(Sample, Nested): + return Nested(Sample(b'quux', 0x47, 'foo'), b'kaleraby') + +@pytest.fixture +def packed(): + return b'quux' + b'\x47' + b'\x03foo\x00' + b'kaleraby' + + +def test_pack(strct, packed): + assert strct.pack() == packed + +def test_unpack(Nested, packed, strct): + assert Nested.unpack(packed) == strct diff --git a/structparse/tests/test_types.py b/structparse/tests/test_types.py new file mode 100644 index 0000000..c534a9a --- /dev/null +++ b/structparse/tests/test_types.py @@ -0,0 +1,76 @@ +import pytest + +from .. import types +import enum + +def test_construction(): + x = types.Uint8(47) + y = types.Uint8(x) + assert x == y + +def test_tobytes(): + assert types._tobytes([97, 98, 99]) == b'abc' + assert types._tobytes('mňau') == b'm\xc5\x88au' + assert len(types._tobytes('mňau')) == 5 + +def test_eq(): + assert types.Uint8(47) == types.Uint8(47) + assert types.Uint8(42) != types.Uint8(47) + +def test_Uint8(): + assert types.Uint8(97).pack() == b'a' + assert types.Uint8.unpack_from(b'abcd') == (types.Uint8(97), b'bcd') + with pytest.raises(ValueError): types.Uint8(4742) + +def test_Tail(): + assert types.Tail(b'an arbitrarily long whatever').pack() == b'an arbitrarily long whatever' + assert types.Tail([97, 98, 99, 100]) == types.Tail(b'abcd') + assert types.Tail.unpack(b'mrkva') == types.Tail(b'mrkva') + +def test_Bytes(): + b4 = types.Bytes(4) + assert repr(b4(b'abcd')) == "Bytes[4](b'abcd')" + + assert b4([97, 98, 99, 100]) == b4(b'abcd') + + with pytest.raises(ValueError): b4(b'abc') + with pytest.raises(ValueError): b4(b'abcde') + + assert b4(b'abcd').pack() == b'abcd' + assert b4.unpack_from(b'abcdefg') == (b4(b'abcd'), b'efg') + with pytest.raises(ValueError): b4.unpack_from(b'abc') + +def test_PascalStr(): + p6 = types.PascalStr(6) + assert repr(p6(b'abcd')) == "PascalStr[6](b'abcd')" + + p6('Hello') # n-1 bytes fit + with pytest.raises(ValueError): p6('Hello!') # n and more don't + + assert p6('hello').pack() == b'\x05hello' + assert p6('hell').pack() == b'\x04hell\x00' + + assert p6.unpack_from(b'\x03hel\x00\x00 world') == (p6(b'hel'), b' world') + with pytest.raises(ValueError): p6.unpack_from(b'\x03hello world') + with pytest.raises(ValueError): p6.unpack_from(b'\x47anything') + +def test_hashable(): + assert hash(types.Uint8(47)) == hash(types.Uint8(47)) + assert hash(types.Uint8(47)) != hash(types.Uint8(42)) + assert hash(types.Bytes(2)([42,47])) != hash(types.Tail([42,47])) + +def test_Enum_works(): + class T(types.Uint8, enum.Enum): + A = 1 + B = 2 + Z = 255 + + assert T(1) == T.A + assert T.A.pack() == b'\x01' + assert T.unpack(b'\xff') == T.Z + with pytest.raises(ValueError): T(47) + with pytest.raises(ValueError): T.unpack_from(b'\x47') + + with pytest.raises(ValueError): + class T(types.Uint8, enum.Enum): + X = 4742 # does not fit into 1 byte diff --git a/structparse/types.py b/structparse/types.py new file mode 100644 index 0000000..4bb546f --- /dev/null +++ b/structparse/types.py @@ -0,0 +1,123 @@ +"""Defines useful simple types for structparse.""" + +from .structdef import Type + +class _SimpleType(Type): + def __init__(self, x): + if x.__class__ is self.__class__: self.__val = x.val + else: + self._validate(x) + self.__val = x + + @property + def val(self): + return self.__val + + def pack(self): + return bytes(self._pack()) + + @classmethod + def unpack_from(cls, buf): + val, rest = cls._unpack_from(buf) + return cls(val), rest + + def _validate(self, input): + """No-op if the input is valid or raises exception if invalid.""" + pass + + def _pack(self): + return self.val + + def __eq__(self, other): + return self.val == other.val + + def __hash__(self): + return hash(self.__class__.__name__) ^ hash(self.val) + + def __repr__(self): + return self.__class__.__name__+'('+repr(self.val)+')' + + +def _tobytes(x): + if isinstance(x, _SimpleType): return bytes(x.val) + if isinstance(x, str): return bytes(x, 'utf8') + return bytes(x) + + +class Uint8(_SimpleType): + @staticmethod + def _unpack_from(buf): + return int(buf[0]), buf[1:] + + @staticmethod + def _validate(x): + if not 0 <= x <= 0xff: raise ValueError('{} is not a 1-byte unsigned int'.format(x)) + + def _pack(self): + return [self.val] + + +class _BytesLike(_SimpleType): + def __init__(self, arg): + super().__init__(_tobytes(arg)) + + +class Tail(_BytesLike): + @staticmethod + def _unpack_from(buf): + return buf, b'' + + +def Bytes(n): + class Cls(_BytesLike): + @staticmethod + def _validate(arg): + if len(arg) != n: + raise ValueError('{} is not {} bytes'.format(arg, n)) + return bytes(arg) + + @staticmethod + def _unpack_from(buf): + return buf[:n], buf[n:] + + Cls.__name__ = 'Bytes[{}]'.format(n) + return Cls + + +def PascalStr(n): + """A fixed-length "Pascal string" -- byte 0 is length, the rest is null-padded string content. + + Note that n is the number of bytes in the resulting binary representation, so at most n-1 bytes + fit inside. + """ + assert n >= 1, 'Cannot create PascalStr that can fit -1 bytes' + assert n-1 <= 0xff, 'Max length of PascalStr must fit into 1 byte' + class Cls(_BytesLike): + @staticmethod + def _validate(arg): + if len(arg) > n-1: + raise ValueError('string too long (n = {})'.format(n)) + + @staticmethod + def _unpack_from(buf): + buf, tail = buf[:n], buf[n:] + s, data = buf[0], buf[1:] + if s > n-1: raise ValueError('packed string too long (n-1 = {}, s = {})'.format(n-1, s)) + result, padding = data[:s], data[s:] + if not all([b == 0 for b in padding]): raise ValueError('packed string not null-padded') + return result, tail + + def _pack(self): + padding = n-1 - len(self.val) + return bytes([len(self.val)]) + self.val + b'\0'*padding + + Cls.__name__ = 'PascalStr[{}]'.format(n) + return Cls + + +# Note: if you are looking for Enum, this Just Works with Python's enum.Enum: +# +# class T(types.Uint8, enum.Enum): +# A = 1 +# B = 2 +# Z = 255