Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo2riva/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_args(argv):
parser.add_argument("--validate", action="store_true", help="Validate using schemas")
parser.add_argument("--runtime-check", action="store_true", help="Runtime check of exported net result")
parser.add_argument("--schema", default=None, help="Schema file to use for validation")
parser.add_argument("--format", default=None, help="Force specific export format: ONNX|TS|CKPT|NEMO")
parser.add_argument("--format", default='default', help="Force specific export format: ONNX|TS|CKPT|NEMO")
parser.add_argument("--verbose", default=None, help="Verbose level for logging")
parser.add_argument("--key", default=None, help="Encryption key or file, default is None")
parser.add_argument(
Expand Down
5 changes: 3 additions & 2 deletions nemo2riva/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ def create_artifact(reg, key, do_encrypt, **af_dict):
return af


def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs):
def get_artifacts(restore_path: str, model=None, passphrase=None, format='default', **patch_kwargs):
artifacts = retrieve_artifacts_as_dict(obj=model, restore_path=restore_path)

# NOTE: when servicemaker calls into get_artifacts, model is always None so this code section
# is never run.
# check if this model has one or more patches to apply, if yes go ahead and run it
if model is not None and _HAVE_PATCHES and model.__class__.__name__ in patches:
for patch in patches[model.__class__.__name__]:
# Apply patches for the given format.
for patch in patches[model.__class__.__name__].get(format, []):
patch(model, artifacts, **patch_kwargs)
elif model is not None and not _HAVE_PATCHES:
logging.error(
Expand Down
8 changes: 5 additions & 3 deletions nemo2riva/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import logging
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from lightning.pytorch import Trainer

from nemo2riva.artifacts import get_artifacts
from nemo2riva.cookbook import export_model, save_archive
Expand Down Expand Up @@ -64,7 +64,7 @@ def Nemo2Riva(args):
key = None
if args.key is not None:
try:
with open(args.key, read_mode) as f:
with open(args.key, 'r') as f:
key = f.read()
except Exception:
# literal key
Expand All @@ -80,7 +80,9 @@ def Nemo2Riva(args):
patch_kwargs = {"import_config" : cfg}
if args.export_subnet:
patch_kwargs['export_subnet'] = args.export_subnet
artifacts, manifest = get_artifacts(restore_path=nemo_in, model=model, passphrase=key, **patch_kwargs)
artifacts, manifest = get_artifacts(
restore_path=nemo_in, model=model, passphrase=key, format=args.format, **patch_kwargs
)

for export_cfg in cfg.exports:
subnet = get_subnet(model, export_cfg.export_subnet)
Expand Down
11 changes: 11 additions & 0 deletions nemo2riva/cookbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def export_model(model, cfg, args, artifacts, metadata):
format_meta = {"has_pytorch_checkpoint": True, "runtime": "PyTorch"}
elif cfg.export_format == "NEMO":
format_meta = {"has_pytorch_checkpoint": True, "runtime": "Python"}
elif cfg.export_format == "STATE":
format_meta = {"has_pytorch_checkpoint": False, "runtime": "Python"}
# TODO: use submodel sections
metadata.update(format_meta)
runtime = format_meta["runtime"]
Expand Down Expand Up @@ -140,6 +142,15 @@ def export_model(model, cfg, args, artifacts, metadata):

elif cfg.export_format == "NEMO":
model.save_to(export_file)
elif cfg.export_format == "STATE":
if not isinstance(model, Exportable):
logging.error("Your NeMo model class ({}) is not Exportable.".format(metadata['obj_cls']))
sys.exit(1)
model.freeze()
model_params = model.state_dict()
torch.save(model_params, export_file)



# Add exported file to the artifact registry

Expand Down
30 changes: 25 additions & 5 deletions nemo2riva/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,33 @@

from nemo2riva.patches.ctc import set_decoder_num_classes
from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version
from nemo2riva.patches.aed_canary import config_for_trtllm
from nemo2riva.patches.mtencdec import change_tokenizer_names
from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning

patches = {
"EncDecCTCModel": [set_decoder_num_classes],
"EncDecCTCModelBPE": [bpe_check_inputs_and_version],
"MTEncDecModel": [change_tokenizer_names],
"FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning],
"RadTTSModel": [generate_vocab_mapping, radtts_model_versioning],
'EncDecCTCModel': {
'default': [set_decoder_num_classes],
'onnx': [set_decoder_num_classes],
},
'EncDecCTCModelBPE': {
'default': [bpe_check_inputs_and_version],
'onnx': [bpe_check_inputs_and_version],
},
'EncDecMultiTaskModel': {
'default': [config_for_trtllm],
'nemo': [],
},
'MTEncDecModel': {
'default': [change_tokenizer_names],
'onnx': [change_tokenizer_names],
},
'FastPitchModel': {
'default': [generate_vocab_mapping, fastpitch_model_versioning],
'onnx': [generate_vocab_mapping, fastpitch_model_versioning],
},
'RadTTSModel': {
'default': [generate_vocab_mapping, radtts_model_versioning],
'ts': [generate_vocab_mapping, radtts_model_versioning],
},
}
74 changes: 74 additions & 0 deletions nemo2riva/patches/aed_canary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

import yaml
import logging


def config_for_trtllm(model, artifacts, **kwargs):
if model.__class__.__name__ == 'EncDecMultiTaskModel':

model_config = yaml.safe_load(artifacts['model_config.yaml']['content'])

keys_required = [
'beam_search',
'encoder',
'head',
'model_defaults',
'prompt_format',
'sample_rate',
'target',
'preprocessor',
]
if 'beam_search' not in model_config and 'decoding' in model_config:
model_config['beam_search'] = model_config['decoding'].get('beam', {'beam_size': 1, 'len_pen': 0.0,
'max_generation_delta': 50}
)
config = dict({k: model_config[k] for k in keys_required})
config['decoder'] = {
'transf_decoder': model_config['transf_decoder'],
'transf_encoder': model_config['transf_encoder'],
'vocabulary': make_vocabulary_file(model,artifacts),
'num_classes': model_config['head']['num_classes'],
'feat_in': model_config['model_defaults']['asr_enc_hidden'],
'n_layers': model_config['transf_decoder']['config_dict']['num_layers'],
}
config['target'] = 'trtllm.canary'

artifacts['model_config.yaml']['content'] = yaml.safe_dump(config, encoding=('utf-8'))


def make_vocabulary_file(model, artifacts, **kwargs):
if model.__class__.__name__ == 'EncDecMultiTaskModel':

tokenizer_vocab = {'tokens': {},
'offsets': model.tokenizer.token_id_offset
}
for lang in model.tokenizer.langs:
tokenizer_vocab['tokens'][lang] = {}
tokenizer_vocab['size'] = model.tokenizer.vocab_size

try:
tokenizer_vocab['bos_id'] = model.tokenizer.bos_id
except Exception as e:
logging.warning(f"Tokenizer is missing bos_id. Could affect accuracy")

try:
tokenizer_vocab['eos_id'] = model.tokenizer.eos_id
except Exception as e:
logging.warning(f"Tokenizer is missing eos_id. Could affect accuracy")
try:
tokenizer_vocab['nospeech_id'] = model.tokenizer.nospeech_id
except Exception as e:
logging.warning(f"Tokenizer is missing nospeech_id. Could affect accuracy")
try:
tokenizer_vocab['pad_id'] = model.tokenizer.pad_id
except Exception as e:
logging.warning(f"Tokenizer is missing pad_id. Could affect accuracy")

for t_id in range(0, model.tokenizer.vocab_size):
lang = model.tokenizer.ids_to_lang([t_id])
tokenizer_vocab['tokens'][lang][t_id] = model.tokenizer.ids_to_tokens([t_id])[0]

return tokenizer_vocab

96 changes: 54 additions & 42 deletions nemo2riva/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

schema_dict = None

supported_formats = ["ONNX", "CKPT", "TS", "NEMO"]
supported_formats = ["ONNX", "CKPT", "TS", "NEMO", "PYTORCH", "STATE"]


@dataclass
Expand Down Expand Up @@ -46,17 +46,29 @@ class ImportConfig:
def get_export_config(export_obj, args):
conf = ExportConfig()
need_autocast = False
if export_obj:
if export_obj is not None:
conf.export_file = list(export_obj)[0]
attribs = export_obj[conf.export_file]
conf.export_subnet = attribs.get('export_subnet', None)
conf.is_onnx=attribs.get('onnx', False)

if not conf.is_onnx:
conf.states_only = attribs.get('states_only', False)
conf.is_torch = attribs.get('torch', False)

if conf.export_file.endswith('.onnx'):
conf.export_format = "ONNX"
elif conf.export_file.endswith('.ts'):
conf.export_format = "TS"
elif conf.export_file.endswith('.nemo'):
conf.export_format = "NEMO"
elif conf.is_torch:
if conf.states_only:
conf.export_format = "STATE"
else:
conf.export_format = "PYTORCH"
else:
conf.export_format = "CKPT"
attribs = export_obj[conf.export_file]
conf.autocast = attribs.get('autocast', False)
need_autocast = conf.autocast

Expand All @@ -66,8 +78,6 @@ def get_export_config(export_obj, args):
if conf.encryption and args.key is None:
raise Exception(f"{conf.export_file} requires encryption and no key was given")

conf.export_subnet = attribs.get('export_subnet', None)

if args.export_subnet:
if conf.export_subnet:
raise Exception("Can't combine schema's export_subnet and export-subnet argument!")
Expand All @@ -83,7 +93,8 @@ def get_export_config(export_obj, args):
conf.max_dim = args.max_dim

# Optional export format override
if args.format is not None:
if args.format != 'default' and export_obj is None:
# When export_obj is None, the root of the network is exported and the format needs to be overridden.
conf.export_format = args.format.upper()
conf.export_file = os.path.splitext(conf.export_file)[0] + "." + conf.export_format.lower()

Expand Down Expand Up @@ -141,27 +152,28 @@ def get_subnet(model, subnet):


def load_schemas():
spec_root = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.abspath(__file__))
# Get schema path.
direc = os.path.join(spec_root, "validation_schemas")
ext = '.yaml'
schema_dir = os.path.join(root_dir, 'validation_schemas')

