Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 53 additions & 61 deletions jupyter_client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations

import functools
import hashlib
import hmac
import json
Expand All @@ -33,6 +34,7 @@
from traitlets import (
Any,
Bool,
Callable,
CBytes,
CUnicode,
Dict,
Expand Down Expand Up @@ -125,15 +127,37 @@ def json_unpacker(s: str | bytes) -> t.Any:
return json.loads(s)


try:
import orjson # type:ignore[import-not-found]
except ModuleNotFoundError:
orjson = None
_default_packer_unpacker = "json", "json"
_default_pack_unpack = (json_packer, json_unpacker)
else:
orjson_packer = functools.partial(
orjson.dumps, default=json_default, option=orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z
)
orjson_unpacker = orjson.loads
_default_packer_unpacker = "orjson", "orjson"
_default_pack_unpack = (orjson_packer, orjson_unpacker)

try:
import msgpack # type:ignore[import-not-found]

except ModuleNotFoundError:
msgpack = None
else:
msgpack_packer = functools.partial(msgpack.packb, default=json_default)
msgpack_unpacker = msgpack.unpackb


def pickle_packer(o: t.Any) -> bytes:
"""Pack an object using the pickle module."""
return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)


pickle_unpacker = pickle.loads

default_packer = json_packer
default_unpacker = json_unpacker

DELIM = b"<IDS|MSG>"
# singleton dummy tracker, which will always report as done
Expand Down Expand Up @@ -316,7 +340,7 @@ class Session(Configurable):

