From e01f63afad6ca46306c9910cf808bce65ad8ce78 Mon Sep 17 00:00:00 2001 From: manishEMS47 Date: Tue, 9 Jun 2026 11:01:46 +0530 Subject: [PATCH 1/2] Added 60dB integration --- README.md | 21 +- cmd/api_key.go | 40 +++- cmd/provider.go | 103 +++++++++ cmd/provider_test.go | 82 +++++++ cmd/root.go | 2 +- cmd/speak.go | 57 +++-- cmd/speak_request_test.go | 16 +- cmd/voices.go | 18 +- cmd/voices_cache.go | 14 +- cmd/voices_query.go | 16 +- docs/providers.md | 67 ++++++ internal/elevenlabs/client.go | 45 ++-- internal/sixtydb/client.go | 397 ++++++++++++++++++++++++++++++++ internal/sixtydb/client_test.go | 232 +++++++++++++++++++ internal/tts/tts.go | 58 +++++ 15 files changed, 1072 insertions(+), 96 deletions(-) create mode 100644 cmd/provider.go create mode 100644 cmd/provider_test.go create mode 100644 docs/providers.md create mode 100644 internal/sixtydb/client.go create mode 100644 internal/sixtydb/client_test.go create mode 100644 internal/tts/tts.go diff --git a/README.md b/README.md index e51c879..04a9061 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,24 @@ sudo apt install build-essential pkg-config libasound2-dev ``` ## Configuration -- `ELEVENLABS_API_KEY` (required) -- `--api-key-file` or `ELEVENLABS_API_KEY_FILE`/`SAG_API_KEY_FILE` to load the key from a file -- Optional defaults: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID` + +`sag` supports two TTS providers and auto-selects one from whichever API key is set: + +- **ElevenLabs** — `ELEVENLABS_API_KEY` (or `--api-key`, or `--api-key-file` / `ELEVENLABS_API_KEY_FILE` / `SAG_API_KEY_FILE`) +- **60db** (`api.60db.ai`) — `SIXTYDB_API_KEY` (or `SIXTYDB_API_KEY_FILE`) + +Selection rules: +- Only one key set → that provider is used. +- Both keys set → ElevenLabs is used (unset `ELEVENLABS_API_KEY` to use 60db); a note is printed. +- Neither set → error. + +Optional defaults: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. Override a provider's host with `--base-url`. + +The same flags work for both providers; `sag` translates them to each API. A few flags are +ElevenLabs-only and are accepted-but-ignored on 60db (a note is printed): `--model-id`, +`--style`, `--speaker-boost`/`--no-speaker-boost`, `--seed`, `--normalize`, `--lang`, +`--latency-tier`. The `--stability`/`--similarity` `0..1` values are scaled to 60db's `0..100` +range automatically. See [docs/providers.md](docs/providers.md) for details. ## Usage diff --git a/cmd/api_key.go b/cmd/api_key.go index 8b22232..ff8e9c0 100644 --- a/cmd/api_key.go +++ b/cmd/api_key.go @@ -6,23 +6,39 @@ import ( "strings" ) -func ensureAPIKey() error { - if cfg.APIKey == "" { - key, err := resolveAPIKeyFromFile() - if err != nil { - return err - } - cfg.APIKey = key +// resolveElevenLabsKey resolves the ElevenLabs API key without erroring when +// absent (returns ""). Order: --api-key, key file, ELEVENLABS_API_KEY, +// SAG_API_KEY. +func resolveElevenLabsKey() (string, error) { + if cfg.APIKey != "" { + return cfg.APIKey, nil + } + key, err := resolveAPIKeyFromFile() + if err != nil { + return "", err + } + if key != "" { + return key, nil + } + if v := os.Getenv("ELEVENLABS_API_KEY"); v != "" { + return v, nil } - if cfg.APIKey == "" { - cfg.APIKey = os.Getenv("ELEVENLABS_API_KEY") + if v := os.Getenv("SAG_API_KEY"); v != "" { + return v, nil } - if cfg.APIKey == "" { - cfg.APIKey = os.Getenv("SAG_API_KEY") + return "", nil +} + +// ensureAPIKey resolves and stores the ElevenLabs API key, erroring if missing. +func ensureAPIKey() error { + key, err := resolveElevenLabsKey() + if err != nil { + return err } - if cfg.APIKey == "" { + if key == "" { return fmt.Errorf("missing ElevenLabs API key (set --api-key, --api-key-file, or ELEVENLABS_API_KEY)") } + cfg.APIKey = key return nil } diff --git a/cmd/provider.go b/cmd/provider.go new file mode 100644 index 0000000..4ddab6a --- /dev/null +++ b/cmd/provider.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "fmt" + "os" + "strings" + + "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/sixtydb" + "github.com/steipete/sag/internal/tts" +) + +const ( + providerElevenLabs = "elevenlabs" + providerSixtyDB = "60db" +) + +// resolveSixtyDBKey resolves the 60db API key from its env vars. +// Order: SIXTYDB_API_KEY, then SIXTYDB_API_KEY_FILE. +func resolveSixtyDBKey() (string, error) { + if key := strings.TrimSpace(os.Getenv("SIXTYDB_API_KEY")); key != "" { + return key, nil + } + if path := strings.TrimSpace(os.Getenv("SIXTYDB_API_KEY_FILE")); path != "" { + data, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("read 60db api key file: %w", err) + } + key := strings.TrimSpace(string(data)) + if key == "" { + return "", fmt.Errorf("60db api key file %q is empty", path) + } + return key, nil + } + return "", nil +} + +// ensureProviderConfigured verifies that at least one provider's key is set. +// Used by PreRunE so 60db-only users aren't rejected for lacking an +// ElevenLabs key. +func ensureProviderConfigured() error { + elKey, err := resolveElevenLabsKey() + if err != nil { + return err + } + sdKey, err := resolveSixtyDBKey() + if err != nil { + return err + } + if elKey == "" && sdKey == "" { + return fmt.Errorf("missing API key (set ELEVENLABS_API_KEY or SIXTYDB_API_KEY; or --api-key / --api-key-file)") + } + return nil +} + +// selectProvider auto-detects the active provider from whichever API key is +// present. If both are set, ElevenLabs wins (preserving prior default) and a +// note is printed. The chosen client is built with cfg.BaseURL, which each +// client treats as a per-provider override (empty => provider default host). +func selectProvider() (tts.Provider, string, error) { + elKey, err := resolveElevenLabsKey() + if err != nil { + return nil, "", err + } + sdKey, err := resolveSixtyDBKey() + if err != nil { + return nil, "", err + } + + switch { + case elKey != "" && sdKey != "": + fmt.Fprintln(os.Stderr, "note: both ElevenLabs and 60db API keys set; using ElevenLabs (unset ELEVENLABS_API_KEY to use 60db)") + return elevenlabs.NewClient(elKey, cfg.BaseURL), providerElevenLabs, nil + case elKey != "": + return elevenlabs.NewClient(elKey, cfg.BaseURL), providerElevenLabs, nil + case sdKey != "": + return sixtydb.NewClient(sdKey, cfg.BaseURL), providerSixtyDB, nil + default: + return nil, "", fmt.Errorf("missing API key (set ELEVENLABS_API_KEY or SIXTYDB_API_KEY; or --api-key / --api-key-file)") + } +} + +// sixtyDBOnlyFlags lists flags that ElevenLabs honors but 60db has no +// equivalent for. When the active provider is 60db and the user set one, we +// note that it is ignored rather than failing. +var sixtyDBIgnoredFlags = []string{ + "model-id", "style", "speaker-boost", "no-speaker-boost", + "seed", "normalize", "lang", "latency-tier", +} + +// noteUnsupportedSixtyDBFlags prints a single stderr note if the user set any +// flag that 60db ignores. +func noteUnsupportedSixtyDBFlags(changed func(string) bool) { + var ignored []string + for _, name := range sixtyDBIgnoredFlags { + if changed(name) { + ignored = append(ignored, "--"+name) + } + } + if len(ignored) > 0 { + fmt.Fprintf(os.Stderr, "note: 60db ignores %s\n", strings.Join(ignored, ", ")) + } +} diff --git a/cmd/provider_test.go b/cmd/provider_test.go new file mode 100644 index 0000000..5b914aa --- /dev/null +++ b/cmd/provider_test.go @@ -0,0 +1,82 @@ +package cmd + +import ( + "testing" +) + +// resetProviderEnv neutralizes every key source so each case starts clean. +// Setting an env var to "" makes the resolvers treat it as absent. +func resetProviderEnv(t *testing.T) { + t.Helper() + cfg.APIKey = "" + cfg.APIKeyFile = "" + cfg.BaseURL = "" + t.Cleanup(func() { cfg.APIKey = ""; cfg.APIKeyFile = ""; cfg.BaseURL = "" }) + for _, k := range []string{ + "ELEVENLABS_API_KEY", "SAG_API_KEY", + "ELEVENLABS_API_KEY_FILE", "SAG_API_KEY_FILE", + "SIXTYDB_API_KEY", "SIXTYDB_API_KEY_FILE", + } { + t.Setenv(k, "") + } +} + +func TestSelectProvider_ElevenLabsOnly(t *testing.T) { + resetProviderEnv(t) + t.Setenv("ELEVENLABS_API_KEY", "el-key") + + _, name, err := selectProvider() + if err != nil { + t.Fatalf("selectProvider error: %v", err) + } + if name != providerElevenLabs { + t.Fatalf("expected %q, got %q", providerElevenLabs, name) + } +} + +func TestSelectProvider_SixtyDBOnly(t *testing.T) { + resetProviderEnv(t) + t.Setenv("SIXTYDB_API_KEY", "sd-key") + + _, name, err := selectProvider() + if err != nil { + t.Fatalf("selectProvider error: %v", err) + } + if name != providerSixtyDB { + t.Fatalf("expected %q, got %q", providerSixtyDB, name) + } +} + +func TestSelectProvider_BothPrefersElevenLabs(t *testing.T) { + resetProviderEnv(t) + t.Setenv("ELEVENLABS_API_KEY", "el-key") + t.Setenv("SIXTYDB_API_KEY", "sd-key") + + _, name, err := selectProvider() + if err != nil { + t.Fatalf("selectProvider error: %v", err) + } + if name != providerElevenLabs { + t.Fatalf("expected ElevenLabs to win tiebreak, got %q", name) + } +} + +func TestSelectProvider_NeitherErrors(t *testing.T) { + resetProviderEnv(t) + + _, _, err := selectProvider() + if err == nil { + t.Fatal("expected error when no API key is set") + } +} + +func TestEnsureProviderConfigured(t *testing.T) { + resetProviderEnv(t) + if err := ensureProviderConfigured(); err == nil { + t.Fatal("expected error with no keys") + } + t.Setenv("SIXTYDB_API_KEY", "sd-key") + if err := ensureProviderConfigured(); err != nil { + t.Fatalf("expected 60db key to satisfy configuration, got %v", err) + } +} diff --git a/cmd/root.go b/cmd/root.go index 14a506c..28b9663 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -45,7 +45,7 @@ func Execute() { func init() { rootCmd.PersistentFlags().StringVar(&cfg.APIKey, "api-key", "", "ElevenLabs API key (or ELEVENLABS_API_KEY)") rootCmd.PersistentFlags().StringVar(&cfg.APIKeyFile, "api-key-file", "", "Read ElevenLabs API key from file (or ELEVENLABS_API_KEY_FILE)") - rootCmd.PersistentFlags().StringVar(&cfg.BaseURL, "base-url", "https://api.elevenlabs.io", "Override ElevenLabs API base URL") + rootCmd.PersistentFlags().StringVar(&cfg.BaseURL, "base-url", "", "Override the provider API base URL (empty = provider default)") rootCmd.PersistentFlags().BoolVarP(&versionFlag, "version", "V", false, "Print version and exit") } diff --git a/cmd/speak.go b/cmd/speak.go index ec3925f..3ae1e52 100644 --- a/cmd/speak.go +++ b/cmd/speak.go @@ -12,11 +12,15 @@ import ( "time" "github.com/steipete/sag/internal/audio" - "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/tts" "github.com/spf13/cobra" ) +// playbackFormat is the format requested when audio must be decoded for +// speaker playback (the oto/afplay path handles MP3). +const playbackFormat = "mp3_44100_128" + type speakOptions struct { voiceID string modelID string @@ -61,7 +65,7 @@ func init() { Long: "If no text argument is provided, the command reads from stdin.\n\nTip: run `sag prompting` for model-specific prompting tips and recommended flag combinations.", Args: cobra.ArbitraryArgs, PreRunE: func(_ *cobra.Command, _ []string) error { - return ensureAPIKey() + return ensureProviderConfigured() }, RunE: func(cmd *cobra.Command, args []string) error { if err := applyRateAndSpeed(&opts); err != nil { @@ -85,7 +89,10 @@ func init() { forceVoiceID = true } } - client := elevenlabs.NewClient(cfg.APIKey, cfg.BaseURL) + client, providerName, err := selectProvider() + if err != nil { + return err + } voiceID, err := resolveVoice(cmd.Context(), client, voiceInput, forceVoiceID) if err != nil { @@ -113,13 +120,22 @@ func init() { } } + if providerName == providerSixtyDB { + noteUnsupportedSixtyDBFlags(cmd.Flags().Changed) + // Speaker playback needs MP3; 60db's stream picks its own format + // and convert honors this request. + if opts.play { + opts.outputFmt = playbackFormat + } + } + ctx, cancel, err := ttsContext(cmd.Context(), opts.timeout) if err != nil { return err } defer cancel() - payload, err := buildTTSRequest(cmd, opts, text) + payload, err := buildTTSRequest(cmd, opts, text, providerName) if err != nil { return err } @@ -272,17 +288,18 @@ func applyRateAndSpeed(opts *speakOptions) error { return nil } -func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (elevenlabs.TTSRequest, error) { +func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName string) (tts.TTSRequest, error) { flags := cmd.Flags() var stabilityPtr *float64 if flags.Changed("stability") { if opts.stability < 0 || opts.stability > 1 { - return elevenlabs.TTSRequest{}, errors.New("stability must be between 0 and 1") + return tts.TTSRequest{}, errors.New("stability must be between 0 and 1") } - if opts.modelID == "eleven_v3" { + // The discrete 0/0.5/1 constraint is specific to ElevenLabs eleven_v3. + if providerName == providerElevenLabs && opts.modelID == "eleven_v3" { if !floatEqualsOneOf(opts.stability, []float64{0, 0.5, 1}) { - return elevenlabs.TTSRequest{}, errors.New("for eleven_v3, stability must be one of 0.0, 0.5, 1.0 (Creative/Natural/Robust)") + return tts.TTSRequest{}, errors.New("for eleven_v3, stability must be one of 0.0, 0.5, 1.0 (Creative/Natural/Robust)") } } stabilityPtr = &opts.stability @@ -291,7 +308,7 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven var similarityPtr *float64 if flags.Changed("similarity") || flags.Changed("similarity-boost") { if opts.similarity < 0 || opts.similarity > 1 { - return elevenlabs.TTSRequest{}, errors.New("similarity must be between 0 and 1") + return tts.TTSRequest{}, errors.New("similarity must be between 0 and 1") } similarityPtr = &opts.similarity } @@ -299,13 +316,13 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven var stylePtr *float64 if flags.Changed("style") { if opts.style < 0 || opts.style > 1 { - return elevenlabs.TTSRequest{}, errors.New("style must be between 0 and 1") + return tts.TTSRequest{}, errors.New("style must be between 0 and 1") } stylePtr = &opts.style } if flags.Changed("speaker-boost") && flags.Changed("no-speaker-boost") { - return elevenlabs.TTSRequest{}, errors.New("choose only one of --speaker-boost or --no-speaker-boost") + return tts.TTSRequest{}, errors.New("choose only one of --speaker-boost or --no-speaker-boost") } var speakerBoostPtr *bool if flags.Changed("speaker-boost") { @@ -319,7 +336,7 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven var seedPtr *uint32 if flags.Changed("seed") { if opts.seed > 4294967295 { - return elevenlabs.TTSRequest{}, errors.New("seed must be between 0 and 4294967295") + return tts.TTSRequest{}, errors.New("seed must be between 0 and 4294967295") } v := uint32(opts.seed) seedPtr = &v @@ -330,7 +347,7 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven switch normalize { case "auto", "on", "off": default: - return elevenlabs.TTSRequest{}, errors.New("normalize must be one of: auto, on, off") + return tts.TTSRequest{}, errors.New("normalize must be one of: auto, on, off") } } else { normalize = "" @@ -339,11 +356,11 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven lang := strings.ToLower(strings.TrimSpace(opts.lang)) if flags.Changed("lang") { if len(lang) != 2 { - return elevenlabs.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") + return tts.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") } for _, r := range lang { if r < 'a' || r > 'z' { - return elevenlabs.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") + return tts.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") } } } else { @@ -351,14 +368,14 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven } speed := opts.speed - return elevenlabs.TTSRequest{ + return tts.TTSRequest{ Text: text, ModelID: opts.modelID, OutputFormat: opts.outputFmt, Seed: seedPtr, ApplyTextNormalization: normalize, LanguageCode: lang, - VoiceSettings: &elevenlabs.VoiceSettings{ + VoiceSettings: &tts.VoiceSettings{ Speed: &speed, Stability: stabilityPtr, SimilarityBoost: similarityPtr, @@ -427,7 +444,7 @@ func isStdinTTY() bool { return (stat.Mode() & os.ModeCharDevice) != 0 } -func streamAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOptions, payload elevenlabs.TTSRequest) (int64, error) { +func streamAndPlay(ctx context.Context, client tts.Provider, opts speakOptions, payload tts.TTSRequest) (int64, error) { resp, err := client.StreamTTS(ctx, opts.voiceID, payload, opts.latencyTier) if err != nil { return 0, err @@ -488,7 +505,7 @@ func streamAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOpt return n, err } -func convertAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOptions, payload elevenlabs.TTSRequest) (int64, error) { +func convertAndPlay(ctx context.Context, client tts.Provider, opts speakOptions, payload tts.TTSRequest) (int64, error) { data, err := client.ConvertTTS(ctx, opts.voiceID, payload) if err != nil { return 0, err @@ -522,7 +539,7 @@ func convertAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOp return n, nil } -func resolveVoice(ctx context.Context, client *elevenlabs.Client, voiceInput string, forceID bool) (string, error) { +func resolveVoice(ctx context.Context, client tts.Provider, voiceInput string, forceID bool) (string, error) { voiceInput = strings.TrimSpace(voiceInput) if voiceInput == "" { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) diff --git a/cmd/speak_request_test.go b/cmd/speak_request_test.go index 4293666..b818968 100644 --- a/cmd/speak_request_test.go +++ b/cmd/speak_request_test.go @@ -33,7 +33,7 @@ func newSpeakTestCommand(t *testing.T) (*cobra.Command, *speakOptions) { func TestBuildTTSRequest_DefaultsOmitOptionalFields(t *testing.T) { cmd, opts := newSpeakTestCommand(t) - req, err := buildTTSRequest(cmd, *opts, "hello") + req, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err != nil { t.Fatalf("buildTTSRequest error: %v", err) } @@ -70,7 +70,7 @@ func TestBuildTTSRequest_SimilarityBoostAlias(t *testing.T) { t.Fatalf("parse flags: %v", err) } - req, err := buildTTSRequest(cmd, *opts, "hello") + req, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err != nil { t.Fatalf("buildTTSRequest error: %v", err) } @@ -85,7 +85,7 @@ func TestBuildTTSRequest_SpeakerBoostSetsJSONKey(t *testing.T) { t.Fatalf("parse flags: %v", err) } - req, err := buildTTSRequest(cmd, *opts, "hello") + req, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err != nil { t.Fatalf("buildTTSRequest error: %v", err) } @@ -107,7 +107,7 @@ func TestBuildTTSRequest_InvalidNormalize(t *testing.T) { if err := cmd.Flags().Parse([]string{"--normalize", "wat"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err == nil || !strings.Contains(err.Error(), "normalize must be one of") { t.Fatalf("expected normalize error, got %v", err) } @@ -118,7 +118,7 @@ func TestBuildTTSRequest_InvalidLang(t *testing.T) { if err := cmd.Flags().Parse([]string{"--lang", "eng"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err == nil || !strings.Contains(err.Error(), "lang must be a 2-letter") { t.Fatalf("expected lang error, got %v", err) } @@ -129,7 +129,7 @@ func TestBuildTTSRequest_InvalidSeed(t *testing.T) { if err := cmd.Flags().Parse([]string{"--seed", "4294967296"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err == nil || !strings.Contains(err.Error(), "seed must be between") { t.Fatalf("expected seed error, got %v", err) } @@ -140,7 +140,7 @@ func TestBuildTTSRequest_SpeakerBoostConflict(t *testing.T) { if err := cmd.Flags().Parse([]string{"--speaker-boost", "--no-speaker-boost"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err == nil || !strings.Contains(err.Error(), "choose only one") { t.Fatalf("expected conflict error, got %v", err) } @@ -152,7 +152,7 @@ func TestBuildTTSRequest_V3StabilityPresetsOnly(t *testing.T) { if err := cmd.Flags().Parse([]string{"--stability", "0.55"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) if err == nil || !strings.Contains(err.Error(), "for eleven_v3, stability must be one of") { t.Fatalf("expected v3 stability preset error, got %v", err) } diff --git a/cmd/voices.go b/cmd/voices.go index e683383..7dd9889 100644 --- a/cmd/voices.go +++ b/cmd/voices.go @@ -12,7 +12,7 @@ import ( "time" "github.com/steipete/sag/internal/audio" - "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/tts" "github.com/spf13/cobra" ) @@ -39,7 +39,7 @@ func init() { Use: "voices", Short: "List available ElevenLabs voices", PreRunE: func(_ *cobra.Command, _ []string) error { - return ensureAPIKey() + return ensureProviderConfigured() }, RunE: func(cmd *cobra.Command, _ []string) error { hasLabelFilters := false @@ -53,12 +53,14 @@ func init() { return errors.New("--try requires --search, --query, --label, or --limit to avoid playing all voices") } - client := elevenlabs.NewClient(cfg.APIKey, cfg.BaseURL) + client, _, err := selectProvider() + if err != nil { + return err + } ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second) defer cancel() - var voices []elevenlabs.Voice - var err error + var voices []tts.Voice if opts.search != "" { voices, err = client.SearchVoices(ctx, opts.search, opts.limit) if err != nil { @@ -169,9 +171,9 @@ func init() { rootCmd.AddCommand(cmd) } -func filterVoicesByName(voices []elevenlabs.Voice, search string) []elevenlabs.Voice { +func filterVoicesByName(voices []tts.Voice, search string) []tts.Voice { searchLower := strings.ToLower(search) - filtered := make([]elevenlabs.Voice, 0, len(voices)) + filtered := make([]tts.Voice, 0, len(voices)) for _, v := range voices { if strings.Contains(strings.ToLower(v.Name), searchLower) { filtered = append(filtered, v) @@ -180,7 +182,7 @@ func filterVoicesByName(voices []elevenlabs.Voice, search string) []elevenlabs.V return filtered } -func playVoicePreviewImpl(ctx context.Context, client *elevenlabs.Client, voice elevenlabs.Voice) error { +func playVoicePreviewImpl(ctx context.Context, client tts.Provider, voice tts.Voice) error { ctx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() diff --git a/cmd/voices_cache.go b/cmd/voices_cache.go index 46abd91..2e61d11 100644 --- a/cmd/voices_cache.go +++ b/cmd/voices_cache.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/tts" ) const ( @@ -25,8 +25,8 @@ type voiceCache struct { } type cachedVoice struct { - Voice elevenlabs.Voice `json:"voice"` - UpdatedAt time.Time `json:"updated_at"` + Voice tts.Voice `json:"voice"` + UpdatedAt time.Time `json:"updated_at"` } func newVoiceCache() *voiceCache { @@ -80,7 +80,7 @@ func saveVoiceCache(path string, cache *voiceCache) error { return os.WriteFile(path, data, 0o644) } -func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elevenlabs.Voice, cache *voiceCache, ttl time.Duration) ([]elevenlabs.Voice, int) { +func hydrateVoices(ctx context.Context, client tts.Provider, voices []tts.Voice, cache *voiceCache, ttl time.Duration) ([]tts.Voice, int) { if ttl <= 0 { ttl = voiceCacheTTL } @@ -88,7 +88,7 @@ func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elev cache = newVoiceCache() } - results := make([]elevenlabs.Voice, len(voices)) + results := make([]tts.Voice, len(voices)) now := time.Now() var metaCount int @@ -109,7 +109,7 @@ func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elev wg.Add(1) sem <- struct{}{} - go func(index int, voice elevenlabs.Voice) { + go func(index int, voice tts.Voice) { defer wg.Done() defer func() { <-sem }() @@ -135,7 +135,7 @@ func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elev return results, metaCount } -func mergeVoice(base elevenlabs.Voice, details elevenlabs.Voice) elevenlabs.Voice { +func mergeVoice(base tts.Voice, details tts.Voice) tts.Voice { merged := base if details.VoiceID != "" { merged.VoiceID = details.VoiceID diff --git a/cmd/voices_query.go b/cmd/voices_query.go index 7ec576b..1794067 100644 --- a/cmd/voices_query.go +++ b/cmd/voices_query.go @@ -6,7 +6,7 @@ import ( "strings" "unicode" - "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/tts" ) type labelFilter struct { @@ -38,11 +38,11 @@ func parseLabelFilters(filters []string) ([]labelFilter, error) { return parsed, nil } -func filterVoicesByLabels(voices []elevenlabs.Voice, filters []labelFilter) []elevenlabs.Voice { +func filterVoicesByLabels(voices []tts.Voice, filters []labelFilter) []tts.Voice { if len(filters) == 0 { return voices } - filtered := make([]elevenlabs.Voice, 0, len(voices)) + filtered := make([]tts.Voice, 0, len(voices)) for _, v := range voices { if matchesAllLabels(v, filters) { filtered = append(filtered, v) @@ -51,7 +51,7 @@ func filterVoicesByLabels(voices []elevenlabs.Voice, filters []labelFilter) []el return filtered } -func matchesAllLabels(voice elevenlabs.Voice, filters []labelFilter) bool { +func matchesAllLabels(voice tts.Voice, filters []labelFilter) bool { if len(filters) == 0 { return true } @@ -85,7 +85,7 @@ func labelValue(labels map[string]string, key string) (string, bool) { return "", false } -func rankVoicesByQuery(voices []elevenlabs.Voice, query string) []elevenlabs.Voice { +func rankVoicesByQuery(voices []tts.Voice, query string) []tts.Voice { query = strings.TrimSpace(query) if query == "" { return voices @@ -104,7 +104,7 @@ func rankVoicesByQuery(voices []elevenlabs.Voice, query string) []elevenlabs.Voi } return scored[i].score > scored[j].score }) - ranked := make([]elevenlabs.Voice, 0, len(scored)) + ranked := make([]tts.Voice, 0, len(scored)) for _, s := range scored { ranked = append(ranked, s.voice) } @@ -112,11 +112,11 @@ func rankVoicesByQuery(voices []elevenlabs.Voice, query string) []elevenlabs.Voi } type scoredVoice struct { - voice elevenlabs.Voice + voice tts.Voice score int } -func scoreVoice(voice elevenlabs.Voice, query string, tokens []string) int { +func scoreVoice(voice tts.Voice, query string, tokens []string) int { name := strings.ToLower(voice.Name) desc := strings.ToLower(voice.Description) labels := strings.ToLower(flattenLabels(voice.Labels)) diff --git a/docs/providers.md b/docs/providers.md new file mode 100644 index 0000000..2ddae7a --- /dev/null +++ b/docs/providers.md @@ -0,0 +1,67 @@ +# Providers + +`sag` speaks to two text-to-speech backends behind one consistent CLI. A small +provider abstraction (`internal/tts.Provider`) lets the command layer and the +audio player stay backend-agnostic; each provider translates the shared request +to and from its own wire format. + +## Selecting a provider + +The provider is auto-detected from whichever API key is present — there is no +`--provider` flag. + +| Keys set | Active provider | +|---|---| +| `ELEVENLABS_API_KEY` (or `--api-key`/file) only | ElevenLabs | +| `SIXTYDB_API_KEY` (or `SIXTYDB_API_KEY_FILE`) only | 60db | +| both | ElevenLabs (note printed; unset `ELEVENLABS_API_KEY` to use 60db) | +| neither | error | + +Override the host for the active provider with `--base-url`. + +## What each provider implements + +| Capability | ElevenLabs | 60db | +|---|---|---| +| Auth | `xi-api-key: ` | `Authorization: Bearer ` | +| Default host | `https://api.elevenlabs.io` | `https://api.60db.ai` | +| List voices | `GET /v1/voices` | `GET /myvoices` | +| Full synthesis | `POST /v1/text-to-speech/{id}` (raw audio) | `POST /tts-synthesize` (base64 in JSON) | +| Streaming | `POST /v1/text-to-speech/{id}/stream` (raw audio) | `POST /tts-stream` (NDJSON, base64 chunks) | + +For 60db, `sag` decodes the base64/NDJSON envelope internally, so streaming and +file output behave the same as ElevenLabs. 60db's WebSocket API is not used. + +## Flag behavior + +Flags are written in ElevenLabs terms and translated per provider so the same +command works on both. + +| Flag | ElevenLabs | 60db | +|---|---|---| +| `--speed` / `--rate` | speed multiplier `0.5–2.0` | passthrough (same range) | +| `--stability` | `0..1` | scaled to `0..100` | +| `--similarity` / `--similarity-boost` | `0..1` | scaled to `0..100` | +| `--format` / `-o` extension | `mp3_44100_128`, `pcm_44100`, … | mapped to `mp3` / `wav` / `ogg` / `flac` | +| `--model-id` | model id (e.g. `eleven_v3`) | ignored (model is tied to the voice) | +| `--style` | style exaggeration | ignored | +| `--speaker-boost` / `--no-speaker-boost` | speaker boost | ignored | +| `--seed` | best-effort determinism | ignored | +| `--normalize` | text normalization | ignored | +| `--lang` | language code | ignored | +| `--latency-tier` | streaming latency tier | ignored | + +When you set an ignored flag while 60db is active, `sag` prints a one-line note +to stderr rather than failing. + +## Notes and limits + +- **Playback format:** speaker playback decodes MP3, so when playing through + speakers `sag` requests MP3 from 60db regardless of `--format`. Use + `--no-play -o file.` to save other formats. +- **Voice previews:** `sag voices --try` is not available for 60db (its + `/myvoices` response has no preview URL); the affected voice is skipped with a + message. +- **Voice metadata:** 60db's per-voice `model` (`60db Fast` / `60db Quality`) is + exposed as a `model` label, so `--label model=...` and `--query` can match it. +- **Text limit:** 60db caps synthesis at 5,000 characters per request. diff --git a/internal/elevenlabs/client.go b/internal/elevenlabs/client.go index 0546ca2..e584d56 100644 --- a/internal/elevenlabs/client.go +++ b/internal/elevenlabs/client.go @@ -10,6 +10,8 @@ import ( "net/url" "path" "strings" + + "github.com/steipete/sag/internal/tts" ) // Client talks to the ElevenLabs HTTP API. @@ -19,6 +21,9 @@ type Client struct { httpClient *http.Client } +// Ensure the ElevenLabs client satisfies the shared provider contract. +var _ tts.Provider = (*Client)(nil) + // NewClient returns a Client configured with the given API key and base URL. func NewClient(apiKey, baseURL string) *Client { if baseURL == "" { @@ -31,15 +36,17 @@ func NewClient(apiKey, baseURL string) *Client { } } -// Voice represents a voice entry returned by ElevenLabs. -type Voice struct { - VoiceID string `json:"voice_id"` - Name string `json:"name"` - Category string `json:"category"` - Description string `json:"description"` - Labels map[string]string `json:"labels,omitempty"` - PreviewURL string `json:"preview_url"` -} +// Voice, VoiceSettings, and TTSRequest are re-exported from the shared tts +// package so existing callers (and tests) can keep using elevenlabs.Voice etc. +// while every provider speaks the same types. +type ( + // Voice represents a voice entry returned by ElevenLabs. + Voice = tts.Voice + // VoiceSettings tunes synthesis parameters for a request. + VoiceSettings = tts.VoiceSettings + // TTSRequest configures a text-to-speech request payload. + TTSRequest = tts.TTSRequest +) type listVoicesResponse struct { Voices []Voice `json:"voices"` @@ -183,26 +190,6 @@ func (c *Client) GetVoice(ctx context.Context, voiceID string) (Voice, error) { return voice, nil } -// TTSRequest configures a text-to-speech request payload. -type TTSRequest struct { - Text string `json:"text"` - ModelID string `json:"model_id,omitempty"` - VoiceSettings *VoiceSettings `json:"voice_settings,omitempty"` - OutputFormat string `json:"output_format,omitempty"` - Seed *uint32 `json:"seed,omitempty"` - ApplyTextNormalization string `json:"apply_text_normalization,omitempty"` - LanguageCode string `json:"language_code,omitempty"` -} - -// VoiceSettings tunes synthesis parameters for a request. -type VoiceSettings struct { - Stability *float64 `json:"stability,omitempty"` - SimilarityBoost *float64 `json:"similarity_boost,omitempty"` - Style *float64 `json:"style,omitempty"` - UseSpeakerBoost *bool `json:"use_speaker_boost,omitempty"` - Speed *float64 `json:"speed,omitempty"` -} - // StreamTTS requests streaming audio for text-to-speech. func (c *Client) StreamTTS(ctx context.Context, voiceID string, payload TTSRequest, latency int) (io.ReadCloser, error) { u, err := url.Parse(c.baseURL) diff --git a/internal/sixtydb/client.go b/internal/sixtydb/client.go new file mode 100644 index 0000000..74b1a80 --- /dev/null +++ b/internal/sixtydb/client.go @@ -0,0 +1,397 @@ +// Package sixtydb provides a small client for the 60db (api.60db.ai) TTS API. +// +// Unlike ElevenLabs, 60db never returns raw audio: the synthesize endpoint +// wraps it as base64 inside a JSON object, and the stream endpoint emits +// newline-delimited JSON frames whose audio is base64 too. This client decodes +// both so StreamTTS/ConvertTTS hand back plain audio bytes, exactly like the +// ElevenLabs client — keeping the command and audio layers provider-agnostic. +package sixtydb + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strings" + + "github.com/steipete/sag/internal/tts" +) + +// DefaultBaseURL is the public 60db API host. +const DefaultBaseURL = "https://api.60db.ai" + +// Client talks to the 60db HTTP API. +type Client struct { + baseURL string + apiKey string + httpClient *http.Client +} + +// Ensure the 60db client satisfies the shared provider contract. +var _ tts.Provider = (*Client)(nil) + +// NewClient returns a Client configured with the given API key and base URL. +func NewClient(apiKey, baseURL string) *Client { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return &Client{ + baseURL: baseURL, + apiKey: apiKey, + httpClient: &http.Client{}, + } +} + +func (c *Client) newRequest(ctx context.Context, method, endpoint string, body io.Reader) (*http.Request, error) { + u, err := url.Parse(c.baseURL) + if err != nil { + return nil, err + } + u.Path = path.Join(u.Path, endpoint) + + req, err := http.NewRequestWithContext(ctx, method, u.String(), body) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + return req, nil +} + +// --- Voices --------------------------------------------------------------- + +type voiceEntry struct { + VoiceID string `json:"voice_id"` + Name string `json:"name"` + Category string `json:"category"` + Model string `json:"model"` + Labels map[string]string `json:"labels"` + Description *string `json:"description"` +} + +type myVoicesResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Data []voiceEntry `json:"data"` +} + +func (e voiceEntry) toVoice() tts.Voice { + labels := make(map[string]string, len(e.Labels)+1) + for k, v := range e.Labels { + labels[k] = v + } + // Fold the model name into labels so `--label model=...` and query ranking + // can match on it, mirroring how ElevenLabs exposes metadata via labels. + if e.Model != "" { + if _, ok := labels["model"]; !ok { + labels["model"] = e.Model + } + } + desc := "" + if e.Description != nil { + desc = *e.Description + } + return tts.Voice{ + VoiceID: e.VoiceID, + Name: e.Name, + Category: e.Category, + Description: desc, + Labels: labels, + // 60db /myvoices exposes no preview URL, so previews are unavailable. + PreviewURL: "", + } +} + +// ListVoices fetches the caller's available 60db voices. +func (c *Client) ListVoices(ctx context.Context) ([]tts.Voice, error) { + req, err := c.newRequest(ctx, http.MethodGet, "/myvoices", nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("list voices failed: %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + var body myVoicesResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, err + } + voices := make([]tts.Voice, 0, len(body.Data)) + for _, e := range body.Data { + voices = append(voices, e.toVoice()) + } + return voices, nil +} + +// SearchVoices filters the voice list by a case-insensitive name substring. +// 60db has no server-side search, so this is done client-side. +func (c *Client) SearchVoices(ctx context.Context, search string, limit int) ([]tts.Voice, error) { + voices, err := c.ListVoices(ctx) + if err != nil { + return nil, err + } + search = strings.TrimSpace(search) + if search != "" { + searchLower := strings.ToLower(search) + filtered := make([]tts.Voice, 0, len(voices)) + for _, v := range voices { + if strings.Contains(strings.ToLower(v.Name), searchLower) { + filtered = append(filtered, v) + } + } + voices = filtered + } + if limit > 0 && len(voices) > limit { + voices = voices[:limit] + } + return voices, nil +} + +// GetVoice returns metadata for a specific voice. 60db has no single-voice +// endpoint, so it resolves against the full list. +func (c *Client) GetVoice(ctx context.Context, voiceID string) (tts.Voice, error) { + voices, err := c.ListVoices(ctx) + if err != nil { + return tts.Voice{}, err + } + for _, v := range voices { + if v.VoiceID == voiceID { + return v, nil + } + } + return tts.Voice{}, fmt.Errorf("voice %q not found", voiceID) +} + +// --- Text to speech ------------------------------------------------------- + +type synthesizeRequest struct { + Text string `json:"text"` + VoiceID string `json:"voice_id,omitempty"` + Speed *float64 `json:"speed,omitempty"` + Stability *float64 `json:"stability,omitempty"` + Similarity *float64 `json:"similarity,omitempty"` + OutputFormat string `json:"output_format,omitempty"` +} + +// buildBody translates a provider-neutral request into 60db's payload. +// includeFormat is false for the streaming endpoint, whose spec omits +// output_format. Fields with no 60db equivalent (model, style, speaker boost, +// seed, normalization, language) are intentionally dropped. +func buildBody(voiceID string, req tts.TTSRequest, includeFormat bool) synthesizeRequest { + body := synthesizeRequest{ + Text: req.Text, + VoiceID: voiceID, + } + if vs := req.VoiceSettings; vs != nil { + if vs.Speed != nil { + body.Speed = vs.Speed // both APIs use 0.5..2.0 + } + if vs.Stability != nil { + body.Stability = scaleToHundred(*vs.Stability) + } + if vs.SimilarityBoost != nil { + body.Similarity = scaleToHundred(*vs.SimilarityBoost) + } + } + if includeFormat { + body.OutputFormat = toSixtyDBFormat(req.OutputFormat) + } + return body +} + +// scaleToHundred converts a 0..1 knob to 60db's 0..100 scale. +func scaleToHundred(v float64) *float64 { + scaled := v * 100 + return &scaled +} + +// toSixtyDBFormat maps ElevenLabs-style format strings to 60db's simple set +// (mp3|wav|ogg|flac). Empty input yields empty (provider default). +func toSixtyDBFormat(format string) string { + format = strings.ToLower(strings.TrimSpace(format)) + switch { + case format == "": + return "" + case strings.HasPrefix(format, "mp3"): + return "mp3" + case strings.HasPrefix(format, "pcm"), strings.HasPrefix(format, "wav"): + return "wav" + case strings.HasPrefix(format, "opus"), strings.HasPrefix(format, "ogg"): + return "ogg" + case strings.HasPrefix(format, "flac"): + return "flac" + default: + return format + } +} + +type synthesizeResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + AudioBase64 string `json:"audio_base64"` +} + +// ConvertTTS downloads the full audio and returns decoded bytes. +func (c *Client) ConvertTTS(ctx context.Context, voiceID string, payload tts.TTSRequest) ([]byte, error) { + bodyBytes, err := json.Marshal(buildBody(voiceID, payload, true)) + if err != nil { + return nil, err + } + req, err := c.newRequest(ctx, http.MethodPost, "/tts-synthesize", bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("convert TTS failed: %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + var body synthesizeResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, err + } + if !body.Success && body.Message != "" { + return nil, fmt.Errorf("convert TTS failed: %s", body.Message) + } + data, err := base64.StdEncoding.DecodeString(body.AudioBase64) + if err != nil { + return nil, fmt.Errorf("decode audio_base64: %w", err) + } + return data, nil +} + +// StreamTTS requests streaming audio and returns a reader that yields decoded +// audio bytes. The latency argument is ignored (60db's stream has no tier). +func (c *Client) StreamTTS(ctx context.Context, voiceID string, payload tts.TTSRequest, _ int) (io.ReadCloser, error) { + bodyBytes, err := json.Marshal(buildBody(voiceID, payload, false)) + if err != nil { + return nil, err + } + req, err := c.newRequest(ctx, http.MethodPost, "/tts-stream", bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/x-ndjson") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + defer func() { _ = resp.Body.Close() }() + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("stream TTS failed: %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + return newNDJSONAudioReader(resp.Body), nil +} + +// --- NDJSON streaming decoder --------------------------------------------- + +type streamFrame struct { + Type string `json:"type"` + Result struct { + AudioContent string `json:"audioContent"` + } `json:"result"` + Message string `json:"message"` +} + +// ndjsonAudioReader unwraps 60db's newline-delimited JSON stream into a plain +// audio byte stream. Each "chunk" frame's base64 audio is decoded and served; +// "complete" ends the stream; "error" surfaces the message. +type ndjsonAudioReader struct { + src io.ReadCloser + reader *bufio.Reader + pending []byte + err error +} + +func newNDJSONAudioReader(src io.ReadCloser) *ndjsonAudioReader { + return &ndjsonAudioReader{ + src: src, + reader: bufio.NewReader(src), + } +} + +func (r *ndjsonAudioReader) Read(p []byte) (int, error) { + for len(r.pending) == 0 { + if r.err != nil { + return 0, r.err + } + if err := r.fill(); err != nil { + r.err = err + if len(r.pending) == 0 { + return 0, err + } + } + } + n := copy(p, r.pending) + r.pending = r.pending[n:] + return n, nil +} + +// fill reads and decodes the next non-empty frame into r.pending. +func (r *ndjsonAudioReader) fill() error { + for { + line, err := r.reader.ReadBytes('\n') + trimmed := bytes.TrimSpace(line) + if len(trimmed) > 0 { + var frame streamFrame + if jerr := json.Unmarshal(trimmed, &frame); jerr != nil { + return fmt.Errorf("decode stream frame: %w", jerr) + } + switch frame.Type { + case "chunk": + audio, derr := base64.StdEncoding.DecodeString(frame.Result.AudioContent) + if derr != nil { + return fmt.Errorf("decode audio chunk: %w", derr) + } + if len(audio) > 0 { + r.pending = audio + return nil + } + case "complete": + return io.EOF + case "error": + if frame.Message != "" { + return fmt.Errorf("stream error: %s", frame.Message) + } + return fmt.Errorf("stream error") + } + // Unknown frame types are ignored; keep reading. + } + if err != nil { + if err == io.EOF { + return io.EOF + } + return err + } + } +} + +func (r *ndjsonAudioReader) Close() error { + return r.src.Close() +} diff --git a/internal/sixtydb/client_test.go b/internal/sixtydb/client_test.go new file mode 100644 index 0000000..cebcd8a --- /dev/null +++ b/internal/sixtydb/client_test.go @@ -0,0 +1,232 @@ +package sixtydb + +import ( + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/steipete/sag/internal/tts" +) + +func TestNewClientDefaultsBase(t *testing.T) { + c := NewClient("key", "") + if c.baseURL != DefaultBaseURL { + t.Fatalf("unexpected baseURL: %s", c.baseURL) + } +} + +func TestListVoicesUnwrapsData(t *testing.T) { + desc := "warm narrator" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/myvoices" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer key" { + t.Fatalf("unexpected auth header: %q", got) + } + _, _ = w.Write([]byte(`{"success":true,"message":"ok","data":[ + {"voice_id":"v1","name":"Aria","category":"professional","model":"60db Quality","labels":{"gender":"female","accent":"American"},"description":"` + desc + `"}, + {"voice_id":"v2","name":"Ravi","category":"cloned","model":"60db Fast","labels":{"gender":"male"},"description":null} + ]}`)) + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + voices, err := c.ListVoices(context.Background()) + if err != nil { + t.Fatalf("ListVoices error: %v", err) + } + if len(voices) != 2 { + t.Fatalf("expected 2 voices, got %d", len(voices)) + } + if voices[0].VoiceID != "v1" || voices[0].Name != "Aria" || voices[0].Category != "professional" { + t.Fatalf("unexpected voice[0]: %+v", voices[0]) + } + if voices[0].Description != desc { + t.Fatalf("expected description %q, got %q", desc, voices[0].Description) + } + if voices[0].Labels["model"] != "60db Quality" { + t.Fatalf("expected model folded into labels, got %+v", voices[0].Labels) + } + if voices[1].Description != "" { + t.Fatalf("expected empty description for null, got %q", voices[1].Description) + } +} + +func TestSearchVoicesFiltersAndLimits(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{"success":true,"data":[ + {"voice_id":"v1","name":"Roger"}, + {"voice_id":"v2","name":"Rogue"}, + {"voice_id":"v3","name":"Sarah"} + ]}`)) + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + voices, err := c.SearchVoices(context.Background(), "rog", 1) + if err != nil { + t.Fatalf("SearchVoices error: %v", err) + } + if len(voices) != 1 { + t.Fatalf("expected 1 voice after limit, got %d", len(voices)) + } + if voices[0].VoiceID != "v1" { + t.Fatalf("expected v1, got %s", voices[0].VoiceID) + } +} + +func TestConvertTTSDecodesBase64AndTranslatesParams(t *testing.T) { + want := []byte("decoded-audio-bytes") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tts-synthesize" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer key" { + t.Fatalf("unexpected auth header: %q", got) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body["text"] != "hi" { + t.Fatalf("expected text hi, got %v", body["text"]) + } + if body["voice_id"] != "v1" { + t.Fatalf("expected voice_id v1, got %v", body["voice_id"]) + } + // 0.5 stability -> 50, 0.8 similarity -> 80 (0..1 -> 0..100) + if body["stability"] != float64(50) { + t.Fatalf("expected stability 50, got %v", body["stability"]) + } + if body["similarity"] != float64(80) { + t.Fatalf("expected similarity 80, got %v", body["similarity"]) + } + if body["speed"] != 1.1 { + t.Fatalf("expected speed 1.1, got %v", body["speed"]) + } + // mp3_44100_128 -> mp3 + if body["output_format"] != "mp3" { + t.Fatalf("expected output_format mp3, got %v", body["output_format"]) + } + // ElevenLabs-only fields must not appear. + for _, k := range []string{"model_id", "style", "use_speaker_boost", "seed", "language_code"} { + if _, ok := body[k]; ok { + t.Fatalf("expected %q to be absent from 60db body", k) + } + } + resp := map[string]any{"success": true, "audio_base64": base64.StdEncoding.EncodeToString(want)} + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + stability := 0.5 + similarity := 0.8 + speed := 1.1 + c := NewClient("key", srv.URL) + data, err := c.ConvertTTS(context.Background(), "v1", tts.TTSRequest{ + Text: "hi", + ModelID: "eleven_v3", + OutputFormat: "mp3_44100_128", + VoiceSettings: &tts.VoiceSettings{ + Stability: &stability, + SimilarityBoost: &similarity, + Speed: &speed, + }, + }) + if err != nil { + t.Fatalf("ConvertTTS error: %v", err) + } + if string(data) != string(want) { + t.Fatalf("unexpected decoded audio: %q", string(data)) + } +} + +func TestStreamTTSDecodesNDJSON(t *testing.T) { + chunk1 := base64.StdEncoding.EncodeToString([]byte("hello-")) + chunk2 := base64.StdEncoding.EncodeToString([]byte("world")) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tts-stream" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + // streaming body omits output_format + if _, ok := body["output_format"]; ok { + t.Fatalf("expected output_format omitted from stream body") + } + _, _ = io.WriteString(w, `{"type":"chunk","result":{"audioContent":"`+chunk1+`"}}`+"\n") + _, _ = io.WriteString(w, `{"type":"chunk","result":{"audioContent":"`+chunk2+`"}}`+"\n") + _, _ = io.WriteString(w, `{"type":"complete"}`+"\n") + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + rc, err := c.StreamTTS(context.Background(), "v1", tts.TTSRequest{Text: "hi"}, 0) + if err != nil { + t.Fatalf("StreamTTS error: %v", err) + } + defer func() { _ = rc.Close() }() + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("read stream: %v", err) + } + if string(got) != "hello-world" { + t.Fatalf("unexpected decoded stream: %q", string(got)) + } +} + +func TestStreamTTSSurfacesErrorFrame(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"type":"error","message":"voice not found"}`+"\n") + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + rc, err := c.StreamTTS(context.Background(), "v1", tts.TTSRequest{Text: "hi"}, 0) + if err != nil { + t.Fatalf("StreamTTS error: %v", err) + } + defer func() { _ = rc.Close() }() + _, err = io.ReadAll(rc) + if err == nil || !strings.Contains(err.Error(), "voice not found") { + t.Fatalf("expected stream error surfaced, got %v", err) + } +} + +func TestStreamTTSHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "bad", http.StatusBadRequest) + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + _, err := c.StreamTTS(context.Background(), "v1", tts.TTSRequest{Text: "hi"}, 0) + if err == nil || !strings.Contains(err.Error(), "400") { + t.Fatalf("expected 400 error, got %v", err) + } +} + +func TestToSixtyDBFormat(t *testing.T) { + cases := map[string]string{ + "mp3_44100_128": "mp3", + "pcm_44100": "wav", + "wav": "wav", + "opus_48000_64": "ogg", + "ogg": "ogg", + "flac": "flac", + "": "", + } + for in, want := range cases { + if got := toSixtyDBFormat(in); got != want { + t.Fatalf("toSixtyDBFormat(%q) = %q, want %q", in, got, want) + } + } +} diff --git a/internal/tts/tts.go b/internal/tts/tts.go new file mode 100644 index 0000000..a5e4cb4 --- /dev/null +++ b/internal/tts/tts.go @@ -0,0 +1,58 @@ +// Package tts defines provider-neutral types and the Provider interface shared +// by every text-to-speech backend (ElevenLabs, 60db, ...). +// +// Keeping the types here lets each provider implementation translate its own +// wire format to and from a single shared shape, so the command layer and the +// audio player never need to know which backend produced the audio. +package tts + +import ( + "context" + "io" +) + +// Voice represents a single voice entry, normalized across providers. +type Voice struct { + VoiceID string `json:"voice_id"` + Name string `json:"name"` + Category string `json:"category"` + Description string `json:"description"` + Labels map[string]string `json:"labels,omitempty"` + PreviewURL string `json:"preview_url"` +} + +// VoiceSettings tunes synthesis parameters for a request. All fields are +// pointers so unset knobs are omitted from the wire payload and the provider's +// own defaults apply. Stability/SimilarityBoost/Style use the 0..1 scale; each +// provider translates to its native range. +type VoiceSettings struct { + Stability *float64 `json:"stability,omitempty"` + SimilarityBoost *float64 `json:"similarity_boost,omitempty"` + Style *float64 `json:"style,omitempty"` + UseSpeakerBoost *bool `json:"use_speaker_boost,omitempty"` + Speed *float64 `json:"speed,omitempty"` +} + +// TTSRequest configures a text-to-speech request. Some fields are honored only +// by certain providers (e.g. ModelID/Seed/LanguageCode are ElevenLabs-specific +// and ignored by 60db); the provider implementation decides what to send. +type TTSRequest struct { + Text string `json:"text"` + ModelID string `json:"model_id,omitempty"` + VoiceSettings *VoiceSettings `json:"voice_settings,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Seed *uint32 `json:"seed,omitempty"` + ApplyTextNormalization string `json:"apply_text_normalization,omitempty"` + LanguageCode string `json:"language_code,omitempty"` +} + +// Provider is the contract every TTS backend implements. StreamTTS and +// ConvertTTS must return raw, ready-to-play audio bytes (decoded/unwrapped from +// any provider-specific envelope) so the audio layer stays provider-agnostic. +type Provider interface { + ListVoices(ctx context.Context) ([]Voice, error) + SearchVoices(ctx context.Context, search string, limit int) ([]Voice, error) + GetVoice(ctx context.Context, voiceID string) (Voice, error) + StreamTTS(ctx context.Context, voiceID string, req TTSRequest, latency int) (io.ReadCloser, error) + ConvertTTS(ctx context.Context, voiceID string, req TTSRequest) ([]byte, error) +} From d64cbe81ed3c0258e3e4872c3db389bf925019f1 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Thu, 11 Jun 2026 00:24:10 +0100 Subject: [PATCH 2/2] feat: tighten 60db integration Co-authored-by: manishEMS47 --- CHANGELOG.md | 2 + README.md | 25 +- cmd/prompting_test.go | 2 + cmd/provider.go | 85 +++-- cmd/provider_test.go | 26 +- cmd/root.go | 4 +- cmd/speak.go | 185 +++++++--- cmd/speak_integration_test.go | 63 ++++ cmd/speak_request_test.go | 175 +++++++--- cmd/speak_test.go | 28 +- cmd/test_helpers_test.go | 33 ++ cmd/voices.go | 5 +- cmd/voices_cache.go | 9 +- cmd/voices_cache_test.go | 36 ++ cmd/voices_test.go | 4 + docs/configuration.md | 32 +- docs/index.md | 9 +- docs/providers.md | 113 +++--- docs/spec.md | 22 +- internal/elevenlabs/client.go | 32 +- internal/sixtydb/client.go | 589 +++++++++++++++++++++++--------- internal/sixtydb/client_test.go | 419 +++++++++++++++++------ internal/tts/tts.go | 45 +-- 23 files changed, 1378 insertions(+), 565 deletions(-) create mode 100644 cmd/test_helpers_test.go create mode 100644 cmd/voices_cache_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index ee85293..92f377f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Changelog ## Unreleased +### Added +- Documented 60db provider support via `/default-voices`, `/myvoices`, `/tts-stream`, and `/tts-synthesize`, with strict provider selection and response validation. (#20, thanks @manishEMS47) ### Changed - Release archives now include target-specific macOS and Linux assets for Homebrew and aqua installers. diff --git a/README.md b/README.md index 04a9061..b0b79b6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# sag 🗣️ — “Mac-style speech with ElevenLabs” +# sag 🗣️ — “Mac-style speech with ElevenLabs or 60db” One-liner TTS that works like `say`: stream to speakers by default, list voices, or save audio files. @@ -25,23 +25,27 @@ sudo apt install build-essential pkg-config libasound2-dev ## Configuration -`sag` supports two TTS providers and auto-selects one from whichever API key is set: +`sag` supports two TTS providers and auto-selects one from your configured credentials: - **ElevenLabs** — `ELEVENLABS_API_KEY` (or `--api-key`, or `--api-key-file` / `ELEVENLABS_API_KEY_FILE` / `SAG_API_KEY_FILE`) - **60db** (`api.60db.ai`) — `SIXTYDB_API_KEY` (or `SIXTYDB_API_KEY_FILE`) Selection rules: - Only one key set → that provider is used. -- Both keys set → ElevenLabs is used (unset `ELEVENLABS_API_KEY` to use 60db); a note is printed. +- Both keys set → error; unset one provider key and retry. - Neither set → error. -Optional defaults: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. Override a provider's host with `--base-url`. +Optional ElevenLabs defaults: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. Override the active provider host with `--base-url`. -The same flags work for both providers; `sag` translates them to each API. A few flags are -ElevenLabs-only and are accepted-but-ignored on 60db (a note is printed): `--model-id`, -`--style`, `--speaker-boost`/`--no-speaker-boost`, `--seed`, `--normalize`, `--lang`, -`--latency-tier`. The `--stability`/`--similarity` `0..1` values are scaled to 60db's `0..100` -range automatically. See [docs/providers.md](docs/providers.md) for details. +60db support is intentionally narrow and follows the documented HTTP API: +- voice discovery merges `GET /default-voices` and `GET /myvoices` +- default MP3 streaming uses `POST /tts-stream` +- explicit file formats and full downloads use `POST /tts-synthesize` +- the `success` envelope is validated even on HTTP 200 responses + +Shared speak flags that work on both providers include `--voice`, `--speed` / `--rate`, `--stability`, `--similarity`, `--format`, `--stream`, `--play`, `--output`, `--timeout`, and `--metrics`. + +ElevenLabs-only speak flags fail fast on 60db: `--model-id`, `--style`, `--speaker-boost` / `--no-speaker-boost`, `--seed`, `--normalize`, `--lang`, and `--latency-tier`. `--stability` / `--similarity` still use the CLI's `0..1` range and are scaled to 60db's documented `0..100` API values. See [docs/providers.md](docs/providers.md) for provider-specific details. ## Usage @@ -76,6 +80,7 @@ sag speak -v Roger --stream --latency-tier 3 "Faster start" sag speak -v Roger --speed 1.2 "Talk a bit faster" sag speak -v Roger --model-id eleven_multilingual_v2 "Use stable v2 baseline" sag speak -v Roger --output out.wav --format pcm_44100 "Wave output" +SIXTYDB_API_KEY=... sag speak -v Aria --output out.wav --no-play "60db WAV output" ``` Key flags (subset): @@ -164,6 +169,6 @@ ffprobe -v quiet -show_entries format=duration -of csv=p=0 long.mp3 - Build: `go build ./cmd/sag` ## Limitations -- ElevenLabs account and API key required. +- One provider API key is required. - Voice defaults to first available if not provided. - Non-mac platforms: playback still works via `go-mp3` + `oto`, but device selection flags are no-ops. diff --git a/cmd/prompting_test.go b/cmd/prompting_test.go index 756f091..5ac5012 100644 --- a/cmd/prompting_test.go +++ b/cmd/prompting_test.go @@ -6,6 +6,8 @@ import ( ) func TestPromptingCommandOutputsGuide(t *testing.T) { + resetRootCommandState() + restore, read := captureStdout(t) defer restore() diff --git a/cmd/provider.go b/cmd/provider.go index 4ddab6a..4fcf500 100644 --- a/cmd/provider.go +++ b/cmd/provider.go @@ -15,7 +15,15 @@ const ( providerSixtyDB = "60db" ) -// resolveSixtyDBKey resolves the 60db API key from its env vars. +type activeProvider struct { + name string + voices tts.VoiceCatalog + + elevenlabs *elevenlabs.Client + sixtydb *sixtydb.Client +} + +// resolveSixtyDBKey resolves the 60db API key from its dedicated env vars. // Order: SIXTYDB_API_KEY, then SIXTYDB_API_KEY_FILE. func resolveSixtyDBKey() (string, error) { if key := strings.TrimSpace(os.Getenv("SIXTYDB_API_KEY")); key != "" { @@ -35,69 +43,60 @@ func resolveSixtyDBKey() (string, error) { return "", nil } -// ensureProviderConfigured verifies that at least one provider's key is set. -// Used by PreRunE so 60db-only users aren't rejected for lacking an -// ElevenLabs key. func ensureProviderConfigured() error { - elKey, err := resolveElevenLabsKey() - if err != nil { - return err - } - sdKey, err := resolveSixtyDBKey() - if err != nil { - return err - } - if elKey == "" && sdKey == "" { - return fmt.Errorf("missing API key (set ELEVENLABS_API_KEY or SIXTYDB_API_KEY; or --api-key / --api-key-file)") - } - return nil + _, err := selectProvider() + return err } -// selectProvider auto-detects the active provider from whichever API key is -// present. If both are set, ElevenLabs wins (preserving prior default) and a -// note is printed. The chosen client is built with cfg.BaseURL, which each -// client treats as a per-provider override (empty => provider default host). -func selectProvider() (tts.Provider, string, error) { +func selectProvider() (activeProvider, error) { elKey, err := resolveElevenLabsKey() if err != nil { - return nil, "", err + return activeProvider{}, err } sdKey, err := resolveSixtyDBKey() if err != nil { - return nil, "", err + return activeProvider{}, err } switch { case elKey != "" && sdKey != "": - fmt.Fprintln(os.Stderr, "note: both ElevenLabs and 60db API keys set; using ElevenLabs (unset ELEVENLABS_API_KEY to use 60db)") - return elevenlabs.NewClient(elKey, cfg.BaseURL), providerElevenLabs, nil + return activeProvider{}, fmt.Errorf("ambiguous provider configuration: both ElevenLabs and 60db keys are set; unset one provider key and retry") case elKey != "": - return elevenlabs.NewClient(elKey, cfg.BaseURL), providerElevenLabs, nil + client := elevenlabs.NewClient(elKey, cfg.BaseURL) + return activeProvider{ + name: providerElevenLabs, + voices: client, + elevenlabs: client, + }, nil case sdKey != "": - return sixtydb.NewClient(sdKey, cfg.BaseURL), providerSixtyDB, nil + client := sixtydb.NewClient(sdKey, cfg.BaseURL) + return activeProvider{ + name: providerSixtyDB, + voices: client, + sixtydb: client, + }, nil default: - return nil, "", fmt.Errorf("missing API key (set ELEVENLABS_API_KEY or SIXTYDB_API_KEY; or --api-key / --api-key-file)") + return activeProvider{}, fmt.Errorf("missing API key (set ELEVENLABS_API_KEY or SIXTYDB_API_KEY)") } } -// sixtyDBOnlyFlags lists flags that ElevenLabs honors but 60db has no -// equivalent for. When the active provider is 60db and the user set one, we -// note that it is ignored rather than failing. -var sixtyDBIgnoredFlags = []string{ - "model-id", "style", "speaker-boost", "no-speaker-boost", - "seed", "normalize", "lang", "latency-tier", +var sixtyDBUnsupportedFlags = []string{ + "model-id", + "style", + "speaker-boost", + "no-speaker-boost", + "seed", + "normalize", + "lang", + "latency-tier", } -// noteUnsupportedSixtyDBFlags prints a single stderr note if the user set any -// flag that 60db ignores. -func noteUnsupportedSixtyDBFlags(changed func(string) bool) { - var ignored []string - for _, name := range sixtyDBIgnoredFlags { +func changedSixtyDBUnsupportedFlags(changed func(string) bool) []string { + var unsupported []string + for _, name := range sixtyDBUnsupportedFlags { if changed(name) { - ignored = append(ignored, "--"+name) + unsupported = append(unsupported, "--"+name) } } - if len(ignored) > 0 { - fmt.Fprintf(os.Stderr, "note: 60db ignores %s\n", strings.Join(ignored, ", ")) - } + return unsupported } diff --git a/cmd/provider_test.go b/cmd/provider_test.go index 5b914aa..5dd504a 100644 --- a/cmd/provider_test.go +++ b/cmd/provider_test.go @@ -1,6 +1,7 @@ package cmd import ( + "strings" "testing" ) @@ -25,12 +26,12 @@ func TestSelectProvider_ElevenLabsOnly(t *testing.T) { resetProviderEnv(t) t.Setenv("ELEVENLABS_API_KEY", "el-key") - _, name, err := selectProvider() + provider, err := selectProvider() if err != nil { t.Fatalf("selectProvider error: %v", err) } - if name != providerElevenLabs { - t.Fatalf("expected %q, got %q", providerElevenLabs, name) + if provider.name != providerElevenLabs || provider.elevenlabs == nil || provider.voices == nil { + t.Fatalf("unexpected provider: %+v", provider) } } @@ -38,33 +39,30 @@ func TestSelectProvider_SixtyDBOnly(t *testing.T) { resetProviderEnv(t) t.Setenv("SIXTYDB_API_KEY", "sd-key") - _, name, err := selectProvider() + provider, err := selectProvider() if err != nil { t.Fatalf("selectProvider error: %v", err) } - if name != providerSixtyDB { - t.Fatalf("expected %q, got %q", providerSixtyDB, name) + if provider.name != providerSixtyDB || provider.sixtydb == nil || provider.voices == nil { + t.Fatalf("unexpected provider: %+v", provider) } } -func TestSelectProvider_BothPrefersElevenLabs(t *testing.T) { +func TestSelectProvider_BothKeysError(t *testing.T) { resetProviderEnv(t) t.Setenv("ELEVENLABS_API_KEY", "el-key") t.Setenv("SIXTYDB_API_KEY", "sd-key") - _, name, err := selectProvider() - if err != nil { - t.Fatalf("selectProvider error: %v", err) - } - if name != providerElevenLabs { - t.Fatalf("expected ElevenLabs to win tiebreak, got %q", name) + _, err := selectProvider() + if err == nil || !strings.Contains(err.Error(), "ambiguous provider configuration") { + t.Fatalf("expected ambiguity error, got %v", err) } } func TestSelectProvider_NeitherErrors(t *testing.T) { resetProviderEnv(t) - _, _, err := selectProvider() + _, err := selectProvider() if err == nil { t.Fatal("expected error when no API key is set") } diff --git a/cmd/root.go b/cmd/root.go index 28b9663..7d1868c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,8 +19,8 @@ var ( versionFlag bool rootCmd = &cobra.Command{ Use: "sag", - Short: "🗣️ ElevenLabs speech, mac-style ease", - Long: "Command-line ElevenLabs TTS with macOS playback. Call it like macOS 'say': if you skip the subcommand, text args are passed to 'speak' (e.g. `sag \"Hello\"`).\n\nTip: run `sag prompting` for model-specific prompting tips.\nModels: `eleven_v3` (default), `eleven_multilingual_v2` (stable), `eleven_flash_v2_5` (fast/cheap), `eleven_turbo_v2_5` (balanced).", + Short: "🗣️ TTS speech, mac-style ease", + Long: "Command-line TTS with macOS-style playback and voice flags. Call it like macOS 'say': if you skip the subcommand, text args are passed to 'speak' (e.g. `sag \"Hello\"`).\n\nTip: run `sag prompting` for ElevenLabs prompting tips. Provider selection is automatic: configure exactly one of ElevenLabs or 60db.", Example: " sag \"Hi Peter\"\n echo 'piped input' | sag\n sag speak -v Roger --rate 200 \"Faster speech\"\n sag prompting", Version: "0.3.0", PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { diff --git a/cmd/speak.go b/cmd/speak.go index 3ae1e52..0a68b3f 100644 --- a/cmd/speak.go +++ b/cmd/speak.go @@ -12,15 +12,13 @@ import ( "time" "github.com/steipete/sag/internal/audio" + "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/sixtydb" "github.com/steipete/sag/internal/tts" "github.com/spf13/cobra" ) -// playbackFormat is the format requested when audio must be decoded for -// speaker playback (the oto/afplay path handles MP3). -const playbackFormat = "mp3_44100_128" - type speakOptions struct { voiceID string modelID string @@ -78,9 +76,14 @@ func init() { return err } + provider, err := selectProvider() + if err != nil { + return err + } + forceVoiceID := cmd.Flags().Changed("voice-id") voiceInput := opts.voiceID - if voiceInput == "" { + if provider.name == providerElevenLabs && voiceInput == "" { if env := os.Getenv("ELEVENLABS_VOICE_ID"); env != "" { voiceInput = env forceVoiceID = true @@ -89,12 +92,8 @@ func init() { forceVoiceID = true } } - client, providerName, err := selectProvider() - if err != nil { - return err - } - voiceID, err := resolveVoice(cmd.Context(), client, voiceInput, forceVoiceID) + voiceID, err := resolveVoice(cmd.Context(), provider.voices, voiceInput, forceVoiceID) if err != nil { return err } @@ -120,12 +119,9 @@ func init() { } } - if providerName == providerSixtyDB { - noteUnsupportedSixtyDBFlags(cmd.Flags().Changed) - // Speaker playback needs MP3; 60db's stream picks its own format - // and convert honors this request. - if opts.play { - opts.outputFmt = playbackFormat + if provider.name == providerSixtyDB { + if err := prepareSixtyDBOptions(cmd, &opts); err != nil { + return err } } @@ -135,29 +131,60 @@ func init() { } defer cancel() - payload, err := buildTTSRequest(cmd, opts, text, providerName) - if err != nil { - return err + var ( + streamFunc func(context.Context) (io.ReadCloser, error) + convertFunc func(context.Context) ([]byte, error) + ) + switch provider.name { + case providerElevenLabs: + payload, err := buildElevenLabsTTSRequest(cmd, opts, text) + if err != nil { + return err + } + streamFunc = func(ctx context.Context) (io.ReadCloser, error) { + return provider.elevenlabs.StreamTTS(ctx, opts.voiceID, payload, opts.latencyTier) + } + convertFunc = func(ctx context.Context) ([]byte, error) { + return provider.elevenlabs.ConvertTTS(ctx, opts.voiceID, payload) + } + case providerSixtyDB: + payload, err := buildSixtyDBTTSRequest(cmd, opts, text) + if err != nil { + return err + } + streamFunc = func(ctx context.Context) (io.ReadCloser, error) { + return provider.sixtydb.StreamTTS(ctx, payload) + } + convertFunc = func(ctx context.Context) ([]byte, error) { + return provider.sixtydb.ConvertTTS(ctx, payload) + } + default: + return fmt.Errorf("unsupported provider %q", provider.name) } start := time.Now() var bytes int64 if opts.stream { - n, err := streamAndPlay(ctx, client, opts, payload) + n, err := streamAndPlay(ctx, opts, streamFunc) bytes = n if err != nil { return err } } else { - n, err := convertAndPlay(ctx, client, opts, payload) + n, err := convertAndPlay(ctx, opts, convertFunc) bytes = n if err != nil { return err } } if opts.metrics { - fmt.Fprintf(os.Stderr, "metrics: chars=%d bytes=%d model=%s voice=%s stream=%t latencyTier=%d dur=%s\n", - len([]rune(text)), bytes, opts.modelID, opts.voiceID, opts.stream, opts.latencyTier, time.Since(start).Truncate(time.Millisecond)) + if provider.name == providerElevenLabs { + fmt.Fprintf(os.Stderr, "metrics: chars=%d bytes=%d provider=%s model=%s voice=%s stream=%t latencyTier=%d dur=%s\n", + len([]rune(text)), bytes, provider.name, opts.modelID, opts.voiceID, opts.stream, opts.latencyTier, time.Since(start).Truncate(time.Millisecond)) + } else { + fmt.Fprintf(os.Stderr, "metrics: chars=%d bytes=%d provider=%s voice=%s stream=%t dur=%s\n", + len([]rune(text)), bytes, provider.name, opts.voiceID, opts.stream, time.Since(start).Truncate(time.Millisecond)) + } } return nil }, @@ -288,19 +315,17 @@ func applyRateAndSpeed(opts *speakOptions) error { return nil } -func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName string) (tts.TTSRequest, error) { +func buildElevenLabsTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (elevenlabs.TTSRequest, error) { flags := cmd.Flags() var stabilityPtr *float64 if flags.Changed("stability") { if opts.stability < 0 || opts.stability > 1 { - return tts.TTSRequest{}, errors.New("stability must be between 0 and 1") + return elevenlabs.TTSRequest{}, errors.New("stability must be between 0 and 1") } // The discrete 0/0.5/1 constraint is specific to ElevenLabs eleven_v3. - if providerName == providerElevenLabs && opts.modelID == "eleven_v3" { - if !floatEqualsOneOf(opts.stability, []float64{0, 0.5, 1}) { - return tts.TTSRequest{}, errors.New("for eleven_v3, stability must be one of 0.0, 0.5, 1.0 (Creative/Natural/Robust)") - } + if opts.modelID == "eleven_v3" && !floatEqualsOneOf(opts.stability, []float64{0, 0.5, 1}) { + return elevenlabs.TTSRequest{}, errors.New("for eleven_v3, stability must be one of 0.0, 0.5, 1.0 (Creative/Natural/Robust)") } stabilityPtr = &opts.stability } @@ -308,7 +333,7 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName s var similarityPtr *float64 if flags.Changed("similarity") || flags.Changed("similarity-boost") { if opts.similarity < 0 || opts.similarity > 1 { - return tts.TTSRequest{}, errors.New("similarity must be between 0 and 1") + return elevenlabs.TTSRequest{}, errors.New("similarity must be between 0 and 1") } similarityPtr = &opts.similarity } @@ -316,13 +341,13 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName s var stylePtr *float64 if flags.Changed("style") { if opts.style < 0 || opts.style > 1 { - return tts.TTSRequest{}, errors.New("style must be between 0 and 1") + return elevenlabs.TTSRequest{}, errors.New("style must be between 0 and 1") } stylePtr = &opts.style } if flags.Changed("speaker-boost") && flags.Changed("no-speaker-boost") { - return tts.TTSRequest{}, errors.New("choose only one of --speaker-boost or --no-speaker-boost") + return elevenlabs.TTSRequest{}, errors.New("choose only one of --speaker-boost or --no-speaker-boost") } var speakerBoostPtr *bool if flags.Changed("speaker-boost") { @@ -336,7 +361,7 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName s var seedPtr *uint32 if flags.Changed("seed") { if opts.seed > 4294967295 { - return tts.TTSRequest{}, errors.New("seed must be between 0 and 4294967295") + return elevenlabs.TTSRequest{}, errors.New("seed must be between 0 and 4294967295") } v := uint32(opts.seed) seedPtr = &v @@ -347,7 +372,7 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName s switch normalize { case "auto", "on", "off": default: - return tts.TTSRequest{}, errors.New("normalize must be one of: auto, on, off") + return elevenlabs.TTSRequest{}, errors.New("normalize must be one of: auto, on, off") } } else { normalize = "" @@ -356,11 +381,11 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName s lang := strings.ToLower(strings.TrimSpace(opts.lang)) if flags.Changed("lang") { if len(lang) != 2 { - return tts.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") + return elevenlabs.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") } for _, r := range lang { if r < 'a' || r > 'z' { - return tts.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") + return elevenlabs.TTSRequest{}, errors.New("lang must be a 2-letter ISO 639-1 code (e.g. en, de, fr)") } } } else { @@ -368,14 +393,14 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName s } speed := opts.speed - return tts.TTSRequest{ + return elevenlabs.TTSRequest{ Text: text, ModelID: opts.modelID, OutputFormat: opts.outputFmt, Seed: seedPtr, ApplyTextNormalization: normalize, LanguageCode: lang, - VoiceSettings: &tts.VoiceSettings{ + VoiceSettings: &elevenlabs.VoiceSettings{ Speed: &speed, Stability: stabilityPtr, SimilarityBoost: similarityPtr, @@ -385,6 +410,63 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text, providerName s }, nil } +func prepareSixtyDBOptions(cmd *cobra.Command, opts *speakOptions) error { + unsupported := changedSixtyDBUnsupportedFlags(cmd.Flags().Changed) + if len(unsupported) > 0 { + return fmt.Errorf("60db does not support %s", strings.Join(unsupported, ", ")) + } + + if !cmd.Flags().Changed("format") && strings.EqualFold(filepath.Ext(opts.outputPath), ".flac") { + opts.outputFmt = "flac" + } + + format := sixtydb.CanonicalOutputFormat(opts.outputFmt) + if format != "" { + opts.outputFmt = format + } + + if opts.play && format != "" && format != "mp3" { + return fmt.Errorf("60db speaker playback requires mp3 audio; use --no-play to save %s output", format) + } + + if opts.stream && format != "" && format != "mp3" { + if cmd.Flags().Changed("stream") { + return fmt.Errorf("60db streaming does not support %s output; use --no-stream", format) + } + opts.stream = false + } + return nil +} + +func buildSixtyDBTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (sixtydb.TTSRequest, error) { + flags := cmd.Flags() + speed := opts.speed + req := sixtydb.TTSRequest{ + Text: text, + VoiceID: opts.voiceID, + Speed: &speed, + } + + if flags.Changed("stability") { + if opts.stability < 0 || opts.stability > 1 { + return sixtydb.TTSRequest{}, errors.New("stability must be between 0 and 1") + } + stability := opts.stability * 100 + req.Stability = &stability + } + if flags.Changed("similarity") || flags.Changed("similarity-boost") { + if opts.similarity < 0 || opts.similarity > 1 { + return sixtydb.TTSRequest{}, errors.New("similarity must be between 0 and 1") + } + similarity := opts.similarity * 100 + req.Similarity = &similarity + } + if !opts.stream { + req.OutputFormat = opts.outputFmt + } + return req, nil +} + func floatEqualsOneOf(v float64, allowed []float64) bool { const eps = 1e-9 for _, a := range allowed { @@ -444,8 +526,12 @@ func isStdinTTY() bool { return (stat.Mode() & os.ModeCharDevice) != 0 } -func streamAndPlay(ctx context.Context, client tts.Provider, opts speakOptions, payload tts.TTSRequest) (int64, error) { - resp, err := client.StreamTTS(ctx, opts.voiceID, payload, opts.latencyTier) +func streamAndPlay(ctx context.Context, opts speakOptions, stream func(context.Context) (io.ReadCloser, error)) (int64, error) { + if !opts.play && opts.outputPath == "" { + return 0, errors.New("nothing to do: enable --play or provide --output") + } + + resp, err := stream(ctx) if err != nil { return 0, err } @@ -496,17 +582,17 @@ func streamAndPlay(ctx context.Context, client tts.Provider, opts speakOptions, return copyNVal, playErr } - if len(writers) == 0 { - return 0, errors.New("nothing to do: enable --play or provide --output") - } - mw := io.MultiWriter(writers...) n, err := io.Copy(mw, resp) return n, err } -func convertAndPlay(ctx context.Context, client tts.Provider, opts speakOptions, payload tts.TTSRequest) (int64, error) { - data, err := client.ConvertTTS(ctx, opts.voiceID, payload) +func convertAndPlay(ctx context.Context, opts speakOptions, convert func(context.Context) ([]byte, error)) (int64, error) { + if !opts.play && opts.outputPath == "" { + return 0, errors.New("nothing to do: enable --play or provide --output") + } + + data, err := convert(ctx) if err != nil { return 0, err } @@ -533,13 +619,10 @@ func convertAndPlay(ctx context.Context, client tts.Provider, opts speakOptions, }() return n, player(ctx, pr) } - if opts.outputPath == "" { - return n, errors.New("nothing to do: enable --play or provide --output") - } return n, nil } -func resolveVoice(ctx context.Context, client tts.Provider, voiceInput string, forceID bool) (string, error) { +func resolveVoice(ctx context.Context, client tts.VoiceCatalog, voiceInput string, forceID bool) (string, error) { voiceInput = strings.TrimSpace(voiceInput) if voiceInput == "" { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -549,7 +632,7 @@ func resolveVoice(ctx context.Context, client tts.Provider, voiceInput string, f return "", fmt.Errorf("voice not specified and failed to fetch voices: %w", err) } if len(voices) == 0 { - return "", errors.New("no voices available; specify --voice or set ELEVENLABS_VOICE_ID") + return "", errors.New("no voices available; specify --voice") } fmt.Fprintf(os.Stderr, "defaulting to voice %s (%s)\n", voices[0].Name, voices[0].VoiceID) return voices[0].VoiceID, nil diff --git a/cmd/speak_integration_test.go b/cmd/speak_integration_test.go index 1684a45..328657d 100644 --- a/cmd/speak_integration_test.go +++ b/cmd/speak_integration_test.go @@ -1,6 +1,7 @@ package cmd import ( + "encoding/base64" "encoding/json" "net/http" "net/http/httptest" @@ -12,6 +13,8 @@ import ( func TestSpeakCommand_FlagsBuildRequestAndMetrics(t *testing.T) { t.Helper() + resetProviderEnv(t) + resetRootCommandState() const voiceID = "abc1234567890123" @@ -106,4 +109,64 @@ func TestSpeakCommand_FlagsBuildRequestAndMetrics(t *testing.T) { if !strings.Contains(stderr, "metrics: chars=") || !strings.Contains(stderr, "bytes=") || !strings.Contains(stderr, "dur=") { t.Fatalf("expected metrics output, got %q", stderr) } + if !strings.Contains(stderr, "provider=elevenlabs") || !strings.Contains(stderr, "model=eleven_v3") || !strings.Contains(stderr, "latencyTier=0") { + t.Fatalf("expected provider-specific metrics output, got %q", stderr) + } +} + +func TestSpeakCommand_SixtyDBMetricsOmitElevenLabsModel(t *testing.T) { + t.Helper() + resetProviderEnv(t) + resetRootCommandState() + t.Setenv("SIXTYDB_API_KEY", "sd-key") + + const voiceID = "voice-001" + audio := base64.StdEncoding.EncodeToString([]byte("ID3audio-bytes")) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tts-synthesize" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + var got map[string]any + if err := json.NewDecoder(r.Body).Decode(&got); err != nil { + t.Fatalf("decode body: %v", err) + } + if got["voice_id"] != voiceID || got["output_format"] != "mp3" { + t.Fatalf("unexpected request body: %+v", got) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "success": true, + "audio_base64": audio, + "output_format": "mp3", + "encoding": "mp3", + }) + })) + defer srv.Close() + + tmp := t.TempDir() + outPath := tmp + "/out.mp3" + restore, read := captureStderr(t) + defer restore() + + rootCmd.SetArgs([]string{ + "--base-url", srv.URL, + "speak", + "--voice-id", voiceID, + "--stream=false", + "--play=false", + "--output", outPath, + "--metrics", + "Hello world", + }) + + if err := rootCmd.Execute(); err != nil { + t.Fatalf("speak command failed: %v", err) + } + + stderr := read() + if !strings.Contains(stderr, "provider=60db") { + t.Fatalf("expected 60db provider in metrics, got %q", stderr) + } + if strings.Contains(stderr, "model=eleven_v3") || strings.Contains(stderr, "latencyTier=") { + t.Fatalf("expected 60db metrics to omit ElevenLabs metadata, got %q", stderr) + } } diff --git a/cmd/speak_request_test.go b/cmd/speak_request_test.go index b818968..802ce1f 100644 --- a/cmd/speak_request_test.go +++ b/cmd/speak_request_test.go @@ -16,8 +16,13 @@ func newSpeakTestCommand(t *testing.T) (*cobra.Command, *speakOptions) { modelID: "eleven_multilingual_v2", outputFmt: "mp3_44100_128", speed: 1.0, + stream: true, + play: true, } cmd := &cobra.Command{Use: "speak"} + cmd.Flags().StringVar(&opts.modelID, "model-id", opts.modelID, "") + cmd.Flags().StringVar(&opts.outputFmt, "format", opts.outputFmt, "") + cmd.Flags().BoolVar(&opts.stream, "stream", opts.stream, "") cmd.Flags().Float64Var(&opts.stability, "stability", 0, "") cmd.Flags().Float64Var(&opts.similarity, "similarity", 0, "") cmd.Flags().Float64Var(&opts.similarity, "similarity-boost", 0, "") @@ -27,15 +32,16 @@ func newSpeakTestCommand(t *testing.T) (*cobra.Command, *speakOptions) { cmd.Flags().Uint64Var(&opts.seed, "seed", 0, "") cmd.Flags().StringVar(&opts.normalize, "normalize", "", "") cmd.Flags().StringVar(&opts.lang, "lang", "", "") + cmd.Flags().IntVar(&opts.latencyTier, "latency-tier", 0, "") return cmd, opts } -func TestBuildTTSRequest_DefaultsOmitOptionalFields(t *testing.T) { +func TestBuildElevenLabsTTSRequest_DefaultsOmitOptionalFields(t *testing.T) { cmd, opts := newSpeakTestCommand(t) - req, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + req, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err != nil { - t.Fatalf("buildTTSRequest error: %v", err) + t.Fatalf("buildElevenLabsTTSRequest error: %v", err) } if req.Seed != nil { @@ -64,30 +70,30 @@ func TestBuildTTSRequest_DefaultsOmitOptionalFields(t *testing.T) { } } -func TestBuildTTSRequest_SimilarityBoostAlias(t *testing.T) { +func TestBuildElevenLabsTTSRequest_SimilarityBoostAlias(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--similarity-boost", "0.9"}); err != nil { t.Fatalf("parse flags: %v", err) } - req, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + req, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err != nil { - t.Fatalf("buildTTSRequest error: %v", err) + t.Fatalf("buildElevenLabsTTSRequest error: %v", err) } if req.VoiceSettings.SimilarityBoost == nil || *req.VoiceSettings.SimilarityBoost != 0.9 { t.Fatalf("expected similarity_boost 0.9, got %#v", req.VoiceSettings.SimilarityBoost) } } -func TestBuildTTSRequest_SpeakerBoostSetsJSONKey(t *testing.T) { +func TestBuildElevenLabsTTSRequest_SpeakerBoostSetsJSONKey(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--speaker-boost"}); err != nil { t.Fatalf("parse flags: %v", err) } - req, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + req, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err != nil { - t.Fatalf("buildTTSRequest error: %v", err) + t.Fatalf("buildElevenLabsTTSRequest error: %v", err) } if req.VoiceSettings.UseSpeakerBoost == nil || *req.VoiceSettings.UseSpeakerBoost != true { t.Fatalf("expected use_speaker_boost true, got %#v", req.VoiceSettings.UseSpeakerBoost) @@ -102,62 +108,174 @@ func TestBuildTTSRequest_SpeakerBoostSetsJSONKey(t *testing.T) { } } -func TestBuildTTSRequest_InvalidNormalize(t *testing.T) { +func TestBuildElevenLabsTTSRequest_InvalidNormalize(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--normalize", "wat"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "normalize must be one of") { t.Fatalf("expected normalize error, got %v", err) } } -func TestBuildTTSRequest_InvalidLang(t *testing.T) { +func TestBuildElevenLabsTTSRequest_InvalidLang(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--lang", "eng"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "lang must be a 2-letter") { t.Fatalf("expected lang error, got %v", err) } } -func TestBuildTTSRequest_InvalidSeed(t *testing.T) { +func TestBuildElevenLabsTTSRequest_InvalidSeed(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--seed", "4294967296"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "seed must be between") { t.Fatalf("expected seed error, got %v", err) } } -func TestBuildTTSRequest_SpeakerBoostConflict(t *testing.T) { +func TestBuildElevenLabsTTSRequest_SpeakerBoostConflict(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--speaker-boost", "--no-speaker-boost"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "choose only one") { t.Fatalf("expected conflict error, got %v", err) } } -func TestBuildTTSRequest_V3StabilityPresetsOnly(t *testing.T) { +func TestBuildElevenLabsTTSRequest_V3StabilityPresetsOnly(t *testing.T) { cmd, opts := newSpeakTestCommand(t) opts.modelID = "eleven_v3" if err := cmd.Flags().Parse([]string{"--stability", "0.55"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello", providerElevenLabs) + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "for eleven_v3, stability must be one of") { t.Fatalf("expected v3 stability preset error, got %v", err) } } +func TestPrepareSixtyDBOptionsRejectsUnsupportedFlags(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + if err := cmd.Flags().Parse([]string{"--model-id", "foo", "--style", "0.3", "--latency-tier", "2"}); err != nil { + t.Fatalf("parse flags: %v", err) + } + + err := prepareSixtyDBOptions(cmd, opts) + if err == nil { + t.Fatal("expected unsupported flag error") + } + for _, flag := range []string{"--model-id", "--style", "--latency-tier"} { + if !strings.Contains(err.Error(), flag) { + t.Fatalf("expected %s in error, got %v", flag, err) + } + } +} + +func TestPrepareSixtyDBOptionsFallsBackToConvertForNonMP3(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.outputFmt = "wav" + opts.play = false + + if err := prepareSixtyDBOptions(cmd, opts); err != nil { + t.Fatalf("prepareSixtyDBOptions error: %v", err) + } + if opts.outputFmt != "wav" { + t.Fatalf("expected canonical wav format, got %q", opts.outputFmt) + } + if opts.stream { + t.Fatalf("expected stream to be disabled for non-mp3 output") + } +} + +func TestPrepareSixtyDBOptionsInfersFLACFromOutputPath(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.outputPath = "voice.FLAC" + opts.play = false + + if err := prepareSixtyDBOptions(cmd, opts); err != nil { + t.Fatalf("prepareSixtyDBOptions error: %v", err) + } + if opts.outputFmt != "flac" { + t.Fatalf("expected flac output format, got %q", opts.outputFmt) + } + if opts.stream { + t.Fatalf("expected stream to be disabled for flac output") + } +} + +func TestPrepareSixtyDBOptionsRejectsExplicitStreamForNonMP3(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + if err := cmd.Flags().Parse([]string{"--stream", "--format", "wav"}); err != nil { + t.Fatalf("parse flags: %v", err) + } + opts.play = false + + err := prepareSixtyDBOptions(cmd, opts) + if err == nil || !strings.Contains(err.Error(), "use --no-stream") { + t.Fatalf("expected explicit stream rejection, got %v", err) + } +} + +func TestPrepareSixtyDBOptionsRejectsPlaybackForNonMP3(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.outputFmt = "flac" + + err := prepareSixtyDBOptions(cmd, opts) + if err == nil || !strings.Contains(err.Error(), "requires mp3 audio") { + t.Fatalf("expected playback format error, got %v", err) + } +} + +func TestBuildSixtyDBTTSRequestScalesDocumentedFields(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.voiceID = "voice_123" + opts.stream = false + if err := cmd.Flags().Parse([]string{"--stability", "0.5", "--similarity-boost", "0.8"}); err != nil { + t.Fatalf("parse flags: %v", err) + } + + req, err := buildSixtyDBTTSRequest(cmd, *opts, "hello") + if err != nil { + t.Fatalf("buildSixtyDBTTSRequest error: %v", err) + } + if req.Text != "hello" || req.VoiceID != "voice_123" { + t.Fatalf("unexpected request identity: %+v", req) + } + if req.Speed == nil || *req.Speed != 1.0 { + t.Fatalf("expected speed 1.0, got %#v", req.Speed) + } + if req.Stability == nil || *req.Stability != 50 { + t.Fatalf("expected stability 50, got %#v", req.Stability) + } + if req.Similarity == nil || *req.Similarity != 80 { + t.Fatalf("expected similarity 80, got %#v", req.Similarity) + } + if req.OutputFormat != "mp3_44100_128" { + t.Fatalf("expected raw output format to be passed to client canonicalizer, got %q", req.OutputFormat) + } +} + +func TestBuildSixtyDBTTSRequestOmitsOutputFormatWhenStreaming(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + req, err := buildSixtyDBTTSRequest(cmd, *opts, "hello") + if err != nil { + t.Fatalf("buildSixtyDBTTSRequest error: %v", err) + } + if req.OutputFormat != "" { + t.Fatalf("expected stream request to omit output format, got %q", req.OutputFormat) + } +} + func TestApplyCompatibilityFlagsNoPlayNoStream(t *testing.T) { opts := &speakOptions{play: true, stream: true} cmd := &cobra.Command{Use: "speak"} @@ -258,22 +376,3 @@ func TestTTSContextNoDeadlineByDefault(t *testing.T) { t.Fatalf("expected no deadline for zero timeout") } } - -func TestTTSContextWithTimeout(t *testing.T) { - ctx, cancel, err := ttsContext(context.Background(), time.Minute) - if err != nil { - t.Fatalf("ttsContext error: %v", err) - } - defer cancel() - - if _, ok := ctx.Deadline(); !ok { - t.Fatalf("expected deadline for non-zero timeout") - } -} - -func TestTTSContextRejectsNegativeTimeout(t *testing.T) { - _, _, err := ttsContext(context.Background(), -time.Second) - if err == nil || !strings.Contains(err.Error(), "timeout must be") { - t.Fatalf("expected timeout error, got %v", err) - } -} diff --git a/cmd/speak_test.go b/cmd/speak_test.go index 1a494b0..e68da27 100644 --- a/cmd/speak_test.go +++ b/cmd/speak_test.go @@ -316,9 +316,10 @@ func TestStreamAndPlayWritesOutput(t *testing.T) { tmp := t.TempDir() out := tmp + "/out.mp3" opts := speakOptions{voiceID: "v1", outputPath: out, stream: true, play: false} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := streamAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := streamAndPlay(context.Background(), opts, func(ctx context.Context) (io.ReadCloser, error) { + return client.StreamTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}, 0) + }); err != nil { t.Fatalf("streamAndPlay error: %v", err) } data, err := os.ReadFile(out) @@ -343,9 +344,10 @@ func TestConvertAndPlayWritesOutput(t *testing.T) { tmp := t.TempDir() out := tmp + "/out.mp3" opts := speakOptions{voiceID: "v1", outputPath: out, play: false} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := convertAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := convertAndPlay(context.Background(), opts, func(ctx context.Context) ([]byte, error) { + return client.ConvertTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}) + }); err != nil { t.Fatalf("convertAndPlay error: %v", err) } data, err := os.ReadFile(out) @@ -358,11 +360,11 @@ func TestConvertAndPlayWritesOutput(t *testing.T) { } func TestStreamAndPlayRequiresWork(t *testing.T) { - client := elevenlabs.NewClient("key", "http://invalid") opts := speakOptions{voiceID: "v1", play: false, stream: true} - payload := elevenlabs.TTSRequest{Text: "hi"} - - _, err := streamAndPlay(context.Background(), client, opts, payload) + _, err := streamAndPlay(context.Background(), opts, func(context.Context) (io.ReadCloser, error) { + t.Fatal("stream should not be invoked when nothing will consume it") + return nil, nil + }) if err == nil { t.Fatalf("expected error when no output and play disabled") } @@ -385,9 +387,10 @@ func TestStreamAndPlayWithPlayback(t *testing.T) { client := elevenlabs.NewClient("key", srv.URL) opts := speakOptions{voiceID: "v1", play: true, stream: true} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := streamAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := streamAndPlay(context.Background(), opts, func(ctx context.Context) (io.ReadCloser, error) { + return client.StreamTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}, 0) + }); err != nil { t.Fatalf("streamAndPlay error: %v", err) } if !called { @@ -412,9 +415,10 @@ func TestConvertAndPlayWithPlayback(t *testing.T) { client := elevenlabs.NewClient("key", srv.URL) opts := speakOptions{voiceID: "v1", play: true, outputPath: "", stream: false} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := convertAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := convertAndPlay(context.Background(), opts, func(ctx context.Context) ([]byte, error) { + return client.ConvertTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}) + }); err != nil { t.Fatalf("convertAndPlay error: %v", err) } if !called { diff --git a/cmd/test_helpers_test.go b/cmd/test_helpers_test.go new file mode 100644 index 0000000..5123dfa --- /dev/null +++ b/cmd/test_helpers_test.go @@ -0,0 +1,33 @@ +package cmd + +import ( + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +func resetRootCommandState() { + resetFlags := func(flags *pflag.FlagSet) { + flags.VisitAll(func(flag *pflag.Flag) { + if replacer, ok := flag.Value.(interface{ Replace([]string) error }); ok { + _ = replacer.Replace([]string{}) + } else { + _ = flags.Set(flag.Name, flag.DefValue) + } + flag.Changed = false + }) + } + + var reset func(*cobra.Command) + reset = func(cmd *cobra.Command) { + cmd.SetArgs(nil) + resetFlags(cmd.Flags()) + resetFlags(cmd.PersistentFlags()) + for _, sub := range cmd.Commands() { + reset(sub) + } + } + + cfg = rootConfig{} + versionFlag = false + reset(rootCmd) +} diff --git a/cmd/voices.go b/cmd/voices.go index 7dd9889..d4f98fd 100644 --- a/cmd/voices.go +++ b/cmd/voices.go @@ -53,10 +53,11 @@ func init() { return errors.New("--try requires --search, --query, --label, or --limit to avoid playing all voices") } - client, _, err := selectProvider() + provider, err := selectProvider() if err != nil { return err } + client := provider.voices ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second) defer cancel() @@ -182,7 +183,7 @@ func filterVoicesByName(voices []tts.Voice, search string) []tts.Voice { return filtered } -func playVoicePreviewImpl(ctx context.Context, client tts.Provider, voice tts.Voice) error { +func playVoicePreviewImpl(ctx context.Context, client tts.VoiceCatalog, voice tts.Voice) error { ctx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() diff --git a/cmd/voices_cache.go b/cmd/voices_cache.go index 2e61d11..738d2c2 100644 --- a/cmd/voices_cache.go +++ b/cmd/voices_cache.go @@ -80,7 +80,7 @@ func saveVoiceCache(path string, cache *voiceCache) error { return os.WriteFile(path, data, 0o644) } -func hydrateVoices(ctx context.Context, client tts.Provider, voices []tts.Voice, cache *voiceCache, ttl time.Duration) ([]tts.Voice, int) { +func hydrateVoices(ctx context.Context, client tts.VoiceCatalog, voices []tts.Voice, cache *voiceCache, ttl time.Duration) ([]tts.Voice, int) { if ttl <= 0 { ttl = voiceCacheTTL } @@ -150,7 +150,12 @@ func mergeVoice(base tts.Voice, details tts.Voice) tts.Voice { merged.Description = details.Description } if len(details.Labels) > 0 { - merged.Labels = details.Labels + if len(merged.Labels) == 0 { + merged.Labels = map[string]string{} + } + for key, value := range details.Labels { + merged.Labels[key] = value + } } if details.PreviewURL != "" { merged.PreviewURL = details.PreviewURL diff --git a/cmd/voices_cache_test.go b/cmd/voices_cache_test.go new file mode 100644 index 0000000..070e2a9 --- /dev/null +++ b/cmd/voices_cache_test.go @@ -0,0 +1,36 @@ +package cmd + +import ( + "testing" + + "github.com/steipete/sag/internal/tts" +) + +func TestMergeVoicePreservesAndOverlaysLabels(t *testing.T) { + base := tts.Voice{ + VoiceID: "v1", + Labels: map[string]string{ + "source": "default", + "model": "60db Quality", + }, + } + details := tts.Voice{ + VoiceID: "v1", + Labels: map[string]string{ + "accent": "American", + "model": "override", + }, + PreviewURL: "https://cdn.example.com/sample.mp3", + } + + merged := mergeVoice(base, details) + if merged.Labels["source"] != "default" { + t.Fatalf("expected source label preserved, got %+v", merged.Labels) + } + if merged.Labels["model"] != "override" || merged.Labels["accent"] != "American" { + t.Fatalf("expected detail labels merged, got %+v", merged.Labels) + } + if merged.PreviewURL != details.PreviewURL { + t.Fatalf("expected preview URL copied, got %q", merged.PreviewURL) + } +} diff --git a/cmd/voices_test.go b/cmd/voices_test.go index ead0001..0851158 100644 --- a/cmd/voices_test.go +++ b/cmd/voices_test.go @@ -13,6 +13,8 @@ import ( ) func TestVoicesCommand(t *testing.T) { + resetRootCommandState() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/voices" { t.Fatalf("unexpected path: %s", r.URL.Path) @@ -63,6 +65,8 @@ func TestFilterVoicesByName(t *testing.T) { } func TestVoicesCommandTryRequiresFilter(t *testing.T) { + resetRootCommandState() + cfg.APIKey = "key" cfg.BaseURL = "http://example.invalid" t.Cleanup(func() { diff --git a/docs/configuration.md b/docs/configuration.md index 0d088a8..2f09c58 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,24 +1,28 @@ --- title: Configuration -description: "API keys, default voices, timeouts, base URL, and player selection — everything sag reads from the environment." +description: "Provider keys, default voices, timeouts, base URL, and player selection — everything sag reads from flags or the environment." --- # Configuration `sag` reads configuration from CLI flags first, then environment variables. There is no config file: the binary stays single-purpose and friendly to ephemeral CI runners. -## API key +## Provider key Required for any TTS or voice call. `sag --help`, `sag prompting`, and `sag --version` work without one. -| Flag / variable | Notes | +| Provider | Flag / variable | Notes | | --- | --- | -| `--api-key` | Inline override. Avoid in shell history; prefer env or `--api-key-file`. | -| `ELEVENLABS_API_KEY` | Primary env var. | -| `SAG_API_KEY` | Accepted alias. | -| `--api-key-file ` | Read the key from a file. | -| `ELEVENLABS_API_KEY_FILE` | Same as `--api-key-file`. | -| `SAG_API_KEY_FILE` | Alias. | +| ElevenLabs | `--api-key` | Inline override. Avoid in shell history; prefer env or `--api-key-file`. | +| ElevenLabs | `ELEVENLABS_API_KEY` | Primary env var. | +| ElevenLabs | `SAG_API_KEY` | Accepted alias. | +| ElevenLabs | `--api-key-file ` | Read the key from a file. | +| ElevenLabs | `ELEVENLABS_API_KEY_FILE` | Same as `--api-key-file`. | +| ElevenLabs | `SAG_API_KEY_FILE` | Alias. | +| 60db | `SIXTYDB_API_KEY` | Primary env var. | +| 60db | `SIXTYDB_API_KEY_FILE` | Read the key from a file. | + +Configure exactly one provider at a time. If both ElevenLabs and 60db credentials are present, `sag` errors instead of guessing. The file form is handy for agents and containers: @@ -34,7 +38,9 @@ When `--voice` / `--voice-id` is omitted, `sag` resolves in this order: 1. `ELEVENLABS_VOICE_ID` 2. `SAG_VOICE_ID` -3. The first voice returned by `/v1/voices` (logged on stderr so you notice). +3. The first voice returned by the active provider's voice listing (logged on stderr so you notice). + +The env defaults apply only to ElevenLabs. With 60db, `sag` falls back to the first merged result from `/default-voices` and `/myvoices`. ```bash export SAG_VOICE_ID=21m00Tcm4TlvDq8ikWAM @@ -77,18 +83,18 @@ Pick a backend explicitly via `--player oto` or `SAG_PLAYER=oto`. See [Streaming ## API base URL -Override the ElevenLabs endpoint when you’re routing through a proxy or talking to a regional/staging deployment: +Override the active provider endpoint when you’re routing through a proxy or talking to a regional/staging deployment: ```bash sag --base-url https://api.elevenlabs.io "Default." sag --base-url https://your-proxy.internal "Routed." ``` -The default is `https://api.elevenlabs.io`. There is no env var for this; it’s deliberate so the API target is always visible in the command line. +The default is `https://api.elevenlabs.io` for ElevenLabs and `https://api.60db.ai` for 60db. There is no env var for this; it’s deliberate so the API target is always visible in the command line. ## Voice metadata cache -`sag voices --query` and `--label` need full voice descriptors. Metadata is cached in your platform-default config directory (`$XDG_CONFIG_HOME/sag/voices.json` on Linux, `~/Library/Application Support/sag/voices.json` on macOS) for 24 hours. Delete the file or pass `--limit 0` after a voice update to force a refresh. +`sag voices --query` and `--label` need full voice descriptors. Metadata is cached in your platform-default cache directory for 24 hours. Delete the file if you need an immediate refresh. ## Compatibility flags (no-ops) diff --git a/docs/index.md b/docs/index.md index 763e01c..4c11c0f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,12 +1,12 @@ --- title: Overview permalink: / -description: "sag is a Go CLI that turns text into ElevenLabs speech the way macOS `say` does. Stream to speakers, save to files, swap voices, choose models — one binary, terminal-first." +description: "sag is a Go CLI that turns text into macOS `say`-style speech with ElevenLabs or 60db. Stream to speakers, save to files, swap voices, choose models — one binary, terminal-first." --- ## Try it -After [installing](install.md) and exporting `ELEVENLABS_API_KEY`, every TTS call is a one-liner. +After [installing](install.md) and exporting exactly one provider key, every TTS call is a one-liner. ```bash # Speak through the speakers, macOS-say style. @@ -31,8 +31,9 @@ sag voices --search english --limit 5 --try ## What sag does - **Drop-in `say` replacement.** Same flag shapes (`-v`, `-r`, `-f`, `-o`), same default of streaming to the speakers. Compatibility no-ops (`--progress`, `--audio-device`, `--quality`, …) keep existing scripts working. -- **Stream-while-you-generate.** Audio plays as bytes arrive over `/v1/text-to-speech/{voice}/stream`. Latency tiers let you trade quality for first-byte time. +- **Stream-while-you-generate.** Audio plays as bytes arrive over ElevenLabs `/v1/text-to-speech/{voice}/stream` or 60db `/tts-stream`, with automatic fallback to full synthesis when the 60db route cannot represent the requested file format. - **Voice discovery.** Server-side name search, semantic `--query` over name/description/labels, repeatable `--label key=value` filters, plus `--try` to play preview clips for the matches. +- **Optional 60db backend.** `sag` merges 60db `/default-voices` and `/myvoices`, validates JSON success envelopes even on HTTP 200, and decodes the documented base64/NDJSON audio responses internally. - **Every ElevenLabs model.** Defaults to `eleven_v3` for expressive output; switch to `eleven_multilingual_v2`, `eleven_flash_v2_5`, or `eleven_turbo_v2_5` with `--model-id`. Stability, similarity, style, speaker-boost, seed, normalization, and language are all flag-controlled. - **Format inference.** `.mp3` → `mp3_44100_128`, `.wav` → `pcm_44100`, `.ogg`/`.opus` → `opus_48000_64`. Override with `--format` when you need something else. - **Cross-platform playback.** macOS uses `afplay` for AirPlay-friendly routing; Linux and Windows fall back to a `go-mp3` + `oto` decoder. Pick explicitly with `--player auto|afplay|oto`. @@ -49,4 +50,4 @@ sag voices --search english --limit 5 --try ## Project -`sag` is MIT-licensed and not affiliated with ElevenLabs or Apple. The [changelog](https://github.com/steipete/sag/blob/main/CHANGELOG.md) tracks releases; the [spec](spec.md) records goals and non-goals. Source: . +`sag` is MIT-licensed and not affiliated with ElevenLabs, 60db, or Apple. The [changelog](https://github.com/steipete/sag/blob/main/CHANGELOG.md) tracks releases; the [spec](spec.md) records goals and non-goals. Source: . diff --git a/docs/providers.md b/docs/providers.md index 2ddae7a..f89ad2c 100644 --- a/docs/providers.md +++ b/docs/providers.md @@ -1,67 +1,74 @@ # Providers -`sag` speaks to two text-to-speech backends behind one consistent CLI. A small -provider abstraction (`internal/tts.Provider`) lets the command layer and the -audio player stay backend-agnostic; each provider translates the shared request -to and from its own wire format. +`sag` supports two HTTP TTS backends. The CLI auto-selects the provider from your configured credentials; there is no `--provider` flag. ## Selecting a provider -The provider is auto-detected from whichever API key is present — there is no -`--provider` flag. - | Keys set | Active provider | -|---|---| -| `ELEVENLABS_API_KEY` (or `--api-key`/file) only | ElevenLabs | -| `SIXTYDB_API_KEY` (or `SIXTYDB_API_KEY_FILE`) only | 60db | -| both | ElevenLabs (note printed; unset `ELEVENLABS_API_KEY` to use 60db) | +| --- | --- | +| `ELEVENLABS_API_KEY` / `SAG_API_KEY` or `--api-key` / `--api-key-file` only | ElevenLabs | +| `SIXTYDB_API_KEY` or `SIXTYDB_API_KEY_FILE` only | 60db | +| both | error | | neither | error | -Override the host for the active provider with `--base-url`. +Use `--base-url` to override the active provider host. + +## 60db routes sag uses -## What each provider implements +The 60db integration is deliberately limited to the documented REST contract: -| Capability | ElevenLabs | 60db | -|---|---|---| -| Auth | `xi-api-key: ` | `Authorization: Bearer ` | -| Default host | `https://api.elevenlabs.io` | `https://api.60db.ai` | -| List voices | `GET /v1/voices` | `GET /myvoices` | -| Full synthesis | `POST /v1/text-to-speech/{id}` (raw audio) | `POST /tts-synthesize` (base64 in JSON) | -| Streaming | `POST /v1/text-to-speech/{id}/stream` (raw audio) | `POST /tts-stream` (NDJSON, base64 chunks) | +| Capability | Route | Notes | +| --- | --- | --- | +| Voice listing | `GET /default-voices` | Workspace-default voices; queried first. | +| Voice listing | `GET /myvoices` | User-created voices; appended after defaults and deduped by `voice_id`. | +| Streaming speak | `POST /tts-stream` | NDJSON stream of base64 chunks; no `output_format` field in the request. | +| Full speak | `POST /tts-synthesize` | JSON envelope with `audio_base64`, `encoding`, and `output_format`. | -For 60db, `sag` decodes the base64/NDJSON envelope internally, so streaming and -file output behave the same as ElevenLabs. 60db's WebSocket API is not used. +For 60db, `sag` validates the JSON `success` envelope even on HTTP 200 responses, decodes base64/NDJSON audio internally, rejects malformed or empty audio, and enforces per-chunk, per-frame, and total-audio limits. ## Flag behavior -Flags are written in ElevenLabs terms and translated per provider so the same -command works on both. - -| Flag | ElevenLabs | 60db | -|---|---|---| -| `--speed` / `--rate` | speed multiplier `0.5–2.0` | passthrough (same range) | -| `--stability` | `0..1` | scaled to `0..100` | -| `--similarity` / `--similarity-boost` | `0..1` | scaled to `0..100` | -| `--format` / `-o` extension | `mp3_44100_128`, `pcm_44100`, … | mapped to `mp3` / `wav` / `ogg` / `flac` | -| `--model-id` | model id (e.g. `eleven_v3`) | ignored (model is tied to the voice) | -| `--style` | style exaggeration | ignored | -| `--speaker-boost` / `--no-speaker-boost` | speaker boost | ignored | -| `--seed` | best-effort determinism | ignored | -| `--normalize` | text normalization | ignored | -| `--lang` | language code | ignored | -| `--latency-tier` | streaming latency tier | ignored | - -When you set an ignored flag while 60db is active, `sag` prints a one-line note -to stderr rather than failing. - -## Notes and limits - -- **Playback format:** speaker playback decodes MP3, so when playing through - speakers `sag` requests MP3 from 60db regardless of `--format`. Use - `--no-play -o file.` to save other formats. -- **Voice previews:** `sag voices --try` is not available for 60db (its - `/myvoices` response has no preview URL); the affected voice is skipped with a - message. -- **Voice metadata:** 60db's per-voice `model` (`60db Fast` / `60db Quality`) is - exposed as a `model` label, so `--label model=...` and `--query` can match it. -- **Text limit:** 60db caps synthesis at 5,000 characters per request. +### Supported on both providers + +- `-v, --voice` +- `-r, --rate` +- `--speed` +- `--stability` +- `--similarity` / `--similarity-boost` +- `--format` +- `--stream` / `--no-stream` +- `--play` / `--no-play` +- `-o, --output` +- `--timeout` +- `--metrics` + +### ElevenLabs-only + +These flags are passed through to ElevenLabs and rejected on 60db: + +- `--model-id` +- `--style` +- `--speaker-boost` +- `--no-speaker-boost` +- `--seed` +- `--normalize` +- `--lang` +- `--latency-tier` + +### Parameter translation + +- `--stability` and `--similarity` stay in the CLI's `0..1` range and are scaled to 60db's documented `0..100` request values. +- `--format` is canonicalized for 60db full synthesis: `mp3_*` → `mp3`, `pcm_*` / `wav` → `wav`, `opus_*` / `ogg` → `ogg`, `flac` → `flac`. + +## Streaming, files, and playback + +- ElevenLabs can stream in the requested output format, so `--stream` and `--format` work together. +- 60db streaming is used only for the default MP3 path because `/tts-stream` does not document `output_format`. +- On 60db, if the effective output format is non-MP3 and streaming was only enabled by the default, `sag` automatically falls back to `POST /tts-synthesize`. +- On 60db, `--stream` plus a non-MP3 format is an error when you explicitly force `--stream`. +- On 60db, `--play` requires MP3 output. Use `--no-play -o out.wav` (or `ogg` / `flac`) for other formats. + +## Voice metadata notes + +- `sag voices --try` uses `GET /voices/:id` on 60db to fetch `sample_url` when the list responses do not include preview URLs. +- 60db voice `model` and `categories` values are exposed as CLI labels so `--query` and `--label model=...` still work. diff --git a/docs/spec.md b/docs/spec.md index e844162..5f6f08e 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -1,11 +1,11 @@ # sag specification -CLI that mirrors macOS `say` but uses ElevenLabs for synthesis. Defaults to streaming directly to speakers and can also write audio files. +CLI that mirrors macOS `say` but uses ElevenLabs or 60db for synthesis. Defaults to streaming directly to speakers and can also write audio files. ## Runtime & deps - Go 1.24+ - Playback uses built-in Go audio (go-mp3 + oto) and should work on macOS/Linux/Windows with a default output device. -- Auth via `ELEVENLABS_API_KEY` (or `--api-key` flag). +- Auth via exactly one configured provider: ElevenLabs (`ELEVENLABS_API_KEY`, `SAG_API_KEY`, `--api-key`) or 60db (`SIXTYDB_API_KEY`). ## Commands @@ -16,7 +16,7 @@ CLI that mirrors macOS `say` but uses ElevenLabs for synthesis. Defaults to stre - `-r/--rate` words-per-minute (default 175) maps to ElevenLabs speed. - `-o/--output` same meaning; format inferred by extension when possible. - Accepts but ignores `--progress`, `--audio-device`, `--network-send`, `--interactive`, `--file-format`, `--data-format`, `--channels`, `--bit-rate`, `--quality`. -- Required: voice (via `-v/--voice` or `ELEVENLABS_VOICE_ID`/`SAG_VOICE_ID`). +- Required: voice (via `-v/--voice` or the active provider default resolution). - Flags: - `--model-id` (default `eleven_v3`; common: `eleven_multilingual_v2`, `eleven_flash_v2_5`, `eleven_turbo_v2_5`) - `--format` (default `mp3_44100_128`; `.wav` infers `pcm_44100`) @@ -34,8 +34,10 @@ CLI that mirrors macOS `say` but uses ElevenLabs for synthesis. Defaults to stre - `--metrics` print basic stats to stderr - `--output ` save audio while optionally playing - Behavior: - - Streaming path calls `POST /v1/text-to-speech/{voice_id}/stream` with JSON body. - - Non-streaming path calls `POST /v1/text-to-speech/{voice_id}` and then plays/saves. + - ElevenLabs streaming path calls `POST /v1/text-to-speech/{voice_id}/stream` with JSON body. + - 60db streaming path calls `POST /tts-stream` and decodes the documented NDJSON chunk stream. + - ElevenLabs non-streaming path calls `POST /v1/text-to-speech/{voice_id}` and then plays/saves. + - 60db non-streaming path calls `POST /tts-synthesize` and decodes `audio_base64`. - Errors if neither playback nor output is selected. Usage examples: @@ -49,7 +51,7 @@ sag speak -v "Roger" -r 200 "mac say style flags" ``` ### `sag voices` -- Lists voices via `GET /v1/voices` (server-side search when supported). +- Lists voices via the active provider. ElevenLabs uses `GET /v1/voices`; 60db merges `GET /default-voices` and `GET /myvoices`. - Flags: - `--search `: search by name (server-side when available) - `--query `: semantic query across name/description/labels (client-side) @@ -67,9 +69,11 @@ sag voices --search "english" - Does not require an API key. ## Config sources -- `ELEVENLABS_API_KEY` for auth (required). -- Default voice env: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. -- `--base-url` flag for alternate API host (defaults to `https://api.elevenlabs.io`). +- Exactly one provider key is required. +- ElevenLabs auth: `ELEVENLABS_API_KEY`, `SAG_API_KEY`, `--api-key`, `--api-key-file`. +- 60db auth: `SIXTYDB_API_KEY`, `SIXTYDB_API_KEY_FILE`. +- ElevenLabs default voice env: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. +- `--base-url` flag for an alternate provider API host. ## Notes & future polish - Add cross-platform playback backends. diff --git a/internal/elevenlabs/client.go b/internal/elevenlabs/client.go index e584d56..b7ccdb0 100644 --- a/internal/elevenlabs/client.go +++ b/internal/elevenlabs/client.go @@ -21,9 +21,6 @@ type Client struct { httpClient *http.Client } -// Ensure the ElevenLabs client satisfies the shared provider contract. -var _ tts.Provider = (*Client)(nil) - // NewClient returns a Client configured with the given API key and base URL. func NewClient(apiKey, baseURL string) *Client { if baseURL == "" { @@ -36,16 +33,11 @@ func NewClient(apiKey, baseURL string) *Client { } } -// Voice, VoiceSettings, and TTSRequest are re-exported from the shared tts -// package so existing callers (and tests) can keep using elevenlabs.Voice etc. -// while every provider speaks the same types. +// Voice is re-exported from the shared voice-catalog package so existing +// callers can keep using elevenlabs.Voice while query/filter code stays shared. type ( // Voice represents a voice entry returned by ElevenLabs. Voice = tts.Voice - // VoiceSettings tunes synthesis parameters for a request. - VoiceSettings = tts.VoiceSettings - // TTSRequest configures a text-to-speech request payload. - TTSRequest = tts.TTSRequest ) type listVoicesResponse struct { @@ -190,6 +182,26 @@ func (c *Client) GetVoice(ctx context.Context, voiceID string) (Voice, error) { return voice, nil } +// TTSRequest configures a text-to-speech request payload. +type TTSRequest struct { + Text string `json:"text"` + ModelID string `json:"model_id,omitempty"` + VoiceSettings *VoiceSettings `json:"voice_settings,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Seed *uint32 `json:"seed,omitempty"` + ApplyTextNormalization string `json:"apply_text_normalization,omitempty"` + LanguageCode string `json:"language_code,omitempty"` +} + +// VoiceSettings tunes synthesis parameters for a request. +type VoiceSettings struct { + Stability *float64 `json:"stability,omitempty"` + SimilarityBoost *float64 `json:"similarity_boost,omitempty"` + Style *float64 `json:"style,omitempty"` + UseSpeakerBoost *bool `json:"use_speaker_boost,omitempty"` + Speed *float64 `json:"speed,omitempty"` +} + // StreamTTS requests streaming audio for text-to-speech. func (c *Client) StreamTTS(ctx context.Context, voiceID string, payload TTSRequest, latency int) (io.ReadCloser, error) { u, err := url.Parse(c.baseURL) diff --git a/internal/sixtydb/client.go b/internal/sixtydb/client.go index 74b1a80..ad30b41 100644 --- a/internal/sixtydb/client.go +++ b/internal/sixtydb/client.go @@ -1,10 +1,5 @@ -// Package sixtydb provides a small client for the 60db (api.60db.ai) TTS API. -// -// Unlike ElevenLabs, 60db never returns raw audio: the synthesize endpoint -// wraps it as base64 inside a JSON object, and the stream endpoint emits -// newline-delimited JSON frames whose audio is base64 too. This client decodes -// both so StreamTTS/ConvertTTS hand back plain audio bytes, exactly like the -// ElevenLabs client — keeping the command and audio layers provider-agnostic. +// Package sixtydb provides a strict adapter for the documented 60db HTTP TTS +// endpoints. package sixtydb import ( @@ -13,6 +8,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -26,6 +22,21 @@ import ( // DefaultBaseURL is the public 60db API host. const DefaultBaseURL = "https://api.60db.ai" +const ( + maxDecodedAudioBytes = 96 << 20 + maxDecodedChunkBytes = 8 << 20 + maxStreamFrameBytes = 12 << 20 +) + +type audioFormat string + +const ( + audioFormatMP3 audioFormat = "mp3" + audioFormatWAV audioFormat = "wav" + audioFormatOGG audioFormat = "ogg" + audioFormatFLAC audioFormat = "flac" +) + // Client talks to the 60db HTTP API. type Client struct { baseURL string @@ -33,8 +44,15 @@ type Client struct { httpClient *http.Client } -// Ensure the 60db client satisfies the shared provider contract. -var _ tts.Provider = (*Client)(nil) +// TTSRequest is the documented 60db TTS payload shape exposed by the CLI. +type TTSRequest struct { + Text string `json:"text"` + VoiceID string `json:"voice_id,omitempty"` + Speed *float64 `json:"speed,omitempty"` + Stability *float64 `json:"stability,omitempty"` + Similarity *float64 `json:"similarity,omitempty"` + OutputFormat string `json:"output_format,omitempty"` +} // NewClient returns a Client configured with the given API key and base URL. func NewClient(apiKey, baseURL string) *Client { @@ -63,7 +81,95 @@ func (c *Client) newRequest(ctx context.Context, method, endpoint string, body i return req, nil } -// --- Voices --------------------------------------------------------------- +type envelope struct { + Success *bool `json:"success"` + Message string `json:"message"` + Error string `json:"error"` +} + +func (c *Client) errorFromEnvelope(op string, env envelope, fallback string) error { + msg := strings.TrimSpace(env.Message) + if msg == "" { + msg = strings.TrimSpace(env.Error) + } + if msg == "" { + msg = fallback + } + return fmt.Errorf("%s: %s", op, c.sanitize(msg)) +} + +func (c *Client) httpError(op string, status string, body []byte) error { + var env envelope + if len(body) > 0 && json.Unmarshal(body, &env) == nil { + return c.errorFromEnvelope(op, env, status) + } + return fmt.Errorf("%s: %s", op, status) +} + +func (c *Client) sanitize(msg string) string { + msg = strings.TrimSpace(msg) + if msg == "" || c.apiKey == "" { + return msg + } + msg = strings.ReplaceAll(msg, "Bearer "+c.apiKey, "Bearer [redacted]") + msg = strings.ReplaceAll(msg, c.apiKey, "[redacted]") + return msg +} + +// CanonicalOutputFormat maps CLI/ElevenLabs-style output names to the simple +// 60db formats documented for /tts-synthesize. +func CanonicalOutputFormat(format string) string { + format = strings.ToLower(strings.TrimSpace(format)) + switch { + case format == "": + return "" + case strings.HasPrefix(format, "mp3"): + return string(audioFormatMP3) + case strings.HasPrefix(format, "pcm"), strings.HasPrefix(format, "wav"): + return string(audioFormatWAV) + case strings.HasPrefix(format, "opus"), strings.HasPrefix(format, "ogg"): + return string(audioFormatOGG) + case strings.HasPrefix(format, "flac"): + return string(audioFormatFLAC) + default: + return format + } +} + +func parseAudioFormat(value string) audioFormat { + switch strings.ToLower(strings.TrimSpace(value)) { + case string(audioFormatMP3): + return audioFormatMP3 + case string(audioFormatWAV): + return audioFormatWAV + case string(audioFormatOGG): + return audioFormatOGG + case string(audioFormatFLAC): + return audioFormatFLAC + default: + return "" + } +} + +func sniffAudioFormat(data []byte) (audioFormat, error) { + if len(data) == 0 { + return "", errors.New("empty audio") + } + switch { + case len(data) >= 3 && string(data[:3]) == "ID3": + return audioFormatMP3, nil + case len(data) >= 2 && data[0] == 0xff && data[1]&0xe0 == 0xe0: + return audioFormatMP3, nil + case len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WAVE": + return audioFormatWAV, nil + case len(data) >= 4 && string(data[:4]) == "OggS": + return audioFormatOGG, nil + case len(data) >= 4 && string(data[:4]) == "fLaC": + return audioFormatFLAC, nil + default: + return "", errors.New("unrecognized audio format") + } +} type voiceEntry struct { VoiceID string `json:"voice_id"` @@ -72,44 +178,85 @@ type voiceEntry struct { Model string `json:"model"` Labels map[string]string `json:"labels"` Description *string `json:"description"` + Categories []string `json:"categories"` +} + +type voicesResponse struct { + envelope + Data []voiceEntry `json:"data"` } -type myVoicesResponse struct { - Success bool `json:"success"` - Message string `json:"message"` - Data []voiceEntry `json:"data"` +type voiceDetailsResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Language string `json:"language"` + Gender string `json:"gender"` + Age string `json:"age"` + Accent string `json:"accent"` + UseCase []string `json:"use_case"` + SampleURL string `json:"sample_url"` + IsCustom bool `json:"is_custom"` } -func (e voiceEntry) toVoice() tts.Voice { - labels := make(map[string]string, len(e.Labels)+1) +func (e voiceEntry) toVoice(source string) tts.Voice { + labels := make(map[string]string, len(e.Labels)+2) for k, v := range e.Labels { labels[k] = v } - // Fold the model name into labels so `--label model=...` and query ranking - // can match on it, mirroring how ElevenLabs exposes metadata via labels. if e.Model != "" { - if _, ok := labels["model"]; !ok { - labels["model"] = e.Model - } + labels["model"] = e.Model + } + if len(e.Categories) > 0 { + labels["categories"] = strings.Join(e.Categories, ", ") } - desc := "" + if source != "" { + labels["source"] = source + } + description := "" if e.Description != nil { - desc = *e.Description + description = strings.TrimSpace(*e.Description) } return tts.Voice{ VoiceID: e.VoiceID, Name: e.Name, Category: e.Category, - Description: desc, + Description: description, Labels: labels, - // 60db /myvoices exposes no preview URL, so previews are unavailable. - PreviewURL: "", } } -// ListVoices fetches the caller's available 60db voices. -func (c *Client) ListVoices(ctx context.Context) ([]tts.Voice, error) { - req, err := c.newRequest(ctx, http.MethodGet, "/myvoices", nil) +func (v voiceDetailsResponse) toVoice() tts.Voice { + labels := map[string]string{} + if v.Language != "" { + labels["language"] = v.Language + } + if v.Gender != "" { + labels["gender"] = v.Gender + } + if v.Age != "" { + labels["age"] = v.Age + } + if v.Accent != "" { + labels["accent"] = v.Accent + } + if len(v.UseCase) > 0 { + labels["use_case"] = strings.Join(v.UseCase, ", ") + } + if v.IsCustom { + labels["source"] = "myvoices" + } + return tts.Voice{ + VoiceID: v.ID, + Name: v.Name, + Description: strings.TrimSpace(v.Description), + Labels: labels, + PreviewURL: strings.TrimSpace(v.SampleURL), + } +} + +func (c *Client) fetchVoices(ctx context.Context, endpoint, source string) ([]tts.Voice, error) { + req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } @@ -122,23 +269,52 @@ func (c *Client) ListVoices(ctx context.Context) ([]tts.Voice, error) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { - b, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("list voices failed: %s: %s", resp.Status, strings.TrimSpace(string(b))) + body, _ := io.ReadAll(resp.Body) + return nil, c.httpError("list voices", resp.Status, body) } - var body myVoicesResponse + var body voicesResponse if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { - return nil, err + return nil, fmt.Errorf("list voices: decode response: %w", err) + } + if body.Success == nil || !*body.Success { + return nil, c.errorFromEnvelope("list voices", body.envelope, "unexpected response envelope") } + voices := make([]tts.Voice, 0, len(body.Data)) - for _, e := range body.Data { - voices = append(voices, e.toVoice()) + for _, entry := range body.Data { + voices = append(voices, entry.toVoice(source)) } return voices, nil } -// SearchVoices filters the voice list by a case-insensitive name substring. -// 60db has no server-side search, so this is done client-side. +// ListVoices merges the documented default and user-created voice catalogs. +func (c *Client) ListVoices(ctx context.Context) ([]tts.Voice, error) { + defaultVoices, err := c.fetchVoices(ctx, "/default-voices", "default") + if err != nil { + return nil, err + } + myVoices, err := c.fetchVoices(ctx, "/myvoices", "myvoices") + if err != nil { + return nil, err + } + + merged := make([]tts.Voice, 0, len(defaultVoices)+len(myVoices)) + seen := make(map[string]struct{}, len(defaultVoices)+len(myVoices)) + for _, voice := range append(defaultVoices, myVoices...) { + if voice.VoiceID == "" { + continue + } + if _, ok := seen[voice.VoiceID]; ok { + continue + } + seen[voice.VoiceID] = struct{}{} + merged = append(merged, voice) + } + return merged, nil +} + +// SearchVoices filters the merged voice catalog by name. func (c *Client) SearchVoices(ctx context.Context, search string, limit int) ([]tts.Voice, error) { voices, err := c.ListVoices(ctx) if err != nil { @@ -148,9 +324,9 @@ func (c *Client) SearchVoices(ctx context.Context, search string, limit int) ([] if search != "" { searchLower := strings.ToLower(search) filtered := make([]tts.Voice, 0, len(voices)) - for _, v := range voices { - if strings.Contains(strings.ToLower(v.Name), searchLower) { - filtered = append(filtered, v) + for _, voice := range voices { + if strings.Contains(strings.ToLower(voice.Name), searchLower) { + filtered = append(filtered, voice) } } voices = filtered @@ -161,93 +337,83 @@ func (c *Client) SearchVoices(ctx context.Context, search string, limit int) ([] return voices, nil } -// GetVoice returns metadata for a specific voice. 60db has no single-voice -// endpoint, so it resolves against the full list. +// GetVoice resolves a voice from the documented per-voice endpoint. func (c *Client) GetVoice(ctx context.Context, voiceID string) (tts.Voice, error) { - voices, err := c.ListVoices(ctx) + req, err := c.newRequest(ctx, http.MethodGet, path.Join("/voices", voiceID), nil) if err != nil { return tts.Voice{}, err } - for _, v := range voices { - if v.VoiceID == voiceID { - return v, nil - } - } - return tts.Voice{}, fmt.Errorf("voice %q not found", voiceID) -} - -// --- Text to speech ------------------------------------------------------- + req.Header.Set("Accept", "application/json") -type synthesizeRequest struct { - Text string `json:"text"` - VoiceID string `json:"voice_id,omitempty"` - Speed *float64 `json:"speed,omitempty"` - Stability *float64 `json:"stability,omitempty"` - Similarity *float64 `json:"similarity,omitempty"` - OutputFormat string `json:"output_format,omitempty"` -} + resp, err := c.httpClient.Do(req) + if err != nil { + return tts.Voice{}, err + } + defer func() { _ = resp.Body.Close() }() -// buildBody translates a provider-neutral request into 60db's payload. -// includeFormat is false for the streaming endpoint, whose spec omits -// output_format. Fields with no 60db equivalent (model, style, speaker boost, -// seed, normalization, language) are intentionally dropped. -func buildBody(voiceID string, req tts.TTSRequest, includeFormat bool) synthesizeRequest { - body := synthesizeRequest{ - Text: req.Text, - VoiceID: voiceID, + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return tts.Voice{}, c.httpError("get voice", resp.Status, body) } - if vs := req.VoiceSettings; vs != nil { - if vs.Speed != nil { - body.Speed = vs.Speed // both APIs use 0.5..2.0 - } - if vs.Stability != nil { - body.Stability = scaleToHundred(*vs.Stability) - } - if vs.SimilarityBoost != nil { - body.Similarity = scaleToHundred(*vs.SimilarityBoost) - } + + var body voiceDetailsResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return tts.Voice{}, fmt.Errorf("get voice: decode response: %w", err) } - if includeFormat { - body.OutputFormat = toSixtyDBFormat(req.OutputFormat) + if strings.TrimSpace(body.ID) == "" { + return tts.Voice{}, errors.New("get voice: missing voice id") } - return body + return body.toVoice(), nil } -// scaleToHundred converts a 0..1 knob to 60db's 0..100 scale. -func scaleToHundred(v float64) *float64 { - scaled := v * 100 - return &scaled +type synthResponse struct { + envelope + AudioBase64 string `json:"audio_base64"` + SampleRate int `json:"sample_rate"` + DurationSeconds float64 `json:"duration_seconds"` + Encoding string `json:"encoding"` + OutputFormat string `json:"output_format"` } -// toSixtyDBFormat maps ElevenLabs-style format strings to 60db's simple set -// (mp3|wav|ogg|flac). Empty input yields empty (provider default). -func toSixtyDBFormat(format string) string { - format = strings.ToLower(strings.TrimSpace(format)) - switch { - case format == "": - return "" - case strings.HasPrefix(format, "mp3"): - return "mp3" - case strings.HasPrefix(format, "pcm"), strings.HasPrefix(format, "wav"): - return "wav" - case strings.HasPrefix(format, "opus"), strings.HasPrefix(format, "ogg"): - return "ogg" - case strings.HasPrefix(format, "flac"): - return "flac" - default: - return format +func validateRequestedFormat(requested string) error { + if requested == "" { + return nil + } + if parseAudioFormat(requested) == "" { + return fmt.Errorf("unsupported 60db output format %q", requested) } + return nil } -type synthesizeResponse struct { - Success bool `json:"success"` - Message string `json:"message"` - AudioBase64 string `json:"audio_base64"` +func validateConvertResponseFormat(requested string, resp synthResponse, data []byte) error { + sniffed, err := sniffAudioFormat(data) + if err != nil { + return err + } + + declared := parseAudioFormat(resp.OutputFormat) + if declared == "" { + declared = parseAudioFormat(resp.Encoding) + } + if declared != "" && declared != sniffed { + return fmt.Errorf("response format mismatch: declared %s, decoded %s", declared, sniffed) + } + + expected := parseAudioFormat(requested) + if expected != "" && sniffed != expected { + return fmt.Errorf("response format mismatch: requested %s, decoded %s", expected, sniffed) + } + return nil } // ConvertTTS downloads the full audio and returns decoded bytes. -func (c *Client) ConvertTTS(ctx context.Context, voiceID string, payload tts.TTSRequest) ([]byte, error) { - bodyBytes, err := json.Marshal(buildBody(voiceID, payload, true)) +func (c *Client) ConvertTTS(ctx context.Context, reqBody TTSRequest) ([]byte, error) { + reqBody.OutputFormat = CanonicalOutputFormat(reqBody.OutputFormat) + if err := validateRequestedFormat(reqBody.OutputFormat); err != nil { + return nil, err + } + + bodyBytes, err := json.Marshal(reqBody) if err != nil { return nil, err } @@ -265,28 +431,52 @@ func (c *Client) ConvertTTS(ctx context.Context, voiceID string, payload tts.TTS defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { - b, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("convert TTS failed: %s: %s", resp.Status, strings.TrimSpace(string(b))) + body, _ := io.ReadAll(resp.Body) + return nil, c.httpError("synthesize audio", resp.Status, body) } - var body synthesizeResponse + var body synthResponse if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { - return nil, err + return nil, fmt.Errorf("synthesize audio: decode response: %w", err) + } + if body.Success == nil || !*body.Success { + return nil, c.errorFromEnvelope("synthesize audio", body.envelope, "unexpected response envelope") } - if !body.Success && body.Message != "" { - return nil, fmt.Errorf("convert TTS failed: %s", body.Message) + if strings.TrimSpace(body.AudioBase64) == "" { + return nil, errors.New("synthesize audio: empty audio_base64") } - data, err := base64.StdEncoding.DecodeString(body.AudioBase64) + + decodedLen := base64.StdEncoding.DecodedLen(len(body.AudioBase64)) + if decodedLen <= 0 { + return nil, errors.New("synthesize audio: empty decoded audio") + } + if decodedLen > maxDecodedAudioBytes { + return nil, fmt.Errorf("synthesize audio: decoded audio exceeds %d bytes", maxDecodedAudioBytes) + } + + data := make([]byte, decodedLen) + n, err := base64.StdEncoding.Decode(data, []byte(body.AudioBase64)) if err != nil { - return nil, fmt.Errorf("decode audio_base64: %w", err) + return nil, fmt.Errorf("synthesize audio: decode audio_base64: %w", err) + } + data = data[:n] + if len(data) == 0 { + return nil, errors.New("synthesize audio: decoded audio was empty") + } + if err := validateConvertResponseFormat(reqBody.OutputFormat, body, data); err != nil { + return nil, fmt.Errorf("synthesize audio: %w", err) } return data, nil } -// StreamTTS requests streaming audio and returns a reader that yields decoded -// audio bytes. The latency argument is ignored (60db's stream has no tier). -func (c *Client) StreamTTS(ctx context.Context, voiceID string, payload tts.TTSRequest, _ int) (io.ReadCloser, error) { - bodyBytes, err := json.Marshal(buildBody(voiceID, payload, false)) +// StreamTTS requests streaming audio. The documented stream API does not +// accept output_format. +func (c *Client) StreamTTS(ctx context.Context, reqBody TTSRequest) (io.ReadCloser, error) { + if CanonicalOutputFormat(reqBody.OutputFormat) != "" { + return nil, errors.New("stream audio: output_format is not supported by /tts-stream") + } + + bodyBytes, err := json.Marshal(reqBody) if err != nil { return nil, err } @@ -303,14 +493,12 @@ func (c *Client) StreamTTS(ctx context.Context, voiceID string, payload tts.TTSR } if resp.StatusCode >= 400 { defer func() { _ = resp.Body.Close() }() - b, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("stream TTS failed: %s: %s", resp.Status, strings.TrimSpace(string(b))) + body, _ := io.ReadAll(resp.Body) + return nil, c.httpError("stream audio", resp.Status, body) } - return newNDJSONAudioReader(resp.Body), nil + return newNDJSONAudioReader(ctx, resp.Body, c.sanitize), nil } -// --- NDJSON streaming decoder --------------------------------------------- - type streamFrame struct { Type string `json:"type"` Result struct { @@ -319,21 +507,37 @@ type streamFrame struct { Message string `json:"message"` } -// ndjsonAudioReader unwraps 60db's newline-delimited JSON stream into a plain -// audio byte stream. Each "chunk" frame's base64 audio is decoded and served; -// "complete" ends the stream; "error" surfaces the message. type ndjsonAudioReader struct { - src io.ReadCloser - reader *bufio.Reader - pending []byte - err error + ctx context.Context + src io.ReadCloser + scanner *bufio.Scanner + pending []byte + err error + stop func() bool + sanitize func(string) string + + totalBytes int64 + sawChunk bool + sawComplete bool } -func newNDJSONAudioReader(src io.ReadCloser) *ndjsonAudioReader { - return &ndjsonAudioReader{ - src: src, - reader: bufio.NewReader(src), +func newNDJSONAudioReader(ctx context.Context, src io.ReadCloser, sanitize func(string) string) *ndjsonAudioReader { + scanner := bufio.NewScanner(src) + scanner.Buffer(make([]byte, 64*1024), maxStreamFrameBytes) + if sanitize == nil { + sanitize = func(msg string) string { return msg } + } + + reader := &ndjsonAudioReader{ + ctx: ctx, + src: src, + scanner: scanner, + sanitize: sanitize, } + reader.stop = context.AfterFunc(ctx, func() { + _ = src.Close() + }) + return reader } func (r *ndjsonAudioReader) Read(p []byte) (int, error) { @@ -353,45 +557,98 @@ func (r *ndjsonAudioReader) Read(p []byte) (int, error) { return n, nil } -// fill reads and decodes the next non-empty frame into r.pending. func (r *ndjsonAudioReader) fill() error { - for { - line, err := r.reader.ReadBytes('\n') - trimmed := bytes.TrimSpace(line) - if len(trimmed) > 0 { - var frame streamFrame - if jerr := json.Unmarshal(trimmed, &frame); jerr != nil { - return fmt.Errorf("decode stream frame: %w", jerr) + if err := r.ctx.Err(); err != nil { + return err + } + for r.scanner.Scan() { + line := bytes.TrimSpace(r.scanner.Bytes()) + if len(line) == 0 { + continue + } + + var frame streamFrame + if err := json.Unmarshal(line, &frame); err != nil { + return fmt.Errorf("decode stream frame: %w", err) + } + + switch frame.Type { + case "chunk": + audio, err := decodeChunk(frame.Result.AudioContent, r.totalBytes) + if err != nil { + return err } - switch frame.Type { - case "chunk": - audio, derr := base64.StdEncoding.DecodeString(frame.Result.AudioContent) - if derr != nil { - return fmt.Errorf("decode audio chunk: %w", derr) + if !r.sawChunk { + if _, err := sniffAudioFormat(audio); err != nil { + return fmt.Errorf("unknown streamed audio format: %w", err) } - if len(audio) > 0 { - r.pending = audio - return nil - } - case "complete": - return io.EOF - case "error": - if frame.Message != "" { - return fmt.Errorf("stream error: %s", frame.Message) - } - return fmt.Errorf("stream error") } - // Unknown frame types are ignored; keep reading. - } - if err != nil { - if err == io.EOF { - return io.EOF + r.sawChunk = true + r.totalBytes += int64(len(audio)) + r.pending = audio + return nil + case "complete": + r.sawComplete = true + if !r.sawChunk { + return errors.New("stream completed without audio") } - return err + return io.EOF + case "error": + if msg := strings.TrimSpace(frame.Message); msg != "" { + return fmt.Errorf("stream error: %s", r.sanitize(msg)) + } + return errors.New("stream error") + default: + return fmt.Errorf("unknown stream frame type %q", frame.Type) + } + } + + if err := r.scanner.Err(); err != nil { + if ctxErr := r.ctx.Err(); ctxErr != nil { + return ctxErr } + return fmt.Errorf("read stream frame: %w", err) + } + if ctxErr := r.ctx.Err(); ctxErr != nil { + return ctxErr + } + if r.sawComplete { + return io.EOF + } + return io.ErrUnexpectedEOF +} + +func decodeChunk(encoded string, totalBytes int64) ([]byte, error) { + encoded = strings.TrimSpace(encoded) + if encoded == "" { + return nil, errors.New("stream chunk missing audioContent") + } + decodedLen := base64.StdEncoding.DecodedLen(len(encoded)) + if decodedLen <= 0 { + return nil, errors.New("stream chunk decoded to empty audio") + } + if decodedLen > maxDecodedChunkBytes { + return nil, fmt.Errorf("stream chunk exceeds %d bytes", maxDecodedChunkBytes) + } + if totalBytes+int64(decodedLen) > maxDecodedAudioBytes { + return nil, fmt.Errorf("stream audio exceeds %d bytes", maxDecodedAudioBytes) } + + audio := make([]byte, decodedLen) + n, err := base64.StdEncoding.Decode(audio, []byte(encoded)) + if err != nil { + return nil, fmt.Errorf("decode audio chunk: %w", err) + } + audio = audio[:n] + if len(audio) == 0 { + return nil, errors.New("stream chunk decoded to empty audio") + } + return audio, nil } func (r *ndjsonAudioReader) Close() error { + if r.stop != nil { + r.stop() + } return r.src.Close() } diff --git a/internal/sixtydb/client_test.go b/internal/sixtydb/client_test.go index cebcd8a..dd9faf4 100644 --- a/internal/sixtydb/client_test.go +++ b/internal/sixtydb/client_test.go @@ -1,16 +1,17 @@ package sixtydb import ( + "bytes" "context" "encoding/base64" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "strings" "testing" - - "github.com/steipete/sag/internal/tts" + "time" ) func TestNewClientDefaultsBase(t *testing.T) { @@ -20,19 +21,27 @@ func TestNewClientDefaultsBase(t *testing.T) { } } -func TestListVoicesUnwrapsData(t *testing.T) { - desc := "warm narrator" +func TestListVoicesMergesDefaultAndMyVoices(t *testing.T) { + var calls []string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/myvoices" { - t.Fatalf("unexpected path: %s", r.URL.Path) - } + calls = append(calls, r.URL.Path) if got := r.Header.Get("Authorization"); got != "Bearer key" { t.Fatalf("unexpected auth header: %q", got) } - _, _ = w.Write([]byte(`{"success":true,"message":"ok","data":[ - {"voice_id":"v1","name":"Aria","category":"professional","model":"60db Quality","labels":{"gender":"female","accent":"American"},"description":"` + desc + `"}, - {"voice_id":"v2","name":"Ravi","category":"cloned","model":"60db Fast","labels":{"gender":"male"},"description":null} - ]}`)) + switch r.URL.Path { + case "/default-voices": + _, _ = io.WriteString(w, `{"success":true,"data":[ + {"voice_id":"v1","name":"Aria","category":"default","model":"60db Quality","labels":{"accent":"US"}}, + {"voice_id":"dup","name":"Default Dup","category":"default"} + ]}`) + case "/myvoices": + _, _ = io.WriteString(w, `{"success":true,"data":[ + {"voice_id":"dup","name":"My Dup","category":"cloned"}, + {"voice_id":"v2","name":"Ravi","category":"cloned","categories":["narration","warm"]} + ]}`) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } })) defer srv.Close() @@ -41,30 +50,36 @@ func TestListVoicesUnwrapsData(t *testing.T) { if err != nil { t.Fatalf("ListVoices error: %v", err) } - if len(voices) != 2 { - t.Fatalf("expected 2 voices, got %d", len(voices)) + if want := []string{"/default-voices", "/myvoices"}; strings.Join(calls, ",") != strings.Join(want, ",") { + t.Fatalf("route order = %v, want %v", calls, want) } - if voices[0].VoiceID != "v1" || voices[0].Name != "Aria" || voices[0].Category != "professional" { - t.Fatalf("unexpected voice[0]: %+v", voices[0]) + if len(voices) != 3 { + t.Fatalf("expected 3 merged voices, got %d", len(voices)) } - if voices[0].Description != desc { - t.Fatalf("expected description %q, got %q", desc, voices[0].Description) + if voices[0].VoiceID != "v1" || voices[0].Labels["source"] != "default" || voices[0].Labels["model"] != "60db Quality" { + t.Fatalf("unexpected default voice: %+v", voices[0]) } - if voices[0].Labels["model"] != "60db Quality" { - t.Fatalf("expected model folded into labels, got %+v", voices[0].Labels) + if voices[1].VoiceID != "dup" || voices[1].Name != "Default Dup" { + t.Fatalf("expected default duplicate to win, got %+v", voices[1]) } - if voices[1].Description != "" { - t.Fatalf("expected empty description for null, got %q", voices[1].Description) + if voices[2].VoiceID != "v2" || voices[2].Labels["source"] != "myvoices" || voices[2].Labels["categories"] != "narration, warm" { + t.Fatalf("unexpected user voice: %+v", voices[2]) } } func TestSearchVoicesFiltersAndLimits(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, _ = w.Write([]byte(`{"success":true,"data":[ - {"voice_id":"v1","name":"Roger"}, - {"voice_id":"v2","name":"Rogue"}, - {"voice_id":"v3","name":"Sarah"} - ]}`)) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/default-voices": + _, _ = io.WriteString(w, `{"success":true,"data":[ + {"voice_id":"v1","name":"Roger"}, + {"voice_id":"v2","name":"Rogue"} + ]}`) + case "/myvoices": + _, _ = io.WriteString(w, `{"success":true,"data":[{"voice_id":"v3","name":"Sarah"}]}`) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } })) defer srv.Close() @@ -73,16 +88,61 @@ func TestSearchVoicesFiltersAndLimits(t *testing.T) { if err != nil { t.Fatalf("SearchVoices error: %v", err) } - if len(voices) != 1 { - t.Fatalf("expected 1 voice after limit, got %d", len(voices)) + if len(voices) != 1 || voices[0].VoiceID != "v1" { + t.Fatalf("unexpected voices: %+v", voices) } - if voices[0].VoiceID != "v1" { - t.Fatalf("expected v1, got %s", voices[0].VoiceID) +} + +func TestGetVoiceUsesDocumentedPerVoiceRoute(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/voices/voice-001" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _, _ = io.WriteString(w, `{ + "id":"voice-001", + "name":"Sarah", + "description":"Professional female voice", + "language":"en-US", + "gender":"female", + "age":"middle", + "accent":"American", + "use_case":["narration","customer-service"], + "sample_url":"https://cdn.60db.com/samples/voice-001.mp3", + "is_custom":false + }`) + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + voice, err := c.GetVoice(context.Background(), "voice-001") + if err != nil { + t.Fatalf("GetVoice error: %v", err) } + if voice.VoiceID != "voice-001" || voice.Name != "Sarah" { + t.Fatalf("unexpected voice: %+v", voice) + } + if voice.PreviewURL != "https://cdn.60db.com/samples/voice-001.mp3" { + t.Fatalf("unexpected preview URL: %q", voice.PreviewURL) + } + if voice.Labels["use_case"] != "narration, customer-service" || voice.Labels["accent"] != "American" { + t.Fatalf("unexpected voice labels: %+v", voice.Labels) + } +} + +func TestListVoicesRejectsHTTP200ErrorEnvelopeWithoutLeakingToken(t *testing.T) { + const secret = "secret-token" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"success":false,"message":"Bearer `+secret+` invalid token"}`) + })) + defer srv.Close() + + c := NewClient(secret, srv.URL) + _, err := c.ListVoices(context.Background()) + assertSanitizedError(t, err, secret, "invalid token") } -func TestConvertTTSDecodesBase64AndTranslatesParams(t *testing.T) { - want := []byte("decoded-audio-bytes") +func TestConvertTTSUsesDocumentedRouteAndValidatesResponse(t *testing.T) { + audio := []byte("ID3converted-audio") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/tts-synthesize" { t.Fatalf("unexpected path: %s", r.URL.Path) @@ -94,71 +154,108 @@ func TestConvertTTSDecodesBase64AndTranslatesParams(t *testing.T) { if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("decode body: %v", err) } - if body["text"] != "hi" { - t.Fatalf("expected text hi, got %v", body["text"]) - } - if body["voice_id"] != "v1" { - t.Fatalf("expected voice_id v1, got %v", body["voice_id"]) - } - // 0.5 stability -> 50, 0.8 similarity -> 80 (0..1 -> 0..100) - if body["stability"] != float64(50) { - t.Fatalf("expected stability 50, got %v", body["stability"]) - } - if body["similarity"] != float64(80) { - t.Fatalf("expected similarity 80, got %v", body["similarity"]) - } - if body["speed"] != 1.1 { - t.Fatalf("expected speed 1.1, got %v", body["speed"]) + if body["text"] != "hi" || body["voice_id"] != "v1" { + t.Fatalf("unexpected request body: %+v", body) } - // mp3_44100_128 -> mp3 - if body["output_format"] != "mp3" { - t.Fatalf("expected output_format mp3, got %v", body["output_format"]) + if body["speed"] != 1.1 || body["stability"] != 50.0 || body["similarity"] != 80.0 || body["output_format"] != "mp3" { + t.Fatalf("unexpected mapped request body: %+v", body) } - // ElevenLabs-only fields must not appear. - for _, k := range []string{"model_id", "style", "use_speaker_boost", "seed", "language_code"} { - if _, ok := body[k]; ok { - t.Fatalf("expected %q to be absent from 60db body", k) - } + + resp := map[string]any{ + "success": true, + "audio_base64": base64.StdEncoding.EncodeToString(audio), + "encoding": "mp3", + "output_format": "mp3", } - resp := map[string]any{"success": true, "audio_base64": base64.StdEncoding.EncodeToString(want)} _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() - stability := 0.5 - similarity := 0.8 speed := 1.1 + stability := 50.0 + similarity := 80.0 c := NewClient("key", srv.URL) - data, err := c.ConvertTTS(context.Background(), "v1", tts.TTSRequest{ + got, err := c.ConvertTTS(context.Background(), TTSRequest{ Text: "hi", - ModelID: "eleven_v3", + VoiceID: "v1", + Speed: &speed, + Stability: &stability, + Similarity: &similarity, OutputFormat: "mp3_44100_128", - VoiceSettings: &tts.VoiceSettings{ - Stability: &stability, - SimilarityBoost: &similarity, - Speed: &speed, - }, }) if err != nil { t.Fatalf("ConvertTTS error: %v", err) } - if string(data) != string(want) { - t.Fatalf("unexpected decoded audio: %q", string(data)) + if !bytes.Equal(got, audio) { + t.Fatalf("decoded audio = %q, want %q", got, audio) + } +} + +func TestConvertTTSRejectsHTTP200ErrorEnvelopeWithoutLeakingToken(t *testing.T) { + const secret = "secret-token" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"success":false,"message":"Bearer `+secret+` invalid token"}`) + })) + defer srv.Close() + + c := NewClient(secret, srv.URL) + _, err := c.ConvertTTS(context.Background(), TTSRequest{Text: "hi"}) + assertSanitizedError(t, err, secret, "invalid token") +} + +func TestConvertTTSRejectsMalformedOrMismatchedAudio(t *testing.T) { + tests := []struct { + name string + body string + want string + }{ + { + name: "empty audio", + body: `{"success":true,"audio_base64":""}`, + want: "empty audio_base64", + }, + { + name: "unknown format", + body: `{"success":true,"audio_base64":"` + mustBase64([]byte("not-audio")) + `"}`, + want: "unrecognized audio format", + }, + { + name: "declared format mismatch", + body: `{"success":true,"audio_base64":"` + mustBase64([]byte("ID3mp3-data")) + `","output_format":"wav"}`, + want: "declared wav, decoded mp3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, tt.body) + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + _, err := c.ConvertTTS(context.Background(), TTSRequest{Text: "hi"}) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected %q, got %v", tt.want, err) + } + }) } } -func TestStreamTTSDecodesNDJSON(t *testing.T) { - chunk1 := base64.StdEncoding.EncodeToString([]byte("hello-")) - chunk2 := base64.StdEncoding.EncodeToString([]byte("world")) +func TestStreamTTSUsesDocumentedRouteAndDecodesNDJSON(t *testing.T) { + chunk1 := mustBase64([]byte("ID3hello-")) + chunk2 := mustBase64([]byte("world")) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/tts-stream" { t.Fatalf("unexpected path: %s", r.URL.Path) } + if got := r.Header.Get("Authorization"); got != "Bearer key" { + t.Fatalf("unexpected auth header: %q", got) + } var body map[string]any if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("decode body: %v", err) } - // streaming body omits output_format if _, ok := body["output_format"]; ok { t.Fatalf("expected output_format omitted from stream body") } @@ -169,52 +266,160 @@ func TestStreamTTSDecodesNDJSON(t *testing.T) { defer srv.Close() c := NewClient("key", srv.URL) - rc, err := c.StreamTTS(context.Background(), "v1", tts.TTSRequest{Text: "hi"}, 0) + rc, err := c.StreamTTS(context.Background(), TTSRequest{Text: "hi"}) if err != nil { t.Fatalf("StreamTTS error: %v", err) } defer func() { _ = rc.Close() }() + got, err := io.ReadAll(rc) if err != nil { t.Fatalf("read stream: %v", err) } - if string(got) != "hello-world" { - t.Fatalf("unexpected decoded stream: %q", string(got)) + if string(got) != "ID3hello-world" { + t.Fatalf("unexpected stream body: %q", got) } } -func TestStreamTTSSurfacesErrorFrame(t *testing.T) { +func TestStreamTTSRejectsInvalidTokenWithoutLeakingToken(t *testing.T) { + const secret = "secret-token" srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, `{"type":"error","message":"voice not found"}`+"\n") + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, `{"success":false,"message":"Bearer `+secret+` invalid token"}`) })) defer srv.Close() - c := NewClient("key", srv.URL) - rc, err := c.StreamTTS(context.Background(), "v1", tts.TTSRequest{Text: "hi"}, 0) - if err != nil { - t.Fatalf("StreamTTS error: %v", err) + c := NewClient(secret, srv.URL) + _, err := c.StreamTTS(context.Background(), TTSRequest{Text: "hi"}) + assertSanitizedError(t, err, secret, "invalid token") +} + +func TestStreamTTSRejectsMalformedStreams(t *testing.T) { + tests := []struct { + name string + lines []string + want string + isErr error + }{ + { + name: "bad json", + lines: []string{`not-json`}, + want: "decode stream frame", + }, + { + name: "unknown frame type", + lines: []string{`{"type":"wat"}`}, + want: `unknown stream frame type "wat"`, + }, + { + name: "missing audio content", + lines: []string{`{"type":"chunk","result":{"audioContent":""}}`}, + want: "missing audioContent", + }, + { + name: "unknown audio format", + lines: []string{`{"type":"chunk","result":{"audioContent":"` + mustBase64([]byte("bad")) + `"}}`}, + want: "unknown streamed audio format", + }, + { + name: "complete without audio", + lines: []string{`{"type":"complete"}`}, + want: "stream completed without audio", + }, + { + name: "empty stream", + lines: nil, + isErr: io.ErrUnexpectedEOF, + }, + { + name: "missing complete frame", + lines: []string{`{"type":"chunk","result":{"audioContent":"` + mustBase64([]byte("ID3chunk")) + `"}}`}, + isErr: io.ErrUnexpectedEOF, + }, } - defer func() { _ = rc.Close() }() - _, err = io.ReadAll(rc) - if err == nil || !strings.Contains(err.Error(), "voice not found") { - t.Fatalf("expected stream error surfaced, got %v", err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := newNDJSONAudioReader(context.Background(), io.NopCloser(strings.NewReader(strings.Join(tt.lines, "\n"))), nil) + defer func() { _ = reader.Close() }() + + _, err := io.ReadAll(reader) + if tt.isErr != nil { + if !errors.Is(err, tt.isErr) { + t.Fatalf("expected %v, got %v", tt.isErr, err) + } + return + } + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected %q, got %v", tt.want, err) + } + }) } } -func TestStreamTTSHTTPError(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - http.Error(w, "bad", http.StatusBadRequest) - })) - defer srv.Close() +func TestNDJSONAudioReaderHonorsCancellation(t *testing.T) { + pr, pw := io.Pipe() + ctx, cancel := context.WithCancel(context.Background()) + reader := newNDJSONAudioReader(ctx, pr, nil) + defer func() { _ = reader.Close() }() - c := NewClient("key", srv.URL) - _, err := c.StreamTTS(context.Background(), "v1", tts.TTSRequest{Text: "hi"}, 0) - if err == nil || !strings.Contains(err.Error(), "400") { - t.Fatalf("expected 400 error, got %v", err) + done := make(chan error, 1) + go func() { + _, err := io.ReadAll(reader) + done <- err + }() + + cancel() + _ = pw.Close() + + select { + case err := <-done: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("stream read did not unblock on cancellation") } } -func TestToSixtyDBFormat(t *testing.T) { +func TestNDJSONAudioReaderRejectsOversizedFrame(t *testing.T) { + line := strings.Repeat("a", maxStreamFrameBytes+1) + reader := newNDJSONAudioReader(context.Background(), io.NopCloser(strings.NewReader(line)), nil) + defer func() { _ = reader.Close() }() + + _, err := io.ReadAll(reader) + if err == nil || !strings.Contains(err.Error(), "token too long") { + t.Fatalf("expected oversized frame error, got %v", err) + } +} + +func TestNDJSONAudioReaderSanitizesErrorFrames(t *testing.T) { + const secret = "secret-token" + reader := newNDJSONAudioReader( + context.Background(), + io.NopCloser(strings.NewReader(`{"type":"error","message":"Bearer `+secret+` invalid token"}`)), + func(msg string) string { + return strings.ReplaceAll(msg, secret, "[redacted]") + }, + ) + defer func() { _ = reader.Close() }() + + _, err := io.ReadAll(reader) + assertSanitizedError(t, err, secret, "invalid token") +} + +func TestDecodeChunkRejectsPerChunkAndTotalLimits(t *testing.T) { + tooLargeChunk := mustBase64(bytes.Repeat([]byte{'a'}, maxDecodedChunkBytes+1)) + if _, err := decodeChunk(tooLargeChunk, 0); err == nil || !strings.Contains(err.Error(), "stream chunk exceeds") { + t.Fatalf("expected chunk limit error, got %v", err) + } + + if _, err := decodeChunk(mustBase64([]byte("abc")), maxDecodedAudioBytes-2); err == nil || !strings.Contains(err.Error(), "stream audio exceeds") { + t.Fatalf("expected total limit error, got %v", err) + } +} + +func TestCanonicalOutputFormat(t *testing.T) { cases := map[string]string{ "mp3_44100_128": "mp3", "pcm_44100": "wav", @@ -225,8 +430,28 @@ func TestToSixtyDBFormat(t *testing.T) { "": "", } for in, want := range cases { - if got := toSixtyDBFormat(in); got != want { - t.Fatalf("toSixtyDBFormat(%q) = %q, want %q", in, got, want) + if got := CanonicalOutputFormat(in); got != want { + t.Fatalf("CanonicalOutputFormat(%q) = %q, want %q", in, got, want) } } } + +func mustBase64(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +func assertSanitizedError(t *testing.T, err error, secret, want string) { + t.Helper() + if err == nil { + t.Fatal("expected error") + } + if strings.Contains(err.Error(), secret) { + t.Fatalf("error leaked secret: %v", err) + } + if !strings.Contains(err.Error(), "[redacted]") { + t.Fatalf("expected redacted marker, got %v", err) + } + if !strings.Contains(err.Error(), want) { + t.Fatalf("expected %q in error, got %v", want, err) + } +} diff --git a/internal/tts/tts.go b/internal/tts/tts.go index a5e4cb4..c426570 100644 --- a/internal/tts/tts.go +++ b/internal/tts/tts.go @@ -1,17 +1,13 @@ -// Package tts defines provider-neutral types and the Provider interface shared -// by every text-to-speech backend (ElevenLabs, 60db, ...). -// -// Keeping the types here lets each provider implementation translate its own -// wire format to and from a single shared shape, so the command layer and the -// audio player never need to know which backend produced the audio. +// Package tts holds the small voice-catalog types shared by the existing +// `voices` and voice-resolution command paths. package tts import ( "context" - "io" ) -// Voice represents a single voice entry, normalized across providers. +// Voice represents a single voice entry normalized for the CLI's voice listing, +// filtering, query, and resolution paths. type Voice struct { VoiceID string `json:"voice_id"` Name string `json:"name"` @@ -21,38 +17,9 @@ type Voice struct { PreviewURL string `json:"preview_url"` } -// VoiceSettings tunes synthesis parameters for a request. All fields are -// pointers so unset knobs are omitted from the wire payload and the provider's -// own defaults apply. Stability/SimilarityBoost/Style use the 0..1 scale; each -// provider translates to its native range. -type VoiceSettings struct { - Stability *float64 `json:"stability,omitempty"` - SimilarityBoost *float64 `json:"similarity_boost,omitempty"` - Style *float64 `json:"style,omitempty"` - UseSpeakerBoost *bool `json:"use_speaker_boost,omitempty"` - Speed *float64 `json:"speed,omitempty"` -} - -// TTSRequest configures a text-to-speech request. Some fields are honored only -// by certain providers (e.g. ModelID/Seed/LanguageCode are ElevenLabs-specific -// and ignored by 60db); the provider implementation decides what to send. -type TTSRequest struct { - Text string `json:"text"` - ModelID string `json:"model_id,omitempty"` - VoiceSettings *VoiceSettings `json:"voice_settings,omitempty"` - OutputFormat string `json:"output_format,omitempty"` - Seed *uint32 `json:"seed,omitempty"` - ApplyTextNormalization string `json:"apply_text_normalization,omitempty"` - LanguageCode string `json:"language_code,omitempty"` -} - -// Provider is the contract every TTS backend implements. StreamTTS and -// ConvertTTS must return raw, ready-to-play audio bytes (decoded/unwrapped from -// any provider-specific envelope) so the audio layer stays provider-agnostic. -type Provider interface { +// VoiceCatalog is the minimal shared interface needed by existing commands. +type VoiceCatalog interface { ListVoices(ctx context.Context) ([]Voice, error) SearchVoices(ctx context.Context, search string, limit int) ([]Voice, error) GetVoice(ctx context.Context, voiceID string) (Voice, error) - StreamTTS(ctx context.Context, voiceID string, req TTSRequest, latency int) (io.ReadCloser, error) - ConvertTTS(ctx context.Context, voiceID string, req TTSRequest) ([]byte, error) }