diff --git a/.bazelrc b/.bazelrc index 7a42b69..223a550 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,7 +2,7 @@ # SPDX-License-Identifier: MIT try-import %workspace%/user.bazelrc -build --show_timestamps --keep_going --color=yes --cxxopt='-std=c++1z' --linkopt='-lstdc++fs' +build --show_timestamps --keep_going --color=yes --cxxopt='-std=c++17' --linkopt='-lstdc++fs' build --workspace_status_command=scripts/workspace_status.sh build:debug --compilation_mode dbg diff --git a/.bazelversion b/.bazelversion index 0062ac9..024b066 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -5.0.0 +6.2.1 diff --git a/Dockerfile b/Dockerfile index 3a2b568..a7c6e0a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,15 @@ # syntax = docker/dockerfile:1.2 -ARG BAZEL_VERSION=5.0.0 +ARG BAZEL_VERSION=6.2.1 ARG TARGET_ARCH=x86_64 # valid values: x86_64, aarch64 ARG TARGET_OS=linux # valid values: linux, l4t -FROM ubuntu:20.04 AS base +FROM ubuntu:24.04 AS base ARG DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y \ - libasound2 \ + libasound2t64 \ libogg0 \ + openssl \ ca-certificates FROM base AS builddep @@ -22,7 +23,8 @@ RUN apt-get update && apt-get install -y \ unzip \ build-essential \ libasound2-dev \ - libogg-dev + libogg-dev \ + libssl-dev RUN if [ "$TARGET_ARCH" = "aarch64" ] && [ "$TARGET_OS" = "l4t" ]; then \ apt-get update && apt-get install -y --no-install-recommends openjdk-11-jdk-headless; \ @@ -48,8 +50,8 @@ COPY scripts /work/scripts COPY third_party /work/third_party COPY riva /work/riva ARG BAZEL_CACHE_ARG="" -RUN bazel test $BAZEL_CACHE_ARG --config=${TARGET_OS}/${TARGET_ARCH} //riva/clients/... --test_summary=detailed --test_output=all && \ - bazel build --stamp --config=release --config=${TARGET_OS}/${TARGET_ARCH} $BAZEL_CACHE_ARG //... && \ +RUN bazel test $BAZEL_CACHE_ARG --config=${TARGET_OS}/${TARGET_ARCH} //riva/clients/... --test_summary=detailed --test_output=all +RUN bazel build --stamp --config=release --config=${TARGET_OS}/${TARGET_ARCH} $BAZEL_CACHE_ARG //... && \ cp -R /work/bazel-bin/riva /opt RUN ls -lah /work; ls -lah /work/.git; cat /work/.bazelrc diff --git a/WORKSPACE b/WORKSPACE index 2cda2c3..4ff14ab 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -17,17 +17,24 @@ new_local_repository( http_archive( name = "com_google_absl", - urls = ["https://github.com/abseil/abseil-cpp/archive/c22c032a353b5dc16d86ddc879e628344e591e77.zip"], - strip_prefix = "abseil-cpp-c22c032a353b5dc16d86ddc879e628344e591e77", - sha256 = "88e79f5b7e3f92d3f19ad470cb38ef6becaf9bf195206ca9dba1a23d4017bc1a" + urls = ["https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz"], + strip_prefix = "abseil-cpp-20220623.1", + sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8" +) + +# hack to force grpc to use local openssl instead of boringssl (to avoid conflicts with libcurl) +new_local_repository( + name = "boringssl", + path = "/usr/include/openssl", + build_file = "//third_party:BUILD.boringssl" ) http_archive( name = "com_github_grpc_grpc", - sha256 = "8c05641b9f91cbc92f51cc4a5b3a226788d7a63f20af4ca7aaca50d92cc94a0d", - strip_prefix = "grpc-1.44.0", + sha256 = "fb1ed98eb3555877d55eb2b948caca44bc8601c6704896594de81558639709ef", + strip_prefix = "grpc-1.50.1", urls = [ - "https://github.com/grpc/grpc/archive/v1.44.0.tar.gz", + "https://github.com/grpc/grpc/archive/refs/tags/v1.50.1.tar.gz", ], ) @@ -69,8 +76,8 @@ grpc_extra_deps() git_repository( name = "nvriva_common", - remote = "https://github.com/nvidia-riva/common.git", - commit = "1301af41cbf429dda8204b22d817c0e17cf8b369" + remote = "https://github.com/atomer-nvidia/common.git", + commit = "60e67e8ba30eac99d8cfb30275b03b76b6562a29" ) http_archive( @@ -88,3 +95,10 @@ http_archive( strip_prefix = "opusfile-0.12", build_file = "//third_party:BUILD.libopusfile" ) + +http_archive( + name = "platforms", + urls = ["https://github.com/bazelbuild/platforms/archive/refs/tags/1.0.0.tar.gz"], + strip_prefix = "platforms-1.0.0", + sha256 = "852b71bfa15712cec124e4a57179b6bc95d59fdf5052945f5d550e072501a769", +) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index d9f2ab5..8f82c9a 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -87,6 +87,8 @@ DEFINE_double( DEFINE_string( custom_configuration, "", "Custom configurations to be sent to the server as key value pairs "); +DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); +DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size"); class RecognizeClient { public: @@ -477,6 +479,8 @@ main(int argc, char** argv) str_usage << " --stop_threshold=" << std::endl; str_usage << " --stop_threshold_eou=" << std::endl; str_usage << " --custom_configuration=" << std::endl; + str_usage << " --timeout_ms=" << std::endl; + str_usage << " --max_grpc_message_size=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -507,9 +511,11 @@ main(int argc, char** argv) std::shared_ptr grpc_channel; try { - auto creds = - riva::clients::CreateChannelCredentials(FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, FLAGS_metadata); - grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds); + auto creds = riva::clients::CreateChannelCredentials( + FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, + FLAGS_metadata); + grpc_channel = riva::clients::CreateChannelBlocking( + FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size); } catch (const std::exception& e) { std::cerr << "Error creating GRPC channel: " << e.what() << std::endl; diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 8bf408a..0790738 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -95,12 +95,12 @@ DEFINE_double( DEFINE_string( custom_configuration, "", "Custom configurations to be sent to the server as key value pairs "); -DEFINE_bool( - speaker_diarization, false, - "Flag that controls if speaker diarization is requested"); +DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); DEFINE_int32( diarization_max_speakers, 4, "Max number of speakers to detect when performing speaker diarization. Default is 4 (Max)"); +DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); +DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size"); void signal_handler(int signal_num) @@ -156,6 +156,8 @@ main(int argc, char** argv) str_usage << " --custom_configuration=" << std::endl; str_usage << " --speaker_diarization=" << std::endl; str_usage << " --diarization_max_speakers=" << std::endl; + str_usage << " --timeout_ms=" << std::endl; + str_usage << " --max_grpc_message_size=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -187,9 +189,11 @@ main(int argc, char** argv) std::shared_ptr grpc_channel; try { - auto creds = - riva::clients::CreateChannelCredentials(FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, FLAGS_metadata); - grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds); + auto creds = riva::clients::CreateChannelCredentials( + FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, + FLAGS_metadata); + grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, + FLAGS_max_grpc_message_size); } catch (const std::exception& e) { std::cerr << "Error creating GRPC channel: " << e.what() << std::endl; @@ -222,8 +226,8 @@ main(int argc, char** argv) FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, - FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration, FLAGS_speaker_diarization, - FLAGS_diarization_max_speakers); + FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration, + FLAGS_speaker_diarization, FLAGS_diarization_max_speakers); if (FLAGS_audio_file.size()) { return recognize_client.DoStreamingFromFile( diff --git a/riva/clients/tts/riva_tts_client.cc b/riva/clients/tts/riva_tts_client.cc index 20c7181..a9ef3e9 100644 --- a/riva/clients/tts/riva_tts_client.cc +++ b/riva/clients/tts/riva_tts_client.cc @@ -52,6 +52,9 @@ DEFINE_string( DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40."); DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words"); DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); +DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); +DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size"); +DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -114,6 +117,9 @@ main(int argc, char** argv) str_usage << " --zero_shot_quality=" << std::endl; str_usage << " --zero_shot_transcript=" << std::endl; str_usage << " --custom_dictionary= " << std::endl; + str_usage << " --timeout_ms= " << std::endl; + str_usage << " --max_grpc_message_size= " << std::endl; + str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -148,7 +154,7 @@ main(int argc, char** argv) auto creds = riva::clients::CreateChannelCredentials( FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, FLAGS_metadata); - grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds); + grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size); } catch (const std::exception& e) { std::cerr << "Error creating GRPC channel: " << e.what() << std::endl; @@ -214,6 +220,11 @@ main(int argc, char** argv) if (not FLAGS_online and not FLAGS_zero_shot_transcript.empty()) { zero_shot_data->set_transcript(FLAGS_zero_shot_transcript); } + if (FLAGS_exaggeration_factor < 0.0 || FLAGS_exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return -1; + } + zero_shot_data->set_exaggeration_factor(FLAGS_exaggeration_factor); } // Send text content using Synthesize(). diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index 0f15e6e..adf867e 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -63,6 +63,8 @@ DEFINE_string( DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40."); DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words"); DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); +DEFINE_double( + exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -114,7 +116,7 @@ synthesizeBatch( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, std::string filepath, std::string zero_shot_prompt_filename, int32_t zero_shot_quality, std::string custom_dictionary, - std::string zero_shot_transcript) + std::string zero_shot_transcript, double exaggeration_factor) { // Parse command line arguments. nr_tts::SynthesizeSpeechRequest request; @@ -122,7 +124,6 @@ synthesizeBatch( request.set_language_code(language); request.set_sample_rate_hz(rate); request.set_voice_name(voice_name); - if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { request.set_encoding(nr::LINEAR_PCM); } else if (FLAGS_audio_encoding == "opus") { @@ -163,6 +164,11 @@ synthesizeBatch( if (not FLAGS_zero_shot_transcript.empty()) { zero_shot_data->set_transcript(FLAGS_zero_shot_transcript); } + if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return -1; + } + zero_shot_data->set_exaggeration_factor(exaggeration_factor); } // Send text content using Synthesize(). @@ -206,14 +212,13 @@ synthesizeOnline( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, double* time_to_first_chunk, std::vector* time_to_next_chunk, size_t* num_samples, std::string filepath, - std::string zero_shot_prompt_filename, int32_t zero_shot_quality) + std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double exaggeration_factor) { nr_tts::SynthesizeSpeechRequest request; request.set_text(text); request.set_language_code(language); request.set_sample_rate_hz(rate); request.set_voice_name(voice_name); - auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED; if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { ae = nr::LINEAR_PCM; @@ -251,6 +256,11 @@ synthesizeOnline( } zero_shot_data->set_sample_rate_hz(zero_shot_sample_rate); zero_shot_data->set_quality(zero_shot_quality); + if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return; + } + zero_shot_data->set_exaggeration_factor(exaggeration_factor); } @@ -357,7 +367,7 @@ main(int argc, char** argv) str_usage << " --zero_shot_quality=" << std::endl; str_usage << " --zero_shot_transcript=" << std::endl; str_usage << " --custom_dictionary= " << std::endl; - + str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -484,7 +494,7 @@ main(int argc, char** argv) std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, &time_to_first_chunk, time_to_next_chunk, &num_samples, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality); + FLAGS_zero_shot_quality, FLAGS_exaggeration_factor); latencies_first_chunk[i]->push_back(time_to_first_chunk); latencies_next_chunks[i]->insert( latencies_next_chunks[i]->end(), time_to_next_chunk->begin(), @@ -560,7 +570,8 @@ main(int argc, char** argv) int32_t num_samples = synthesizeBatch( std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript); + FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript, + FLAGS_exaggeration_factor); results_num_samples[i]->push_back(num_samples); } })); diff --git a/riva/clients/utils/grpc.h b/riva/clients/utils/grpc.h index 5aa4ff3..16e64a3 100644 --- a/riva/clients/utils/grpc.h +++ b/riva/clients/utils/grpc.h @@ -16,7 +16,7 @@ #include "riva/utils/files/files.h" #include "riva/utils/string_processing.h" -constexpr int MAX_GRPC_MESSAGE_SIZE = 64 * 1024 * 1024; +constexpr int MAX_GRPC_MESSAGE_SIZE = 128 * 1024 * 1024; using grpc::Status; using grpc::StatusCode; @@ -58,11 +58,11 @@ class CustomAuthenticator : public grpc::MetadataCredentialsPlugin { std::shared_ptr CreateChannelBlocking( const std::string& uri, const std::shared_ptr credentials, - uint64_t timeout_ms = 10000) + uint64_t timeout_ms = 10000, uint64_t max_grpc_message_size = MAX_GRPC_MESSAGE_SIZE) { grpc::ChannelArguments channel_args; - channel_args.SetMaxReceiveMessageSize(MAX_GRPC_MESSAGE_SIZE); - channel_args.SetMaxSendMessageSize(MAX_GRPC_MESSAGE_SIZE); + channel_args.SetMaxReceiveMessageSize(max_grpc_message_size); + channel_args.SetMaxSendMessageSize(max_grpc_message_size); auto channel = grpc::CreateCustomChannel(uri, credentials, channel_args); auto deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms); @@ -120,4 +120,3 @@ CreateChannelCredentials( } } // namespace riva::clients - diff --git a/third_party/BUILD.boringssl b/third_party/BUILD.boringssl new file mode 100644 index 0000000..213af6a --- /dev/null +++ b/third_party/BUILD.boringssl @@ -0,0 +1,24 @@ +""" +Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +NVIDIA CORPORATION and its licensors retain all intellectual property +and proprietary rights in and to this software, related documentation +and any modifications thereto. Any use, reproduction, disclosure or +distribution of this software and related documentation without an express +license agreement from NVIDIA CORPORATION is strictly prohibited. +""" + +cc_library( + name = "ssl", + visibility = ["//visibility:public"], + hdrs = glob(["*"]), + + linkopts = ["-lssl"], +) + +cc_library( + name = "crypto", + visibility = ["//visibility:public"], + hdrs = glob(["*"]), + + linkopts = ["-lcrypto"], +) \ No newline at end of file