diff --git a/_doc/conf.py b/_doc/conf.py index 80d9f1ae..e46703f1 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -78,7 +78,7 @@ } intersphinx_mapping = { - "experimental_experiment": ( + "_".join(["experimental", "experiment"]): ( "https://sdpython.github.io/doc/experimental-experiment/dev/", None, ), diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index 8b92a760..4fbcb974 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -66,7 +66,8 @@ def forward(self, x: torch.Tensor, cache: MambaCache): cache = MambaCache(_config(), max_batch_size=1, device="cpu") self.assertEqual( - string_type(cache), "MambaCache(conv_states=[T10r3,...], ssm_states=[T10r3,...])" + string_type(cache), + "MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])", ) x = torch.ones(2, 8, 16).to(torch.float16) model = Model() diff --git a/_unittests/ut_xrun_doc/test_args.py b/_unittests/ut_xrun_doc/test_args.py new file mode 100644 index 00000000..62b399f7 --- /dev/null +++ b/_unittests/ut_xrun_doc/test_args.py @@ -0,0 +1,30 @@ +import unittest +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.args import get_parsed_args + + +class TestHelpers(ExtTestCase): + def test_args(self): + try: + args = get_parsed_args( + "plot_custom_backend_llama", + config=( + "medium", + "large or medium depending, large means closer to the real model", + ), + num_hidden_layers=(1, "number of hidden layers"), + with_mask=(0, "tries with a mask as a secondary input"), + optim=("", "Optimization to apply, empty string for all"), + description="doc", + new_args=["--config", "m"], + ) + except SystemExit as e: + raise AssertionError(f"SystemExist caught: {e}") + self.assertEqual(args.config, "m") + self.assertEqual(args.num_hidden_layers, 1) + self.assertEqual(args.with_mask, 0) + self.assertEqual(args.optim, "") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_helpers.py b/_unittests/ut_xrun_doc/test_helpers.py index 94f883ef..09c04244 100644 --- a/_unittests/ut_xrun_doc/test_helpers.py +++ b/_unittests/ut_xrun_doc/test_helpers.py @@ -26,6 +26,8 @@ from_array_ml_dtypes, dtype_to_tensor_dtype, string_diff, + rename_dynamic_dimensions, + rename_dynamic_expression, ) TFLOAT = onnx.TensorProto.FLOAT @@ -241,6 +243,89 @@ def test_string_signature(self): def test_make_hash(self): self.assertIsInstance(make_hash([]), str) + def test_string_type_one(self): + self.assertEqual(string_type(None), "None") + self.assertEqual(string_type([4]), "#1[int]") + self.assertEqual(string_type((4, 5)), "(int,int)") + self.assertEqual(string_type([4] * 100), "#100[int,...]") + self.assertEqual(string_type((4,) * 100), "#100(int,...)") + + def test_string_type_at(self): + self.assertEqual(string_type(None), "None") + a = np.array([4, 5], dtype=np.float32) + t = torch.tensor([4, 5], dtype=torch.float32) + self.assertEqual(string_type([a]), "#1[A1r1]") + self.assertEqual(string_type([t]), "#1[T1r1]") + self.assertEqual(string_type((a,)), "(A1r1,)") + self.assertEqual(string_type((t,)), "(T1r1,)") + self.assertEqual(string_type([a] * 100), "#100[A1r1,...]") + self.assertEqual(string_type([t] * 100), "#100[T1r1,...]") + self.assertEqual(string_type((a,) * 100), "#100(A1r1,...)") + self.assertEqual(string_type((t,) * 100), "#100(T1r1,...)") + + def test_string_type_at_with_shape(self): + self.assertEqual(string_type(None), "None") + a = np.array([4, 5], dtype=np.float32) + t = torch.tensor([4, 5], dtype=torch.float32) + self.assertEqual(string_type([a], with_shape=True), "#1[A1s2]") + self.assertEqual(string_type([t], with_shape=True), "#1[T1s2]") + self.assertEqual(string_type((a,), with_shape=True), "(A1s2,)") + self.assertEqual(string_type((t,), with_shape=True), "(T1s2,)") + self.assertEqual(string_type([a] * 100, with_shape=True), "#100[A1s2,...]") + self.assertEqual(string_type([t] * 100, with_shape=True), "#100[T1s2,...]") + self.assertEqual(string_type((a,) * 100, with_shape=True), "#100(A1s2,...)") + self.assertEqual(string_type((t,) * 100, with_shape=True), "#100(T1s2,...)") + + def test_string_type_at_with_shape_min_max(self): + self.assertEqual(string_type(None), "None") + a = np.array([4, 5], dtype=np.float32) + t = torch.tensor([4, 5], dtype=torch.float32) + self.assertEqual( + string_type([a], with_shape=True, with_min_max=True), "#1[A1s2[4.0,5.0:A4.5]]" + ) + self.assertEqual( + string_type([t], with_shape=True, with_min_max=True), "#1[T1s2[4.0,5.0:A4.5]]" + ) + self.assertEqual( + string_type((a,), with_shape=True, with_min_max=True), "(A1s2[4.0,5.0:A4.5],)" + ) + self.assertEqual( + string_type((t,), with_shape=True, with_min_max=True), "(T1s2[4.0,5.0:A4.5],)" + ) + self.assertEqual( + string_type([a] * 100, with_shape=True, with_min_max=True), + "#100[A1s2[4.0,5.0:A4.5],...]", + ) + self.assertEqual( + string_type([t] * 100, with_shape=True, with_min_max=True), + "#100[T1s2[4.0,5.0:A4.5],...]", + ) + self.assertEqual( + string_type((a,) * 100, with_shape=True, with_min_max=True), + "#100(A1s2[4.0,5.0:A4.5],...)", + ) + self.assertEqual( + string_type((t,) * 100, with_shape=True, with_min_max=True), + "#100(T1s2[4.0,5.0:A4.5],...)", + ) + + def test_pretty_onnx_att(self): + node = oh.make_node("Cast", ["xm2c"], ["xm2"], to=1) + pretty_onnx(node.attribute[0]) + + def test_rename_dimension(self): + res = rename_dynamic_dimensions( + {"a": {"B", "C"}}, + { + "B", + }, + ) + self.assertEqual(res, {"B": "B", "a": "B", "C": "B"}) + + def test_rename_dynamic_expression(self): + text = rename_dynamic_expression("a * 10 - a", {"a": "x"}) + self.assertEqual(text, "x * 10 - x") + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_onnx_tools.py b/_unittests/ut_xrun_doc/test_onnx_tools.py index 5ce8bb2d..608d238e 100644 --- a/_unittests/ut_xrun_doc/test_onnx_tools.py +++ b/_unittests/ut_xrun_doc/test_onnx_tools.py @@ -5,7 +5,12 @@ from onnx import TensorProto from onnx.checker import check_model from onnx_diagnostic.ext_test_case import ExtTestCase -from onnx_diagnostic.onnx_tools import onnx_lighten, onnx_unlighten, onnx_find +from onnx_diagnostic.onnx_tools import ( + onnx_lighten, + onnx_unlighten, + onnx_find, + _validate_function, +) from onnx_diagnostic.torch_test_helper import check_model_ort TFLOAT = TensorProto.FLOAT @@ -67,6 +72,23 @@ def test_onnx_find(self): self.assertIn("xm2", res[0].output) self.assertIn("xm2", res[1].input) + def test__validate_function(self): + new_domain = "custom" + + linear_regression = oh.make_function( + new_domain, + "LinearRegression", + ["x", "a", "b"], + ["y"], + [ + oh.make_node("MatMul", ["x", "a"], ["xa"]), + oh.make_node("Add", ["xa", "b"], ["y"]), + ], + [oh.make_opsetid("", 14)], + [], + ) + _validate_function(linear_regression) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_torch_test_helper.py b/_unittests/ut_xrun_doc/test_torch_test_helper.py index 8f7f4507..bbab1f92 100644 --- a/_unittests/ut_xrun_doc/test_torch_test_helper.py +++ b/_unittests/ut_xrun_doc/test_torch_test_helper.py @@ -1,10 +1,12 @@ import unittest import numpy as np +import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh +import torch from onnx_diagnostic.ext_test_case import ExtTestCase -from onnx_diagnostic.torch_test_helper import dummy_llm, check_model_ort +from onnx_diagnostic.torch_test_helper import dummy_llm, check_model_ort, to_numpy TFLOAT = onnx.TensorProto.FLOAT @@ -12,10 +14,19 @@ class TestOrtSession(ExtTestCase): def test_dummy_llm(self): - for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer"]: + for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer", "LLM"]: model, inputs = dummy_llm(cls_name) model(*inputs) + def test_dummy_llm_ds(self): + for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer", "LLM"]: + model, inputs, ds = dummy_llm(cls_name, dynamic_shapes=True) + model(*inputs) + self.assertIsInstance(ds, dict) + + def test_dummy_llm_exc(self): + self.assertRaise(lambda: dummy_llm("LLLLLL"), NotImplementedError) + def test_check_model_ort(self): model = oh.make_model( oh.make_graph( @@ -47,6 +58,11 @@ def test_check_model_ort(self): ) check_model_ort(model) + def test_to_numpy(self): + t = torch.tensor([0, 1], dtype=torch.bfloat16) + a = to_numpy(t) + self.assertEqual(a.dtype, ml_dtypes.bfloat16) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_unit_test.py b/_unittests/ut_xrun_doc/test_unit_test.py index c6281ad4..5933429f 100644 --- a/_unittests/ut_xrun_doc/test_unit_test.py +++ b/_unittests/ut_xrun_doc/test_unit_test.py @@ -1,3 +1,4 @@ +import math import os import unittest import pandas @@ -6,6 +7,12 @@ ExtTestCase, statistics_on_file, statistics_on_folder, + is_apple, + is_windows, + is_azure, + is_linux, + unit_test_going, + measure_time, ) @@ -52,6 +59,31 @@ def test_statistics_on_folders(self): self.assertEqual(len(gr.columns), 4) self.assertEqual(total.shape, (2,)) + def test_is(self): + is_apple() + is_windows() + is_azure() + is_linux() + unit_test_going() + + def test_measure_time(self): + res = measure_time(lambda: math.cos(0.5)) + self.assertIsInstance(res, dict) + self.assertEqual( + set(res), + { + "min_exec", + "max_exec", + "average", + "warmup_time", + "context_size", + "deviation", + "repeat", + "ttime", + "number", + }, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index ffa37269..3868b5fc 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -14,7 +14,7 @@ from contextlib import redirect_stderr, redirect_stdout from io import StringIO from timeit import Timer -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy from numpy.testing import assert_allclose @@ -38,28 +38,6 @@ def is_linux() -> bool: return sys.platform == "linux" -def skipif_transformers(version_to_skip: Union[str, Set[str]], msg: str) -> Callable: - """Skips a unit test if transformers has a specific version.""" - if isinstance(version_to_skip, str): - version_to_skip = {version_to_skip} - import transformers - - if transformers.__version__ in version_to_skip: - msg = f"Unstable test. {msg}" - return unittest.skip(msg) - return lambda x: x - - -def skipif_not_onnxrt(msg) -> Callable: - """Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.""" - UNITTEST_ONNXRT = os.environ.get("UNITTEST_ONNXRT", "0") - value = int(UNITTEST_ONNXRT) - if not value: - msg = f"Set UNITTEST_ONNXRT=1 to run the unittest. {msg}" - return unittest.skip(msg) - return lambda x: x - - def skipif_ci_windows(msg) -> Callable: """Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.""" if is_windows() and is_azure(): diff --git a/onnx_diagnostic/helpers.py b/onnx_diagnostic/helpers.py index 43927d96..3d4435df 100644 --- a/onnx_diagnostic/helpers.py +++ b/onnx_diagnostic/helpers.py @@ -66,13 +66,13 @@ def size_type(dtype: Any) -> int: return 4 if dtype == np.float16 or dtype == np.int16: return 2 - if dtype == np.int16 or dtype == np.uint16: + if dtype == np.int16: return 2 - if dtype == np.int32 or dtype == np.uint32: + if dtype == np.int32: return 4 - if dtype == np.int64 or dtype == np.uint64: + if dtype == np.int64: return 8 - if dtype == np.int8 or dtype == np.uint8: + if dtype == np.int8: return 1 if hasattr(np, "uint64"): # it fails on mac @@ -82,6 +82,8 @@ def size_type(dtype: Any) -> int: return 4 if dtype == np.uint16: return 2 + if dtype == np.uint8: + return 1 import torch @@ -225,7 +227,7 @@ def string_type( if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) return f"({tt},...)#{len(obj)}[{mini},{maxi}:A[{avg}]]" - return f"({tt},...)#{len(obj)}" if with_shape else f"({tt},...)" + return f"#{len(obj)}({tt},...)" if isinstance(obj, list): if len(obj) < limit: js = ",".join( @@ -250,8 +252,8 @@ def string_type( ) if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) - return f"[{tt},...]#{len(obj)}[{mini},{maxi}:{avg}]" - return f"[{tt},...]#{len(obj)}" if with_shape else f"[{tt},...]" + return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]" + return f"#{len(obj)}[{tt},...]" if isinstance(obj, set): if len(obj) < 10: js = ",".join( @@ -932,7 +934,7 @@ def rename_dynamic_dimensions( many names for dynamic dimensions. When building the onnx model, some of them are redundant and can be replaced by the name provided by the user. - :param constraints: exhaustive list of used name and all the values equal to it + :param constraints: exhaustive list of used names and all the values equal to it :param original: the names to use if possible :param ban_prefix: avoid any rewriting by a constant starting with this prefix :return: replacement dictionary