| 
15 | 15 | import torch  | 
16 | 16 | from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \  | 
17 | 17 |     skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \  | 
18 |  | -    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, \  | 
 | 18 | +    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, TEST_HPU, \  | 
19 | 19 |     _TestParametrizer, compose_parametrize_fns, dtype_name, \  | 
20 | 20 |     TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \  | 
21 | 21 |     get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE, \  | 
@@ -590,6 +590,18 @@ def setUpClass(cls):  | 
590 | 590 |     def _should_stop_test_suite(self):  | 
591 | 591 |         return False  | 
592 | 592 | 
 
  | 
 | 593 | +class HPUTestBase(DeviceTypeTestBase):  | 
 | 594 | +    device_type = 'hpu'  | 
 | 595 | +    primary_device: ClassVar[str]  | 
 | 596 | + | 
 | 597 | +    @classmethod  | 
 | 598 | +    def get_primary_device(cls):  | 
 | 599 | +        return cls.primary_device  | 
 | 600 | + | 
 | 601 | +    @classmethod  | 
 | 602 | +    def setUpClass(cls):  | 
 | 603 | +        cls.primary_device = 'hpu:0'  | 
 | 604 | + | 
593 | 605 | class PrivateUse1TestBase(DeviceTypeTestBase):  | 
594 | 606 |     primary_device: ClassVar[str]  | 
595 | 607 |     device_mod = None  | 
@@ -701,6 +713,8 @@ def get_desired_device_type_test_bases(except_for=None, only_for=None, include_l  | 
701 | 713 |         test_bases.append(MPSTestBase)  | 
702 | 714 |     if only_for == 'xpu' and TEST_XPU and XPUTestBase not in test_bases:  | 
703 | 715 |         test_bases.append(XPUTestBase)  | 
 | 716 | +    if TEST_HPU and HPUTestBase not in test_bases:  | 
 | 717 | +        test_bases.append(HPUTestBase)  | 
704 | 718 |     # Filter out the device types based on user inputs  | 
705 | 719 |     desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for)  | 
706 | 720 |     if include_lazy:  | 
@@ -1060,6 +1074,10 @@ class skipMPSIf(skipIf):  | 
1060 | 1074 |     def __init__(self, dep, reason):  | 
1061 | 1075 |         super().__init__(dep, reason, device_type='mps')  | 
1062 | 1076 | 
 
  | 
 | 1077 | +class skipHPUIf(skipIf):  | 
 | 1078 | +    def __init__(self, dep, reason):  | 
 | 1079 | +        super().__init__(dep, reason, device_type='hpu')  | 
 | 1080 | + | 
1063 | 1081 | # Skips a test on XLA if the condition is true.  | 
1064 | 1082 | class skipXLAIf(skipIf):  | 
1065 | 1083 | 
 
  | 
@@ -1343,6 +1361,9 @@ def onlyMPS(fn):  | 
1343 | 1361 | def onlyXPU(fn):  | 
1344 | 1362 |     return onlyOn('xpu')(fn)  | 
1345 | 1363 | 
 
  | 
 | 1364 | +def onlyHPU(fn):  | 
 | 1365 | +    return onlyOn('hpu')(fn)  | 
 | 1366 | + | 
1346 | 1367 | def onlyPRIVATEUSE1(fn):  | 
1347 | 1368 |     device_type = torch._C._get_privateuse1_backend_name()  | 
1348 | 1369 |     device_mod = getattr(torch, device_type, None)  | 
@@ -1401,6 +1422,9 @@ def expectedFailureMeta(fn):  | 
1401 | 1422 | def expectedFailureXLA(fn):  | 
1402 | 1423 |     return expectedFailure('xla')(fn)  | 
1403 | 1424 | 
 
  | 
 | 1425 | +def expectedFailureHPU(fn):  | 
 | 1426 | +    return expectedFailure('hpu')(fn)  | 
 | 1427 | + | 
1404 | 1428 | # Skips a test on CPU if LAPACK is not available.  | 
1405 | 1429 | def skipCPUIfNoLapack(fn):  | 
1406 | 1430 |     return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)  | 
@@ -1578,6 +1602,9 @@ def skipXLA(fn):  | 
1578 | 1602 | def skipMPS(fn):  | 
1579 | 1603 |     return skipMPSIf(True, "test doesn't work on MPS backend")(fn)  | 
1580 | 1604 | 
 
  | 
 | 1605 | +def skipHPU(fn):  | 
 | 1606 | +    return skipHPUIf(True, "test doesn't work on HPU backend")(fn)  | 
 | 1607 | + | 
1581 | 1608 | def skipPRIVATEUSE1(fn):  | 
1582 | 1609 |     return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)  | 
1583 | 1610 | 
 
  | 
 | 
0 commit comments