Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/infinicore/framework/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@ def to_torch_dtype(infini_dtype):
return torch.float32
elif infini_dtype == infinicore.bfloat16:
return torch.bfloat16
elif infini_dtype == infinicore.int8:
return torch.int8
elif infini_dtype == infinicore.int16:
return torch.int16
elif infini_dtype == infinicore.int32:
return torch.int32
elif infini_dtype == infinicore.int64:
return torch.int64
elif infini_dtype == infinicore.uint8:
return torch.uint8
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")

Expand All @@ -26,9 +32,15 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.float16
elif torch_dtype == torch.bfloat16:
return infinicore.bfloat16
elif torch_dtype == torch.int8:
return infinicore.int8
elif torch_dtype == torch.int16:
return infinicore.int16
elif torch_dtype == torch.int32:
return infinicore.int32
elif torch_dtype == torch.int64:
return infinicore.int64
elif torch_dtype == torch.uint8:
return infinicore.uint8
else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
40 changes: 36 additions & 4 deletions test/infinicore/framework/tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import infinicore
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
from .utils import is_integer_dtype


class TensorInitializer:
Expand Down Expand Up @@ -38,6 +40,10 @@ def create_tensor(
torch_device_str = torch_device_map[device]
torch_dtype = to_torch_dtype(dtype)

# Handle integer types differently for random initialization
if mode == TensorInitializer.RANDOM and is_integer_dtype(dtype):
mode = TensorInitializer.RANDINT # Use randint for integer types

# Handle strided tensors - calculate required storage size
if strides is not None:
# Calculate the required storage size for strided tensor
Expand All @@ -61,9 +67,22 @@ def create_tensor(
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.RANDINT:
# For integer types, use appropriate range
if is_integer_dtype(dtype):
if dtype == infinicore.uint8:
low, high = 0, 256
elif dtype == infinicore.int8:
low, high = -128, 128
elif dtype == infinicore.int16:
low, high = -32768, 32768
else: # int32, int64, uint32
low, high = -1000, 1000
else:
low, high = -1000, 1000

base_tensor = torch.randint(
-2000000000,
2000000000,
low,
high,
(storage_size,),
dtype=torch_dtype,
device=torch_device_str,
Expand Down Expand Up @@ -92,9 +111,22 @@ def create_tensor(
elif mode == TensorInitializer.ONES:
tensor = torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT:
# For integer types, use appropriate range
if is_integer_dtype(dtype):
if dtype == infinicore.uint8:
low, high = 0, 256
elif dtype == infinicore.int8:
low, high = -128, 128
elif dtype == infinicore.int16:
low, high = -32768, 32768
else: # int32, int64, uint32
low, high = -1000, 1000
else:
low, high = -1000, 1000

tensor = torch.randint(
-2000000000,
2000000000,
low,
high,
shape,
dtype=torch_dtype,
device=torch_device_str,
Expand Down
99 changes: 76 additions & 23 deletions test/infinicore/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,52 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
print(f" {desc} time: {elapsed * 1000 :6f} ms")


def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
infinicore.int8,
infinicore.int16,
infinicore.int32,
infinicore.int64,
infinicore.uint8,
]


def is_float_dtype(dtype):
"""Check if dtype is floating point type"""
return dtype in [infinicore.float16, infinicore.float32, infinicore.bfloat16]


def debug(
actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True, dtype=None
):
"""
Debug function to compare two tensors and print differences
"""
# Convert to float32 for bfloat16 comparison
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)

print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose, dtype)

import numpy as np
# Use appropriate comparison based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, require exact equality
import numpy as np

np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
np.testing.assert_array_equal(actual.cpu(), desired.cpu())
else:
# For float types, use allclose
import numpy as np

np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)


def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True, dtype=None
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
Expand All @@ -69,13 +96,21 @@ def print_discrepancy(
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)

# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
# Calculate difference mask based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, exact equality required
diff_mask = actual != expected
else:
# For float types, use tolerance-based comparison
nan_mismatch = (
actual_isnan ^ expected_isnan
if equal_nan
else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)

diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected

Expand Down Expand Up @@ -107,8 +142,9 @@ def add_color(text, color_code):

print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
if not (dtype and is_integer_dtype(dtype)):
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
Expand All @@ -130,6 +166,10 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
"""
Get tolerance settings based on data type
"""
# For integer types, return zero tolerance (exact match required)
if is_integer_dtype(tensor_dtype):
return 0, 0

tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
Expand Down Expand Up @@ -162,8 +202,6 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
Args:
infini_result: infinicore tensor result
torch_reference: PyTorch tensor reference (for shape and device)
dtype: infinicore data type
device_str: torch device string

Returns:
torch.Tensor: PyTorch tensor with infinicore data
Expand All @@ -179,7 +217,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):


def compare_results(
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False, dtype=None
):
"""
Generic function to compare infinicore result with PyTorch reference result
Expand All @@ -190,19 +228,29 @@ def compare_results(
atol: absolute tolerance
rtol: relative tolerance
debug_mode: whether to enable debug output
dtype: infinicore data type for comparison logic

Returns:
bool: True if results match within tolerance
"""
# Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result)

# Choose comparison method based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, require exact equality
result = torch.equal(torch_result_from_infini, torch_result)
else:
# For float types, use tolerance-based comparison
result = torch.allclose(
torch_result_from_infini, torch_result, atol=atol, rtol=rtol
)

# Debug mode: detailed comparison
if debug_mode:
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol, dtype=dtype)

# Check if results match within tolerance
return torch.allclose(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
return result


def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
Expand All @@ -227,7 +275,12 @@ def compare_test_results(infini_result, torch_result):
if config.debug and mode_name:
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m")
return compare_results(
infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug
infini_result,
torch_result,
atol=atol,
rtol=rtol,
debug_mode=config.debug,
dtype=dtype,
)

return compare_test_results
Expand Down