global schema_dict
schema_dict = {} # Create an empty dict

# Select only .yaml files
yaml_files = [os.path.join(direc, i) for i in os.listdir(direc) if os.path.splitext(i)[1] == '.yaml']

# Iterate over your txt files
for f in yaml_files:
conf = OmegaConf.load(f)
key = ''
for meta in conf.metadata:
if 'obj_cls' in meta.keys():
key = meta['obj_cls']
schema_dict[key] = f
logging.info(f"Loaded schema file {f} for {key}")

schema_dict = OmegaConf.load(os.path.join(schema_dir, 'index.yaml'))
for key in schema_dict:
for format in schema_dict[key]:
# None means default export. None can be specified in the index YAML to avoid warning logs.
if schema_dict[key][format] is not None:
schema_dict[key][format] = os.path.join(schema_dir, schema_dict[key][format])
logging.info(f'Indexing validation schema file "{schema_dict[key][format]}" for model "{key}" [{format}]')

def get_schema_path(key, format=None):
format = 'default' if format is None else format
if key in schema_dict and format in schema_dict[key]:
return schema_dict[key][format]
return None

def is_schema_exists(key, format=None):
format = 'default' if format is None else format
return key in schema_dict and format in schema_dict[key]

def get_exports(schema_path):
# Load the schema.
Expand All @@ -183,33 +195,33 @@ def get_exports(schema_path):


