diff --git a/CHANGELOG.md b/CHANGELOG.md index ee85293..92f377f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Changelog ## Unreleased +### Added +- Documented 60db provider support via `/default-voices`, `/myvoices`, `/tts-stream`, and `/tts-synthesize`, with strict provider selection and response validation. (#20, thanks @manishEMS47) ### Changed - Release archives now include target-specific macOS and Linux assets for Homebrew and aqua installers. diff --git a/README.md b/README.md index e51c879..b0b79b6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# sag 🗣️ — “Mac-style speech with ElevenLabs” +# sag 🗣️ — “Mac-style speech with ElevenLabs or 60db” One-liner TTS that works like `say`: stream to speakers by default, list voices, or save audio files. @@ -24,9 +24,28 @@ sudo apt install build-essential pkg-config libasound2-dev ``` ## Configuration -- `ELEVENLABS_API_KEY` (required) -- `--api-key-file` or `ELEVENLABS_API_KEY_FILE`/`SAG_API_KEY_FILE` to load the key from a file -- Optional defaults: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID` + +`sag` supports two TTS providers and auto-selects one from your configured credentials: + +- **ElevenLabs** — `ELEVENLABS_API_KEY` (or `--api-key`, or `--api-key-file` / `ELEVENLABS_API_KEY_FILE` / `SAG_API_KEY_FILE`) +- **60db** (`api.60db.ai`) — `SIXTYDB_API_KEY` (or `SIXTYDB_API_KEY_FILE`) + +Selection rules: +- Only one key set → that provider is used. +- Both keys set → error; unset one provider key and retry. +- Neither set → error. + +Optional ElevenLabs defaults: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. Override the active provider host with `--base-url`. + +60db support is intentionally narrow and follows the documented HTTP API: +- voice discovery merges `GET /default-voices` and `GET /myvoices` +- default MP3 streaming uses `POST /tts-stream` +- explicit file formats and full downloads use `POST /tts-synthesize` +- the `success` envelope is validated even on HTTP 200 responses + +Shared speak flags that work on both providers include `--voice`, `--speed` / `--rate`, `--stability`, `--similarity`, `--format`, `--stream`, `--play`, `--output`, `--timeout`, and `--metrics`. + +ElevenLabs-only speak flags fail fast on 60db: `--model-id`, `--style`, `--speaker-boost` / `--no-speaker-boost`, `--seed`, `--normalize`, `--lang`, and `--latency-tier`. `--stability` / `--similarity` still use the CLI's `0..1` range and are scaled to 60db's documented `0..100` API values. See [docs/providers.md](docs/providers.md) for provider-specific details. ## Usage @@ -61,6 +80,7 @@ sag speak -v Roger --stream --latency-tier 3 "Faster start" sag speak -v Roger --speed 1.2 "Talk a bit faster" sag speak -v Roger --model-id eleven_multilingual_v2 "Use stable v2 baseline" sag speak -v Roger --output out.wav --format pcm_44100 "Wave output" +SIXTYDB_API_KEY=... sag speak -v Aria --output out.wav --no-play "60db WAV output" ``` Key flags (subset): @@ -149,6 +169,6 @@ ffprobe -v quiet -show_entries format=duration -of csv=p=0 long.mp3 - Build: `go build ./cmd/sag` ## Limitations -- ElevenLabs account and API key required. +- One provider API key is required. - Voice defaults to first available if not provided. - Non-mac platforms: playback still works via `go-mp3` + `oto`, but device selection flags are no-ops. diff --git a/cmd/api_key.go b/cmd/api_key.go index 8b22232..ff8e9c0 100644 --- a/cmd/api_key.go +++ b/cmd/api_key.go @@ -6,23 +6,39 @@ import ( "strings" ) -func ensureAPIKey() error { - if cfg.APIKey == "" { - key, err := resolveAPIKeyFromFile() - if err != nil { - return err - } - cfg.APIKey = key +// resolveElevenLabsKey resolves the ElevenLabs API key without erroring when +// absent (returns ""). Order: --api-key, key file, ELEVENLABS_API_KEY, +// SAG_API_KEY. +func resolveElevenLabsKey() (string, error) { + if cfg.APIKey != "" { + return cfg.APIKey, nil + } + key, err := resolveAPIKeyFromFile() + if err != nil { + return "", err + } + if key != "" { + return key, nil + } + if v := os.Getenv("ELEVENLABS_API_KEY"); v != "" { + return v, nil } - if cfg.APIKey == "" { - cfg.APIKey = os.Getenv("ELEVENLABS_API_KEY") + if v := os.Getenv("SAG_API_KEY"); v != "" { + return v, nil } - if cfg.APIKey == "" { - cfg.APIKey = os.Getenv("SAG_API_KEY") + return "", nil +} + +// ensureAPIKey resolves and stores the ElevenLabs API key, erroring if missing. +func ensureAPIKey() error { + key, err := resolveElevenLabsKey() + if err != nil { + return err } - if cfg.APIKey == "" { + if key == "" { return fmt.Errorf("missing ElevenLabs API key (set --api-key, --api-key-file, or ELEVENLABS_API_KEY)") } + cfg.APIKey = key return nil } diff --git a/cmd/prompting_test.go b/cmd/prompting_test.go index 756f091..5ac5012 100644 --- a/cmd/prompting_test.go +++ b/cmd/prompting_test.go @@ -6,6 +6,8 @@ import ( ) func TestPromptingCommandOutputsGuide(t *testing.T) { + resetRootCommandState() + restore, read := captureStdout(t) defer restore() diff --git a/cmd/provider.go b/cmd/provider.go new file mode 100644 index 0000000..4fcf500 --- /dev/null +++ b/cmd/provider.go @@ -0,0 +1,102 @@ +package cmd + +import ( + "fmt" + "os" + "strings" + + "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/sixtydb" + "github.com/steipete/sag/internal/tts" +) + +const ( + providerElevenLabs = "elevenlabs" + providerSixtyDB = "60db" +) + +type activeProvider struct { + name string + voices tts.VoiceCatalog + + elevenlabs *elevenlabs.Client + sixtydb *sixtydb.Client +} + +// resolveSixtyDBKey resolves the 60db API key from its dedicated env vars. +// Order: SIXTYDB_API_KEY, then SIXTYDB_API_KEY_FILE. +func resolveSixtyDBKey() (string, error) { + if key := strings.TrimSpace(os.Getenv("SIXTYDB_API_KEY")); key != "" { + return key, nil + } + if path := strings.TrimSpace(os.Getenv("SIXTYDB_API_KEY_FILE")); path != "" { + data, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("read 60db api key file: %w", err) + } + key := strings.TrimSpace(string(data)) + if key == "" { + return "", fmt.Errorf("60db api key file %q is empty", path) + } + return key, nil + } + return "", nil +} + +func ensureProviderConfigured() error { + _, err := selectProvider() + return err +} + +func selectProvider() (activeProvider, error) { + elKey, err := resolveElevenLabsKey() + if err != nil { + return activeProvider{}, err + } + sdKey, err := resolveSixtyDBKey() + if err != nil { + return activeProvider{}, err + } + + switch { + case elKey != "" && sdKey != "": + return activeProvider{}, fmt.Errorf("ambiguous provider configuration: both ElevenLabs and 60db keys are set; unset one provider key and retry") + case elKey != "": + client := elevenlabs.NewClient(elKey, cfg.BaseURL) + return activeProvider{ + name: providerElevenLabs, + voices: client, + elevenlabs: client, + }, nil + case sdKey != "": + client := sixtydb.NewClient(sdKey, cfg.BaseURL) + return activeProvider{ + name: providerSixtyDB, + voices: client, + sixtydb: client, + }, nil + default: + return activeProvider{}, fmt.Errorf("missing API key (set ELEVENLABS_API_KEY or SIXTYDB_API_KEY)") + } +} + +var sixtyDBUnsupportedFlags = []string{ + "model-id", + "style", + "speaker-boost", + "no-speaker-boost", + "seed", + "normalize", + "lang", + "latency-tier", +} + +func changedSixtyDBUnsupportedFlags(changed func(string) bool) []string { + var unsupported []string + for _, name := range sixtyDBUnsupportedFlags { + if changed(name) { + unsupported = append(unsupported, "--"+name) + } + } + return unsupported +} diff --git a/cmd/provider_test.go b/cmd/provider_test.go new file mode 100644 index 0000000..5dd504a --- /dev/null +++ b/cmd/provider_test.go @@ -0,0 +1,80 @@ +package cmd + +import ( + "strings" + "testing" +) + +// resetProviderEnv neutralizes every key source so each case starts clean. +// Setting an env var to "" makes the resolvers treat it as absent. +func resetProviderEnv(t *testing.T) { + t.Helper() + cfg.APIKey = "" + cfg.APIKeyFile = "" + cfg.BaseURL = "" + t.Cleanup(func() { cfg.APIKey = ""; cfg.APIKeyFile = ""; cfg.BaseURL = "" }) + for _, k := range []string{ + "ELEVENLABS_API_KEY", "SAG_API_KEY", + "ELEVENLABS_API_KEY_FILE", "SAG_API_KEY_FILE", + "SIXTYDB_API_KEY", "SIXTYDB_API_KEY_FILE", + } { + t.Setenv(k, "") + } +} + +func TestSelectProvider_ElevenLabsOnly(t *testing.T) { + resetProviderEnv(t) + t.Setenv("ELEVENLABS_API_KEY", "el-key") + + provider, err := selectProvider() + if err != nil { + t.Fatalf("selectProvider error: %v", err) + } + if provider.name != providerElevenLabs || provider.elevenlabs == nil || provider.voices == nil { + t.Fatalf("unexpected provider: %+v", provider) + } +} + +func TestSelectProvider_SixtyDBOnly(t *testing.T) { + resetProviderEnv(t) + t.Setenv("SIXTYDB_API_KEY", "sd-key") + + provider, err := selectProvider() + if err != nil { + t.Fatalf("selectProvider error: %v", err) + } + if provider.name != providerSixtyDB || provider.sixtydb == nil || provider.voices == nil { + t.Fatalf("unexpected provider: %+v", provider) + } +} + +func TestSelectProvider_BothKeysError(t *testing.T) { + resetProviderEnv(t) + t.Setenv("ELEVENLABS_API_KEY", "el-key") + t.Setenv("SIXTYDB_API_KEY", "sd-key") + + _, err := selectProvider() + if err == nil || !strings.Contains(err.Error(), "ambiguous provider configuration") { + t.Fatalf("expected ambiguity error, got %v", err) + } +} + +func TestSelectProvider_NeitherErrors(t *testing.T) { + resetProviderEnv(t) + + _, err := selectProvider() + if err == nil { + t.Fatal("expected error when no API key is set") + } +} + +func TestEnsureProviderConfigured(t *testing.T) { + resetProviderEnv(t) + if err := ensureProviderConfigured(); err == nil { + t.Fatal("expected error with no keys") + } + t.Setenv("SIXTYDB_API_KEY", "sd-key") + if err := ensureProviderConfigured(); err != nil { + t.Fatalf("expected 60db key to satisfy configuration, got %v", err) + } +} diff --git a/cmd/root.go b/cmd/root.go index 14a506c..7d1868c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,8 +19,8 @@ var ( versionFlag bool rootCmd = &cobra.Command{ Use: "sag", - Short: "🗣️ ElevenLabs speech, mac-style ease", - Long: "Command-line ElevenLabs TTS with macOS playback. Call it like macOS 'say': if you skip the subcommand, text args are passed to 'speak' (e.g. `sag \"Hello\"`).\n\nTip: run `sag prompting` for model-specific prompting tips.\nModels: `eleven_v3` (default), `eleven_multilingual_v2` (stable), `eleven_flash_v2_5` (fast/cheap), `eleven_turbo_v2_5` (balanced).", + Short: "🗣️ TTS speech, mac-style ease", + Long: "Command-line TTS with macOS-style playback and voice flags. Call it like macOS 'say': if you skip the subcommand, text args are passed to 'speak' (e.g. `sag \"Hello\"`).\n\nTip: run `sag prompting` for ElevenLabs prompting tips. Provider selection is automatic: configure exactly one of ElevenLabs or 60db.", Example: " sag \"Hi Peter\"\n echo 'piped input' | sag\n sag speak -v Roger --rate 200 \"Faster speech\"\n sag prompting", Version: "0.3.0", PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { @@ -45,7 +45,7 @@ func Execute() { func init() { rootCmd.PersistentFlags().StringVar(&cfg.APIKey, "api-key", "", "ElevenLabs API key (or ELEVENLABS_API_KEY)") rootCmd.PersistentFlags().StringVar(&cfg.APIKeyFile, "api-key-file", "", "Read ElevenLabs API key from file (or ELEVENLABS_API_KEY_FILE)") - rootCmd.PersistentFlags().StringVar(&cfg.BaseURL, "base-url", "https://api.elevenlabs.io", "Override ElevenLabs API base URL") + rootCmd.PersistentFlags().StringVar(&cfg.BaseURL, "base-url", "", "Override the provider API base URL (empty = provider default)") rootCmd.PersistentFlags().BoolVarP(&versionFlag, "version", "V", false, "Print version and exit") } diff --git a/cmd/speak.go b/cmd/speak.go index ec3925f..0a68b3f 100644 --- a/cmd/speak.go +++ b/cmd/speak.go @@ -13,6 +13,8 @@ import ( "github.com/steipete/sag/internal/audio" "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/sixtydb" + "github.com/steipete/sag/internal/tts" "github.com/spf13/cobra" ) @@ -61,7 +63,7 @@ func init() { Long: "If no text argument is provided, the command reads from stdin.\n\nTip: run `sag prompting` for model-specific prompting tips and recommended flag combinations.", Args: cobra.ArbitraryArgs, PreRunE: func(_ *cobra.Command, _ []string) error { - return ensureAPIKey() + return ensureProviderConfigured() }, RunE: func(cmd *cobra.Command, args []string) error { if err := applyRateAndSpeed(&opts); err != nil { @@ -74,9 +76,14 @@ func init() { return err } + provider, err := selectProvider() + if err != nil { + return err + } + forceVoiceID := cmd.Flags().Changed("voice-id") voiceInput := opts.voiceID - if voiceInput == "" { + if provider.name == providerElevenLabs && voiceInput == "" { if env := os.Getenv("ELEVENLABS_VOICE_ID"); env != "" { voiceInput = env forceVoiceID = true @@ -85,9 +92,8 @@ func init() { forceVoiceID = true } } - client := elevenlabs.NewClient(cfg.APIKey, cfg.BaseURL) - voiceID, err := resolveVoice(cmd.Context(), client, voiceInput, forceVoiceID) + voiceID, err := resolveVoice(cmd.Context(), provider.voices, voiceInput, forceVoiceID) if err != nil { return err } @@ -113,35 +119,72 @@ func init() { } } + if provider.name == providerSixtyDB { + if err := prepareSixtyDBOptions(cmd, &opts); err != nil { + return err + } + } + ctx, cancel, err := ttsContext(cmd.Context(), opts.timeout) if err != nil { return err } defer cancel() - payload, err := buildTTSRequest(cmd, opts, text) - if err != nil { - return err + var ( + streamFunc func(context.Context) (io.ReadCloser, error) + convertFunc func(context.Context) ([]byte, error) + ) + switch provider.name { + case providerElevenLabs: + payload, err := buildElevenLabsTTSRequest(cmd, opts, text) + if err != nil { + return err + } + streamFunc = func(ctx context.Context) (io.ReadCloser, error) { + return provider.elevenlabs.StreamTTS(ctx, opts.voiceID, payload, opts.latencyTier) + } + convertFunc = func(ctx context.Context) ([]byte, error) { + return provider.elevenlabs.ConvertTTS(ctx, opts.voiceID, payload) + } + case providerSixtyDB: + payload, err := buildSixtyDBTTSRequest(cmd, opts, text) + if err != nil { + return err + } + streamFunc = func(ctx context.Context) (io.ReadCloser, error) { + return provider.sixtydb.StreamTTS(ctx, payload) + } + convertFunc = func(ctx context.Context) ([]byte, error) { + return provider.sixtydb.ConvertTTS(ctx, payload) + } + default: + return fmt.Errorf("unsupported provider %q", provider.name) } start := time.Now() var bytes int64 if opts.stream { - n, err := streamAndPlay(ctx, client, opts, payload) + n, err := streamAndPlay(ctx, opts, streamFunc) bytes = n if err != nil { return err } } else { - n, err := convertAndPlay(ctx, client, opts, payload) + n, err := convertAndPlay(ctx, opts, convertFunc) bytes = n if err != nil { return err } } if opts.metrics { - fmt.Fprintf(os.Stderr, "metrics: chars=%d bytes=%d model=%s voice=%s stream=%t latencyTier=%d dur=%s\n", - len([]rune(text)), bytes, opts.modelID, opts.voiceID, opts.stream, opts.latencyTier, time.Since(start).Truncate(time.Millisecond)) + if provider.name == providerElevenLabs { + fmt.Fprintf(os.Stderr, "metrics: chars=%d bytes=%d provider=%s model=%s voice=%s stream=%t latencyTier=%d dur=%s\n", + len([]rune(text)), bytes, provider.name, opts.modelID, opts.voiceID, opts.stream, opts.latencyTier, time.Since(start).Truncate(time.Millisecond)) + } else { + fmt.Fprintf(os.Stderr, "metrics: chars=%d bytes=%d provider=%s voice=%s stream=%t dur=%s\n", + len([]rune(text)), bytes, provider.name, opts.voiceID, opts.stream, time.Since(start).Truncate(time.Millisecond)) + } } return nil }, @@ -272,7 +315,7 @@ func applyRateAndSpeed(opts *speakOptions) error { return nil } -func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (elevenlabs.TTSRequest, error) { +func buildElevenLabsTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (elevenlabs.TTSRequest, error) { flags := cmd.Flags() var stabilityPtr *float64 @@ -280,10 +323,9 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven if opts.stability < 0 || opts.stability > 1 { return elevenlabs.TTSRequest{}, errors.New("stability must be between 0 and 1") } - if opts.modelID == "eleven_v3" { - if !floatEqualsOneOf(opts.stability, []float64{0, 0.5, 1}) { - return elevenlabs.TTSRequest{}, errors.New("for eleven_v3, stability must be one of 0.0, 0.5, 1.0 (Creative/Natural/Robust)") - } + // The discrete 0/0.5/1 constraint is specific to ElevenLabs eleven_v3. + if opts.modelID == "eleven_v3" && !floatEqualsOneOf(opts.stability, []float64{0, 0.5, 1}) { + return elevenlabs.TTSRequest{}, errors.New("for eleven_v3, stability must be one of 0.0, 0.5, 1.0 (Creative/Natural/Robust)") } stabilityPtr = &opts.stability } @@ -368,6 +410,63 @@ func buildTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (eleven }, nil } +func prepareSixtyDBOptions(cmd *cobra.Command, opts *speakOptions) error { + unsupported := changedSixtyDBUnsupportedFlags(cmd.Flags().Changed) + if len(unsupported) > 0 { + return fmt.Errorf("60db does not support %s", strings.Join(unsupported, ", ")) + } + + if !cmd.Flags().Changed("format") && strings.EqualFold(filepath.Ext(opts.outputPath), ".flac") { + opts.outputFmt = "flac" + } + + format := sixtydb.CanonicalOutputFormat(opts.outputFmt) + if format != "" { + opts.outputFmt = format + } + + if opts.play && format != "" && format != "mp3" { + return fmt.Errorf("60db speaker playback requires mp3 audio; use --no-play to save %s output", format) + } + + if opts.stream && format != "" && format != "mp3" { + if cmd.Flags().Changed("stream") { + return fmt.Errorf("60db streaming does not support %s output; use --no-stream", format) + } + opts.stream = false + } + return nil +} + +func buildSixtyDBTTSRequest(cmd *cobra.Command, opts speakOptions, text string) (sixtydb.TTSRequest, error) { + flags := cmd.Flags() + speed := opts.speed + req := sixtydb.TTSRequest{ + Text: text, + VoiceID: opts.voiceID, + Speed: &speed, + } + + if flags.Changed("stability") { + if opts.stability < 0 || opts.stability > 1 { + return sixtydb.TTSRequest{}, errors.New("stability must be between 0 and 1") + } + stability := opts.stability * 100 + req.Stability = &stability + } + if flags.Changed("similarity") || flags.Changed("similarity-boost") { + if opts.similarity < 0 || opts.similarity > 1 { + return sixtydb.TTSRequest{}, errors.New("similarity must be between 0 and 1") + } + similarity := opts.similarity * 100 + req.Similarity = &similarity + } + if !opts.stream { + req.OutputFormat = opts.outputFmt + } + return req, nil +} + func floatEqualsOneOf(v float64, allowed []float64) bool { const eps = 1e-9 for _, a := range allowed { @@ -427,8 +526,12 @@ func isStdinTTY() bool { return (stat.Mode() & os.ModeCharDevice) != 0 } -func streamAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOptions, payload elevenlabs.TTSRequest) (int64, error) { - resp, err := client.StreamTTS(ctx, opts.voiceID, payload, opts.latencyTier) +func streamAndPlay(ctx context.Context, opts speakOptions, stream func(context.Context) (io.ReadCloser, error)) (int64, error) { + if !opts.play && opts.outputPath == "" { + return 0, errors.New("nothing to do: enable --play or provide --output") + } + + resp, err := stream(ctx) if err != nil { return 0, err } @@ -479,17 +582,17 @@ func streamAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOpt return copyNVal, playErr } - if len(writers) == 0 { - return 0, errors.New("nothing to do: enable --play or provide --output") - } - mw := io.MultiWriter(writers...) n, err := io.Copy(mw, resp) return n, err } -func convertAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOptions, payload elevenlabs.TTSRequest) (int64, error) { - data, err := client.ConvertTTS(ctx, opts.voiceID, payload) +func convertAndPlay(ctx context.Context, opts speakOptions, convert func(context.Context) ([]byte, error)) (int64, error) { + if !opts.play && opts.outputPath == "" { + return 0, errors.New("nothing to do: enable --play or provide --output") + } + + data, err := convert(ctx) if err != nil { return 0, err } @@ -516,13 +619,10 @@ func convertAndPlay(ctx context.Context, client *elevenlabs.Client, opts speakOp }() return n, player(ctx, pr) } - if opts.outputPath == "" { - return n, errors.New("nothing to do: enable --play or provide --output") - } return n, nil } -func resolveVoice(ctx context.Context, client *elevenlabs.Client, voiceInput string, forceID bool) (string, error) { +func resolveVoice(ctx context.Context, client tts.VoiceCatalog, voiceInput string, forceID bool) (string, error) { voiceInput = strings.TrimSpace(voiceInput) if voiceInput == "" { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -532,7 +632,7 @@ func resolveVoice(ctx context.Context, client *elevenlabs.Client, voiceInput str return "", fmt.Errorf("voice not specified and failed to fetch voices: %w", err) } if len(voices) == 0 { - return "", errors.New("no voices available; specify --voice or set ELEVENLABS_VOICE_ID") + return "", errors.New("no voices available; specify --voice") } fmt.Fprintf(os.Stderr, "defaulting to voice %s (%s)\n", voices[0].Name, voices[0].VoiceID) return voices[0].VoiceID, nil diff --git a/cmd/speak_integration_test.go b/cmd/speak_integration_test.go index 1684a45..328657d 100644 --- a/cmd/speak_integration_test.go +++ b/cmd/speak_integration_test.go @@ -1,6 +1,7 @@ package cmd import ( + "encoding/base64" "encoding/json" "net/http" "net/http/httptest" @@ -12,6 +13,8 @@ import ( func TestSpeakCommand_FlagsBuildRequestAndMetrics(t *testing.T) { t.Helper() + resetProviderEnv(t) + resetRootCommandState() const voiceID = "abc1234567890123" @@ -106,4 +109,64 @@ func TestSpeakCommand_FlagsBuildRequestAndMetrics(t *testing.T) { if !strings.Contains(stderr, "metrics: chars=") || !strings.Contains(stderr, "bytes=") || !strings.Contains(stderr, "dur=") { t.Fatalf("expected metrics output, got %q", stderr) } + if !strings.Contains(stderr, "provider=elevenlabs") || !strings.Contains(stderr, "model=eleven_v3") || !strings.Contains(stderr, "latencyTier=0") { + t.Fatalf("expected provider-specific metrics output, got %q", stderr) + } +} + +func TestSpeakCommand_SixtyDBMetricsOmitElevenLabsModel(t *testing.T) { + t.Helper() + resetProviderEnv(t) + resetRootCommandState() + t.Setenv("SIXTYDB_API_KEY", "sd-key") + + const voiceID = "voice-001" + audio := base64.StdEncoding.EncodeToString([]byte("ID3audio-bytes")) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tts-synthesize" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + var got map[string]any + if err := json.NewDecoder(r.Body).Decode(&got); err != nil { + t.Fatalf("decode body: %v", err) + } + if got["voice_id"] != voiceID || got["output_format"] != "mp3" { + t.Fatalf("unexpected request body: %+v", got) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "success": true, + "audio_base64": audio, + "output_format": "mp3", + "encoding": "mp3", + }) + })) + defer srv.Close() + + tmp := t.TempDir() + outPath := tmp + "/out.mp3" + restore, read := captureStderr(t) + defer restore() + + rootCmd.SetArgs([]string{ + "--base-url", srv.URL, + "speak", + "--voice-id", voiceID, + "--stream=false", + "--play=false", + "--output", outPath, + "--metrics", + "Hello world", + }) + + if err := rootCmd.Execute(); err != nil { + t.Fatalf("speak command failed: %v", err) + } + + stderr := read() + if !strings.Contains(stderr, "provider=60db") { + t.Fatalf("expected 60db provider in metrics, got %q", stderr) + } + if strings.Contains(stderr, "model=eleven_v3") || strings.Contains(stderr, "latencyTier=") { + t.Fatalf("expected 60db metrics to omit ElevenLabs metadata, got %q", stderr) + } } diff --git a/cmd/speak_request_test.go b/cmd/speak_request_test.go index 4293666..802ce1f 100644 --- a/cmd/speak_request_test.go +++ b/cmd/speak_request_test.go @@ -16,8 +16,13 @@ func newSpeakTestCommand(t *testing.T) (*cobra.Command, *speakOptions) { modelID: "eleven_multilingual_v2", outputFmt: "mp3_44100_128", speed: 1.0, + stream: true, + play: true, } cmd := &cobra.Command{Use: "speak"} + cmd.Flags().StringVar(&opts.modelID, "model-id", opts.modelID, "") + cmd.Flags().StringVar(&opts.outputFmt, "format", opts.outputFmt, "") + cmd.Flags().BoolVar(&opts.stream, "stream", opts.stream, "") cmd.Flags().Float64Var(&opts.stability, "stability", 0, "") cmd.Flags().Float64Var(&opts.similarity, "similarity", 0, "") cmd.Flags().Float64Var(&opts.similarity, "similarity-boost", 0, "") @@ -27,15 +32,16 @@ func newSpeakTestCommand(t *testing.T) (*cobra.Command, *speakOptions) { cmd.Flags().Uint64Var(&opts.seed, "seed", 0, "") cmd.Flags().StringVar(&opts.normalize, "normalize", "", "") cmd.Flags().StringVar(&opts.lang, "lang", "", "") + cmd.Flags().IntVar(&opts.latencyTier, "latency-tier", 0, "") return cmd, opts } -func TestBuildTTSRequest_DefaultsOmitOptionalFields(t *testing.T) { +func TestBuildElevenLabsTTSRequest_DefaultsOmitOptionalFields(t *testing.T) { cmd, opts := newSpeakTestCommand(t) - req, err := buildTTSRequest(cmd, *opts, "hello") + req, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err != nil { - t.Fatalf("buildTTSRequest error: %v", err) + t.Fatalf("buildElevenLabsTTSRequest error: %v", err) } if req.Seed != nil { @@ -64,30 +70,30 @@ func TestBuildTTSRequest_DefaultsOmitOptionalFields(t *testing.T) { } } -func TestBuildTTSRequest_SimilarityBoostAlias(t *testing.T) { +func TestBuildElevenLabsTTSRequest_SimilarityBoostAlias(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--similarity-boost", "0.9"}); err != nil { t.Fatalf("parse flags: %v", err) } - req, err := buildTTSRequest(cmd, *opts, "hello") + req, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err != nil { - t.Fatalf("buildTTSRequest error: %v", err) + t.Fatalf("buildElevenLabsTTSRequest error: %v", err) } if req.VoiceSettings.SimilarityBoost == nil || *req.VoiceSettings.SimilarityBoost != 0.9 { t.Fatalf("expected similarity_boost 0.9, got %#v", req.VoiceSettings.SimilarityBoost) } } -func TestBuildTTSRequest_SpeakerBoostSetsJSONKey(t *testing.T) { +func TestBuildElevenLabsTTSRequest_SpeakerBoostSetsJSONKey(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--speaker-boost"}); err != nil { t.Fatalf("parse flags: %v", err) } - req, err := buildTTSRequest(cmd, *opts, "hello") + req, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err != nil { - t.Fatalf("buildTTSRequest error: %v", err) + t.Fatalf("buildElevenLabsTTSRequest error: %v", err) } if req.VoiceSettings.UseSpeakerBoost == nil || *req.VoiceSettings.UseSpeakerBoost != true { t.Fatalf("expected use_speaker_boost true, got %#v", req.VoiceSettings.UseSpeakerBoost) @@ -102,62 +108,174 @@ func TestBuildTTSRequest_SpeakerBoostSetsJSONKey(t *testing.T) { } } -func TestBuildTTSRequest_InvalidNormalize(t *testing.T) { +func TestBuildElevenLabsTTSRequest_InvalidNormalize(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--normalize", "wat"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "normalize must be one of") { t.Fatalf("expected normalize error, got %v", err) } } -func TestBuildTTSRequest_InvalidLang(t *testing.T) { +func TestBuildElevenLabsTTSRequest_InvalidLang(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--lang", "eng"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "lang must be a 2-letter") { t.Fatalf("expected lang error, got %v", err) } } -func TestBuildTTSRequest_InvalidSeed(t *testing.T) { +func TestBuildElevenLabsTTSRequest_InvalidSeed(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--seed", "4294967296"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "seed must be between") { t.Fatalf("expected seed error, got %v", err) } } -func TestBuildTTSRequest_SpeakerBoostConflict(t *testing.T) { +func TestBuildElevenLabsTTSRequest_SpeakerBoostConflict(t *testing.T) { cmd, opts := newSpeakTestCommand(t) if err := cmd.Flags().Parse([]string{"--speaker-boost", "--no-speaker-boost"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "choose only one") { t.Fatalf("expected conflict error, got %v", err) } } -func TestBuildTTSRequest_V3StabilityPresetsOnly(t *testing.T) { +func TestBuildElevenLabsTTSRequest_V3StabilityPresetsOnly(t *testing.T) { cmd, opts := newSpeakTestCommand(t) opts.modelID = "eleven_v3" if err := cmd.Flags().Parse([]string{"--stability", "0.55"}); err != nil { t.Fatalf("parse flags: %v", err) } - _, err := buildTTSRequest(cmd, *opts, "hello") + _, err := buildElevenLabsTTSRequest(cmd, *opts, "hello") if err == nil || !strings.Contains(err.Error(), "for eleven_v3, stability must be one of") { t.Fatalf("expected v3 stability preset error, got %v", err) } } +func TestPrepareSixtyDBOptionsRejectsUnsupportedFlags(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + if err := cmd.Flags().Parse([]string{"--model-id", "foo", "--style", "0.3", "--latency-tier", "2"}); err != nil { + t.Fatalf("parse flags: %v", err) + } + + err := prepareSixtyDBOptions(cmd, opts) + if err == nil { + t.Fatal("expected unsupported flag error") + } + for _, flag := range []string{"--model-id", "--style", "--latency-tier"} { + if !strings.Contains(err.Error(), flag) { + t.Fatalf("expected %s in error, got %v", flag, err) + } + } +} + +func TestPrepareSixtyDBOptionsFallsBackToConvertForNonMP3(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.outputFmt = "wav" + opts.play = false + + if err := prepareSixtyDBOptions(cmd, opts); err != nil { + t.Fatalf("prepareSixtyDBOptions error: %v", err) + } + if opts.outputFmt != "wav" { + t.Fatalf("expected canonical wav format, got %q", opts.outputFmt) + } + if opts.stream { + t.Fatalf("expected stream to be disabled for non-mp3 output") + } +} + +func TestPrepareSixtyDBOptionsInfersFLACFromOutputPath(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.outputPath = "voice.FLAC" + opts.play = false + + if err := prepareSixtyDBOptions(cmd, opts); err != nil { + t.Fatalf("prepareSixtyDBOptions error: %v", err) + } + if opts.outputFmt != "flac" { + t.Fatalf("expected flac output format, got %q", opts.outputFmt) + } + if opts.stream { + t.Fatalf("expected stream to be disabled for flac output") + } +} + +func TestPrepareSixtyDBOptionsRejectsExplicitStreamForNonMP3(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + if err := cmd.Flags().Parse([]string{"--stream", "--format", "wav"}); err != nil { + t.Fatalf("parse flags: %v", err) + } + opts.play = false + + err := prepareSixtyDBOptions(cmd, opts) + if err == nil || !strings.Contains(err.Error(), "use --no-stream") { + t.Fatalf("expected explicit stream rejection, got %v", err) + } +} + +func TestPrepareSixtyDBOptionsRejectsPlaybackForNonMP3(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.outputFmt = "flac" + + err := prepareSixtyDBOptions(cmd, opts) + if err == nil || !strings.Contains(err.Error(), "requires mp3 audio") { + t.Fatalf("expected playback format error, got %v", err) + } +} + +func TestBuildSixtyDBTTSRequestScalesDocumentedFields(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + opts.voiceID = "voice_123" + opts.stream = false + if err := cmd.Flags().Parse([]string{"--stability", "0.5", "--similarity-boost", "0.8"}); err != nil { + t.Fatalf("parse flags: %v", err) + } + + req, err := buildSixtyDBTTSRequest(cmd, *opts, "hello") + if err != nil { + t.Fatalf("buildSixtyDBTTSRequest error: %v", err) + } + if req.Text != "hello" || req.VoiceID != "voice_123" { + t.Fatalf("unexpected request identity: %+v", req) + } + if req.Speed == nil || *req.Speed != 1.0 { + t.Fatalf("expected speed 1.0, got %#v", req.Speed) + } + if req.Stability == nil || *req.Stability != 50 { + t.Fatalf("expected stability 50, got %#v", req.Stability) + } + if req.Similarity == nil || *req.Similarity != 80 { + t.Fatalf("expected similarity 80, got %#v", req.Similarity) + } + if req.OutputFormat != "mp3_44100_128" { + t.Fatalf("expected raw output format to be passed to client canonicalizer, got %q", req.OutputFormat) + } +} + +func TestBuildSixtyDBTTSRequestOmitsOutputFormatWhenStreaming(t *testing.T) { + cmd, opts := newSpeakTestCommand(t) + req, err := buildSixtyDBTTSRequest(cmd, *opts, "hello") + if err != nil { + t.Fatalf("buildSixtyDBTTSRequest error: %v", err) + } + if req.OutputFormat != "" { + t.Fatalf("expected stream request to omit output format, got %q", req.OutputFormat) + } +} + func TestApplyCompatibilityFlagsNoPlayNoStream(t *testing.T) { opts := &speakOptions{play: true, stream: true} cmd := &cobra.Command{Use: "speak"} @@ -258,22 +376,3 @@ func TestTTSContextNoDeadlineByDefault(t *testing.T) { t.Fatalf("expected no deadline for zero timeout") } } - -func TestTTSContextWithTimeout(t *testing.T) { - ctx, cancel, err := ttsContext(context.Background(), time.Minute) - if err != nil { - t.Fatalf("ttsContext error: %v", err) - } - defer cancel() - - if _, ok := ctx.Deadline(); !ok { - t.Fatalf("expected deadline for non-zero timeout") - } -} - -func TestTTSContextRejectsNegativeTimeout(t *testing.T) { - _, _, err := ttsContext(context.Background(), -time.Second) - if err == nil || !strings.Contains(err.Error(), "timeout must be") { - t.Fatalf("expected timeout error, got %v", err) - } -} diff --git a/cmd/speak_test.go b/cmd/speak_test.go index 1a494b0..e68da27 100644 --- a/cmd/speak_test.go +++ b/cmd/speak_test.go @@ -316,9 +316,10 @@ func TestStreamAndPlayWritesOutput(t *testing.T) { tmp := t.TempDir() out := tmp + "/out.mp3" opts := speakOptions{voiceID: "v1", outputPath: out, stream: true, play: false} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := streamAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := streamAndPlay(context.Background(), opts, func(ctx context.Context) (io.ReadCloser, error) { + return client.StreamTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}, 0) + }); err != nil { t.Fatalf("streamAndPlay error: %v", err) } data, err := os.ReadFile(out) @@ -343,9 +344,10 @@ func TestConvertAndPlayWritesOutput(t *testing.T) { tmp := t.TempDir() out := tmp + "/out.mp3" opts := speakOptions{voiceID: "v1", outputPath: out, play: false} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := convertAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := convertAndPlay(context.Background(), opts, func(ctx context.Context) ([]byte, error) { + return client.ConvertTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}) + }); err != nil { t.Fatalf("convertAndPlay error: %v", err) } data, err := os.ReadFile(out) @@ -358,11 +360,11 @@ func TestConvertAndPlayWritesOutput(t *testing.T) { } func TestStreamAndPlayRequiresWork(t *testing.T) { - client := elevenlabs.NewClient("key", "http://invalid") opts := speakOptions{voiceID: "v1", play: false, stream: true} - payload := elevenlabs.TTSRequest{Text: "hi"} - - _, err := streamAndPlay(context.Background(), client, opts, payload) + _, err := streamAndPlay(context.Background(), opts, func(context.Context) (io.ReadCloser, error) { + t.Fatal("stream should not be invoked when nothing will consume it") + return nil, nil + }) if err == nil { t.Fatalf("expected error when no output and play disabled") } @@ -385,9 +387,10 @@ func TestStreamAndPlayWithPlayback(t *testing.T) { client := elevenlabs.NewClient("key", srv.URL) opts := speakOptions{voiceID: "v1", play: true, stream: true} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := streamAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := streamAndPlay(context.Background(), opts, func(ctx context.Context) (io.ReadCloser, error) { + return client.StreamTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}, 0) + }); err != nil { t.Fatalf("streamAndPlay error: %v", err) } if !called { @@ -412,9 +415,10 @@ func TestConvertAndPlayWithPlayback(t *testing.T) { client := elevenlabs.NewClient("key", srv.URL) opts := speakOptions{voiceID: "v1", play: true, outputPath: "", stream: false} - payload := elevenlabs.TTSRequest{Text: "hi"} - if _, err := convertAndPlay(context.Background(), client, opts, payload); err != nil { + if _, err := convertAndPlay(context.Background(), opts, func(ctx context.Context) ([]byte, error) { + return client.ConvertTTS(ctx, opts.voiceID, elevenlabs.TTSRequest{Text: "hi"}) + }); err != nil { t.Fatalf("convertAndPlay error: %v", err) } if !called { diff --git a/cmd/test_helpers_test.go b/cmd/test_helpers_test.go new file mode 100644 index 0000000..5123dfa --- /dev/null +++ b/cmd/test_helpers_test.go @@ -0,0 +1,33 @@ +package cmd + +import ( + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +func resetRootCommandState() { + resetFlags := func(flags *pflag.FlagSet) { + flags.VisitAll(func(flag *pflag.Flag) { + if replacer, ok := flag.Value.(interface{ Replace([]string) error }); ok { + _ = replacer.Replace([]string{}) + } else { + _ = flags.Set(flag.Name, flag.DefValue) + } + flag.Changed = false + }) + } + + var reset func(*cobra.Command) + reset = func(cmd *cobra.Command) { + cmd.SetArgs(nil) + resetFlags(cmd.Flags()) + resetFlags(cmd.PersistentFlags()) + for _, sub := range cmd.Commands() { + reset(sub) + } + } + + cfg = rootConfig{} + versionFlag = false + reset(rootCmd) +} diff --git a/cmd/voices.go b/cmd/voices.go index e683383..d4f98fd 100644 --- a/cmd/voices.go +++ b/cmd/voices.go @@ -12,7 +12,7 @@ import ( "time" "github.com/steipete/sag/internal/audio" - "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/tts" "github.com/spf13/cobra" ) @@ -39,7 +39,7 @@ func init() { Use: "voices", Short: "List available ElevenLabs voices", PreRunE: func(_ *cobra.Command, _ []string) error { - return ensureAPIKey() + return ensureProviderConfigured() }, RunE: func(cmd *cobra.Command, _ []string) error { hasLabelFilters := false @@ -53,12 +53,15 @@ func init() { return errors.New("--try requires --search, --query, --label, or --limit to avoid playing all voices") } - client := elevenlabs.NewClient(cfg.APIKey, cfg.BaseURL) + provider, err := selectProvider() + if err != nil { + return err + } + client := provider.voices ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second) defer cancel() - var voices []elevenlabs.Voice - var err error + var voices []tts.Voice if opts.search != "" { voices, err = client.SearchVoices(ctx, opts.search, opts.limit) if err != nil { @@ -169,9 +172,9 @@ func init() { rootCmd.AddCommand(cmd) } -func filterVoicesByName(voices []elevenlabs.Voice, search string) []elevenlabs.Voice { +func filterVoicesByName(voices []tts.Voice, search string) []tts.Voice { searchLower := strings.ToLower(search) - filtered := make([]elevenlabs.Voice, 0, len(voices)) + filtered := make([]tts.Voice, 0, len(voices)) for _, v := range voices { if strings.Contains(strings.ToLower(v.Name), searchLower) { filtered = append(filtered, v) @@ -180,7 +183,7 @@ func filterVoicesByName(voices []elevenlabs.Voice, search string) []elevenlabs.V return filtered } -func playVoicePreviewImpl(ctx context.Context, client *elevenlabs.Client, voice elevenlabs.Voice) error { +func playVoicePreviewImpl(ctx context.Context, client tts.VoiceCatalog, voice tts.Voice) error { ctx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() diff --git a/cmd/voices_cache.go b/cmd/voices_cache.go index 46abd91..738d2c2 100644 --- a/cmd/voices_cache.go +++ b/cmd/voices_cache.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/tts" ) const ( @@ -25,8 +25,8 @@ type voiceCache struct { } type cachedVoice struct { - Voice elevenlabs.Voice `json:"voice"` - UpdatedAt time.Time `json:"updated_at"` + Voice tts.Voice `json:"voice"` + UpdatedAt time.Time `json:"updated_at"` } func newVoiceCache() *voiceCache { @@ -80,7 +80,7 @@ func saveVoiceCache(path string, cache *voiceCache) error { return os.WriteFile(path, data, 0o644) } -func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elevenlabs.Voice, cache *voiceCache, ttl time.Duration) ([]elevenlabs.Voice, int) { +func hydrateVoices(ctx context.Context, client tts.VoiceCatalog, voices []tts.Voice, cache *voiceCache, ttl time.Duration) ([]tts.Voice, int) { if ttl <= 0 { ttl = voiceCacheTTL } @@ -88,7 +88,7 @@ func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elev cache = newVoiceCache() } - results := make([]elevenlabs.Voice, len(voices)) + results := make([]tts.Voice, len(voices)) now := time.Now() var metaCount int @@ -109,7 +109,7 @@ func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elev wg.Add(1) sem <- struct{}{} - go func(index int, voice elevenlabs.Voice) { + go func(index int, voice tts.Voice) { defer wg.Done() defer func() { <-sem }() @@ -135,7 +135,7 @@ func hydrateVoices(ctx context.Context, client *elevenlabs.Client, voices []elev return results, metaCount } -func mergeVoice(base elevenlabs.Voice, details elevenlabs.Voice) elevenlabs.Voice { +func mergeVoice(base tts.Voice, details tts.Voice) tts.Voice { merged := base if details.VoiceID != "" { merged.VoiceID = details.VoiceID @@ -150,7 +150,12 @@ func mergeVoice(base elevenlabs.Voice, details elevenlabs.Voice) elevenlabs.Voic merged.Description = details.Description } if len(details.Labels) > 0 { - merged.Labels = details.Labels + if len(merged.Labels) == 0 { + merged.Labels = map[string]string{} + } + for key, value := range details.Labels { + merged.Labels[key] = value + } } if details.PreviewURL != "" { merged.PreviewURL = details.PreviewURL diff --git a/cmd/voices_cache_test.go b/cmd/voices_cache_test.go new file mode 100644 index 0000000..070e2a9 --- /dev/null +++ b/cmd/voices_cache_test.go @@ -0,0 +1,36 @@ +package cmd + +import ( + "testing" + + "github.com/steipete/sag/internal/tts" +) + +func TestMergeVoicePreservesAndOverlaysLabels(t *testing.T) { + base := tts.Voice{ + VoiceID: "v1", + Labels: map[string]string{ + "source": "default", + "model": "60db Quality", + }, + } + details := tts.Voice{ + VoiceID: "v1", + Labels: map[string]string{ + "accent": "American", + "model": "override", + }, + PreviewURL: "https://cdn.example.com/sample.mp3", + } + + merged := mergeVoice(base, details) + if merged.Labels["source"] != "default" { + t.Fatalf("expected source label preserved, got %+v", merged.Labels) + } + if merged.Labels["model"] != "override" || merged.Labels["accent"] != "American" { + t.Fatalf("expected detail labels merged, got %+v", merged.Labels) + } + if merged.PreviewURL != details.PreviewURL { + t.Fatalf("expected preview URL copied, got %q", merged.PreviewURL) + } +} diff --git a/cmd/voices_query.go b/cmd/voices_query.go index 7ec576b..1794067 100644 --- a/cmd/voices_query.go +++ b/cmd/voices_query.go @@ -6,7 +6,7 @@ import ( "strings" "unicode" - "github.com/steipete/sag/internal/elevenlabs" + "github.com/steipete/sag/internal/tts" ) type labelFilter struct { @@ -38,11 +38,11 @@ func parseLabelFilters(filters []string) ([]labelFilter, error) { return parsed, nil } -func filterVoicesByLabels(voices []elevenlabs.Voice, filters []labelFilter) []elevenlabs.Voice { +func filterVoicesByLabels(voices []tts.Voice, filters []labelFilter) []tts.Voice { if len(filters) == 0 { return voices } - filtered := make([]elevenlabs.Voice, 0, len(voices)) + filtered := make([]tts.Voice, 0, len(voices)) for _, v := range voices { if matchesAllLabels(v, filters) { filtered = append(filtered, v) @@ -51,7 +51,7 @@ func filterVoicesByLabels(voices []elevenlabs.Voice, filters []labelFilter) []el return filtered } -func matchesAllLabels(voice elevenlabs.Voice, filters []labelFilter) bool { +func matchesAllLabels(voice tts.Voice, filters []labelFilter) bool { if len(filters) == 0 { return true } @@ -85,7 +85,7 @@ func labelValue(labels map[string]string, key string) (string, bool) { return "", false } -func rankVoicesByQuery(voices []elevenlabs.Voice, query string) []elevenlabs.Voice { +func rankVoicesByQuery(voices []tts.Voice, query string) []tts.Voice { query = strings.TrimSpace(query) if query == "" { return voices @@ -104,7 +104,7 @@ func rankVoicesByQuery(voices []elevenlabs.Voice, query string) []elevenlabs.Voi } return scored[i].score > scored[j].score }) - ranked := make([]elevenlabs.Voice, 0, len(scored)) + ranked := make([]tts.Voice, 0, len(scored)) for _, s := range scored { ranked = append(ranked, s.voice) } @@ -112,11 +112,11 @@ func rankVoicesByQuery(voices []elevenlabs.Voice, query string) []elevenlabs.Voi } type scoredVoice struct { - voice elevenlabs.Voice + voice tts.Voice score int } -func scoreVoice(voice elevenlabs.Voice, query string, tokens []string) int { +func scoreVoice(voice tts.Voice, query string, tokens []string) int { name := strings.ToLower(voice.Name) desc := strings.ToLower(voice.Description) labels := strings.ToLower(flattenLabels(voice.Labels)) diff --git a/cmd/voices_test.go b/cmd/voices_test.go index ead0001..0851158 100644 --- a/cmd/voices_test.go +++ b/cmd/voices_test.go @@ -13,6 +13,8 @@ import ( ) func TestVoicesCommand(t *testing.T) { + resetRootCommandState() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/voices" { t.Fatalf("unexpected path: %s", r.URL.Path) @@ -63,6 +65,8 @@ func TestFilterVoicesByName(t *testing.T) { } func TestVoicesCommandTryRequiresFilter(t *testing.T) { + resetRootCommandState() + cfg.APIKey = "key" cfg.BaseURL = "http://example.invalid" t.Cleanup(func() { diff --git a/docs/configuration.md b/docs/configuration.md index 0d088a8..2f09c58 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,24 +1,28 @@ --- title: Configuration -description: "API keys, default voices, timeouts, base URL, and player selection — everything sag reads from the environment." +description: "Provider keys, default voices, timeouts, base URL, and player selection — everything sag reads from flags or the environment." --- # Configuration `sag` reads configuration from CLI flags first, then environment variables. There is no config file: the binary stays single-purpose and friendly to ephemeral CI runners. -## API key +## Provider key Required for any TTS or voice call. `sag --help`, `sag prompting`, and `sag --version` work without one. -| Flag / variable | Notes | +| Provider | Flag / variable | Notes | | --- | --- | -| `--api-key` | Inline override. Avoid in shell history; prefer env or `--api-key-file`. | -| `ELEVENLABS_API_KEY` | Primary env var. | -| `SAG_API_KEY` | Accepted alias. | -| `--api-key-file ` | Read the key from a file. | -| `ELEVENLABS_API_KEY_FILE` | Same as `--api-key-file`. | -| `SAG_API_KEY_FILE` | Alias. | +| ElevenLabs | `--api-key` | Inline override. Avoid in shell history; prefer env or `--api-key-file`. | +| ElevenLabs | `ELEVENLABS_API_KEY` | Primary env var. | +| ElevenLabs | `SAG_API_KEY` | Accepted alias. | +| ElevenLabs | `--api-key-file ` | Read the key from a file. | +| ElevenLabs | `ELEVENLABS_API_KEY_FILE` | Same as `--api-key-file`. | +| ElevenLabs | `SAG_API_KEY_FILE` | Alias. | +| 60db | `SIXTYDB_API_KEY` | Primary env var. | +| 60db | `SIXTYDB_API_KEY_FILE` | Read the key from a file. | + +Configure exactly one provider at a time. If both ElevenLabs and 60db credentials are present, `sag` errors instead of guessing. The file form is handy for agents and containers: @@ -34,7 +38,9 @@ When `--voice` / `--voice-id` is omitted, `sag` resolves in this order: 1. `ELEVENLABS_VOICE_ID` 2. `SAG_VOICE_ID` -3. The first voice returned by `/v1/voices` (logged on stderr so you notice). +3. The first voice returned by the active provider's voice listing (logged on stderr so you notice). + +The env defaults apply only to ElevenLabs. With 60db, `sag` falls back to the first merged result from `/default-voices` and `/myvoices`. ```bash export SAG_VOICE_ID=21m00Tcm4TlvDq8ikWAM @@ -77,18 +83,18 @@ Pick a backend explicitly via `--player oto` or `SAG_PLAYER=oto`. See [Streaming ## API base URL -Override the ElevenLabs endpoint when you’re routing through a proxy or talking to a regional/staging deployment: +Override the active provider endpoint when you’re routing through a proxy or talking to a regional/staging deployment: ```bash sag --base-url https://api.elevenlabs.io "Default." sag --base-url https://your-proxy.internal "Routed." ``` -The default is `https://api.elevenlabs.io`. There is no env var for this; it’s deliberate so the API target is always visible in the command line. +The default is `https://api.elevenlabs.io` for ElevenLabs and `https://api.60db.ai` for 60db. There is no env var for this; it’s deliberate so the API target is always visible in the command line. ## Voice metadata cache -`sag voices --query` and `--label` need full voice descriptors. Metadata is cached in your platform-default config directory (`$XDG_CONFIG_HOME/sag/voices.json` on Linux, `~/Library/Application Support/sag/voices.json` on macOS) for 24 hours. Delete the file or pass `--limit 0` after a voice update to force a refresh. +`sag voices --query` and `--label` need full voice descriptors. Metadata is cached in your platform-default cache directory for 24 hours. Delete the file if you need an immediate refresh. ## Compatibility flags (no-ops) diff --git a/docs/index.md b/docs/index.md index 763e01c..4c11c0f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,12 +1,12 @@ --- title: Overview permalink: / -description: "sag is a Go CLI that turns text into ElevenLabs speech the way macOS `say` does. Stream to speakers, save to files, swap voices, choose models — one binary, terminal-first." +description: "sag is a Go CLI that turns text into macOS `say`-style speech with ElevenLabs or 60db. Stream to speakers, save to files, swap voices, choose models — one binary, terminal-first." --- ## Try it -After [installing](install.md) and exporting `ELEVENLABS_API_KEY`, every TTS call is a one-liner. +After [installing](install.md) and exporting exactly one provider key, every TTS call is a one-liner. ```bash # Speak through the speakers, macOS-say style. @@ -31,8 +31,9 @@ sag voices --search english --limit 5 --try ## What sag does - **Drop-in `say` replacement.** Same flag shapes (`-v`, `-r`, `-f`, `-o`), same default of streaming to the speakers. Compatibility no-ops (`--progress`, `--audio-device`, `--quality`, …) keep existing scripts working. -- **Stream-while-you-generate.** Audio plays as bytes arrive over `/v1/text-to-speech/{voice}/stream`. Latency tiers let you trade quality for first-byte time. +- **Stream-while-you-generate.** Audio plays as bytes arrive over ElevenLabs `/v1/text-to-speech/{voice}/stream` or 60db `/tts-stream`, with automatic fallback to full synthesis when the 60db route cannot represent the requested file format. - **Voice discovery.** Server-side name search, semantic `--query` over name/description/labels, repeatable `--label key=value` filters, plus `--try` to play preview clips for the matches. +- **Optional 60db backend.** `sag` merges 60db `/default-voices` and `/myvoices`, validates JSON success envelopes even on HTTP 200, and decodes the documented base64/NDJSON audio responses internally. - **Every ElevenLabs model.** Defaults to `eleven_v3` for expressive output; switch to `eleven_multilingual_v2`, `eleven_flash_v2_5`, or `eleven_turbo_v2_5` with `--model-id`. Stability, similarity, style, speaker-boost, seed, normalization, and language are all flag-controlled. - **Format inference.** `.mp3` → `mp3_44100_128`, `.wav` → `pcm_44100`, `.ogg`/`.opus` → `opus_48000_64`. Override with `--format` when you need something else. - **Cross-platform playback.** macOS uses `afplay` for AirPlay-friendly routing; Linux and Windows fall back to a `go-mp3` + `oto` decoder. Pick explicitly with `--player auto|afplay|oto`. @@ -49,4 +50,4 @@ sag voices --search english --limit 5 --try ## Project -`sag` is MIT-licensed and not affiliated with ElevenLabs or Apple. The [changelog](https://github.com/steipete/sag/blob/main/CHANGELOG.md) tracks releases; the [spec](spec.md) records goals and non-goals. Source: . +`sag` is MIT-licensed and not affiliated with ElevenLabs, 60db, or Apple. The [changelog](https://github.com/steipete/sag/blob/main/CHANGELOG.md) tracks releases; the [spec](spec.md) records goals and non-goals. Source: . diff --git a/docs/providers.md b/docs/providers.md new file mode 100644 index 0000000..f89ad2c --- /dev/null +++ b/docs/providers.md @@ -0,0 +1,74 @@ +# Providers + +`sag` supports two HTTP TTS backends. The CLI auto-selects the provider from your configured credentials; there is no `--provider` flag. + +## Selecting a provider + +| Keys set | Active provider | +| --- | --- | +| `ELEVENLABS_API_KEY` / `SAG_API_KEY` or `--api-key` / `--api-key-file` only | ElevenLabs | +| `SIXTYDB_API_KEY` or `SIXTYDB_API_KEY_FILE` only | 60db | +| both | error | +| neither | error | + +Use `--base-url` to override the active provider host. + +## 60db routes sag uses + +The 60db integration is deliberately limited to the documented REST contract: + +| Capability | Route | Notes | +| --- | --- | --- | +| Voice listing | `GET /default-voices` | Workspace-default voices; queried first. | +| Voice listing | `GET /myvoices` | User-created voices; appended after defaults and deduped by `voice_id`. | +| Streaming speak | `POST /tts-stream` | NDJSON stream of base64 chunks; no `output_format` field in the request. | +| Full speak | `POST /tts-synthesize` | JSON envelope with `audio_base64`, `encoding`, and `output_format`. | + +For 60db, `sag` validates the JSON `success` envelope even on HTTP 200 responses, decodes base64/NDJSON audio internally, rejects malformed or empty audio, and enforces per-chunk, per-frame, and total-audio limits. + +## Flag behavior + +### Supported on both providers + +- `-v, --voice` +- `-r, --rate` +- `--speed` +- `--stability` +- `--similarity` / `--similarity-boost` +- `--format` +- `--stream` / `--no-stream` +- `--play` / `--no-play` +- `-o, --output` +- `--timeout` +- `--metrics` + +### ElevenLabs-only + +These flags are passed through to ElevenLabs and rejected on 60db: + +- `--model-id` +- `--style` +- `--speaker-boost` +- `--no-speaker-boost` +- `--seed` +- `--normalize` +- `--lang` +- `--latency-tier` + +### Parameter translation + +- `--stability` and `--similarity` stay in the CLI's `0..1` range and are scaled to 60db's documented `0..100` request values. +- `--format` is canonicalized for 60db full synthesis: `mp3_*` → `mp3`, `pcm_*` / `wav` → `wav`, `opus_*` / `ogg` → `ogg`, `flac` → `flac`. + +## Streaming, files, and playback + +- ElevenLabs can stream in the requested output format, so `--stream` and `--format` work together. +- 60db streaming is used only for the default MP3 path because `/tts-stream` does not document `output_format`. +- On 60db, if the effective output format is non-MP3 and streaming was only enabled by the default, `sag` automatically falls back to `POST /tts-synthesize`. +- On 60db, `--stream` plus a non-MP3 format is an error when you explicitly force `--stream`. +- On 60db, `--play` requires MP3 output. Use `--no-play -o out.wav` (or `ogg` / `flac`) for other formats. + +## Voice metadata notes + +- `sag voices --try` uses `GET /voices/:id` on 60db to fetch `sample_url` when the list responses do not include preview URLs. +- 60db voice `model` and `categories` values are exposed as CLI labels so `--query` and `--label model=...` still work. diff --git a/docs/spec.md b/docs/spec.md index e844162..5f6f08e 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -1,11 +1,11 @@ # sag specification -CLI that mirrors macOS `say` but uses ElevenLabs for synthesis. Defaults to streaming directly to speakers and can also write audio files. +CLI that mirrors macOS `say` but uses ElevenLabs or 60db for synthesis. Defaults to streaming directly to speakers and can also write audio files. ## Runtime & deps - Go 1.24+ - Playback uses built-in Go audio (go-mp3 + oto) and should work on macOS/Linux/Windows with a default output device. -- Auth via `ELEVENLABS_API_KEY` (or `--api-key` flag). +- Auth via exactly one configured provider: ElevenLabs (`ELEVENLABS_API_KEY`, `SAG_API_KEY`, `--api-key`) or 60db (`SIXTYDB_API_KEY`). ## Commands @@ -16,7 +16,7 @@ CLI that mirrors macOS `say` but uses ElevenLabs for synthesis. Defaults to stre - `-r/--rate` words-per-minute (default 175) maps to ElevenLabs speed. - `-o/--output` same meaning; format inferred by extension when possible. - Accepts but ignores `--progress`, `--audio-device`, `--network-send`, `--interactive`, `--file-format`, `--data-format`, `--channels`, `--bit-rate`, `--quality`. -- Required: voice (via `-v/--voice` or `ELEVENLABS_VOICE_ID`/`SAG_VOICE_ID`). +- Required: voice (via `-v/--voice` or the active provider default resolution). - Flags: - `--model-id` (default `eleven_v3`; common: `eleven_multilingual_v2`, `eleven_flash_v2_5`, `eleven_turbo_v2_5`) - `--format` (default `mp3_44100_128`; `.wav` infers `pcm_44100`) @@ -34,8 +34,10 @@ CLI that mirrors macOS `say` but uses ElevenLabs for synthesis. Defaults to stre - `--metrics` print basic stats to stderr - `--output ` save audio while optionally playing - Behavior: - - Streaming path calls `POST /v1/text-to-speech/{voice_id}/stream` with JSON body. - - Non-streaming path calls `POST /v1/text-to-speech/{voice_id}` and then plays/saves. + - ElevenLabs streaming path calls `POST /v1/text-to-speech/{voice_id}/stream` with JSON body. + - 60db streaming path calls `POST /tts-stream` and decodes the documented NDJSON chunk stream. + - ElevenLabs non-streaming path calls `POST /v1/text-to-speech/{voice_id}` and then plays/saves. + - 60db non-streaming path calls `POST /tts-synthesize` and decodes `audio_base64`. - Errors if neither playback nor output is selected. Usage examples: @@ -49,7 +51,7 @@ sag speak -v "Roger" -r 200 "mac say style flags" ``` ### `sag voices` -- Lists voices via `GET /v1/voices` (server-side search when supported). +- Lists voices via the active provider. ElevenLabs uses `GET /v1/voices`; 60db merges `GET /default-voices` and `GET /myvoices`. - Flags: - `--search `: search by name (server-side when available) - `--query `: semantic query across name/description/labels (client-side) @@ -67,9 +69,11 @@ sag voices --search "english" - Does not require an API key. ## Config sources -- `ELEVENLABS_API_KEY` for auth (required). -- Default voice env: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. -- `--base-url` flag for alternate API host (defaults to `https://api.elevenlabs.io`). +- Exactly one provider key is required. +- ElevenLabs auth: `ELEVENLABS_API_KEY`, `SAG_API_KEY`, `--api-key`, `--api-key-file`. +- 60db auth: `SIXTYDB_API_KEY`, `SIXTYDB_API_KEY_FILE`. +- ElevenLabs default voice env: `ELEVENLABS_VOICE_ID` or `SAG_VOICE_ID`. +- `--base-url` flag for an alternate provider API host. ## Notes & future polish - Add cross-platform playback backends. diff --git a/internal/elevenlabs/client.go b/internal/elevenlabs/client.go index 0546ca2..b7ccdb0 100644 --- a/internal/elevenlabs/client.go +++ b/internal/elevenlabs/client.go @@ -10,6 +10,8 @@ import ( "net/url" "path" "strings" + + "github.com/steipete/sag/internal/tts" ) // Client talks to the ElevenLabs HTTP API. @@ -31,15 +33,12 @@ func NewClient(apiKey, baseURL string) *Client { } } -// Voice represents a voice entry returned by ElevenLabs. -type Voice struct { - VoiceID string `json:"voice_id"` - Name string `json:"name"` - Category string `json:"category"` - Description string `json:"description"` - Labels map[string]string `json:"labels,omitempty"` - PreviewURL string `json:"preview_url"` -} +// Voice is re-exported from the shared voice-catalog package so existing +// callers can keep using elevenlabs.Voice while query/filter code stays shared. +type ( + // Voice represents a voice entry returned by ElevenLabs. + Voice = tts.Voice +) type listVoicesResponse struct { Voices []Voice `json:"voices"` diff --git a/internal/sixtydb/client.go b/internal/sixtydb/client.go new file mode 100644 index 0000000..ad30b41 --- /dev/null +++ b/internal/sixtydb/client.go @@ -0,0 +1,654 @@ +// Package sixtydb provides a strict adapter for the documented 60db HTTP TTS +// endpoints. +package sixtydb + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strings" + + "github.com/steipete/sag/internal/tts" +) + +// DefaultBaseURL is the public 60db API host. +const DefaultBaseURL = "https://api.60db.ai" + +const ( + maxDecodedAudioBytes = 96 << 20 + maxDecodedChunkBytes = 8 << 20 + maxStreamFrameBytes = 12 << 20 +) + +type audioFormat string + +const ( + audioFormatMP3 audioFormat = "mp3" + audioFormatWAV audioFormat = "wav" + audioFormatOGG audioFormat = "ogg" + audioFormatFLAC audioFormat = "flac" +) + +// Client talks to the 60db HTTP API. +type Client struct { + baseURL string + apiKey string + httpClient *http.Client +} + +// TTSRequest is the documented 60db TTS payload shape exposed by the CLI. +type TTSRequest struct { + Text string `json:"text"` + VoiceID string `json:"voice_id,omitempty"` + Speed *float64 `json:"speed,omitempty"` + Stability *float64 `json:"stability,omitempty"` + Similarity *float64 `json:"similarity,omitempty"` + OutputFormat string `json:"output_format,omitempty"` +} + +// NewClient returns a Client configured with the given API key and base URL. +func NewClient(apiKey, baseURL string) *Client { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return &Client{ + baseURL: baseURL, + apiKey: apiKey, + httpClient: &http.Client{}, + } +} + +func (c *Client) newRequest(ctx context.Context, method, endpoint string, body io.Reader) (*http.Request, error) { + u, err := url.Parse(c.baseURL) + if err != nil { + return nil, err + } + u.Path = path.Join(u.Path, endpoint) + + req, err := http.NewRequestWithContext(ctx, method, u.String(), body) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + return req, nil +} + +type envelope struct { + Success *bool `json:"success"` + Message string `json:"message"` + Error string `json:"error"` +} + +func (c *Client) errorFromEnvelope(op string, env envelope, fallback string) error { + msg := strings.TrimSpace(env.Message) + if msg == "" { + msg = strings.TrimSpace(env.Error) + } + if msg == "" { + msg = fallback + } + return fmt.Errorf("%s: %s", op, c.sanitize(msg)) +} + +func (c *Client) httpError(op string, status string, body []byte) error { + var env envelope + if len(body) > 0 && json.Unmarshal(body, &env) == nil { + return c.errorFromEnvelope(op, env, status) + } + return fmt.Errorf("%s: %s", op, status) +} + +func (c *Client) sanitize(msg string) string { + msg = strings.TrimSpace(msg) + if msg == "" || c.apiKey == "" { + return msg + } + msg = strings.ReplaceAll(msg, "Bearer "+c.apiKey, "Bearer [redacted]") + msg = strings.ReplaceAll(msg, c.apiKey, "[redacted]") + return msg +} + +// CanonicalOutputFormat maps CLI/ElevenLabs-style output names to the simple +// 60db formats documented for /tts-synthesize. +func CanonicalOutputFormat(format string) string { + format = strings.ToLower(strings.TrimSpace(format)) + switch { + case format == "": + return "" + case strings.HasPrefix(format, "mp3"): + return string(audioFormatMP3) + case strings.HasPrefix(format, "pcm"), strings.HasPrefix(format, "wav"): + return string(audioFormatWAV) + case strings.HasPrefix(format, "opus"), strings.HasPrefix(format, "ogg"): + return string(audioFormatOGG) + case strings.HasPrefix(format, "flac"): + return string(audioFormatFLAC) + default: + return format + } +} + +func parseAudioFormat(value string) audioFormat { + switch strings.ToLower(strings.TrimSpace(value)) { + case string(audioFormatMP3): + return audioFormatMP3 + case string(audioFormatWAV): + return audioFormatWAV + case string(audioFormatOGG): + return audioFormatOGG + case string(audioFormatFLAC): + return audioFormatFLAC + default: + return "" + } +} + +func sniffAudioFormat(data []byte) (audioFormat, error) { + if len(data) == 0 { + return "", errors.New("empty audio") + } + switch { + case len(data) >= 3 && string(data[:3]) == "ID3": + return audioFormatMP3, nil + case len(data) >= 2 && data[0] == 0xff && data[1]&0xe0 == 0xe0: + return audioFormatMP3, nil + case len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WAVE": + return audioFormatWAV, nil + case len(data) >= 4 && string(data[:4]) == "OggS": + return audioFormatOGG, nil + case len(data) >= 4 && string(data[:4]) == "fLaC": + return audioFormatFLAC, nil + default: + return "", errors.New("unrecognized audio format") + } +} + +type voiceEntry struct { + VoiceID string `json:"voice_id"` + Name string `json:"name"` + Category string `json:"category"` + Model string `json:"model"` + Labels map[string]string `json:"labels"` + Description *string `json:"description"` + Categories []string `json:"categories"` +} + +type voicesResponse struct { + envelope + Data []voiceEntry `json:"data"` +} + +type voiceDetailsResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Language string `json:"language"` + Gender string `json:"gender"` + Age string `json:"age"` + Accent string `json:"accent"` + UseCase []string `json:"use_case"` + SampleURL string `json:"sample_url"` + IsCustom bool `json:"is_custom"` +} + +func (e voiceEntry) toVoice(source string) tts.Voice { + labels := make(map[string]string, len(e.Labels)+2) + for k, v := range e.Labels { + labels[k] = v + } + if e.Model != "" { + labels["model"] = e.Model + } + if len(e.Categories) > 0 { + labels["categories"] = strings.Join(e.Categories, ", ") + } + if source != "" { + labels["source"] = source + } + description := "" + if e.Description != nil { + description = strings.TrimSpace(*e.Description) + } + return tts.Voice{ + VoiceID: e.VoiceID, + Name: e.Name, + Category: e.Category, + Description: description, + Labels: labels, + } +} + +func (v voiceDetailsResponse) toVoice() tts.Voice { + labels := map[string]string{} + if v.Language != "" { + labels["language"] = v.Language + } + if v.Gender != "" { + labels["gender"] = v.Gender + } + if v.Age != "" { + labels["age"] = v.Age + } + if v.Accent != "" { + labels["accent"] = v.Accent + } + if len(v.UseCase) > 0 { + labels["use_case"] = strings.Join(v.UseCase, ", ") + } + if v.IsCustom { + labels["source"] = "myvoices" + } + return tts.Voice{ + VoiceID: v.ID, + Name: v.Name, + Description: strings.TrimSpace(v.Description), + Labels: labels, + PreviewURL: strings.TrimSpace(v.SampleURL), + } +} + +func (c *Client) fetchVoices(ctx context.Context, endpoint, source string) ([]tts.Voice, error) { + req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return nil, c.httpError("list voices", resp.Status, body) + } + + var body voicesResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, fmt.Errorf("list voices: decode response: %w", err) + } + if body.Success == nil || !*body.Success { + return nil, c.errorFromEnvelope("list voices", body.envelope, "unexpected response envelope") + } + + voices := make([]tts.Voice, 0, len(body.Data)) + for _, entry := range body.Data { + voices = append(voices, entry.toVoice(source)) + } + return voices, nil +} + +// ListVoices merges the documented default and user-created voice catalogs. +func (c *Client) ListVoices(ctx context.Context) ([]tts.Voice, error) { + defaultVoices, err := c.fetchVoices(ctx, "/default-voices", "default") + if err != nil { + return nil, err + } + myVoices, err := c.fetchVoices(ctx, "/myvoices", "myvoices") + if err != nil { + return nil, err + } + + merged := make([]tts.Voice, 0, len(defaultVoices)+len(myVoices)) + seen := make(map[string]struct{}, len(defaultVoices)+len(myVoices)) + for _, voice := range append(defaultVoices, myVoices...) { + if voice.VoiceID == "" { + continue + } + if _, ok := seen[voice.VoiceID]; ok { + continue + } + seen[voice.VoiceID] = struct{}{} + merged = append(merged, voice) + } + return merged, nil +} + +// SearchVoices filters the merged voice catalog by name. +func (c *Client) SearchVoices(ctx context.Context, search string, limit int) ([]tts.Voice, error) { + voices, err := c.ListVoices(ctx) + if err != nil { + return nil, err + } + search = strings.TrimSpace(search) + if search != "" { + searchLower := strings.ToLower(search) + filtered := make([]tts.Voice, 0, len(voices)) + for _, voice := range voices { + if strings.Contains(strings.ToLower(voice.Name), searchLower) { + filtered = append(filtered, voice) + } + } + voices = filtered + } + if limit > 0 && len(voices) > limit { + voices = voices[:limit] + } + return voices, nil +} + +// GetVoice resolves a voice from the documented per-voice endpoint. +func (c *Client) GetVoice(ctx context.Context, voiceID string) (tts.Voice, error) { + req, err := c.newRequest(ctx, http.MethodGet, path.Join("/voices", voiceID), nil) + if err != nil { + return tts.Voice{}, err + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return tts.Voice{}, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return tts.Voice{}, c.httpError("get voice", resp.Status, body) + } + + var body voiceDetailsResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return tts.Voice{}, fmt.Errorf("get voice: decode response: %w", err) + } + if strings.TrimSpace(body.ID) == "" { + return tts.Voice{}, errors.New("get voice: missing voice id") + } + return body.toVoice(), nil +} + +type synthResponse struct { + envelope + AudioBase64 string `json:"audio_base64"` + SampleRate int `json:"sample_rate"` + DurationSeconds float64 `json:"duration_seconds"` + Encoding string `json:"encoding"` + OutputFormat string `json:"output_format"` +} + +func validateRequestedFormat(requested string) error { + if requested == "" { + return nil + } + if parseAudioFormat(requested) == "" { + return fmt.Errorf("unsupported 60db output format %q", requested) + } + return nil +} + +func validateConvertResponseFormat(requested string, resp synthResponse, data []byte) error { + sniffed, err := sniffAudioFormat(data) + if err != nil { + return err + } + + declared := parseAudioFormat(resp.OutputFormat) + if declared == "" { + declared = parseAudioFormat(resp.Encoding) + } + if declared != "" && declared != sniffed { + return fmt.Errorf("response format mismatch: declared %s, decoded %s", declared, sniffed) + } + + expected := parseAudioFormat(requested) + if expected != "" && sniffed != expected { + return fmt.Errorf("response format mismatch: requested %s, decoded %s", expected, sniffed) + } + return nil +} + +// ConvertTTS downloads the full audio and returns decoded bytes. +func (c *Client) ConvertTTS(ctx context.Context, reqBody TTSRequest) ([]byte, error) { + reqBody.OutputFormat = CanonicalOutputFormat(reqBody.OutputFormat) + if err := validateRequestedFormat(reqBody.OutputFormat); err != nil { + return nil, err + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + req, err := c.newRequest(ctx, http.MethodPost, "/tts-synthesize", bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return nil, c.httpError("synthesize audio", resp.Status, body) + } + + var body synthResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, fmt.Errorf("synthesize audio: decode response: %w", err) + } + if body.Success == nil || !*body.Success { + return nil, c.errorFromEnvelope("synthesize audio", body.envelope, "unexpected response envelope") + } + if strings.TrimSpace(body.AudioBase64) == "" { + return nil, errors.New("synthesize audio: empty audio_base64") + } + + decodedLen := base64.StdEncoding.DecodedLen(len(body.AudioBase64)) + if decodedLen <= 0 { + return nil, errors.New("synthesize audio: empty decoded audio") + } + if decodedLen > maxDecodedAudioBytes { + return nil, fmt.Errorf("synthesize audio: decoded audio exceeds %d bytes", maxDecodedAudioBytes) + } + + data := make([]byte, decodedLen) + n, err := base64.StdEncoding.Decode(data, []byte(body.AudioBase64)) + if err != nil { + return nil, fmt.Errorf("synthesize audio: decode audio_base64: %w", err) + } + data = data[:n] + if len(data) == 0 { + return nil, errors.New("synthesize audio: decoded audio was empty") + } + if err := validateConvertResponseFormat(reqBody.OutputFormat, body, data); err != nil { + return nil, fmt.Errorf("synthesize audio: %w", err) + } + return data, nil +} + +// StreamTTS requests streaming audio. The documented stream API does not +// accept output_format. +func (c *Client) StreamTTS(ctx context.Context, reqBody TTSRequest) (io.ReadCloser, error) { + if CanonicalOutputFormat(reqBody.OutputFormat) != "" { + return nil, errors.New("stream audio: output_format is not supported by /tts-stream") + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + req, err := c.newRequest(ctx, http.MethodPost, "/tts-stream", bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/x-ndjson") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + defer func() { _ = resp.Body.Close() }() + body, _ := io.ReadAll(resp.Body) + return nil, c.httpError("stream audio", resp.Status, body) + } + return newNDJSONAudioReader(ctx, resp.Body, c.sanitize), nil +} + +type streamFrame struct { + Type string `json:"type"` + Result struct { + AudioContent string `json:"audioContent"` + } `json:"result"` + Message string `json:"message"` +} + +type ndjsonAudioReader struct { + ctx context.Context + src io.ReadCloser + scanner *bufio.Scanner + pending []byte + err error + stop func() bool + sanitize func(string) string + + totalBytes int64 + sawChunk bool + sawComplete bool +} + +func newNDJSONAudioReader(ctx context.Context, src io.ReadCloser, sanitize func(string) string) *ndjsonAudioReader { + scanner := bufio.NewScanner(src) + scanner.Buffer(make([]byte, 64*1024), maxStreamFrameBytes) + if sanitize == nil { + sanitize = func(msg string) string { return msg } + } + + reader := &ndjsonAudioReader{ + ctx: ctx, + src: src, + scanner: scanner, + sanitize: sanitize, + } + reader.stop = context.AfterFunc(ctx, func() { + _ = src.Close() + }) + return reader +} + +func (r *ndjsonAudioReader) Read(p []byte) (int, error) { + for len(r.pending) == 0 { + if r.err != nil { + return 0, r.err + } + if err := r.fill(); err != nil { + r.err = err + if len(r.pending) == 0 { + return 0, err + } + } + } + n := copy(p, r.pending) + r.pending = r.pending[n:] + return n, nil +} + +func (r *ndjsonAudioReader) fill() error { + if err := r.ctx.Err(); err != nil { + return err + } + for r.scanner.Scan() { + line := bytes.TrimSpace(r.scanner.Bytes()) + if len(line) == 0 { + continue + } + + var frame streamFrame + if err := json.Unmarshal(line, &frame); err != nil { + return fmt.Errorf("decode stream frame: %w", err) + } + + switch frame.Type { + case "chunk": + audio, err := decodeChunk(frame.Result.AudioContent, r.totalBytes) + if err != nil { + return err + } + if !r.sawChunk { + if _, err := sniffAudioFormat(audio); err != nil { + return fmt.Errorf("unknown streamed audio format: %w", err) + } + } + r.sawChunk = true + r.totalBytes += int64(len(audio)) + r.pending = audio + return nil + case "complete": + r.sawComplete = true + if !r.sawChunk { + return errors.New("stream completed without audio") + } + return io.EOF + case "error": + if msg := strings.TrimSpace(frame.Message); msg != "" { + return fmt.Errorf("stream error: %s", r.sanitize(msg)) + } + return errors.New("stream error") + default: + return fmt.Errorf("unknown stream frame type %q", frame.Type) + } + } + + if err := r.scanner.Err(); err != nil { + if ctxErr := r.ctx.Err(); ctxErr != nil { + return ctxErr + } + return fmt.Errorf("read stream frame: %w", err) + } + if ctxErr := r.ctx.Err(); ctxErr != nil { + return ctxErr + } + if r.sawComplete { + return io.EOF + } + return io.ErrUnexpectedEOF +} + +func decodeChunk(encoded string, totalBytes int64) ([]byte, error) { + encoded = strings.TrimSpace(encoded) + if encoded == "" { + return nil, errors.New("stream chunk missing audioContent") + } + decodedLen := base64.StdEncoding.DecodedLen(len(encoded)) + if decodedLen <= 0 { + return nil, errors.New("stream chunk decoded to empty audio") + } + if decodedLen > maxDecodedChunkBytes { + return nil, fmt.Errorf("stream chunk exceeds %d bytes", maxDecodedChunkBytes) + } + if totalBytes+int64(decodedLen) > maxDecodedAudioBytes { + return nil, fmt.Errorf("stream audio exceeds %d bytes", maxDecodedAudioBytes) + } + + audio := make([]byte, decodedLen) + n, err := base64.StdEncoding.Decode(audio, []byte(encoded)) + if err != nil { + return nil, fmt.Errorf("decode audio chunk: %w", err) + } + audio = audio[:n] + if len(audio) == 0 { + return nil, errors.New("stream chunk decoded to empty audio") + } + return audio, nil +} + +func (r *ndjsonAudioReader) Close() error { + if r.stop != nil { + r.stop() + } + return r.src.Close() +} diff --git a/internal/sixtydb/client_test.go b/internal/sixtydb/client_test.go new file mode 100644 index 0000000..dd9faf4 --- /dev/null +++ b/internal/sixtydb/client_test.go @@ -0,0 +1,457 @@ +package sixtydb + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestNewClientDefaultsBase(t *testing.T) { + c := NewClient("key", "") + if c.baseURL != DefaultBaseURL { + t.Fatalf("unexpected baseURL: %s", c.baseURL) + } +} + +func TestListVoicesMergesDefaultAndMyVoices(t *testing.T) { + var calls []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, r.URL.Path) + if got := r.Header.Get("Authorization"); got != "Bearer key" { + t.Fatalf("unexpected auth header: %q", got) + } + switch r.URL.Path { + case "/default-voices": + _, _ = io.WriteString(w, `{"success":true,"data":[ + {"voice_id":"v1","name":"Aria","category":"default","model":"60db Quality","labels":{"accent":"US"}}, + {"voice_id":"dup","name":"Default Dup","category":"default"} + ]}`) + case "/myvoices": + _, _ = io.WriteString(w, `{"success":true,"data":[ + {"voice_id":"dup","name":"My Dup","category":"cloned"}, + {"voice_id":"v2","name":"Ravi","category":"cloned","categories":["narration","warm"]} + ]}`) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + voices, err := c.ListVoices(context.Background()) + if err != nil { + t.Fatalf("ListVoices error: %v", err) + } + if want := []string{"/default-voices", "/myvoices"}; strings.Join(calls, ",") != strings.Join(want, ",") { + t.Fatalf("route order = %v, want %v", calls, want) + } + if len(voices) != 3 { + t.Fatalf("expected 3 merged voices, got %d", len(voices)) + } + if voices[0].VoiceID != "v1" || voices[0].Labels["source"] != "default" || voices[0].Labels["model"] != "60db Quality" { + t.Fatalf("unexpected default voice: %+v", voices[0]) + } + if voices[1].VoiceID != "dup" || voices[1].Name != "Default Dup" { + t.Fatalf("expected default duplicate to win, got %+v", voices[1]) + } + if voices[2].VoiceID != "v2" || voices[2].Labels["source"] != "myvoices" || voices[2].Labels["categories"] != "narration, warm" { + t.Fatalf("unexpected user voice: %+v", voices[2]) + } +} + +func TestSearchVoicesFiltersAndLimits(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/default-voices": + _, _ = io.WriteString(w, `{"success":true,"data":[ + {"voice_id":"v1","name":"Roger"}, + {"voice_id":"v2","name":"Rogue"} + ]}`) + case "/myvoices": + _, _ = io.WriteString(w, `{"success":true,"data":[{"voice_id":"v3","name":"Sarah"}]}`) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + voices, err := c.SearchVoices(context.Background(), "rog", 1) + if err != nil { + t.Fatalf("SearchVoices error: %v", err) + } + if len(voices) != 1 || voices[0].VoiceID != "v1" { + t.Fatalf("unexpected voices: %+v", voices) + } +} + +func TestGetVoiceUsesDocumentedPerVoiceRoute(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/voices/voice-001" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _, _ = io.WriteString(w, `{ + "id":"voice-001", + "name":"Sarah", + "description":"Professional female voice", + "language":"en-US", + "gender":"female", + "age":"middle", + "accent":"American", + "use_case":["narration","customer-service"], + "sample_url":"https://cdn.60db.com/samples/voice-001.mp3", + "is_custom":false + }`) + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + voice, err := c.GetVoice(context.Background(), "voice-001") + if err != nil { + t.Fatalf("GetVoice error: %v", err) + } + if voice.VoiceID != "voice-001" || voice.Name != "Sarah" { + t.Fatalf("unexpected voice: %+v", voice) + } + if voice.PreviewURL != "https://cdn.60db.com/samples/voice-001.mp3" { + t.Fatalf("unexpected preview URL: %q", voice.PreviewURL) + } + if voice.Labels["use_case"] != "narration, customer-service" || voice.Labels["accent"] != "American" { + t.Fatalf("unexpected voice labels: %+v", voice.Labels) + } +} + +func TestListVoicesRejectsHTTP200ErrorEnvelopeWithoutLeakingToken(t *testing.T) { + const secret = "secret-token" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"success":false,"message":"Bearer `+secret+` invalid token"}`) + })) + defer srv.Close() + + c := NewClient(secret, srv.URL) + _, err := c.ListVoices(context.Background()) + assertSanitizedError(t, err, secret, "invalid token") +} + +func TestConvertTTSUsesDocumentedRouteAndValidatesResponse(t *testing.T) { + audio := []byte("ID3converted-audio") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tts-synthesize" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer key" { + t.Fatalf("unexpected auth header: %q", got) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body["text"] != "hi" || body["voice_id"] != "v1" { + t.Fatalf("unexpected request body: %+v", body) + } + if body["speed"] != 1.1 || body["stability"] != 50.0 || body["similarity"] != 80.0 || body["output_format"] != "mp3" { + t.Fatalf("unexpected mapped request body: %+v", body) + } + + resp := map[string]any{ + "success": true, + "audio_base64": base64.StdEncoding.EncodeToString(audio), + "encoding": "mp3", + "output_format": "mp3", + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + speed := 1.1 + stability := 50.0 + similarity := 80.0 + c := NewClient("key", srv.URL) + got, err := c.ConvertTTS(context.Background(), TTSRequest{ + Text: "hi", + VoiceID: "v1", + Speed: &speed, + Stability: &stability, + Similarity: &similarity, + OutputFormat: "mp3_44100_128", + }) + if err != nil { + t.Fatalf("ConvertTTS error: %v", err) + } + if !bytes.Equal(got, audio) { + t.Fatalf("decoded audio = %q, want %q", got, audio) + } +} + +func TestConvertTTSRejectsHTTP200ErrorEnvelopeWithoutLeakingToken(t *testing.T) { + const secret = "secret-token" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"success":false,"message":"Bearer `+secret+` invalid token"}`) + })) + defer srv.Close() + + c := NewClient(secret, srv.URL) + _, err := c.ConvertTTS(context.Background(), TTSRequest{Text: "hi"}) + assertSanitizedError(t, err, secret, "invalid token") +} + +func TestConvertTTSRejectsMalformedOrMismatchedAudio(t *testing.T) { + tests := []struct { + name string + body string + want string + }{ + { + name: "empty audio", + body: `{"success":true,"audio_base64":""}`, + want: "empty audio_base64", + }, + { + name: "unknown format", + body: `{"success":true,"audio_base64":"` + mustBase64([]byte("not-audio")) + `"}`, + want: "unrecognized audio format", + }, + { + name: "declared format mismatch", + body: `{"success":true,"audio_base64":"` + mustBase64([]byte("ID3mp3-data")) + `","output_format":"wav"}`, + want: "declared wav, decoded mp3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, tt.body) + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + _, err := c.ConvertTTS(context.Background(), TTSRequest{Text: "hi"}) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected %q, got %v", tt.want, err) + } + }) + } +} + +func TestStreamTTSUsesDocumentedRouteAndDecodesNDJSON(t *testing.T) { + chunk1 := mustBase64([]byte("ID3hello-")) + chunk2 := mustBase64([]byte("world")) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tts-stream" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer key" { + t.Fatalf("unexpected auth header: %q", got) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if _, ok := body["output_format"]; ok { + t.Fatalf("expected output_format omitted from stream body") + } + _, _ = io.WriteString(w, `{"type":"chunk","result":{"audioContent":"`+chunk1+`"}}`+"\n") + _, _ = io.WriteString(w, `{"type":"chunk","result":{"audioContent":"`+chunk2+`"}}`+"\n") + _, _ = io.WriteString(w, `{"type":"complete"}`+"\n") + })) + defer srv.Close() + + c := NewClient("key", srv.URL) + rc, err := c.StreamTTS(context.Background(), TTSRequest{Text: "hi"}) + if err != nil { + t.Fatalf("StreamTTS error: %v", err) + } + defer func() { _ = rc.Close() }() + + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("read stream: %v", err) + } + if string(got) != "ID3hello-world" { + t.Fatalf("unexpected stream body: %q", got) + } +} + +func TestStreamTTSRejectsInvalidTokenWithoutLeakingToken(t *testing.T) { + const secret = "secret-token" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, `{"success":false,"message":"Bearer `+secret+` invalid token"}`) + })) + defer srv.Close() + + c := NewClient(secret, srv.URL) + _, err := c.StreamTTS(context.Background(), TTSRequest{Text: "hi"}) + assertSanitizedError(t, err, secret, "invalid token") +} + +func TestStreamTTSRejectsMalformedStreams(t *testing.T) { + tests := []struct { + name string + lines []string + want string + isErr error + }{ + { + name: "bad json", + lines: []string{`not-json`}, + want: "decode stream frame", + }, + { + name: "unknown frame type", + lines: []string{`{"type":"wat"}`}, + want: `unknown stream frame type "wat"`, + }, + { + name: "missing audio content", + lines: []string{`{"type":"chunk","result":{"audioContent":""}}`}, + want: "missing audioContent", + }, + { + name: "unknown audio format", + lines: []string{`{"type":"chunk","result":{"audioContent":"` + mustBase64([]byte("bad")) + `"}}`}, + want: "unknown streamed audio format", + }, + { + name: "complete without audio", + lines: []string{`{"type":"complete"}`}, + want: "stream completed without audio", + }, + { + name: "empty stream", + lines: nil, + isErr: io.ErrUnexpectedEOF, + }, + { + name: "missing complete frame", + lines: []string{`{"type":"chunk","result":{"audioContent":"` + mustBase64([]byte("ID3chunk")) + `"}}`}, + isErr: io.ErrUnexpectedEOF, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := newNDJSONAudioReader(context.Background(), io.NopCloser(strings.NewReader(strings.Join(tt.lines, "\n"))), nil) + defer func() { _ = reader.Close() }() + + _, err := io.ReadAll(reader) + if tt.isErr != nil { + if !errors.Is(err, tt.isErr) { + t.Fatalf("expected %v, got %v", tt.isErr, err) + } + return + } + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected %q, got %v", tt.want, err) + } + }) + } +} + +func TestNDJSONAudioReaderHonorsCancellation(t *testing.T) { + pr, pw := io.Pipe() + ctx, cancel := context.WithCancel(context.Background()) + reader := newNDJSONAudioReader(ctx, pr, nil) + defer func() { _ = reader.Close() }() + + done := make(chan error, 1) + go func() { + _, err := io.ReadAll(reader) + done <- err + }() + + cancel() + _ = pw.Close() + + select { + case err := <-done: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("stream read did not unblock on cancellation") + } +} + +func TestNDJSONAudioReaderRejectsOversizedFrame(t *testing.T) { + line := strings.Repeat("a", maxStreamFrameBytes+1) + reader := newNDJSONAudioReader(context.Background(), io.NopCloser(strings.NewReader(line)), nil) + defer func() { _ = reader.Close() }() + + _, err := io.ReadAll(reader) + if err == nil || !strings.Contains(err.Error(), "token too long") { + t.Fatalf("expected oversized frame error, got %v", err) + } +} + +func TestNDJSONAudioReaderSanitizesErrorFrames(t *testing.T) { + const secret = "secret-token" + reader := newNDJSONAudioReader( + context.Background(), + io.NopCloser(strings.NewReader(`{"type":"error","message":"Bearer `+secret+` invalid token"}`)), + func(msg string) string { + return strings.ReplaceAll(msg, secret, "[redacted]") + }, + ) + defer func() { _ = reader.Close() }() + + _, err := io.ReadAll(reader) + assertSanitizedError(t, err, secret, "invalid token") +} + +func TestDecodeChunkRejectsPerChunkAndTotalLimits(t *testing.T) { + tooLargeChunk := mustBase64(bytes.Repeat([]byte{'a'}, maxDecodedChunkBytes+1)) + if _, err := decodeChunk(tooLargeChunk, 0); err == nil || !strings.Contains(err.Error(), "stream chunk exceeds") { + t.Fatalf("expected chunk limit error, got %v", err) + } + + if _, err := decodeChunk(mustBase64([]byte("abc")), maxDecodedAudioBytes-2); err == nil || !strings.Contains(err.Error(), "stream audio exceeds") { + t.Fatalf("expected total limit error, got %v", err) + } +} + +func TestCanonicalOutputFormat(t *testing.T) { + cases := map[string]string{ + "mp3_44100_128": "mp3", + "pcm_44100": "wav", + "wav": "wav", + "opus_48000_64": "ogg", + "ogg": "ogg", + "flac": "flac", + "": "", + } + for in, want := range cases { + if got := CanonicalOutputFormat(in); got != want { + t.Fatalf("CanonicalOutputFormat(%q) = %q, want %q", in, got, want) + } + } +} + +func mustBase64(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +func assertSanitizedError(t *testing.T, err error, secret, want string) { + t.Helper() + if err == nil { + t.Fatal("expected error") + } + if strings.Contains(err.Error(), secret) { + t.Fatalf("error leaked secret: %v", err) + } + if !strings.Contains(err.Error(), "[redacted]") { + t.Fatalf("expected redacted marker, got %v", err) + } + if !strings.Contains(err.Error(), want) { + t.Fatalf("expected %q in error, got %v", want, err) + } +} diff --git a/internal/tts/tts.go b/internal/tts/tts.go new file mode 100644 index 0000000..c426570 --- /dev/null +++ b/internal/tts/tts.go @@ -0,0 +1,25 @@ +// Package tts holds the small voice-catalog types shared by the existing +// `voices` and voice-resolution command paths. +package tts + +import ( + "context" +) + +// Voice represents a single voice entry normalized for the CLI's voice listing, +// filtering, query, and resolution paths. +type Voice struct { + VoiceID string `json:"voice_id"` + Name string `json:"name"` + Category string `json:"category"` + Description string `json:"description"` + Labels map[string]string `json:"labels,omitempty"` + PreviewURL string `json:"preview_url"` +} + +// VoiceCatalog is the minimal shared interface needed by existing commands. +type VoiceCatalog interface { + ListVoices(ctx context.Context) ([]Voice, error) + SearchVoices(ctx context.Context, search string, limit int) ([]Voice, error) + GetVoice(ctx context.Context, voiceID string) (Voice, error) +}