Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: MIT

try-import %workspace%/user.bazelrc
build --show_timestamps --keep_going --color=yes --cxxopt='-std=c++1z' --linkopt='-lstdc++fs'
build --show_timestamps --keep_going --color=yes --cxxopt='-std=c++17' --linkopt='-lstdc++fs'
build --workspace_status_command=scripts/workspace_status.sh

build:debug --compilation_mode dbg
Expand Down
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
5.0.0
6.2.1
14 changes: 8 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# syntax = docker/dockerfile:1.2
ARG BAZEL_VERSION=5.0.0
ARG BAZEL_VERSION=6.2.1
ARG TARGET_ARCH=x86_64 # valid values: x86_64, aarch64
ARG TARGET_OS=linux # valid values: linux, l4t

FROM ubuntu:20.04 AS base
FROM ubuntu:24.04 AS base
ARG DEBIAN_FRONTEND=noninteractive

RUN apt-get update && apt-get install -y \
libasound2 \
libasound2t64 \
libogg0 \
openssl \
ca-certificates

FROM base AS builddep
Expand All @@ -22,7 +23,8 @@ RUN apt-get update && apt-get install -y \
unzip \
build-essential \
libasound2-dev \
libogg-dev
libogg-dev \
libssl-dev

RUN if [ "$TARGET_ARCH" = "aarch64" ] && [ "$TARGET_OS" = "l4t" ]; then \
apt-get update && apt-get install -y --no-install-recommends openjdk-11-jdk-headless; \
Expand All @@ -48,8 +50,8 @@ COPY scripts /work/scripts
COPY third_party /work/third_party
COPY riva /work/riva
ARG BAZEL_CACHE_ARG=""
RUN bazel test $BAZEL_CACHE_ARG --config=${TARGET_OS}/${TARGET_ARCH} //riva/clients/... --test_summary=detailed --test_output=all && \
bazel build --stamp --config=release --config=${TARGET_OS}/${TARGET_ARCH} $BAZEL_CACHE_ARG //... && \
RUN bazel test $BAZEL_CACHE_ARG --config=${TARGET_OS}/${TARGET_ARCH} //riva/clients/... --test_summary=detailed --test_output=all
RUN bazel build --stamp --config=release --config=${TARGET_OS}/${TARGET_ARCH} $BAZEL_CACHE_ARG //... && \
cp -R /work/bazel-bin/riva /opt