debug : bool
whether to trigger extra debugging statements
packer/unpacker : str : 'json', 'pickle' or import_string
packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string
importstrings for methods to serialize message parts. If just
'json' or 'pickle', predefined JSON and pickle packers will be used.
Otherwise, the entire importstring must be used.
Expand Down Expand Up @@ -351,48 +375,40 @@ class Session(Configurable):
""",
)

# serialization traits:
packer = DottedObjectName(
"json",
_default_packer_unpacker[0],
config=True,
help="""The name of the packer for serializing messages.
Should be one of 'json', 'pickle', or an import name
for a custom callable serializer.""",
)

@observe("packer")
def _packer_changed(self, change: t.Any) -> None:
new = change["new"]
if new.lower() == "json":
self.pack = json_packer
self.unpack = json_unpacker
self.unpacker = new
elif new.lower() == "pickle":
self.pack = pickle_packer
self.unpack = pickle_unpacker
self.unpacker = new
else:
self.pack = import_item(str(new))

unpacker = DottedObjectName(
"json",
_default_packer_unpacker[1],
config=True,
help="""The name of the unpacker for unserializing messages.
Only used with custom functions for `packer`.""",
)

@observe("unpacker")
def _unpacker_changed(self, change: t.Any) -> None:
new = change["new"]
if new.lower() == "json":
self.pack = json_packer
self.unpack = json_unpacker
self.packer = new
elif new.lower() == "pickle":
self.pack = pickle_packer
self.unpack = pickle_unpacker
self.packer = new
pack = Callable(_default_pack_unpack[0]) # the actual packer function
unpack = Callable(_default_pack_unpack[1]) # the actual unpacker function

@observe("packer", "unpacker")
def _packer_unpacker_changed(self, change: t.Any) -> None:
new = change["new"].lower()
if new == "orjson" and orjson:
self.pack, self.unpack = orjson_packer, orjson_unpacker
elif new == "json" or new == "orjson":
self.pack, self.unpack = json_packer, json_unpacker
elif new == "pickle":
self.pack, self.unpack = pickle_packer, pickle_unpacker
elif new == "msgpack" and msgpack:
self.pack, self.unpack = msgpack_packer, msgpack_unpacker
else:
self.unpack = import_item(str(new))
obj = import_item(str(change["new"]))
name = "pack" if change["name"] == "packer" else "unpack"
self.set_trait(name, obj)
return
self.packer = self.unpacker = change["new"]

session = CUnicode("", config=True, help="""The UUID identifying this session.""")

Expand All @@ -417,8 +433,7 @@ def _session_changed(self, change: t.Any) -> None:
metadata = Dict(
{},
config=True,
help="Metadata dictionary, which serves as the default top-level metadata dict for each "
"message.",
help="Metadata dictionary, which serves as the default top-level metadata dict for each message.",
)

# if 0, no adapting to do.
Expand Down Expand Up @@ -487,25 +502,6 @@ def _keyfile_changed(self, change: t.Any) -> None:
# for protecting against sends from forks
pid = Integer()

# serialization traits:

pack = Any(default_packer) # the actual packer function

@observe("pack")
def _pack_changed(self, change: t.Any) -> None:
new = change["new"]
if not callable(new):
raise TypeError("packer must be callable, not %s" % type(new))

unpack = Any(default_unpacker) # the actual packer function

@observe("unpack")
def _unpack_changed(self, change: t.Any) -> None:
# unpacker is not checked - it is assumed to be
new = change["new"]
if not callable(new):
raise TypeError("unpacker must be callable, not %s" % type(new))

# thresholds:
copy_threshold = Integer(
2**16,
Expand All @@ -515,8 +511,7 @@ def _unpack_changed(self, change: t.Any) -> None:
buffer_threshold = Integer(
MAX_BYTES,
config=True,
help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid "
"pickling.",
help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.",
)
item_threshold = Integer(
MAX_ITEMS,
Expand All @@ -534,7 +529,7 @@ def __init__(self, **kwargs: t.Any) -> None:

debug : bool
whether to trigger extra debugging statements
packer/unpacker : str : 'json', 'pickle' or import_string
packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string
importstrings for methods to serialize message parts. If just
'json' or 'pickle', predefined JSON and pickle packers will be used.
Otherwise, the entire importstring must be used.
Expand Down Expand Up @@ -626,10 +621,7 @@ def _check_packers(self) -> None:
unpacked = unpack(packed)
assert unpacked == msg_list
except Exception as e:
msg = (
f"unpacker '{self.unpacker}' could not handle output from packer"
f" '{self.packer}': {e}"
)
msg = f"unpacker {self.unpacker!r} could not handle output from packer {self.packer!r}: {e}"
raise ValueError(msg) from e

# check datetime support
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"jupyter_core>=5.1",
"orjson>=3.10.18; implementation_name == 'cpython'",
"python-dateutil>=2.8.2",
"pyzmq>=25.0",
"tornado>=6.4.1",
Expand Down Expand Up @@ -55,6 +56,7 @@ test = [
"pytest-jupyter[client]>=0.6.2",
"pytest-cov",
"pytest-timeout",
"msgpack"
]
docs = [
"ipykernel",
Expand Down
74 changes: 58 additions & 16 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import uuid
import warnings
from datetime import datetime
from pickle import PicklingError
from unittest import mock

import pytest
import zmq
from dateutil.tz import tzlocal
from tornado import ioloop
from traitlets import TraitError
from zmq.eventloop.zmqstream import ZMQStream

from jupyter_client import jsonutil
Expand All @@ -41,6 +43,16 @@ def session():
return ss.Session()


serializers = [
("json", ss.json_packer, ss.json_unpacker),
("pickle", ss.pickle_packer, ss.pickle_unpacker),
]
if ss.orjson:
serializers.append(("orjson", ss.orjson_packer, ss.orjson_unpacker))
if ss.msgpack:
serializers.append(("msgpack", ss.msgpack_packer, ss.msgpack_unpacker))


@pytest.mark.usefixtures("no_copy_threshold")
class TestSession:
def assertEqual(self, a, b):
Expand All @@ -64,7 +76,11 @@ def test_msg(self, session):
self.assertEqual(msg["header"]["msg_type"], "execute")
self.assertEqual(msg["msg_type"], "execute")

def test_serialize(self, session):
@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
def test_serialize(self, session, packer, pack, unpack):
session.packer = packer
assert session.pack is pack
assert session.unpack is unpack
msg = session.msg("execute", content=dict(a=10, b=1.1))
msg_list = session.serialize(msg, ident=b"foo")
ident, msg_list = session.feed_identities(msg_list)
Expand Down Expand Up @@ -234,16 +250,16 @@ async def test_send(self, session):
def test_args(self, session):
"""initialization arguments for Session"""
s = session
self.assertTrue(s.pack is ss.default_packer)
self.assertTrue(s.unpack is ss.default_unpacker)
assert s.pack is ss._default_pack_unpack[0]
assert s.unpack is ss._default_pack_unpack[1]
self.assertEqual(s.username, os.environ.get("USER", "username"))

s = ss.Session()
self.assertEqual(s.username, os.environ.get("USER", "username"))

with pytest.raises(TypeError):
with pytest.raises(TraitError):
ss.Session(pack="hi")
with pytest.raises(TypeError):
with pytest.raises(TraitError):
ss.Session(unpack="hi")
u = str(uuid.uuid4())
s = ss.Session(username="carrot", session=u)
Expand Down Expand Up @@ -491,11 +507,6 @@ async def test_send_raw(self, session):
B.close()
ctx.term()

def test_set_packer(self, session):
s = session
s.packer = "json"
s.unpacker = "json"

def test_clone(self, session):
s = session
s._add_digest("initial")
Expand All @@ -515,14 +526,45 @@ def test_squash_unicode():
assert ss.squash_unicode("hi") == b"hi"


def test_json_packer():
ss.json_packer(dict(a=1))
with pytest.raises(ValueError):
ss.json_packer(dict(a=ss.Session()))
ss.json_packer(dict(a=datetime(2021, 4, 1, 12, tzinfo=tzlocal())))
@pytest.mark.parametrize(
["description", "data"],
[
("dict", [{"a": 1}, [{"a": 1}]]),
("infinite", [math.inf, ["inf", None]]),
("datetime", [datetime(2021, 4, 1, 12, tzinfo=tzlocal()), []]),
],
)
@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
def test_serialize_objects(packer, pack, unpack, description, data):
data_in, data_out_options = data
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ss.json_packer(dict(a=math.inf))
value = pack(data_in)
unpacked = unpack(value)
if (description == "infinite") and (packer in ["pickle", "msgpack"]):
assert math.isinf(unpacked)
elif description == "datetime":
assert data_in == jsonutil.parse_date(unpacked)
else:
assert unpacked in data_out_options


@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
def test_cannot_serialize(session, packer, pack, unpack):
data = {"a": session}
with pytest.raises((TypeError, ValueError, PicklingError)):
pack(data)


@pytest.mark.parametrize("mode", ["packer", "unpacker"])
@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
def test_pack_unpack(session, packer, pack, unpack, mode):
s: ss.Session = session
s.set_trait(mode, packer)
assert s.pack is pack
assert s.unpack is unpack
mode_reverse = "unpacker" if mode == "packer" else "packer"
assert getattr(s, mode_reverse) == packer


def test_message_cls():
Expand Down