diff --git a/nemo2riva/cookbook.py b/nemo2riva/cookbook.py index a583ef1..3d5f5d3 100644 --- a/nemo2riva/cookbook.py +++ b/nemo2riva/cookbook.py @@ -48,6 +48,8 @@ def export_model(model, cfg, args, artifacts, metadata): format_meta = {"has_pytorch_checkpoint": True, "runtime": "PyTorch"} elif cfg.export_format == "NEMO": format_meta = {"has_pytorch_checkpoint": True, "runtime": "Python"} + elif cfg.export_format == "STATE": + format_meta = {"has_pytorch_checkpoint": False, "runtime": "Python"} # TODO: use submodel sections metadata.update(format_meta) runtime = format_meta["runtime"] @@ -140,6 +142,15 @@ def export_model(model, cfg, args, artifacts, metadata): elif cfg.export_format == "NEMO": model.save_to(export_file) + elif cfg.export_format == "STATE": + if not isinstance(model, Exportable): + logging.error("Your NeMo model class ({}) is not Exportable.".format(metadata['obj_cls'])) + sys.exit(1) + model.freeze() + model_params = model.state_dict() + torch.save(model_params, export_file) + + # Add exported file to the artifact registry diff --git a/nemo2riva/patches/__init__.py b/nemo2riva/patches/__init__.py index a70d4ce..3d5a747 100644 --- a/nemo2riva/patches/__init__.py +++ b/nemo2riva/patches/__init__.py @@ -3,12 +3,14 @@ from nemo2riva.patches.ctc import set_decoder_num_classes from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version +from nemo2riva.patches.aed_canary import config_for_trtllm from nemo2riva.patches.mtencdec import change_tokenizer_names from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning patches = { "EncDecCTCModel": [set_decoder_num_classes], "EncDecCTCModelBPE": [bpe_check_inputs_and_version], + "EncDecMultiTaskModel": [config_for_trtllm], "MTEncDecModel": [change_tokenizer_names], "FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning], "RadTTSModel": [generate_vocab_mapping, radtts_model_versioning], diff --git a/nemo2riva/patches/aed_canary.py b/nemo2riva/patches/aed_canary.py new file mode 100644 index 0000000..6bf8ddf --- /dev/null +++ b/nemo2riva/patches/aed_canary.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +import yaml +import logging + + +def config_for_trtllm(model, artifacts, **kwargs): + if model.__class__.__name__ == 'EncDecMultiTaskModel': + + model_config = yaml.safe_load(artifacts['model_config.yaml']['content']) + + keys_required = [ + 'beam_search', + 'encoder', + 'head', + 'model_defaults', + 'prompt_format', + 'sample_rate', + 'target', + 'preprocessor', + ] + if 'beam_search' not in model_config and 'decoding' in model_config: + model_config['beam_search'] = model_config['decoding'].get('beam', {'beam_size': 1, 'len_pen': 0.0, + 'max_generation_delta': 50} + ) + config = dict({k: model_config[k] for k in keys_required}) + config['decoder'] = { + 'transf_decoder': model_config['transf_decoder'], + 'transf_encoder': model_config['transf_encoder'], + 'vocabulary': make_vocabulary_file(model,artifacts), + 'num_classes': model_config['head']['num_classes'], + 'feat_in': model_config['model_defaults']['asr_enc_hidden'], + 'n_layers': model_config['transf_decoder']['config_dict']['num_layers'], + } + config['target'] = 'trtllm.canary' + + artifacts['model_config.yaml']['content'] = yaml.safe_dump(config, encoding=('utf-8')) + + +def make_vocabulary_file(model, artifacts, **kwargs): + if model.__class__.__name__ == 'EncDecMultiTaskModel': + + tokenizer_vocab = {'tokens': {}, + 'offsets': model.tokenizer.token_id_offset + } + for lang in model.tokenizer.langs: + tokenizer_vocab['tokens'][lang] = {} + tokenizer_vocab['size'] = model.tokenizer.vocab_size + + try: + tokenizer_vocab['bos_id'] = model.tokenizer.bos_id + except Exception as e: + logging.warning(f"Tokenizer is missing bos_id. Could affect accuracy") + + try: + tokenizer_vocab['eos_id'] = model.tokenizer.eos_id + except Exception as e: + logging.warning(f"Tokenizer is missing eos_id. Could affect accuracy") + try: + tokenizer_vocab['nospeech_id'] = model.tokenizer.nospeech_id + except Exception as e: + logging.warning(f"Tokenizer is missing nospeech_id. Could affect accuracy") + try: + tokenizer_vocab['pad_id'] = model.tokenizer.pad_id + except Exception as e: + logging.warning(f"Tokenizer is missing pad_id. Could affect accuracy") + + for t_id in range(0, model.tokenizer.vocab_size): + lang = model.tokenizer.ids_to_lang([t_id]) + tokenizer_vocab['tokens'][lang][t_id] = model.tokenizer.ids_to_tokens([t_id])[0] + + return tokenizer_vocab + diff --git a/nemo2riva/schema.py b/nemo2riva/schema.py index f881c96..1e08c5b 100644 --- a/nemo2riva/schema.py +++ b/nemo2riva/schema.py @@ -17,7 +17,7 @@ schema_dict = None -supported_formats = ["ONNX", "CKPT", "TS", "NEMO"] +supported_formats = ["ONNX", "CKPT", "TS", "NEMO", "PYTORCH", "STATE"] @dataclass @@ -48,15 +48,27 @@ def get_export_config(export_obj, args): need_autocast = False if export_obj: conf.export_file = list(export_obj)[0] + attribs = export_obj[conf.export_file] + conf.export_subnet = attribs.get('export_subnet', None) + conf.is_onnx=attribs.get('onnx', False) + + if not conf.is_onnx: + conf.states_only = attribs.get('states_only', False) + conf.is_torch = attribs.get('torch', False) + if conf.export_file.endswith('.onnx'): conf.export_format = "ONNX" elif conf.export_file.endswith('.ts'): conf.export_format = "TS" elif conf.export_file.endswith('.nemo'): conf.export_format = "NEMO" + elif conf.is_torch: + if conf.states_only: + conf.export_format = "STATE" + else: + conf.export_format = "PYTORCH" else: conf.export_format = "CKPT" - attribs = export_obj[conf.export_file] conf.autocast = attribs.get('autocast', False) need_autocast = conf.autocast @@ -66,8 +78,6 @@ def get_export_config(export_obj, args): if conf.encryption and args.key is None: raise Exception(f"{conf.export_file} requires encryption and no key was given") - conf.export_subnet = attribs.get('export_subnet', None) - if args.export_subnet: if conf.export_subnet: raise Exception("Can't combine schema's export_subnet and export-subnet argument!") diff --git a/nemo2riva/scripts/__init__.py b/nemo2riva/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nemo2riva/validation_schemas/asr-scr-exported-aedmodel.yaml b/nemo2riva/validation_schemas/asr-scr-exported-aedmodel.yaml new file mode 100644 index 0000000..5a6f627 --- /dev/null +++ b/nemo2riva/validation_schemas/asr-scr-exported-aedmodel.yaml @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +# Define required metadata fields expected in the archive (optional). +metadata: + - obj_cls: nemo.collections.asr.models.EncDecMultiTaskModel + + +# Define list of files that are expected (optional). +artifact_properties: + # List of files. + - model_config.yaml + - encoder.onnx: + export_subnet: encoder + onnx: True + - decoder.pt: + export_subnet: transf_decoder + states_only: True + torch: True + onnx: False + - log_softmax.pt: + export_subnet: log_softmax + states_only: True + torch: True + onnx: False + + +# Define list of files with expected content (optional). +# Functionality limited to yaml files (e.g. model_config.yaml). +artifact_content: + # List of files. + - model_config.yaml: + # List of sections.subsections. ... that are required. + # (Optional `: True` instructs to check the presence of the file in indicated as leaf in the archive) + - transf_decoder + - transf_encoder + - vocabulary + - num_classes + - feat_in + - n_layers + - target + - beam_search + - encoder + - head + - model_defaults + - prompt_format + - sample_rate + - target + - preprocessor diff --git a/requirements.txt b/requirements.txt index dffdec7..f369637 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,11 @@ # SPDX-License-Identifier: MIT nemo_toolkit>=1.6.0 +torch>=2.5.0 nvidia-eff>=0.6.4 nvidia-eff-tao-encryption>=0.1.8 nvidia-pyindex==1.0.6 -onnx==1.14.1 -onnxruntime==1.16.3 -onnxruntime-gpu==1.16.3 -onnx-graphsurgeon==0.3.27 +onnx>=1.17.0 +onnxruntime>=1.17.0 +onnxruntime-gpu>=1.17.0 +onnx-graphsurgeon>=0.3.27