RUN ls -lah /work; ls -lah /work/.git; cat /work/.bazelrc
Expand Down
30 changes: 22 additions & 8 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@ new_local_repository(

http_archive(
name = "com_google_absl",
urls = ["https://github.com/abseil/abseil-cpp/archive/c22c032a353b5dc16d86ddc879e628344e591e77.zip"],
strip_prefix = "abseil-cpp-c22c032a353b5dc16d86ddc879e628344e591e77",
sha256 = "88e79f5b7e3f92d3f19ad470cb38ef6becaf9bf195206ca9dba1a23d4017bc1a"
urls = ["https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz"],
strip_prefix = "abseil-cpp-20220623.1",
sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8"
)

# hack to force grpc to use local openssl instead of boringssl (to avoid conflicts with libcurl)
new_local_repository(
name = "boringssl",
path = "/usr/include/openssl",
build_file = "//third_party:BUILD.boringssl"
)

http_archive(
name = "com_github_grpc_grpc",
sha256 = "8c05641b9f91cbc92f51cc4a5b3a226788d7a63f20af4ca7aaca50d92cc94a0d",
strip_prefix = "grpc-1.44.0",
sha256 = "fb1ed98eb3555877d55eb2b948caca44bc8601c6704896594de81558639709ef",
strip_prefix = "grpc-1.50.1",
urls = [
"https://github.com/grpc/grpc/archive/v1.44.0.tar.gz",
"https://github.com/grpc/grpc/archive/refs/tags/v1.50.1.tar.gz",
],
)

Expand Down Expand Up @@ -69,8 +76,8 @@ grpc_extra_deps()

git_repository(
name = "nvriva_common",
remote = "https://github.com/nvidia-riva/common.git",
commit = "1301af41cbf429dda8204b22d817c0e17cf8b369"
remote = "https://github.com/atomer-nvidia/common.git",
commit = "60e67e8ba30eac99d8cfb30275b03b76b6562a29"
)

http_archive(
Expand All @@ -88,3 +95,10 @@ http_archive(
strip_prefix = "opusfile-0.12",
build_file = "//third_party:BUILD.libopusfile"
)

http_archive(
name = "platforms",
urls = ["https://github.com/bazelbuild/platforms/archive/refs/tags/1.0.0.tar.gz"],
strip_prefix = "platforms-1.0.0",
sha256 = "852b71bfa15712cec124e4a57179b6bc95d59fdf5052945f5d550e072501a769",
)
12 changes: 9 additions & 3 deletions riva/clients/asr/riva_asr_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ DEFINE_double(
DEFINE_string(
custom_configuration, "",
"Custom configurations to be sent to the server as key value pairs <key:value,key:value,...>");
DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation");
DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size");

class RecognizeClient {
public:
Expand Down Expand Up @@ -477,6 +479,8 @@ main(int argc, char** argv)
str_usage << " --stop_threshold=<float>" << std::endl;
str_usage << " --stop_threshold_eou=<float>" << std::endl;
str_usage << " --custom_configuration=<key:value,key:value,...>" << std::endl;
str_usage << " --timeout_ms=<uint64_t>" << std::endl;
str_usage << " --max_grpc_message_size=<uint64_t>" << std::endl;
gflags::SetUsageMessage(str_usage.str());
gflags::SetVersionString(::riva::utils::kBuildScmRevision);

Expand Down Expand Up @@ -507,9 +511,11 @@ main(int argc, char** argv)

std::shared_ptr<grpc::Channel> grpc_channel;
try {
auto creds =
riva::clients::CreateChannelCredentials(FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, FLAGS_metadata);
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds);
auto creds = riva::clients::CreateChannelCredentials(
FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
FLAGS_metadata);
grpc_channel = riva::clients::CreateChannelBlocking(
FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
}
catch (const std::exception& e) {
std::cerr << "Error creating GRPC channel: " << e.what() << std::endl;
Expand Down
20 changes: 12 additions & 8 deletions riva/clients/asr/riva_streaming_asr_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ DEFINE_double(
DEFINE_string(
custom_configuration, "",
"Custom configurations to be sent to the server as key value pairs <key:value,key:value,...>");
DEFINE_bool(
speaker_diarization, false,
"Flag that controls if speaker diarization is requested");
DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested");
DEFINE_int32(
diarization_max_speakers, 4,
"Max number of speakers to detect when performing speaker diarization. Default is 4 (Max)");
DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation");
DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size");

void
signal_handler(int signal_num)
Expand Down Expand Up @@ -156,6 +156,8 @@ main(int argc, char** argv)
str_usage << " --custom_configuration=<key:value,key:value,...>" << std::endl;
str_usage << " --speaker_diarization=<true|false>" << std::endl;
str_usage << " --diarization_max_speakers=<int>" << std::endl;
str_usage << " --timeout_ms=<uint64_t>" << std::endl;
str_usage << " --max_grpc_message_size=<uint64_t>" << std::endl;
gflags::SetUsageMessage(str_usage.str());
gflags::SetVersionString(::riva::utils::kBuildScmRevision);

Expand Down Expand Up @@ -187,9 +189,11 @@ main(int argc, char** argv)

std::shared_ptr<grpc::Channel> grpc_channel;
try {
auto creds =
riva::clients::CreateChannelCredentials(FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, FLAGS_metadata);
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds);
auto creds = riva::clients::CreateChannelCredentials(
FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
FLAGS_metadata);
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms,
FLAGS_max_grpc_message_size);
}
catch (const std::exception& e) {
std::cerr << "Error creating GRPC channel: " << e.what() << std::endl;
Expand Down Expand Up @@ -222,8 +226,8 @@ main(int argc, char** argv)
FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime,
FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score,
FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou,
FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration, FLAGS_speaker_diarization,
FLAGS_diarization_max_speakers);
FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration,
FLAGS_speaker_diarization, FLAGS_diarization_max_speakers);

if (FLAGS_audio_file.size()) {
return recognize_client.DoStreamingFromFile(
Expand Down
13 changes: 12 additions & 1 deletion riva/clients/tts/riva_tts_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ DEFINE_string(
DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40.");
DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words");
DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt.");
DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation");
DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size");
DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0");

static const std::string LC_enUS = "en-US";

Expand Down Expand Up @@ -114,6 +117,9 @@ main(int argc, char** argv)
str_usage << " --zero_shot_quality=<quality>" << std::endl;
str_usage << " --zero_shot_transcript=<text>" << std::endl;
str_usage << " --custom_dictionary=<filename> " << std::endl;
str_usage << " --timeout_ms=<timeout_ms> " << std::endl;
str_usage << " --max_grpc_message_size=<max_grpc_message_size> " << std::endl;
str_usage << " --exaggeration_factor=<exaggeration_factor> " << std::endl;
gflags::SetUsageMessage(str_usage.str());
gflags::SetVersionString(::riva::utils::kBuildScmRevision);

Expand Down Expand Up @@ -148,7 +154,7 @@ main(int argc, char** argv)
auto creds = riva::clients::CreateChannelCredentials(
FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
FLAGS_metadata);
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds);
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
}
catch (const std::exception& e) {
std::cerr << "Error creating GRPC channel: " << e.what() << std::endl;
Expand Down Expand Up @@ -214,6 +220,11 @@ main(int argc, char** argv)
if (not FLAGS_online and not FLAGS_zero_shot_transcript.empty()) {
zero_shot_data->set_transcript(FLAGS_zero_shot_transcript);
}
if (FLAGS_exaggeration_factor < 0.0 || FLAGS_exaggeration_factor > 2.0) {
LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl;
return -1;
}
zero_shot_data->set_exaggeration_factor(FLAGS_exaggeration_factor);
}

// Send text content using Synthesize().
Expand Down
25 changes: 18 additions & 7 deletions riva/clients/tts/riva_tts_perf_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ DEFINE_string(
DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40.");
DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words");
DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt.");
DEFINE_double(
exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0");

static const std::string LC_enUS = "en-US";

Expand Down Expand Up @@ -114,15 +116,14 @@ synthesizeBatch(
std::unique_ptr<nr_tts::RivaSpeechSynthesis::Stub> tts, std::string text, std::string language,
uint32_t rate, std::string voice_name, std::string filepath,
std::string zero_shot_prompt_filename, int32_t zero_shot_quality, std::string custom_dictionary,
std::string zero_shot_transcript)
std::string zero_shot_transcript, double exaggeration_factor)
{
// Parse command line arguments.
nr_tts::SynthesizeSpeechRequest request;
request.set_text(text);
request.set_language_code(language);
request.set_sample_rate_hz(rate);
request.set_voice_name(voice_name);

if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
request.set_encoding(nr::LINEAR_PCM);
} else if (FLAGS_audio_encoding == "opus") {
Expand Down Expand Up @@ -163,6 +164,11 @@ synthesizeBatch(
if (not FLAGS_zero_shot_transcript.empty()) {
zero_shot_data->set_transcript(FLAGS_zero_shot_transcript);
}
if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) {
LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl;
return -1;
}
zero_shot_data->set_exaggeration_factor(exaggeration_factor);
}

// Send text content using Synthesize().
Expand Down Expand Up @@ -206,14 +212,13 @@ synthesizeOnline(
std::unique_ptr<nr_tts::RivaSpeechSynthesis::Stub> tts, std::string text, std::string language,
uint32_t rate, std::string voice_name, double* time_to_first_chunk,
std::vector<double>* time_to_next_chunk, size_t* num_samples, std::string filepath,
std::string zero_shot_prompt_filename, int32_t zero_shot_quality)
std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double exaggeration_factor)
{
nr_tts::SynthesizeSpeechRequest request;
request.set_text(text);
request.set_language_code(language);
request.set_sample_rate_hz(rate);
request.set_voice_name(voice_name);

auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED;
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
ae = nr::LINEAR_PCM;
Expand Down Expand Up @@ -251,6 +256,11 @@ synthesizeOnline(
}
zero_shot_data->set_sample_rate_hz(zero_shot_sample_rate);
zero_shot_data->set_quality(zero_shot_quality);
if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) {
LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl;
return;
}
zero_shot_data->set_exaggeration_factor(exaggeration_factor);
}


Expand Down Expand Up @@ -357,7 +367,7 @@ main(int argc, char** argv)
str_usage << " --zero_shot_quality=<quality>" << std::endl;
str_usage << " --zero_shot_transcript=<text>" << std::endl;
str_usage << " --custom_dictionary=<filename> " << std::endl;

str_usage << " --exaggeration_factor=<exaggeration_factor> " << std::endl;
gflags::SetUsageMessage(str_usage.str());
gflags::SetVersionString(::riva::utils::kBuildScmRevision);

Expand Down Expand Up @@ -484,7 +494,7 @@ main(int argc, char** argv)
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
&time_to_first_chunk, time_to_next_chunk, &num_samples,
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
FLAGS_zero_shot_quality);
FLAGS_zero_shot_quality, FLAGS_exaggeration_factor);
latencies_first_chunk[i]->push_back(time_to_first_chunk);
latencies_next_chunks[i]->insert(
latencies_next_chunks[i]->end(), time_to_next_chunk->begin(),
Expand Down Expand Up @@ -560,7 +570,8 @@ main(int argc, char** argv)
int32_t num_samples = synthesizeBatch(
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript);
FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript,
FLAGS_exaggeration_factor);
results_num_samples[i]->push_back(num_samples);
}
}));
Expand Down
9 changes: 4 additions & 5 deletions riva/clients/utils/grpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "riva/utils/files/files.h"
#include "riva/utils/string_processing.h"

