diff --git a/nemo2riva/args.py b/nemo2riva/args.py index bdee602..81a8863 100644 --- a/nemo2riva/args.py +++ b/nemo2riva/args.py @@ -15,7 +15,7 @@ def get_args(argv): parser.add_argument("--validate", action="store_true", help="Validate using schemas") parser.add_argument("--runtime-check", action="store_true", help="Runtime check of exported net result") parser.add_argument("--schema", default=None, help="Schema file to use for validation") - parser.add_argument("--format", default=None, help="Force specific export format: ONNX|TS|CKPT|NEMO") + parser.add_argument("--format", default='default', help="Force specific export format: ONNX|TS|CKPT|NEMO") parser.add_argument("--verbose", default=None, help="Verbose level for logging") parser.add_argument("--key", default=None, help="Encryption key or file, default is None") parser.add_argument( diff --git a/nemo2riva/artifacts.py b/nemo2riva/artifacts.py index 1e750db..2f53233 100644 --- a/nemo2riva/artifacts.py +++ b/nemo2riva/artifacts.py @@ -93,14 +93,15 @@ def create_artifact(reg, key, do_encrypt, **af_dict): return af -def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs): +def get_artifacts(restore_path: str, model=None, passphrase=None, format='default', **patch_kwargs): artifacts = retrieve_artifacts_as_dict(obj=model, restore_path=restore_path) # NOTE: when servicemaker calls into get_artifacts, model is always None so this code section # is never run. # check if this model has one or more patches to apply, if yes go ahead and run it if model is not None and _HAVE_PATCHES and model.__class__.__name__ in patches: - for patch in patches[model.__class__.__name__]: + # Apply patches for the given format. + for patch in patches[model.__class__.__name__].get(format, []): patch(model, artifacts, **patch_kwargs) elif model is not None and not _HAVE_PATCHES: logging.error( diff --git a/nemo2riva/convert.py b/nemo2riva/convert.py index 7ff18b9..f00d88b 100644 --- a/nemo2riva/convert.py +++ b/nemo2riva/convert.py @@ -13,7 +13,7 @@ from nemo.core.config.pytorch_lightning import TrainerConfig from nemo.utils import logging from omegaconf import OmegaConf -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo2riva.artifacts import get_artifacts from nemo2riva.cookbook import export_model, save_archive @@ -64,7 +64,7 @@ def Nemo2Riva(args): key = None if args.key is not None: try: - with open(args.key, read_mode) as f: + with open(args.key, 'r') as f: key = f.read() except Exception: # literal key @@ -80,7 +80,9 @@ def Nemo2Riva(args): patch_kwargs = {"import_config" : cfg} if args.export_subnet: patch_kwargs['export_subnet'] = args.export_subnet - artifacts, manifest = get_artifacts(restore_path=nemo_in, model=model, passphrase=key, **patch_kwargs) + artifacts, manifest = get_artifacts( + restore_path=nemo_in, model=model, passphrase=key, format=args.format, **patch_kwargs + ) for export_cfg in cfg.exports: subnet = get_subnet(model, export_cfg.export_subnet) diff --git a/nemo2riva/cookbook.py b/nemo2riva/cookbook.py index a583ef1..3d5f5d3 100644 --- a/nemo2riva/cookbook.py +++ b/nemo2riva/cookbook.py @@ -48,6 +48,8 @@ def export_model(model, cfg, args, artifacts, metadata): format_meta = {"has_pytorch_checkpoint": True, "runtime": "PyTorch"} elif cfg.export_format == "NEMO": format_meta = {"has_pytorch_checkpoint": True, "runtime": "Python"} + elif cfg.export_format == "STATE": + format_meta = {"has_pytorch_checkpoint": False, "runtime": "Python"} # TODO: use submodel sections metadata.update(format_meta) runtime = format_meta["runtime"] @@ -140,6 +142,15 @@ def export_model(model, cfg, args, artifacts, metadata): elif cfg.export_format == "NEMO": model.save_to(export_file) + elif cfg.export_format == "STATE": + if not isinstance(model, Exportable): + logging.error("Your NeMo model class ({}) is not Exportable.".format(metadata['obj_cls'])) + sys.exit(1) + model.freeze() + model_params = model.state_dict() + torch.save(model_params, export_file) + + # Add exported file to the artifact registry diff --git a/nemo2riva/patches/__init__.py b/nemo2riva/patches/__init__.py index a70d4ce..c073d5b 100644 --- a/nemo2riva/patches/__init__.py +++ b/nemo2riva/patches/__init__.py @@ -3,13 +3,33 @@ from nemo2riva.patches.ctc import set_decoder_num_classes from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version +from nemo2riva.patches.aed_canary import config_for_trtllm from nemo2riva.patches.mtencdec import change_tokenizer_names from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning patches = { - "EncDecCTCModel": [set_decoder_num_classes], - "EncDecCTCModelBPE": [bpe_check_inputs_and_version], - "MTEncDecModel": [change_tokenizer_names], - "FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning], - "RadTTSModel": [generate_vocab_mapping, radtts_model_versioning], + 'EncDecCTCModel': { + 'default': [set_decoder_num_classes], + 'onnx': [set_decoder_num_classes], + }, + 'EncDecCTCModelBPE': { + 'default': [bpe_check_inputs_and_version], + 'onnx': [bpe_check_inputs_and_version], + }, + 'EncDecMultiTaskModel': { + 'default': [config_for_trtllm], + 'nemo': [], + }, + 'MTEncDecModel': { + 'default': [change_tokenizer_names], + 'onnx': [change_tokenizer_names], + }, + 'FastPitchModel': { + 'default': [generate_vocab_mapping, fastpitch_model_versioning], + 'onnx': [generate_vocab_mapping, fastpitch_model_versioning], + }, + 'RadTTSModel': { + 'default': [generate_vocab_mapping, radtts_model_versioning], + 'ts': [generate_vocab_mapping, radtts_model_versioning], + }, } diff --git a/nemo2riva/patches/aed_canary.py b/nemo2riva/patches/aed_canary.py new file mode 100644 index 0000000..6bf8ddf --- /dev/null +++ b/nemo2riva/patches/aed_canary.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +import yaml +import logging + + +def config_for_trtllm(model, artifacts, **kwargs): + if model.__class__.__name__ == 'EncDecMultiTaskModel': + + model_config = yaml.safe_load(artifacts['model_config.yaml']['content']) + + keys_required = [ + 'beam_search', + 'encoder', + 'head', + 'model_defaults', + 'prompt_format', + 'sample_rate', + 'target', + 'preprocessor', + ] + if 'beam_search' not in model_config and 'decoding' in model_config: + model_config['beam_search'] = model_config['decoding'].get('beam', {'beam_size': 1, 'len_pen': 0.0, + 'max_generation_delta': 50} + ) + config = dict({k: model_config[k] for k in keys_required}) + config['decoder'] = { + 'transf_decoder': model_config['transf_decoder'], + 'transf_encoder': model_config['transf_encoder'], + 'vocabulary': make_vocabulary_file(model,artifacts), + 'num_classes': model_config['head']['num_classes'], + 'feat_in': model_config['model_defaults']['asr_enc_hidden'], + 'n_layers': model_config['transf_decoder']['config_dict']['num_layers'], + } + config['target'] = 'trtllm.canary' + + artifacts['model_config.yaml']['content'] = yaml.safe_dump(config, encoding=('utf-8')) + + +def make_vocabulary_file(model, artifacts, **kwargs): + if model.__class__.__name__ == 'EncDecMultiTaskModel': + + tokenizer_vocab = {'tokens': {}, + 'offsets': model.tokenizer.token_id_offset + } + for lang in model.tokenizer.langs: + tokenizer_vocab['tokens'][lang] = {} + tokenizer_vocab['size'] = model.tokenizer.vocab_size + + try: + tokenizer_vocab['bos_id'] = model.tokenizer.bos_id + except Exception as e: + logging.warning(f"Tokenizer is missing bos_id. Could affect accuracy") + + try: + tokenizer_vocab['eos_id'] = model.tokenizer.eos_id + except Exception as e: + logging.warning(f"Tokenizer is missing eos_id. Could affect accuracy") + try: + tokenizer_vocab['nospeech_id'] = model.tokenizer.nospeech_id + except Exception as e: + logging.warning(f"Tokenizer is missing nospeech_id. Could affect accuracy") + try: + tokenizer_vocab['pad_id'] = model.tokenizer.pad_id + except Exception as e: + logging.warning(f"Tokenizer is missing pad_id. Could affect accuracy") + + for t_id in range(0, model.tokenizer.vocab_size): + lang = model.tokenizer.ids_to_lang([t_id]) + tokenizer_vocab['tokens'][lang][t_id] = model.tokenizer.ids_to_tokens([t_id])[0] + + return tokenizer_vocab + diff --git a/nemo2riva/schema.py b/nemo2riva/schema.py index f881c96..c0eb99f 100644 --- a/nemo2riva/schema.py +++ b/nemo2riva/schema.py @@ -17,7 +17,7 @@ schema_dict = None -supported_formats = ["ONNX", "CKPT", "TS", "NEMO"] +supported_formats = ["ONNX", "CKPT", "TS", "NEMO", "PYTORCH", "STATE"] @dataclass @@ -46,17 +46,29 @@ class ImportConfig: def get_export_config(export_obj, args): conf = ExportConfig() need_autocast = False - if export_obj: + if export_obj is not None: conf.export_file = list(export_obj)[0] + attribs = export_obj[conf.export_file] + conf.export_subnet = attribs.get('export_subnet', None) + conf.is_onnx=attribs.get('onnx', False) + + if not conf.is_onnx: + conf.states_only = attribs.get('states_only', False) + conf.is_torch = attribs.get('torch', False) + if conf.export_file.endswith('.onnx'): conf.export_format = "ONNX" elif conf.export_file.endswith('.ts'): conf.export_format = "TS" elif conf.export_file.endswith('.nemo'): conf.export_format = "NEMO" + elif conf.is_torch: + if conf.states_only: + conf.export_format = "STATE" + else: + conf.export_format = "PYTORCH" else: conf.export_format = "CKPT" - attribs = export_obj[conf.export_file] conf.autocast = attribs.get('autocast', False) need_autocast = conf.autocast @@ -66,8 +78,6 @@ def get_export_config(export_obj, args): if conf.encryption and args.key is None: raise Exception(f"{conf.export_file} requires encryption and no key was given") - conf.export_subnet = attribs.get('export_subnet', None) - if args.export_subnet: if conf.export_subnet: raise Exception("Can't combine schema's export_subnet and export-subnet argument!") @@ -83,7 +93,8 @@ def get_export_config(export_obj, args): conf.max_dim = args.max_dim # Optional export format override - if args.format is not None: + if args.format != 'default' and export_obj is None: + # When export_obj is None, the root of the network is exported and the format needs to be overridden. conf.export_format = args.format.upper() conf.export_file = os.path.splitext(conf.export_file)[0] + "." + conf.export_format.lower() @@ -141,27 +152,28 @@ def get_subnet(model, subnet): def load_schemas(): - spec_root = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.dirname(os.path.abspath(__file__)) # Get schema path. - direc = os.path.join(spec_root, "validation_schemas") - ext = '.yaml' + schema_dir = os.path.join(root_dir, 'validation_schemas') global schema_dict - schema_dict = {} # Create an empty dict - - # Select only .yaml files - yaml_files = [os.path.join(direc, i) for i in os.listdir(direc) if os.path.splitext(i)[1] == '.yaml'] - - # Iterate over your txt files - for f in yaml_files: - conf = OmegaConf.load(f) - key = '' - for meta in conf.metadata: - if 'obj_cls' in meta.keys(): - key = meta['obj_cls'] - schema_dict[key] = f - logging.info(f"Loaded schema file {f} for {key}") - + schema_dict = OmegaConf.load(os.path.join(schema_dir, 'index.yaml')) + for key in schema_dict: + for format in schema_dict[key]: + # None means default export. None can be specified in the index YAML to avoid warning logs. + if schema_dict[key][format] is not None: + schema_dict[key][format] = os.path.join(schema_dir, schema_dict[key][format]) + logging.info(f'Indexing validation schema file "{schema_dict[key][format]}" for model "{key}" [{format}]') + +def get_schema_path(key, format=None): + format = 'default' if format is None else format + if key in schema_dict and format in schema_dict[key]: + return schema_dict[key][format] + return None + +def is_schema_exists(key, format=None): + format = 'default' if format is None else format + return key in schema_dict and format in schema_dict[key] def get_exports(schema_path): # Load the schema. @@ -183,33 +195,33 @@ def get_exports(schema_path): def get_import_config(model, args): - - # Explicit schema name passed in args - schema = args.schema - if schema_dict is None: load_schemas() - # create config object with default values (ONNX) - conf = ImportConfig() - key = get_schema_key(model) - - # - # Now check if there is a schema defined for target model class - # - if schema is None and key in schema_dict: - schema = schema_dict[key] - logging.info("Found validation schema for {} at {}".format(key, schema)) + if args.schema is not None: + # Explicit schema name passed in args + schema = args.schema + else: + key = get_schema_key(model) + format = args.format + if is_schema_exists(key, format=format): + schema = get_schema_path(key, format=format) + logging.info(f'Using validation schema "{schema}" for "{key}" [{format}]') + else: + logging.warning( + f'Validation schema not found for "{key}" [{format}]\n' + + 'Riva does not guarantee support for this network and likely will not work with it.' + ) + schema = None if schema is None: - logging.warning( - "Validation schema not found for {}.\n".format(key) - + "That means Riva does not yet support a pipeline for this network and likely will not work with it." - ) exports = [None] else: exports = get_exports(schema) + # create config object with default values (ONNX) + conf = ImportConfig() + conf.exports = [get_export_config(export_obj, args) for export_obj in exports] conf.validation_schema = schema diff --git a/nemo2riva/scripts/__init__.py b/nemo2riva/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nemo2riva/validation_schemas/asr-scr-exported-aedmodel.yaml b/nemo2riva/validation_schemas/asr-scr-exported-aedmodel.yaml new file mode 100644 index 0000000..10685e1 --- /dev/null +++ b/nemo2riva/validation_schemas/asr-scr-exported-aedmodel.yaml @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +# Define required metadata fields expected in the archive (optional). +metadata: + - obj_cls: nemo.collections.asr.models.EncDecMultiTaskModel + + +# Define list of files that are expected (optional). +artifact_properties: + # List of files. + - model_config.yaml + - encoder.onnx: + export_subnet: encoder + onnx: True + - decoder.pt: + export_subnet: transf_decoder + states_only: True + torch: True + onnx: False + - log_softmax.pt: + export_subnet: log_softmax + states_only: True + torch: True + onnx: False + - encoder_decoder_proj.pt: + export_subnet: encoder_decoder_proj + torch: False + onnx: False + + +# Define list of files with expected content (optional). +# Functionality limited to yaml files (e.g. model_config.yaml). +artifact_content: + # List of files. + - model_config.yaml: + # List of sections.subsections. ... that are required. + # (Optional `: True` instructs to check the presence of the file in indicated as leaf in the archive) + - transf_decoder + - transf_encoder + - vocabulary + - num_classes + - feat_in + - n_layers + - target + - beam_search + - encoder + - head + - model_defaults + - prompt_format + - sample_rate + - target + - preprocessor diff --git a/nemo2riva/validation_schemas/index.yaml b/nemo2riva/validation_schemas/index.yaml new file mode 100644 index 0000000..504d423 --- /dev/null +++ b/nemo2riva/validation_schemas/index.yaml @@ -0,0 +1,56 @@ +# Specify the conversion schema for each Nemo model class +nemo.collections.asr.models.EncDecMultiTaskModel: + default: asr-scr-exported-aedmodel.yaml + nemo: null # Default export + +nemo.collections.asr.models.classification_models.EncDecClassificationModel: + default: asr-scr-exported-encdecclsmodel.yaml + onnx: asr-scr-exported-encdecclsmodel.yaml + +nemo.collections.asr.models.EncDecCTCModel: + default: asr-stt-exported-encdecctcmodel.yaml + onnx: asr-stt-exported-encdecctcmodel.yaml + +nemo.collections.asr.models.EncDecCTCModelBPE: + default: asr-stt-exported-encdectcmodelbpe.yaml + onnx: asr-stt-exported-encdectcmodelbpe.yaml + +nemo.collections.nlp.models.IntentSlotClassificationModel: + default: nlp-isc-exported-bert.yaml + onnx: nlp-isc-exported-bert.yaml + +nemo.collections.nlp.models.MTEncDecModel: + default: nlp-mt-exported-encdecmtmodel.yaml + onnx: nlp-mt-exported-encdecmtmodel.yaml + +nemo.collections.nlp.models.MegatronNMTModel: + default: nlp-mt-exported-megatronnmtmodel.yaml + onnx: nlp-mt-exported-megatronnmtmodel.yaml + +nemo.collections.nlp.models.PunctuationCapitalizationModel: + default: nlp-pc-exported-bert.yaml + onnx: nlp-pc-exported-bert.yaml + +nemo.collections.nlp.models.QAModel: + default: nlp-qa-exported-bert.yaml + onnx: nlp-qa-exported-bert.yaml + +nemo.collections.nlp.models.TextClassificationModel: + default: nlp-tc-exported-bert.yaml + onnx: nlp-tc-exported-bert.yaml + +nemo.collections.nlp.models.TokenClassificationModel: + default: nlp-tkc-exported-bert.yaml + onnx: nlp-tkc-exported-bert.yaml + +nemo.collections.tts.models.FastPitchModel: + default: tts-exported-fastpitchmodel.yaml + onnx: tts-exported-fastpitchmodel.yaml + +nemo.collections.tts.models.HifiGanModel: + default: tts-exported-hifiganmodel.yaml + onnx: tts-exported-hifiganmodel.yaml + +nemo.collections.tts.models.RadTTSModel: + default: tts-exported-radttsmodel.yaml + ts: tts-exported-radttsmodel.yaml \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index dffdec7..f369637 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,11 @@ # SPDX-License-Identifier: MIT nemo_toolkit>=1.6.0 +torch>=2.5.0 nvidia-eff>=0.6.4 nvidia-eff-tao-encryption>=0.1.8 nvidia-pyindex==1.0.6 -onnx==1.14.1 -onnxruntime==1.16.3 -onnxruntime-gpu==1.16.3 -onnx-graphsurgeon==0.3.27 +onnx>=1.17.0 +onnxruntime>=1.17.0 +onnxruntime-gpu>=1.17.0 +onnx-graphsurgeon>=0.3.27