def get_import_config(model, args):

# Explicit schema name passed in args
schema = args.schema

if schema_dict is None:
load_schemas()

# create config object with default values (ONNX)
conf = ImportConfig()
key = get_schema_key(model)

#
# Now check if there is a schema defined for target model class
#
if schema is None and key in schema_dict:
schema = schema_dict[key]
logging.info("Found validation schema for {} at {}".format(key, schema))
if args.schema is not None:
# Explicit schema name passed in args
schema = args.schema
else:
key = get_schema_key(model)
format = args.format
if is_schema_exists(key, format=format):
schema = get_schema_path(key, format=format)
logging.info(f'Using validation schema "{schema}" for "{key}" [{format}]')
else:
logging.warning(
f'Validation schema not found for "{key}" [{format}]\n'
+ 'Riva does not guarantee support for this network and likely will not work with it.'
)
schema = None

if schema is None:
logging.warning(
"Validation schema not found for {}.\n".format(key)
+ "That means Riva does not yet support a pipeline for this network and likely will not work with it."
)
exports = [None]
else:
exports = get_exports(schema)

# create config object with default values (ONNX)
conf = ImportConfig()

conf.exports = [get_export_config(export_obj, args) for export_obj in exports]

conf.validation_schema = schema
Expand Down
Empty file added nemo2riva/scripts/__init__.py
Empty file.
Loading