diff --git a/README.md b/README.md
index aa6d9d9..d107f95 100644
--- a/README.md
+++ b/README.md
@@ -131,6 +131,14 @@ For additional information and usage, run:
nemo2riva --key tlt_encode --out <path to save .riva model> --onnx-opset 14 <path to .nemo model> |
+ | MagpieTTS Decoder |
+ nemo2riva <model.ckpt path$gt; --load_ckpt --model_config <hparams_file$gt; --audio_codecpath <codec .nemo ckpt> --key tlt_encode --out magpie_decoder.riva --submodel decoder |
+
+
+ | MagpieTTS Encoder |
+ nemo2riva <model.ckpt path> --load_ckpt --model_config <hparams_file> --audio_codecpath <codec .nemo ckpt> --key tlt_encode --out magpie_encoder.riva --submodel encoder |
+
+
| Voice Activity Detection |
Segment VAD |
nemo2riva --key tlt_encode --out <path to save .riva model> --onnx-opset 18 <path to .nemo model> |
diff --git a/nemo2riva/args.py b/nemo2riva/args.py
index bdee602..5a52505 100644
--- a/nemo2riva/args.py
+++ b/nemo2riva/args.py
@@ -10,8 +10,12 @@ def get_args(argv):
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=f"Convert NeMo models to Riva EFF input format",
)
- parser.add_argument("source", help="Source .nemo file")
+ parser.add_argument("source", help="Source .nemo or ckpt file")
parser.add_argument("--out", default=None, help="Location to write resulting Riva EFF input to")
+ parser.add_argument("--load_ckpt", action="store_true", help="Load using checkpoint instead of .nemo file")
+ parser.add_argument("--submodel", default="decoder", help="Submodel to export. Default is decoder for MagpieTTSModel.")
+ parser.add_argument("--model_config", default=None, help="Hparams file")
+ parser.add_argument("--audio_codecpath", default=None, help="Audiocodec path. Needed only for magpietts models.")
parser.add_argument("--validate", action="store_true", help="Validate using schemas")
parser.add_argument("--runtime-check", action="store_true", help="Runtime check of exported net result")
parser.add_argument("--schema", default=None, help="Schema file to use for validation")
diff --git a/nemo2riva/artifacts.py b/nemo2riva/artifacts.py
index 1e750db..49811e1 100644
--- a/nemo2riva/artifacts.py
+++ b/nemo2riva/artifacts.py
@@ -77,6 +77,36 @@ def retrieve_artifacts_as_dict(restore_path: str, obj: Optional["ModelPT"] = Non
logging.error(f"Could not retrieve the artifact {file_key} at {member.name}. Error occured:\n{tb}")
return artifacts
+def retrieve_artifacts_as_dict_from_ckpt(ckpt: dict, model_cfg: dict, obj: Optional["ModelPT"] = None):
+ """ Retrieves all NeMo artifacts and returns them as dict
+ Args:
+ ckpt: dict containing the checkpoint
+ model_cfg: dict containing the model config
+ obj: ModelPT object (Optional, DEFAULT: None)
+ """
+ artifacts = {}
+ ## ckpt
+ #f = open(ckpt, "rb")
+ #artifact_content = f.read()
+ #aname = "model.ckpt"
+ #artifacts[aname] = {
+ # "conf_path": aname,
+ # "path_type": "TAR_PATH",
+ # "content": artifact_content,
+ #}
+ #f.close()
+
+ f = open(model_cfg, "rb")
+ artifact_content = f.read()
+ aname = "model_config.yaml"
+ artifacts[aname] = {
+ "conf_path": aname,
+ "path_type": "TAR_PATH",
+ "content": artifact_content,
+ }
+ f.close()
+ return artifacts
+
def create_artifact(reg, key, do_encrypt, **af_dict):
# only works for plain content now - no encryption in Nemo
@@ -93,8 +123,11 @@ def create_artifact(reg, key, do_encrypt, **af_dict):
return af
-def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs):
- artifacts = retrieve_artifacts_as_dict(obj=model, restore_path=restore_path)
+def get_artifacts(restore_path: str, model=None, passphrase=None, model_cfg=None, from_ckpt=False, **patch_kwargs):
+ if not from_ckpt:
+ artifacts = retrieve_artifacts_as_dict(obj=model, restore_path=restore_path)
+ else:
+ artifacts = retrieve_artifacts_as_dict_from_ckpt(ckpt=restore_path, model_cfg=model_cfg, obj=model)
# NOTE: when servicemaker calls into get_artifacts, model is always None so this code section
# is never run.
@@ -119,7 +152,7 @@ def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs
nemo_manifest = {'files': artifacts, 'metadata': {'format_version': 1}}
if 'model_config.yaml' in artifacts.keys():
nemo_manifest['has_nemo_config'] = True
-
+
nemo_files = nemo_manifest['files']
nemo_metadata = nemo_manifest['metadata']
reg = ArtifactRegistry(passphrase=passphrase)
@@ -134,4 +167,4 @@ def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs
create_artifact(reg, key, False, content_callback=cb_override, **af_dict)
logging.info(f"Retrieved artifacts: {artifacts.keys()}")
- return reg, nemo_manifest
+ return reg, nemo_manifest
\ No newline at end of file
diff --git a/nemo2riva/convert.py b/nemo2riva/convert.py
index 7221173..19661ec 100644
--- a/nemo2riva/convert.py
+++ b/nemo2riva/convert.py
@@ -12,14 +12,16 @@
from nemo.core import ModelPT
from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import logging
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, open_dict
from lightning.pytorch import Trainer
+
from nemo2riva.artifacts import get_artifacts
from nemo2riva.cookbook import export_model, save_archive
from nemo2riva.schema import get_import_config, get_subnet, validate_archive
+
def Nemo2Riva(args):
"""Convert a .nemo saved model into .riva Riva input format."""
nemo_in = args.source
@@ -48,8 +50,37 @@ def Nemo2Riva(args):
try:
with torch.inference_mode():
- # Restore instance from .nemo file using generic model restore_from
- model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
+ if args.load_ckpt:
+ if not args.model_config:
+ raise ValueError("Hparams file is required when loading from checkpoint")
+ model_cfg = OmegaConf.load(args.model_config)
+ ckpt = torch.load(nemo_in, weights_only=False)
+ if "state_dict" in ckpt.keys():
+ ckpt = ckpt["state_dict"]
+
+ if "cfg" in model_cfg:
+ model_cfg = model_cfg.cfg
+ with open_dict(model_cfg):
+ if model_cfg.target.split(".")[-1] == "MagpieTTSModel":
+ from nemo2riva.patches.tts.magpietts import update_config, update_ckpt
+ from nemo.collections.tts.models.magpietts import MagpieTTSModel
+ legacy_codebooks = False
+ if not args.audio_codecpath:
+ raise ValueError("Audio codec path is required when loading from checkpoint for MagpieTTSModel.")
+ model_cfg = update_config(model_cfg, args.audio_codecpath, legacy_codebooks)
+ state_dict = update_ckpt(ckpt)
+
+ model = MagpieTTSModel(cfg=model_cfg)
+ model.load_state_dict(state_dict)
+ model.cuda()
+ model.eval()
+ model = model.half()
+ else:
+ model = ModelPT(cfg=model_cfg)
+ model.load_state_dict(ckpt)
+ else:
+ # Restore instance from .nemo file using generic model restore_from
+ model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
except Exception as e:
logging.error(
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
@@ -78,9 +109,11 @@ def Nemo2Riva(args):
warnings.filterwarnings('ignore', category=UserWarning)
# TODO: revisit export_subnet cli arg
patch_kwargs = {"import_config" : cfg}
+ if model.__class__.__name__ == "MagpieTTSModel":
+ patch_kwargs['is_encoder'] = args.submodel == "encoder"
if args.export_subnet:
patch_kwargs['export_subnet'] = args.export_subnet
- artifacts, manifest = get_artifacts(restore_path=nemo_in, model=model, passphrase=key, **patch_kwargs)
+ artifacts, manifest = get_artifacts(restore_path=nemo_in, model=model, passphrase=key, model_cfg=args.model_config, from_ckpt=args.load_ckpt, **patch_kwargs)
for export_cfg in cfg.exports:
subnet = get_subnet(model, export_cfg.export_subnet)
diff --git a/nemo2riva/cookbook.py b/nemo2riva/cookbook.py
index a583ef1..1d15b2c 100644
--- a/nemo2riva/cookbook.py
+++ b/nemo2riva/cookbook.py
@@ -64,9 +64,9 @@ def export_model(model, cfg, args, artifacts, metadata):
export_filename = cfg.export_file
export_file = os.path.join(tmpdir, export_filename)
- if cfg.export_format in ["ONNX", "TS"]:
+ if cfg.export_format in ["ONNX", "TS"] or (model.__class__.__name__ == "MagpieTTSModel" and args.submodel == "encoder"):
# Export the model, get the descriptions.
- if not isinstance(model, Exportable):
+ if not isinstance(model, Exportable) and not model.__class__.__name__ == "MagpieTTSModel":
logging.error("Your NeMo model class ({}) is not Exportable.".format(metadata['obj_cls']))
sys.exit(1)
@@ -86,15 +86,38 @@ def export_model(model, cfg, args, artifacts, metadata):
if cfg.max_dim is not None:
in_args["max_dim"] = cfg.max_dim
- input_example = model.input_module.input_example(**in_args)
- _, descriptions = model.export(
- export_file,
- input_example=input_example,
- check_trace=args.runtime_check,
- onnx_opset_version=args.onnx_opset,
- verbose=bool(args.verbose),
- )
- del model
+ if model.__class__.__name__ == "MagpieTTSModel" and cfg.export_format == "ONNX":
+ from nemo2riva.patches.tts.magpietts import EncoderOnnxModel
+ with torch.no_grad():
+ model.eval()
+ model = model.half()
+ encoder_model = EncoderOnnxModel(model)
+ input_args, dynamic_axes, output_names, input_names = encoder_model._prepare_for_export()
+
+ torch.onnx.export(encoder_model,
+ input_args,
+ export_file,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ opset_version=17)
+
+ enc_gs = gs.import_onnx(onnx.load(export_file))
+ outputs = enc_gs.outputs
+ fix_outputs = [outputs[0]]
+ enc_gs.outputs = fix_outputs
+ onnx.save(gs.export_onnx(enc_gs), export_file)
+ del model, encoder_model
+ else:
+ input_example = model.input_module.input_example(**in_args)
+ _, descriptions = model.export(
+ export_file,
+ input_example=input_example,
+ check_trace=args.runtime_check,
+ onnx_opset_version=args.onnx_opset,
+ verbose=bool(args.verbose),
+ )
+ del model
if cfg.export_format == 'ONNX':
o_list = os.listdir(tmpdir)
save_as_external_data = len(o_list) > 1
diff --git a/nemo2riva/patches/__init__.py b/nemo2riva/patches/__init__.py
index a70d4ce..ec63d6f 100644
--- a/nemo2riva/patches/__init__.py
+++ b/nemo2riva/patches/__init__.py
@@ -4,7 +4,7 @@
from nemo2riva.patches.ctc import set_decoder_num_classes
from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version
from nemo2riva.patches.mtencdec import change_tokenizer_names
-from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning
+from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning, magpietts_model_versioning
patches = {
"EncDecCTCModel": [set_decoder_num_classes],
@@ -12,4 +12,5 @@
"MTEncDecModel": [change_tokenizer_names],
"FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning],
"RadTTSModel": [generate_vocab_mapping, radtts_model_versioning],
+ "MagpieTTSModel": [magpietts_model_versioning],
}
diff --git a/nemo2riva/patches/tts/__init__.py b/nemo2riva/patches/tts/__init__.py
index d6c851c..d9e7ca0 100644
--- a/nemo2riva/patches/tts/__init__.py
+++ b/nemo2riva/patches/tts/__init__.py
@@ -4,9 +4,11 @@
from nemo2riva.patches.tts.fastpitch import fastpitch_model_versioning
from nemo2riva.patches.tts.general import generate_vocab_mapping
from nemo2riva.patches.tts.radtts import radtts_model_versioning
+from nemo2riva.patches.tts.magpietts import magpietts_model_versioning
__all__ = [
fastpitch_model_versioning,
generate_vocab_mapping,
- radtts_model_versioning
+ radtts_model_versioning,
+ magpietts_model_versioning
]
diff --git a/nemo2riva/patches/tts/magpieTTS_README.md b/nemo2riva/patches/tts/magpieTTS_README.md
new file mode 100644
index 0000000..f6023ad
--- /dev/null
+++ b/nemo2riva/patches/tts/magpieTTS_README.md
@@ -0,0 +1,7 @@
+# Decoder
+Command to export decoder:
+`nemo2riva --load_ckpt --model_config --audio_codecpath --key tlt_encode --out magpie_decoder.riva --submodel decoder`
+
+# Encoder
+Command to export encoder:
+`nemo2riva --load_ckpt --model_config --audio_codecpath --key tlt_encode --out magpie_encoder.riva --submodel encoder`
diff --git a/nemo2riva/patches/tts/magpietts.py b/nemo2riva/patches/tts/magpietts.py
new file mode 100644
index 0000000..18de43a
--- /dev/null
+++ b/nemo2riva/patches/tts/magpietts.py
@@ -0,0 +1,140 @@
+# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import random
+import nemo
+import torch
+import yaml
+from nemo.core.neural_types.neural_type import NeuralType
+from packaging.version import Version
+from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths
+from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors
+
+
+def update_ckpt(state_dict):
+ new_state_dict = {}
+ for key in state_dict.keys():
+ if 't5_encoder' in key:
+ new_key = key.replace('t5_encoder', 'encoder')
+ new_state_dict[new_key] = state_dict[key]
+ elif 't5_decoder' in key:
+ new_key = key.replace('t5_decoder', 'decoder')
+ new_state_dict[new_key] = state_dict[key]
+ else:
+ new_state_dict[key] = state_dict[key]
+ return new_state_dict
+
+
+def update_config(model_cfg, codecmodel_path, legacy_codebooks=False):
+ ''' helper function to rename older yamls from t5 to magpie '''
+ model_cfg.codecmodel_path = codecmodel_path
+ if hasattr(model_cfg, 'text_tokenizer'):
+ # Backward compatibility for models trained with absolute paths in text_tokenizer
+ model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
+ model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722"
+ model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0
+ model_cfg.train_ds = None
+ model_cfg.validation_ds = None
+ if "t5_encoder" in model_cfg:
+ model_cfg.encoder = model_cfg.t5_encoder
+ del model_cfg.t5_encoder
+ if "t5_decoder" in model_cfg:
+ model_cfg.decoder = model_cfg.t5_decoder
+ del model_cfg.t5_decoder
+ if hasattr(model_cfg, 'decoder') and hasattr(model_cfg.decoder, 'prior_eps'):
+ # Added to prevent crash after removing arg from transformer_2501.py in https://github.com/blisc/NeMo/pull/56
+ del model_cfg.decoder.prior_eps
+ if legacy_codebooks:
+ # Added to address backward compatibility arising from
+ # https://github.com/blisc/NeMo/pull/64
+ print("WARNING: Using legacy codebook indices for backward compatibility. Should only be used with old checkpoints.")
+ num_audio_tokens_per_codebook = model_cfg.num_audio_tokens_per_codebook
+ model_cfg.forced_num_all_tokens_per_codebook = num_audio_tokens_per_codebook
+ model_cfg.forced_audio_eos_id = num_audio_tokens_per_codebook - 1
+ model_cfg.forced_audio_bos_id = num_audio_tokens_per_codebook - 2
+ if model_cfg.model_type == 'decoder_context_tts':
+ model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 3
+ model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 4
+ model_cfg.forced_mask_token_id = num_audio_tokens_per_codebook - 5
+ else:
+ model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 1
+ model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 2
+
+ return model_cfg
+
+
+class EncoderOnnxModel(torch.nn.Module):
+ def __init__(self, model, tokenizer_name="english_phoneme"):
+ super().__init__()
+ model = model.eval().half()
+ self.tokenizer_name=tokenizer_name
+ self.tokenizer=model.tokenizer
+ self.bos_id=model.bos_id
+ self.eos_id=model.eos_id
+ self.text_embedding=model.text_embedding
+ self.encoder=model.encoder
+
+
+ def forward(self,tokens, token_mask):
+ emb_text=self.text_embedding(tokens)
+ output=self.encoder(emb_text, token_mask, None, None, None, None, None)
+ return output
+
+ def _prepare_for_export(self):
+ text = "Hello world! How are you doing today?"
+ n_batches = 2
+ text_encoding = [self.bos_id] + self.tokenizer.encode(text, self.tokenizer_name) + [self.eos_id]
+ text_encoding = torch.IntTensor([text_encoding for _ in range(n_batches)]).cuda()
+
+ text_lens = torch.IntTensor([text_encoding.shape[1] for _ in range(n_batches)]).cuda()
+ max_text_len = torch.max(text_lens).item()
+ text_mask = get_mask_from_lengths(text_lens).cuda() # (B, T)
+
+ dummy_output = self(text_encoding, text_mask)
+
+ input_names = ["tokens", "token_mask"]
+ output_names = ["output"]
+ dynamic_axes = {
+ "tokens": {
+ 0: "batch_size",
+ 1: "n_texts"
+ },
+ "token_mask": {
+ 0: "batch_size",
+ 1: "n_texts"
+ }
+ }
+ inputs_args = {
+ 'tokens': text_encoding,
+ 'token_mask': text_mask,
+
+ }
+ return inputs_args, dynamic_axes, output_names, input_names
+
+
+def magpietts_model_versioning(model, artifacts, **kwargs):
+ # Patch for generating magpieTTS model versions
+ try:
+ nemo_version = Version(nemo.__version__)
+ except NameError:
+ # If can't find the nemo version, return without patching
+ return None
+
+ # Don't override built-in format
+ # export_format is read from schemas, radtts is still currently torchscript in the schema
+ format_= kwargs['import_config'].exports[0].export_format
+
+ # Patch the model config yaml to add the volume and ragged batch flags
+ for art in artifacts:
+ from nemo.collections.tts.modules.magpietts_modules import SpecialAudioToken
+ if art == 'model_config.yaml':
+ model_config = yaml.safe_load(artifacts['model_config.yaml']['content'])["cfg"]
+ model_config["target"] = "nemo.collections.tts.models.t5tts.T5TTS_Model"
+ if kwargs['is_encoder']:
+ model_config["target"] = model_config["target"] + ".text_encoder"
+ else:
+ model_config['num_audio_tokens_per_codebook'] = model.num_all_tokens_per_codebook # - len(SpecialAudioToken)
+ model_config['num_audio_codebooks'] = model.num_audio_codebooks
+ model_config["export_config"] = {'enable_ragged_batches': True}
+ artifacts['model_config.yaml']['content'] = yaml.dump(model_config).encode()
+
\ No newline at end of file
diff --git a/nemo2riva/schema.py b/nemo2riva/schema.py
index f881c96..391e850 100644
--- a/nemo2riva/schema.py
+++ b/nemo2riva/schema.py
@@ -209,6 +209,10 @@ def get_import_config(model, args):
exports = [None]
else:
exports = get_exports(schema)
+ if key == "nemo.collections.tts.models.MagpieTTSModel" and args.submodel == "encoder":
+ exports = [
+ {'model_graph.onnx': {'onnx': True, 'autocast': True, 'encryption': True}},
+ ]
conf.exports = [get_export_config(export_obj, args) for export_obj in exports]
diff --git a/nemo2riva/validation_schemas/tts-exported-magpiemodel.yaml b/nemo2riva/validation_schemas/tts-exported-magpiemodel.yaml
new file mode 100644
index 0000000..b02f53b
--- /dev/null
+++ b/nemo2riva/validation_schemas/tts-exported-magpiemodel.yaml
@@ -0,0 +1,25 @@
+# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+# Define required metadata fields expeced in the archive (optional).
+metadata:
+ - obj_cls: nemo.collections.tts.models.MagpieTTSModel
+ - min_nemo_version: 1.1
+
+# Define list of files that are expected (optional).
+artifact_properties:
+ # List of files.
+ - model_config.yaml
+ - model.ckpt:
+ # Dictionary of expected properties (name:value) (optional).
+ onnx: False
+ torch: True
+ states_only: True
+
+# Define list of files with expected content (optional).
+# Functionality limited to yaml files (e.g. model_config.yaml).
+artifact_content:
+ # List of files.
+ - model_config.yaml:
+ # List of sections.subsections. ... that are required.
+