Skip to content

Commit 184ea24

Browse files
authored
exclude apex.transformer if distributed is not available (#1936)
1 parent 4b46e00 commit 184ea24

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

apex/__init__.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,27 @@
1313
# load time) the error message is timely and visible.
1414
from . import optimizers
1515
from . import normalization
16-
from . import transformer
17-
18-
19-
__all__ = ["optimizers", "normalization", "transformer"]
20-
21-
22-
# Logging utilities for apex.transformer module
23-
class RankInfoFormatter(logging.Formatter):
24-
25-
def format(self, record):
26-
from apex.transformer.parallel_state import get_rank_info
27-
record.rank_info = get_rank_info()
28-
return super().format(record)
29-
30-
31-
_library_root_logger = logging.getLogger(__name__)
32-
handler = logging.StreamHandler()
33-
handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
34-
_library_root_logger.addHandler(handler)
35-
_library_root_logger.propagate = False
3616

17+
if torch.distributed.is_available():
18+
from . import transformer
19+
__all__ = ["optimizers", "normalization", "transformer"]
20+
21+
# Logging utilities for apex.transformer module
22+
class RankInfoFormatter(logging.Formatter):
23+
24+
def format(self, record):
25+
from apex.transformer.parallel_state import get_rank_info
26+
record.rank_info = get_rank_info()
27+
return super().format(record)
28+
29+
_library_root_logger = logging.getLogger(__name__)
30+
handler = logging.StreamHandler()
31+
handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
32+
_library_root_logger.addHandler(handler)
33+
_library_root_logger.propagate = False
34+
else:
35+
# Transformers require PyTorch built with distributed support
36+
__all__ = ["optimizers", "normalization"]
3737

3838
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
3939
cudnn_available = torch.backends.cudnn.is_available()

0 commit comments

Comments
 (0)