constexpr int MAX_GRPC_MESSAGE_SIZE = 64 * 1024 * 1024;
constexpr int MAX_GRPC_MESSAGE_SIZE = 128 * 1024 * 1024;

using grpc::Status;
using grpc::StatusCode;
Expand Down Expand Up @@ -58,11 +58,11 @@ class CustomAuthenticator : public grpc::MetadataCredentialsPlugin {
std::shared_ptr<grpc::Channel>
CreateChannelBlocking(
const std::string& uri, const std::shared_ptr<grpc::ChannelCredentials> credentials,
uint64_t timeout_ms = 10000)
uint64_t timeout_ms = 10000, uint64_t max_grpc_message_size = MAX_GRPC_MESSAGE_SIZE)
{
grpc::ChannelArguments channel_args;
channel_args.SetMaxReceiveMessageSize(MAX_GRPC_MESSAGE_SIZE);
channel_args.SetMaxSendMessageSize(MAX_GRPC_MESSAGE_SIZE);
channel_args.SetMaxReceiveMessageSize(max_grpc_message_size);
channel_args.SetMaxSendMessageSize(max_grpc_message_size);
auto channel = grpc::CreateCustomChannel(uri, credentials, channel_args);

auto deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms);
Expand Down Expand Up @@ -120,4 +120,3 @@ CreateChannelCredentials(
}

} // namespace riva::clients

24 changes: 24 additions & 0 deletions third_party/BUILD.boringssl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
"""

cc_library(
name = "ssl",
visibility = ["//visibility:public"],
hdrs = glob(["*"]),

linkopts = ["-lssl"],
)

cc_library(
name = "crypto",
visibility = ["//visibility:public"],
hdrs = glob(["*"]),

linkopts = ["-lcrypto"],
)