Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion thunder/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from numbers import Number
from collections.abc import Callable

from thunder.core.langctx import langctx, Languages
from thunder.core.langctxs import langctx, Languages
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

from thunder.numpy.langctx import register_method

from thunder.core.proxies import TensorProxy
Expand Down
24 changes: 24 additions & 0 deletions thunder/tests/test_numpy_langctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from thunder.numpy import size as np_size
from thunder.core.langctxs import langctx, Languages, resolve_language
from thunder.core.proxies import TensorProxy
from thunder.core.trace import detached_trace
from thunder.core.devices import cpu
from thunder.core.dtypes import float32


def test_numpy_langctx_registration_and_len_size():
with detached_trace():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice trick here

t = TensorProxy(shape=(2, 3), device=cpu, dtype=float32)

with langctx(Languages.NUMPY):
assert len(t) == 2 # axis 0 length
assert t.size() == 6 # total elements
assert np_size(t) == 6
Comment on lines +14 to +16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you choose to test these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to check that any numpy function is actually working correctly under the context manager. If you have any better ideas for testing, please, suggest one, I'm ready to fix



def test_numpy_langctx_resolve_language():
numpy_ctx_by_enum = resolve_language(Languages.NUMPY)
numpy_ctx_by_name = resolve_language("numpy")

assert numpy_ctx_by_enum is numpy_ctx_by_name
assert numpy_ctx_by_enum.name == "numpy"
Loading