diff --git a/README.md b/README.md index 1150c66cbe..cf64f99ead 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,21 @@ -------------------------------------------------------------------------------- -Fairseq(-py) is a sequence modeling toolkit that allows researchers and -developers to train custom models for translation, summarization, language -modeling and other text generation tasks. +# Fairseq + +Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks. + +## New in Version 0.13.0 +- Added support for PyTorch 2.6+ with safe globals handling +- Updated dependency requirements for modern Python environments +- Improved Windows compatibility +- Added explicit fairscale dependency + +## Requirements and Installation +* [PyTorch](http://pytorch.org/) version >= 2.6.0 +* Python version >= 3.8 +* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) +* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` and `--deprecated_fused_adam` options We provide reference implementations of various sequence modeling papers: @@ -146,39 +158,6 @@ en2de.translate('Hello world', beam=5) See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. -# Requirements and Installation - -* [PyTorch](http://pytorch.org/) version >= 1.10.0 -* Python version >= 3.8 -* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) -* **To install fairseq** and develop locally: - -``` bash -git clone https://github.com/pytorch/fairseq -cd fairseq -pip install --editable ./ - -# on MacOS: -# CFLAGS="-stdlib=libc++" pip install --editable ./ - -# to install the latest stable release (0.10.x) -# pip install fairseq -``` - -* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: - -``` bash -git clone https://github.com/NVIDIA/apex -cd apex -pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ - --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ - --global-option="--fast_multihead_attn" ./ -``` - -* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` -* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` - as command line options to `nvidia-docker run` . - # Getting Started The [full documentation](https://fairseq.readthedocs.io/) contains instructions diff --git a/examples/MMPT/mmpt/utils/load_config.py b/examples/MMPT/mmpt/utils/load_config.py index ede4f94117..bcc67bb89e 100644 --- a/examples/MMPT/mmpt/utils/load_config.py +++ b/examples/MMPT/mmpt/utils/load_config.py @@ -59,6 +59,13 @@ def recursive_config(config_path): includes = config.includes config.pop("includes") base_config = recursive_config(includes) + + # Filter out any MISSING values from config before merging + if isinstance(config, omegaconf.DictConfig): + for key in list(config.keys()): + if config[key] is None or (isinstance(config[key], str) and config[key] == "???"): + config.pop(key) + config = OmegaConf.merge(base_config, config) return config diff --git a/examples/speech_recognition/kaldi/kaldi_initializer.py b/examples/speech_recognition/kaldi/kaldi_initializer.py index 6d2a2a4b6b..f36b4f990b 100644 --- a/examples/speech_recognition/kaldi/kaldi_initializer.py +++ b/examples/speech_recognition/kaldi/kaldi_initializer.py @@ -669,7 +669,7 @@ def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path: return hlg_graph -@hydra.main(config_path=config_path, config_name="kaldi_initializer") +@hydra.main(version_base=None, config_path=config_path, config_name="kaldi_initializer") def cli_main(cfg: KaldiInitializerConfig) -> None: container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) cfg = OmegaConf.create(container) @@ -683,13 +683,19 @@ def cli_main(cfg: KaldiInitializerConfig) -> None: logging.basicConfig(level=logging.INFO) try: - from hydra._internal.utils import ( - get_args, - ) # pylint: disable=import-outside-toplevel - - cfg_name = get_args().config_name or "kaldi_initializer" - except ImportError: - logger.warning("Failed to get config name from hydra args") + import sys + + # Use built-in argparse instead of hydra._internal.utils.get_args + cfg_name = "kaldi_initializer" + for i, arg in enumerate(sys.argv): + if arg == "--config-name" and i + 1 < len(sys.argv): + cfg_name = sys.argv[i + 1] + break + elif arg.startswith("--config-name="): + cfg_name = arg.split("=", 1)[1] + break + except: + logger.warning("Failed to get config name from command line arguments") cfg_name = "kaldi_initializer" cs = ConfigStore.instance() diff --git a/examples/speech_recognition/new/infer.py b/examples/speech_recognition/new/infer.py index ca5cea4a7c..18e404d6f9 100644 --- a/examples/speech_recognition/new/infer.py +++ b/examples/speech_recognition/new/infer.py @@ -440,7 +440,7 @@ def main(cfg: InferConfig) -> float: return wer -@hydra.main(config_path=config_path, config_name="infer") +@hydra.main(version_base=None, config_path=config_path, config_name="infer") def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) cfg = OmegaConf.create(container) @@ -478,13 +478,19 @@ def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: def cli_main() -> None: try: - from hydra._internal.utils import ( - get_args, - ) # pylint: disable=import-outside-toplevel - - cfg_name = get_args().config_name or "infer" - except ImportError: - logger.warning("Failed to get config name from hydra args") + import sys + + # Use built-in argparse instead of hydra._internal.utils.get_args + cfg_name = "infer" + for i, arg in enumerate(sys.argv): + if arg == "--config-name" and i + 1 < len(sys.argv): + cfg_name = sys.argv[i + 1] + break + elif arg.startswith("--config-name="): + cfg_name = arg.split("=", 1)[1] + break + except: + logger.warning("Failed to get config name from command line arguments") cfg_name = "infer" cs = ConfigStore.instance() diff --git a/examples/wav2vec/unsupervised/w2vu_generate.py b/examples/wav2vec/unsupervised/w2vu_generate.py index 0611297a4f..bb73e2b3b3 100644 --- a/examples/wav2vec/unsupervised/w2vu_generate.py +++ b/examples/wav2vec/unsupervised/w2vu_generate.py @@ -672,6 +672,7 @@ def main(cfg: UnsupGenerateConfig, model=None): @hydra.main( + version_base=None, config_path=os.path.join("../../..", "fairseq", "config"), config_name="config" ) def hydra_main(cfg): @@ -698,11 +699,19 @@ def hydra_main(cfg): def cli_main(): try: - from hydra._internal.utils import get_args - - cfg_name = get_args().config_name or "config" + import sys + + # Use built-in argparse instead of hydra._internal.utils.get_args + cfg_name = "config" + for i, arg in enumerate(sys.argv): + if arg == "--config-name" and i + 1 < len(sys.argv): + cfg_name = sys.argv[i + 1] + break + elif arg.startswith("--config-name="): + cfg_name = arg.split("=", 1)[1] + break except: - logger.warning("Failed to get config name from hydra args") + logger.warning("Failed to get config name from command line arguments") cfg_name = "config" cs = ConfigStore.instance() diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index e3f316b9e7..2264d9a7de 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -337,7 +337,11 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): local_path = PathManager.get_local_path(path) with open(local_path, "rb") as f: - state = torch.load(f, map_location=torch.device("cpu")) + # See if torch.load has weights_only parameter + if hasattr(torch.load, "weights_only"): + state = torch.load(f, map_location=torch.device("cpu"), weights_only=False) + else: + state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state["args"] is not None and arg_overrides is not None: args = state["args"] @@ -345,22 +349,21 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): setattr(args, arg_name, arg_val) if "cfg" in state and state["cfg"] is not None: - - # hack to be able to set Namespace in dict config. this should be removed when we update to newer - # omegaconf version that supports object flags, or when we migrate all existing models + # Use proper object flags approach for omegaconf 2.1+ from omegaconf import __version__ as oc_version - from omegaconf import _utils - - if oc_version < "2.2": + + if oc_version >= "2.1.0": + # OmegaConf 2.1+ can handle this with allow_objects flag + state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True}) + else: + # Fallback for older versions using the hacky approach + from omegaconf import _utils old_primitive = _utils.is_primitive_type _utils.is_primitive_type = lambda _: True - state["cfg"] = OmegaConf.create(state["cfg"]) - _utils.is_primitive_type = old_primitive - OmegaConf.set_struct(state["cfg"], True) - else: - state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True}) + + OmegaConf.set_struct(state["cfg"], True) if arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) @@ -906,12 +909,22 @@ def load_ema_from_checkpoint(fpath): new_state = None with PathManager.open(fpath, "rb") as f: - new_state = torch.load( - f, - map_location=( - lambda s, _: torch.serialization.default_restore_location(s, "cpu") - ), - ) + # See if torch.load has weights_only parameter + if hasattr(torch.load, "weights_only"): + new_state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + weights_only=False + ) + else: + new_state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) # EMA model is stored in a separate "extra state" model_params = new_state["extra_state"]["ema"] diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 5a7784bad1..11c6357aed 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -58,4 +58,10 @@ def add_defaults(cfg: DictConfig) -> None: dc = REGISTRIES[k]["dataclass_registry"].get(name) if dc is not None: + # Filter out any MISSING values before merging + if OmegaConf.is_config(field_cfg): + for key in list(field_cfg.keys()): + if field_cfg[key] is None or (isinstance(field_cfg[key], str) and field_cfg[key] == "???"): + field_cfg.pop(key) + cfg[k] = merge_with_parent(dc, field_cfg) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index f6467d5f40..b57f324b8d 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -16,7 +16,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import FairseqConfig from hydra.core.global_hydra import GlobalHydra -from hydra.experimental import compose, initialize +from hydra import compose, initialize from omegaconf import DictConfig, OmegaConf, open_dict, _utils logger = logging.getLogger(__name__) @@ -362,27 +362,6 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: return overrides, deletes -class omegaconf_no_object_check: - def __init__(self): - # Changed in https://github.com/omry/omegaconf/pull/911 - both are kept for back compat. - if hasattr(_utils, "is_primitive_type"): - self.old_is_primitive = _utils.is_primitive_type - else: - self.old_is_primitive = _utils.is_primitive_type_annotation - - def __enter__(self): - if hasattr(_utils, "is_primitive_type"): - _utils.is_primitive_type = lambda _: True - else: - _utils.is_primitive_type_annotation = lambda _: True - - def __exit__(self, type, value, traceback): - if hasattr(_utils, "is_primitive_type"): - _utils.is_primitive_type = self.old_is_primitive - else: - _utils.is_primitive_type_annotation = self.old_is_primitive - - def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: """Convert a flat argparse.Namespace to a structured DictConfig.""" @@ -394,7 +373,7 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: GlobalHydra.instance().clear() - with initialize(config_path=config_path): + with initialize(version_base=None, config_path=config_path): try: composed_cfg = compose("config", overrides=overrides, strict=False) except: @@ -404,47 +383,46 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: for k in deletes: composed_cfg[k] = None - cfg = OmegaConf.create( - OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) - ) - - # hack to be able to set Namespace in dict config. this should be removed when we update to newer - # omegaconf version that supports object flags, or when we migrate all existing models - from omegaconf import _utils - - with omegaconf_no_object_check(): - if cfg.task is None and getattr(args, "task", None): - cfg.task = Namespace(**vars(args)) - from fairseq.tasks import TASK_REGISTRY - - _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) - cfg.task._name = args.task - if cfg.model is None and getattr(args, "arch", None): - cfg.model = Namespace(**vars(args)) - from fairseq.models import ARCH_MODEL_REGISTRY - - _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) - cfg.model._name = args.arch - if cfg.optimizer is None and getattr(args, "optimizer", None): - cfg.optimizer = Namespace(**vars(args)) - from fairseq.optim import OPTIMIZER_REGISTRY - - _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) - cfg.optimizer._name = args.optimizer - if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): - cfg.lr_scheduler = Namespace(**vars(args)) - from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY - - _set_legacy_defaults( - cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler] - ) - cfg.lr_scheduler._name = args.lr_scheduler - if cfg.criterion is None and getattr(args, "criterion", None): - cfg.criterion = Namespace(**vars(args)) - from fairseq.criterions import CRITERION_REGISTRY - - _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) - cfg.criterion._name = args.criterion + # Filter out any MISSING values before creating the final config + composed_cfg_dict = OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + if isinstance(composed_cfg_dict, dict): + for k, v in list(composed_cfg_dict.items()): + if v == "???" or v is None: + composed_cfg_dict.pop(k) + + # Create the config with proper object handling in omegaconf 2.1+ + cfg = OmegaConf.create(composed_cfg_dict, flags={"allow_objects": True}) + + # Handle task, model, optimizer, lr_scheduler, criterion namespace conversions properly + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + _set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]) + cfg.lr_scheduler._name = args.lr_scheduler + + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion OmegaConf.set_struct(cfg, True) return cfg @@ -504,7 +482,21 @@ def remove_missing_rec(src_keys, target_cfg): with open_dict(cfg): remove_missing_rec(cfg, dc) - merged_cfg = OmegaConf.merge(dc, cfg) - merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] + # Filter out any MISSING values from the config before merging to avoid + # them overriding values in dc (omegaconf 2.1+ behavior change) + filtered_cfg = OmegaConf.create({}) if cfg is None else cfg + if OmegaConf.is_config(filtered_cfg): + with open_dict(filtered_cfg): + filtered_cfg_dict = {} + for k, v in filtered_cfg.items(): + if v is not None and not (isinstance(v, str) and v == "???"): + filtered_cfg_dict[k] = v + + # Create a new config with the filtered values + filtered_cfg = OmegaConf.create(filtered_cfg_dict) + + merged_cfg = OmegaConf.merge(dc, filtered_cfg) + if hasattr(cfg, "_parent") and cfg._parent is not None: + merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] OmegaConf.set_struct(merged_cfg, True) return merged_cfg diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 4c64b61bad..5fec1344ee 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -479,6 +479,8 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): import wandb except ImportError: wandb = None +except AttributeError: + wandb = None class WandBProgressBarWrapper(BaseProgressBar): diff --git a/fairseq/models/transformer/transformer_config.py b/fairseq/models/transformer/transformer_config.py index 4650de2e17..396c0801c7 100644 --- a/fairseq/models/transformer/transformer_config.py +++ b/fairseq/models/transformer/transformer_config.py @@ -245,8 +245,11 @@ class TransformerConfig(FairseqDataclass): def __getattr__(self, name): match = re.match(_NAME_PARSER, name) if match: - sub = safe_getattr(self, match[1]) - return safe_getattr(sub, match[2]) + try: + sub = safe_getattr(self, match[1]) + return safe_getattr(sub, match[2]) + except (AttributeError, KeyError): + raise AttributeError(f"invalid argument {name}.") raise AttributeError(f"invalid argument {name}.") def __setattr__(self, name, value): diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index e39d1d6848..25f7d9c91a 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -38,11 +38,14 @@ def state_dict(self) -> Dict[str, Any]: return self._state def __getattr__(self, name): - if name not in self._state and name in self._factories: - self._state[name] = self._factories[name]() - - if name in self._state: - return self._state[name] + try: + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + except (AttributeError, KeyError): + pass raise AttributeError(f"Task state has no factory for attribute {name}") diff --git a/fairseq/utils.py b/fairseq/utils.py index 4d4b350523..cc608f9fbd 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -832,11 +832,49 @@ def safe_getattr(obj, k, default=None): from omegaconf import OmegaConf if OmegaConf.is_config(obj): - return obj[k] if k in obj and obj[k] is not None else default + try: + return obj[k] if k in obj and obj[k] is not None else default + except (KeyError, AttributeError): + return default return getattr(obj, k, default) +def safe_dictconfig_get(cfg, key, default=None): + """ + Safely access keys in DictConfig objects with omegaconf 2.1+ compatibility. + + In omegaconf 2.1+, cfg[key] raises KeyError and cfg.key raises AttributeError + if the key doesn't exist, while in older versions it returned None. + This function provides backward compatibility. + + Args: + cfg: DictConfig or similar object + key: Key to access + default: Default value if key doesn't exist + + Returns: + Value at key or default if key doesn't exist + """ + from omegaconf import OmegaConf + + if OmegaConf.is_config(cfg): + try: + # First check if key exists to avoid error + if key in cfg: + val = cfg[key] + return val if val is not None else default + return default + except (KeyError, AttributeError): + return default + + # Fall back to dict-like get or attribute access + if hasattr(cfg, "get"): + return cfg.get(key, default) + + return getattr(cfg, key, default) + + def safe_hasattr(obj, k): """Returns True if the given key exists and is not None.""" return getattr(obj, k, None) is not None diff --git a/fairseq/version.txt b/fairseq/version.txt index 26acbf080b..51de3305bb 100644 --- a/fairseq/version.txt +++ b/fairseq/version.txt @@ -1 +1 @@ -0.12.2 +0.13.0 \ No newline at end of file diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index 607340af0d..6d9b114603 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -15,14 +15,13 @@ from fairseq import distributed_utils, metrics from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.initialize import add_defaults, hydra_init -from fairseq.dataclass.utils import omegaconf_no_object_check from fairseq.utils import reset_logging from fairseq_cli.train import main as pre_main logger = logging.getLogger("fairseq_cli.hydra_train") -@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") +@hydra.main(version_base=None, config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> float: _hydra_main(cfg) @@ -41,10 +40,9 @@ def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: HydraConfig.get().job_logging, resolve=True ) - with omegaconf_no_object_check(): - cfg = OmegaConf.create( - OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) - ) + # Create the config with proper object handling in omegaconf 2.1+ + cfg_dict = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(cfg_dict, flags={"allow_objects": True}) OmegaConf.set_struct(cfg, True) try: @@ -76,11 +74,20 @@ def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: def cli_main(): try: - from hydra._internal.utils import get_args - - cfg_name = get_args().config_name or "config" + import sys + from hydra.core.config_store import ConfigStore + + # Use built-in argparse instead of hydra._internal.utils.get_args + cfg_name = "config" + for i, arg in enumerate(sys.argv): + if arg == "--config-name" and i + 1 < len(sys.argv): + cfg_name = sys.argv[i + 1] + break + elif arg.startswith("--config-name="): + cfg_name = arg.split("=", 1)[1] + break except: - logger.warning("Failed to get config name from hydra args") + logger.warning("Failed to get config name from command line arguments") cfg_name = "config" hydra_init(cfg_name) diff --git a/fairseq_cli/hydra_validate.py b/fairseq_cli/hydra_validate.py index cb6f7612d0..ec7207a9df 100644 --- a/fairseq_cli/hydra_validate.py +++ b/fairseq_cli/hydra_validate.py @@ -17,7 +17,6 @@ from fairseq import checkpoint_utils, distributed_utils, utils from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.initialize import add_defaults, hydra_init -from fairseq.dataclass.utils import omegaconf_no_object_check from fairseq.distributed import utils as distributed_utils from fairseq.logging import metrics, progress_bar from fairseq.utils import reset_logging @@ -31,7 +30,7 @@ logger = logging.getLogger("fairseq_cli.validate") -@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") +@hydra.main(version_base=None, config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> float: return _hydra_main(cfg) @@ -50,10 +49,9 @@ def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: HydraConfig.get().job_logging, resolve=True ) - with omegaconf_no_object_check(): - cfg = OmegaConf.create( - OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) - ) + # Create the config with proper object handling in omegaconf 2.1+ + cfg_dict = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(cfg_dict, flags={"allow_objects": True}) OmegaConf.set_struct(cfg, True) assert ( @@ -173,11 +171,20 @@ def apply_half(t): def cli_main(): try: - from hydra._internal.utils import get_args - - cfg_name = get_args().config_name or "config" + import sys + from hydra.core.config_store import ConfigStore + + # Use built-in argparse instead of hydra._internal.utils.get_args + cfg_name = "config" + for i, arg in enumerate(sys.argv): + if arg == "--config-name" and i + 1 < len(sys.argv): + cfg_name = sys.argv[i + 1] + break + elif arg.startswith("--config-name="): + cfg_name = arg.split("=", 1)[1] + break except: - logger.warning("Failed to get config name from hydra args") + logger.warning("Failed to get config name from command line arguments") cfg_name = "config" hydra_init(cfg_name) diff --git a/pyproject.toml b/pyproject.toml index 4d84c9bc36..9f779e1d16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,57 @@ [build-system] requires = [ - "setuptools>=18.0", + "setuptools>=61.0", "wheel", "cython", - "numpy>=1.21.3", - "torch>=1.10", ] build-backend = "setuptools.build_meta" +[project] +name = "fairseq" +dynamic = ["version"] +description = "Facebook AI Research Sequence-to-Sequence Toolkit" +readme = "README.md" +requires-python = ">=3.8" +license = {text = "MIT"} +authors = [ + {name = "Facebook AI Research"} +] +classifiers = [ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "cffi", + "hydra-core>=1.3.2", + "omegaconf>=2.1.0", + "numpy>=1.24.2", + "regex", + "sacrebleu>=1.4.12", + "torch>=2.4.0", + "tqdm", + "bitarray", + "torchaudio>=2.4.0", + "torchvision>=0.19.0", + "scikit-learn", + "packaging", + "scipy>=1.15.0", + "pandas>=2.2.0", + "transformers>=4.51.3", + "numpy>=1.24.2", +] + +[project.optional-dependencies] +dev = ["flake8", "pytest", "black==22.3.0"] +docs = ["sphinx", "sphinx-argparse"] + +[project.urls] +"Homepage" = "https://github.com/pytorch/fairseq" +"Bug Tracker" = "https://github.com/pytorch/fairseq/issues" + [tool.black] extend-exclude = ''' ( diff --git a/setup.py b/setup.py index dae06080c5..8734b64d29 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,10 @@ from setuptools import Extension, find_packages, setup from torch.utils import cpp_extension +from torch.utils.cpp_extension import CUDAExtension + +# Check if we're on Windows +is_windows = sys.platform == "win32" if sys.version_info < (3, 6): sys.exit("Sorry, Python >= 3.6 is required for fairseq.") @@ -37,6 +41,18 @@ def write_version_py(): else: extra_compile_args = ["-std=c++11", "-O3"] +# Add Windows-specific configurations +cuda_lib_path = None +cuda_include_path = None + +if "CUDA_HOME" in os.environ: + cuda_home = os.environ.get("CUDA_HOME") + if is_windows: + # On Windows, some libraries are in bin instead of lib/x64 + cuda_lib_path = [os.path.join(cuda_home, "lib", "x64"), + os.path.join(cuda_home, "bin")] + cuda_include_path = [os.path.join(cuda_home, "include")] + class NumpyExtension(Extension): """Source: https://stackoverflow.com/a/54128391""" @@ -102,34 +118,70 @@ def include_dirs(self, dirs): ), ] ) -if "CUDA_HOME" in os.environ: - extensions.extend( - [ - cpp_extension.CppExtension( - "fairseq.libnat_cuda", - sources=[ - "fairseq/clib/libnat_cuda/edit_dist.cu", - "fairseq/clib/libnat_cuda/binding.cpp", - ], - ), - cpp_extension.CppExtension( - "fairseq.ngram_repeat_block_cuda", - sources=[ - "fairseq/clib/cuda/ngram_repeat_block_cuda.cpp", - "fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu", - ], - ), - cpp_extension.CppExtension( - "alignment_train_cuda_binding", - sources=[ - "examples/operators/alignment_train_kernel.cu", - "examples/operators/alignment_train_cuda.cpp", - ], - ), - ] - ) -cmdclass = {"build_ext": cpp_extension.BuildExtension} +# Add CUDA extensions if CUDA is available +if "CUDA_HOME" in os.environ and not is_windows: + # Configure CUDA extension settings based on platform + cuda_extension_args = {} + + # Special handling for Windows + if is_windows: + cuda_home = os.environ.get("CUDA_HOME") + # Manual specification of libraries and include directories + cuda_extension_args = { + 'library_dirs': [ + os.path.join(cuda_home, 'lib', 'x64'), + os.path.join(cuda_home, 'bin'), + ], + 'libraries': ['cudart'], + 'define_macros': [('TORCH_EXTENSION_NAME', 'fairseq_cuda_extension')], + } + + # Add CUDA extensions with platform-specific settings + extensions.extend([ + cpp_extension.CppExtension( + "fairseq.libnat_cuda", + sources=[ + "fairseq/clib/libnat_cuda/edit_dist.cu", + "fairseq/clib/libnat_cuda/binding.cpp", + ], + **cuda_extension_args + ), + cpp_extension.CppExtension( + "fairseq.ngram_repeat_block_cuda", + sources=[ + "fairseq/clib/cuda/ngram_repeat_block_cuda.cpp", + "fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu", + ], + **cuda_extension_args + ), + cpp_extension.CppExtension( + "alignment_train_cuda_binding", + sources=[ + "examples/operators/alignment_train_kernel.cu", + "examples/operators/alignment_train_cuda.cpp", + ], + **cuda_extension_args + ), + ]) + +# Customize build_ext class for Windows +if is_windows: + from torch.utils.cpp_extension import BuildExtension + + class CustomBuildExtension(BuildExtension): + def build_extensions(self): + # Define specific compiler flags for Windows MSVC + for extension in self.extensions: + if hasattr(extension, 'sources') and any(source.endswith('.cu') for source in extension.sources): + self.compiler.compiler_so.append('/EHsc') + self.compiler.compiler_so.append('/MD') + + BuildExtension.build_extensions(self) + + cmdclass = {"build_ext": CustomBuildExtension} +else: + cmdclass = {"build_ext": cpp_extension.BuildExtension} if "READTHEDOCS" in os.environ: # don't build extensions when generating docs @@ -138,8 +190,8 @@ def include_dirs(self, dirs): del cmdclass["build_ext"] # use CPU build of PyTorch - dependency_links = [ - "https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" + dependency_links = [ + "https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.1.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl" ] else: dependency_links = [] @@ -179,15 +231,15 @@ def do_setup(package_data): install_requires=[ "cffi", "cython", - "hydra-core>=1.0.7,<1.1", - "omegaconf<2.1", - "numpy>=1.21.3", + "hydra-core>=1.3.2", + "omegaconf>=2.1.0", + "numpy>=1.24.2", "regex", "sacrebleu>=1.4.12", - "torch>=1.13", + "torch>=2.4.0", "tqdm", "bitarray", - "torchaudio>=0.8.0", + "torchaudio>=2.4.0", "scikit-learn", "packaging", ], @@ -254,4 +306,4 @@ def get_files(path, relative_to="fairseq"): do_setup(package_data) finally: if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples): - os.unlink(fairseq_examples) + os.unlink(fairseq_examples) \ No newline at end of file diff --git a/setup_backup.py b/setup_backup.py new file mode 100644 index 0000000000..7021c41b39 --- /dev/null +++ b/setup_backup.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import sys + +from setuptools import Extension, find_packages, setup +from torch.utils import cpp_extension +from torch.utils.cpp_extension import CUDAExtension + +# Check if we're on Windows +is_windows = sys.platform == "win32" + +if sys.version_info < (3, 6): + sys.exit("Sorry, Python >= 3.6 is required for fairseq.") + + +def write_version_py(): + with open(os.path.join("fairseq", "version.txt")) as f: + version = f.read().strip() + + # write version info to fairseq/version.py + with open(os.path.join("fairseq", "version.py"), "w") as f: + f.write('__version__ = "{}"\n'.format(version)) + return version + + +version = write_version_py() + + +with open("README.md") as f: + readme = f.read() + + +if sys.platform == "darwin": + extra_compile_args = ["-stdlib=libc++", "-O3"] +else: + extra_compile_args = ["-std=c++11", "-O3"] + +# Add Windows-specific configurations +cuda_lib_path = None +cuda_include_path = None + +if "CUDA_HOME" in os.environ: + cuda_home = os.environ.get("CUDA_HOME") + if is_windows: + # On Windows, some libraries are in bin instead of lib/x64 + cuda_lib_path = [os.path.join(cuda_home, "lib", "x64"), + os.path.join(cuda_home, "bin")] + cuda_include_path = [os.path.join(cuda_home, "include")] + + +class NumpyExtension(Extension): + """Source: https://stackoverflow.com/a/54128391""" + + def __init__(self, *args, **kwargs): + self.__include_dirs = [] + super().__init__(*args, **kwargs) + + @property + def include_dirs(self): + import numpy + + return self.__include_dirs + [numpy.get_include()] + + @include_dirs.setter + def include_dirs(self, dirs): + self.__include_dirs = dirs + + +extensions = [ + Extension( + "fairseq.libbleu", + sources=[ + "fairseq/clib/libbleu/libbleu.cpp", + "fairseq/clib/libbleu/module.cpp", + ], + extra_compile_args=extra_compile_args, + ), + NumpyExtension( + "fairseq.data.data_utils_fast", + sources=["fairseq/data/data_utils_fast.pyx"], + language="c++", + extra_compile_args=extra_compile_args, + ), + NumpyExtension( + "fairseq.data.token_block_utils_fast", + sources=["fairseq/data/token_block_utils_fast.pyx"], + language="c++", + extra_compile_args=extra_compile_args, + ), +] + + +extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libbase", + sources=[ + "fairseq/clib/libbase/balanced_assignment.cpp", + ], + ), + cpp_extension.CppExtension( + "fairseq.libnat", + sources=[ + "fairseq/clib/libnat/edit_dist.cpp", + ], + ), + cpp_extension.CppExtension( + "alignment_train_cpu_binding", + sources=[ + "examples/operators/alignment_train_cpu.cpp", + ], + ), + ] +) + +# Add CUDA extensions if CUDA is available +if "CUDA_HOME" in os.environ: + # Configure CUDA extension settings based on platform + cuda_extension_args = {} + + # Special handling for Windows + if is_windows: + cuda_home = os.environ.get("CUDA_HOME") + # Manual specification of libraries and include directories + cuda_extension_args = { + 'library_dirs': [ + os.path.join(cuda_home, 'lib', 'x64'), + os.path.join(cuda_home, 'bin'), + ], + 'libraries': ['cudart'], + 'define_macros': [('TORCH_EXTENSION_NAME', 'fairseq_cuda_extension')], + } + + # Add CUDA extensions with platform-specific settings + extensions.extend([ + cpp_extension.CUDAExtension( + "fairseq.libnat_cuda", + sources=[ + "fairseq/clib/libnat_cuda/edit_dist.cu", + "fairseq/clib/libnat_cuda/binding.cpp", + ], + **cuda_extension_args + ), + cpp_extension.CUDAExtension( + "fairseq.ngram_repeat_block_cuda", + sources=[ + "fairseq/clib/cuda/ngram_repeat_block_cuda.cpp", + "fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu", + ], + **cuda_extension_args + ), + cpp_extension.CUDAExtension( + "alignment_train_cuda_binding", + sources=[ + "examples/operators/alignment_train_kernel.cu", + "examples/operators/alignment_train_cuda.cpp", + ], + **cuda_extension_args + ), + ]) + +# Customize build_ext class for Windows +if is_windows: + from torch.utils.cpp_extension import BuildExtension + + class CustomBuildExtension(BuildExtension): + def build_extensions(self): + # Define specific compiler flags for Windows MSVC + for extension in self.extensions: + if hasattr(extension, 'sources') and any(source.endswith('.cu') for source in extension.sources): + self.compiler.compiler_so.append('/EHsc') + self.compiler.compiler_so.append('/MD') + + BuildExtension.build_extensions(self) + + cmdclass = {"build_ext": CustomBuildExtension} +else: + cmdclass = {"build_ext": cpp_extension.BuildExtension} + +if "READTHEDOCS" in os.environ: + # don't build extensions when generating docs + extensions = [] + if "build_ext" in cmdclass: + del cmdclass["build_ext"] + + # use CPU build of PyTorch + dependency_links = [ + "https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.1.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl" + ] +else: + dependency_links = [] + + +if "clean" in sys.argv[1:]: + # Source: https://bit.ly/2NLVsgE + print("deleting Cython files...") + + subprocess.run( + ["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"], + shell=True, + ) + + +extra_packages = [] +if os.path.exists(os.path.join("fairseq", "model_parallel", "megatron", "mpu")): + extra_packages.append("fairseq.model_parallel.megatron.mpu") + + +def do_setup(package_data): + setup( + name="fairseq", + version=version, + description="Facebook AI Research Sequence-to-Sequence Toolkit", + url="https://github.com/pytorch/fairseq", + classifiers=[ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + long_description=readme, + long_description_content_type="text/markdown", + install_requires=[ + "cffi", + "cython", + "hydra-core>=1.3.2", + "omegaconf>=2.1.0", + "numpy>=1.21.3", + "regex", + "sacrebleu>=1.4.12", + "torch>=2.1.0", + "tqdm", + "bitarray", + "torchaudio>=2.1.0", + "scikit-learn", + "packaging", + ], + extras_require={ + "dev": ["flake8", "pytest", "black==22.3.0"], + "docs": ["sphinx", "sphinx-argparse"], + }, + dependency_links=dependency_links, + packages=find_packages( + exclude=[ + "examples", + "examples.*", + "scripts", + "scripts.*", + "tests", + "tests.*", + ] + ) + + extra_packages, + package_data=package_data, + ext_modules=extensions, + test_suite="tests", + entry_points={ + "console_scripts": [ + "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", + "fairseq-generate = fairseq_cli.generate:cli_main", + "fairseq-hydra-train = fairseq_cli.hydra_train:cli_main", + "fairseq-interactive = fairseq_cli.interactive:cli_main", + "fairseq-preprocess = fairseq_cli.preprocess:cli_main", + "fairseq-score = fairseq_cli.score:cli_main", + "fairseq-train = fairseq_cli.train:cli_main", + "fairseq-validate = fairseq_cli.validate:cli_main", + ], + }, + cmdclass=cmdclass, + zip_safe=False, + ) + + +def get_files(path, relative_to="fairseq"): + all_files = [] + for root, _dirs, files in os.walk(path, followlinks=True): + root = os.path.relpath(root, relative_to) + for file in files: + if file.endswith(".pyc"): + continue + all_files.append(os.path.join(root, file)) + return all_files + + +if __name__ == "__main__": + try: + # symlink examples into fairseq package so package_data accepts them + fairseq_examples = os.path.join("fairseq", "examples") + if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): + os.symlink(os.path.join("..", "examples"), fairseq_examples) + + package_data = { + "fairseq": ( + get_files(fairseq_examples) + + get_files(os.path.join("fairseq", "config")) + ) + } + do_setup(package_data) + finally: + if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples): + os.unlink(fairseq_examples) diff --git a/windows_install.bat b/windows_install.bat new file mode 100644 index 0000000000..b03f0f9f54 --- /dev/null +++ b/windows_install.bat @@ -0,0 +1,16 @@ +@echo off +echo Installing fairseq with CUDA support on Windows + +REM Set CUDA environment variables +set CUDA_HOME=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4 +set CUDA_PATH=%CUDA_HOME% +set PATH=%CUDA_HOME%\bin;%PATH% + +REM Print environment for debugging +echo CUDA_HOME: %CUDA_HOME% +echo CUDA_PATH: %CUDA_PATH% + +REM Install fairseq +pip install -e . + +pause \ No newline at end of file