diff --git a/msgpack_numpy.py b/msgpack_numpy.py index abb46ef..9712bb7 100644 --- a/msgpack_numpy.py +++ b/msgpack_numpy.py @@ -51,7 +51,7 @@ def ndarray_to_bytes(obj): def tostr(x): return x -def encode(obj, chain=None): +def encode(obj, chain=None, allow_pickle=True): """ Data encoder for serializing numpy data types. """ @@ -60,6 +60,8 @@ def encode(obj, chain=None): # If the dtype is structured, store the interface description; # otherwise, store the corresponding array protocol type string: if obj.dtype.kind in ('V', 'O'): + if obj.dtype.kind == 'O' and not allow_pickle: + raise ValueError("Can't pickle object arrays if allow_pickle is False") kind = bytes(obj.dtype.kind, 'ascii') descr = obj.dtype.descr else: @@ -81,7 +83,7 @@ def encode(obj, chain=None): else: return obj if chain is None else chain(obj) -def decode(obj, chain=None): +def decode(obj, chain=None, allow_pickle=True): """ Decoder for deserializing numpy data types. """ @@ -97,6 +99,8 @@ def decode(obj, chain=None): descr = [tuple(tostr(t) if type(t) is bytes else t for t in d) \ for d in obj[b'type']] elif b'kind' in obj and obj[b'kind'] == b'O': + if not allow_pickle: + raise ValueError("Can't unpickle object arrays if allow_pickle is False") return pickle.loads(obj[b'data']) else: descr = obj[b'type'] @@ -138,8 +142,9 @@ def __init__(self, default=None, encoding='utf-8', unicode_errors='strict', use_single_float=False, - autoreset=1): - default = functools.partial(encode, chain=default) + autoreset=1, + allow_pickle=True): + default = functools.partial(encode, chain=default, allow_pickle=allow_pickle) super(Packer, self).__init__(default=default, encoding=encoding, unicode_errors=unicode_errors, @@ -149,8 +154,8 @@ class Unpacker(_Unpacker): def __init__(self, file_like=None, read_size=0, use_list=None, object_hook=None, object_pairs_hook=None, list_hook=None, encoding='utf-8', - unicode_errors='strict', max_buffer_size=0): - object_hook = functools.partial(decode, chain=object_hook) + unicode_errors='strict', max_buffer_size=0, allow_pickle=True): + object_hook = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle) super(Unpacker, self).__init__(file_like=file_like, read_size=read_size, use_list=use_list, @@ -168,8 +173,9 @@ def __init__(self, default=None, use_single_float=False, autoreset=1, use_bin_type=True, - strict_types=False): - default = functools.partial(encode, chain=default) + strict_types=False, + allow_pickle=True): + default = functools.partial(encode, chain=default, allow_pickle=allow_pickle) super(Packer, self).__init__(default=default, unicode_errors=unicode_errors, use_single_float=use_single_float, @@ -183,8 +189,8 @@ def __init__(self, file_like=None, read_size=0, use_list=None, object_hook=None, object_pairs_hook=None, list_hook=None, unicode_errors='strict', max_buffer_size=0, - ext_hook=msgpack.ExtType): - object_hook = functools.partial(decode, chain=object_hook) + ext_hook=msgpack.ExtType, allow_pickle=True): + object_hook = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle) super(Unpacker, self).__init__(file_like=file_like, read_size=read_size, use_list=use_list, @@ -205,8 +211,9 @@ def __init__(self, use_bin_type=True, strict_types=False, datetime=False, - unicode_errors=None): - default = functools.partial(encode, chain=default) + unicode_errors=None, + allow_pickle=True): + default = functools.partial(encode, chain=default, allow_pickle=allow_pickle) super(Packer, self).__init__(default=default, use_single_float=use_single_float, autoreset=autoreset, @@ -233,8 +240,9 @@ def __init__(self, max_bin_len=-1, max_array_len=-1, max_map_len=-1, - max_ext_len=-1): - object_hook = functools.partial(decode, chain=object_hook) + max_ext_len=-1, + allow_pickle=True): + object_hook = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle) super(Unpacker, self).__init__(file_like=file_like, read_size=read_size, use_list=use_list, @@ -268,22 +276,22 @@ def packb(o, **kwargs): return Packer(**kwargs).pack(o) -def unpack(stream, **kwargs): +def unpack(stream, allow_pickle=True, **kwargs): """ Unpack a packed object from a stream. """ object_hook = kwargs.get('object_hook') - kwargs['object_hook'] = functools.partial(decode, chain=object_hook) + kwargs['object_hook'] = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle) return _unpack(stream, **kwargs) -def unpackb(packed, **kwargs): +def unpackb(packed, allow_pickle=True, **kwargs): """ Unpack a packed object. """ object_hook = kwargs.get('object_hook') - kwargs['object_hook'] = functools.partial(decode, chain=object_hook) + kwargs['object_hook'] = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle) return _unpackb(packed, **kwargs) load = unpack @@ -291,18 +299,30 @@ def unpackb(packed, **kwargs): dump = pack dumps = packb -def patch(): +def patch(allow_pickle=True): """ Monkey patch msgpack module to enable support for serializing numpy types. """ - - setattr(msgpack, 'Packer', Packer) - setattr(msgpack, 'Unpacker', Unpacker) - setattr(msgpack, 'load', unpack) - setattr(msgpack, 'loads', unpackb) - setattr(msgpack, 'dump', pack) - setattr(msgpack, 'dumps', packb) - setattr(msgpack, 'pack', pack) - setattr(msgpack, 'packb', packb) - setattr(msgpack, 'unpack', unpack) - setattr(msgpack, 'unpackb', unpackb) + class Packer_(Packer): + def __init__(self, *args, **kws): + super(Packer, self).__init__(*args, **kws, allow_pickle=allow_pickle) + + class Unpacker_(Unpacker): + def __init__(self, *args, **kws): + super(Unpacker, self).__init__(*args, **kws, allow_pickle=allow_pickle) + + pack_ = functools.partial(pack, allow_pickle=allow_pickle) + packb_ = functools.partial(packb, allow_pickle=allow_pickle) + unpack_ = functools.partial(unpack, allow_pickle=allow_pickle) + unpackb_ = functools.partial(unpackb, allow_pickle=allow_pickle) + + setattr(msgpack, 'Packer', Packer_) + setattr(msgpack, 'Unpacker', Unpacker_) + setattr(msgpack, 'load', unpack_) + setattr(msgpack, 'loads', unpackb_) + setattr(msgpack, 'dump', pack_) + setattr(msgpack, 'dumps', packb_) + setattr(msgpack, 'pack', pack_) + setattr(msgpack, 'packb', packb_) + setattr(msgpack, 'unpack', unpack_) + setattr(msgpack, 'unpackb', unpackb_) diff --git a/tests.py b/tests.py index 4bd4bfa..d9eaa1a 100644 --- a/tests.py +++ b/tests.py @@ -288,5 +288,13 @@ def test_numpy_nested_structured_array(self): assert_array_equal(x, x_rec) self.assertEqual(x.dtype, x_rec.dtype) +class test_numpy_msgpack_no_pickle(test_numpy_msgpack): + def setUp(self): + patch(allow_pickle=False) + + def test_numpy_array_object(self): + x = np.random.rand(5).astype(object) + self.assertRaises(ValueError, self.encode_decode, x) + if __name__ == '__main__': main()