|
13 | 13 | # load time) the error message is timely and visible. |
14 | 14 | from . import optimizers |
15 | 15 | from . import normalization |
16 | | -from . import transformer |
17 | | - |
18 | | - |
19 | | -__all__ = ["optimizers", "normalization", "transformer"] |
20 | | - |
21 | | - |
22 | | -# Logging utilities for apex.transformer module |
23 | | -class RankInfoFormatter(logging.Formatter): |
24 | | - |
25 | | - def format(self, record): |
26 | | - from apex.transformer.parallel_state import get_rank_info |
27 | | - record.rank_info = get_rank_info() |
28 | | - return super().format(record) |
29 | | - |
30 | | - |
31 | | -_library_root_logger = logging.getLogger(__name__) |
32 | | -handler = logging.StreamHandler() |
33 | | -handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S")) |
34 | | -_library_root_logger.addHandler(handler) |
35 | | -_library_root_logger.propagate = False |
36 | 16 |
|
| 17 | +if torch.distributed.is_available(): |
| 18 | + from . import transformer |
| 19 | + __all__ = ["optimizers", "normalization", "transformer"] |
| 20 | + |
| 21 | + # Logging utilities for apex.transformer module |
| 22 | + class RankInfoFormatter(logging.Formatter): |
| 23 | + |
| 24 | + def format(self, record): |
| 25 | + from apex.transformer.parallel_state import get_rank_info |
| 26 | + record.rank_info = get_rank_info() |
| 27 | + return super().format(record) |
| 28 | + |
| 29 | + _library_root_logger = logging.getLogger(__name__) |
| 30 | + handler = logging.StreamHandler() |
| 31 | + handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S")) |
| 32 | + _library_root_logger.addHandler(handler) |
| 33 | + _library_root_logger.propagate = False |
| 34 | +else: |
| 35 | + # Transformers require PyTorch built with distributed support |
| 36 | + __all__ = ["optimizers", "normalization"] |
37 | 37 |
|
38 | 38 | def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: |
39 | 39 | cudnn_available = torch.backends.cudnn.is_available() |
|
0 commit comments