diff --git a/test/infinicore/framework/datatypes.py b/test/infinicore/framework/datatypes.py index 608c8e35d..b40b0fe8e 100644 --- a/test/infinicore/framework/datatypes.py +++ b/test/infinicore/framework/datatypes.py @@ -10,10 +10,16 @@ def to_torch_dtype(infini_dtype): return torch.float32 elif infini_dtype == infinicore.bfloat16: return torch.bfloat16 + elif infini_dtype == infinicore.int8: + return torch.int8 + elif infini_dtype == infinicore.int16: + return torch.int16 elif infini_dtype == infinicore.int32: return torch.int32 elif infini_dtype == infinicore.int64: return torch.int64 + elif infini_dtype == infinicore.uint8: + return torch.uint8 else: raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}") @@ -26,9 +32,15 @@ def to_infinicore_dtype(torch_dtype): return infinicore.float16 elif torch_dtype == torch.bfloat16: return infinicore.bfloat16 + elif torch_dtype == torch.int8: + return infinicore.int8 + elif torch_dtype == torch.int16: + return infinicore.int16 elif torch_dtype == torch.int32: return infinicore.int32 elif torch_dtype == torch.int64: return infinicore.int64 + elif torch_dtype == torch.uint8: + return infinicore.uint8 else: raise ValueError(f"Unsupported torch dtype: {torch_dtype}") diff --git a/test/infinicore/framework/tensor.py b/test/infinicore/framework/tensor.py index 6aa5ca7b4..9015600af 100644 --- a/test/infinicore/framework/tensor.py +++ b/test/infinicore/framework/tensor.py @@ -1,7 +1,9 @@ import torch +import infinicore from pathlib import Path from .datatypes import to_torch_dtype from .devices import torch_device_map +from .utils import is_integer_dtype class TensorInitializer: @@ -38,6 +40,10 @@ def create_tensor( torch_device_str = torch_device_map[device] torch_dtype = to_torch_dtype(dtype) + # Handle integer types differently for random initialization + if mode == TensorInitializer.RANDOM and is_integer_dtype(dtype): + mode = TensorInitializer.RANDINT # Use randint for integer types + # Handle strided tensors - calculate required storage size if strides is not None: # Calculate the required storage size for strided tensor @@ -61,9 +67,22 @@ def create_tensor( storage_size, dtype=torch_dtype, device=torch_device_str ) elif mode == TensorInitializer.RANDINT: + # For integer types, use appropriate range + if is_integer_dtype(dtype): + if dtype == infinicore.uint8: + low, high = 0, 256 + elif dtype == infinicore.int8: + low, high = -128, 128 + elif dtype == infinicore.int16: + low, high = -32768, 32768 + else: # int32, int64, uint32 + low, high = -1000, 1000 + else: + low, high = -1000, 1000 + base_tensor = torch.randint( - -2000000000, - 2000000000, + low, + high, (storage_size,), dtype=torch_dtype, device=torch_device_str, @@ -92,9 +111,22 @@ def create_tensor( elif mode == TensorInitializer.ONES: tensor = torch.ones(shape, dtype=torch_dtype, device=torch_device_str) elif mode == TensorInitializer.RANDINT: + # For integer types, use appropriate range + if is_integer_dtype(dtype): + if dtype == infinicore.uint8: + low, high = 0, 256 + elif dtype == infinicore.int8: + low, high = -128, 128 + elif dtype == infinicore.int16: + low, high = -32768, 32768 + else: # int32, int64, uint32 + low, high = -1000, 1000 + else: + low, high = -1000, 1000 + tensor = torch.randint( - -2000000000, - 2000000000, + low, + high, shape, dtype=torch_dtype, device=torch_device_str, diff --git a/test/infinicore/framework/utils.py b/test/infinicore/framework/utils.py index 7e6a138bb..f4c36263e 100644 --- a/test/infinicore/framework/utils.py +++ b/test/infinicore/framework/utils.py @@ -37,25 +37,52 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations): print(f" {desc} time: {elapsed * 1000 :6f} ms") -def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): +def is_integer_dtype(dtype): + """Check if dtype is integer type""" + return dtype in [ + infinicore.int8, + infinicore.int16, + infinicore.int32, + infinicore.int64, + infinicore.uint8, + ] + + +def is_float_dtype(dtype): + """Check if dtype is floating point type""" + return dtype in [infinicore.float16, infinicore.float32, infinicore.bfloat16] + + +def debug( + actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True, dtype=None +): """ Debug function to compare two tensors and print differences """ + # Convert to float32 for bfloat16 comparison if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16: actual = actual.to(torch.float32) desired = desired.to(torch.float32) - print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose) + print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose, dtype) - import numpy as np + # Use appropriate comparison based on dtype + if dtype and is_integer_dtype(dtype): + # For integer types, require exact equality + import numpy as np - np.testing.assert_allclose( - actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True - ) + np.testing.assert_array_equal(actual.cpu(), desired.cpu()) + else: + # For float types, use allclose + import numpy as np + + np.testing.assert_allclose( + actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True + ) def print_discrepancy( - actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True + actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True, dtype=None ): """Print detailed tensor differences""" if actual.shape != expected.shape: @@ -69,13 +96,21 @@ def print_discrepancy( actual_isnan = torch.isnan(actual) expected_isnan = torch.isnan(expected) - # Calculate difference mask - nan_mismatch = ( - actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan - ) - diff_mask = nan_mismatch | ( - torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)) - ) + # Calculate difference mask based on dtype + if dtype and is_integer_dtype(dtype): + # For integer types, exact equality required + diff_mask = actual != expected + else: + # For float types, use tolerance-based comparison + nan_mismatch = ( + actual_isnan ^ expected_isnan + if equal_nan + else actual_isnan | expected_isnan + ) + diff_mask = nan_mismatch | ( + torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)) + ) + diff_indices = torch.nonzero(diff_mask, as_tuple=False) delta = actual - expected @@ -107,8 +142,9 @@ def add_color(text, color_code): print(f" - Actual dtype: {actual.dtype}") print(f" - Desired dtype: {expected.dtype}") - print(f" - Atol: {atol}") - print(f" - Rtol: {rtol}") + if not (dtype and is_integer_dtype(dtype)): + print(f" - Atol: {atol}") + print(f" - Rtol: {rtol}") print( f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)" ) @@ -130,6 +166,10 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3 """ Get tolerance settings based on data type """ + # For integer types, return zero tolerance (exact match required) + if is_integer_dtype(tensor_dtype): + return 0, 0 + tolerance = tolerance_map.get( tensor_dtype, {"atol": default_atol, "rtol": default_rtol} ) @@ -162,8 +202,6 @@ def convert_infinicore_to_torch(infini_result, torch_reference): Args: infini_result: infinicore tensor result torch_reference: PyTorch tensor reference (for shape and device) - dtype: infinicore data type - device_str: torch device string Returns: torch.Tensor: PyTorch tensor with infinicore data @@ -179,7 +217,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference): def compare_results( - infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False + infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False, dtype=None ): """ Generic function to compare infinicore result with PyTorch reference result @@ -190,6 +228,7 @@ def compare_results( atol: absolute tolerance rtol: relative tolerance debug_mode: whether to enable debug output + dtype: infinicore data type for comparison logic Returns: bool: True if results match within tolerance @@ -197,12 +236,21 @@ def compare_results( # Convert infinicore result to PyTorch tensor for comparison torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result) + # Choose comparison method based on dtype + if dtype and is_integer_dtype(dtype): + # For integer types, require exact equality + result = torch.equal(torch_result_from_infini, torch_result) + else: + # For float types, use tolerance-based comparison + result = torch.allclose( + torch_result_from_infini, torch_result, atol=atol, rtol=rtol + ) + # Debug mode: detailed comparison if debug_mode: - debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol) + debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol, dtype=dtype) - # Check if results match within tolerance - return torch.allclose(torch_result_from_infini, torch_result, atol=atol, rtol=rtol) + return result def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""): @@ -227,7 +275,12 @@ def compare_test_results(infini_result, torch_result): if config.debug and mode_name: print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m") return compare_results( - infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug + infini_result, + torch_result, + atol=atol, + rtol=rtol, + debug_mode=config.debug, + dtype=dtype, ) return compare_test_results