Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/CONFIG_ENV_VARS.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Total variables: **369**
| `JIRA_CLOSE_TRANSITIONS` | `getEnv` | `"Done,Closed,Resolve Issue"` | `JiraCloseTransitions` | `-` |
| `JIRA_EMAIL` | `getEnv` | `""` | `JiraEmail` | `-` |
| `JIRA_PROJECT` | `getEnv` | `"SEC"` | `JiraProject` | `-` |
| `JOB_DATABASE_URL` | `getEnv` | `""` | `JobDatabaseURL` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` |
| `JOB_DATABASE_URL` | `getEnv` | `""` | `-` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` |
| `JOB_MAX_ATTEMPTS` | `getEnvInt` | `3` | `JobMaxAttempts` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` |
| `JOB_NATS_CONSUMER` | `getEnv` | `"job-worker"` | `JobNATSConsumer` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` |
| `JOB_NATS_STREAM` | `getEnv` | `"CEREBRO_JOBS"` | `JobNATSStream` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` |
Expand Down Expand Up @@ -352,7 +352,7 @@ Total variables: **369**
| `VAULT_TOKEN` | `getEnv` | `""` | `VaultToken` | `-` |
| `VULNDB_STATE_FILE` | `getEnv` | `filepath.Join(".cerebro", "vulndb.db")` | `VulnDBStateFile` | `-` |
| `WAREHOUSE_BACKEND` | `getEnv` | `defaultWarehouseBackend` | `WarehouseBackend` | `backend-specific connection settings must be present when an alternative warehouse backend is selected`, `must be one of snowflake, sqlite, postgres` |
| `WAREHOUSE_POSTGRES_DSN` | `getEnv` | `""` | `WarehousePostgresDSN` | `backend-specific connection settings must be present when an alternative warehouse backend is selected` |
| `WAREHOUSE_POSTGRES_DSN` | `getEnv` | `""` | `-` | `backend-specific connection settings must be present when an alternative warehouse backend is selected` |
| `WAREHOUSE_SQLITE_PATH` | `getEnv` | `defaultWarehouseSQLitePath` | `WarehouseSQLitePath` | `backend-specific connection settings must be present when an alternative warehouse backend is selected` |
| `WEBHOOK_URLS` | `getEnv` | `""` | `WebhookURLs` | `-` |
| `WIZ_API_URL` | `getEnv` | `""` | `WizAPIURL` | `-` |
Expand Down
161 changes: 161 additions & 0 deletions internal/agents/session_store_postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package agents

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
)

const postgresSessionTable = "cerebro_agent_sessions"

type PostgresSessionStore struct {
db *sql.DB
rewriteSQL func(string) string
}

func NewPostgresSessionStore(db *sql.DB) *PostgresSessionStore {
return &PostgresSessionStore{db: db}
}

func (s *PostgresSessionStore) EnsureSchema(ctx context.Context) error {
if s == nil || s.db == nil {
return fmt.Errorf("postgres session store is not initialized")
}
_, err := s.db.ExecContext(ctx, s.q(`
CREATE TABLE IF NOT EXISTS `+postgresSessionTable+` (
id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
user_id TEXT NOT NULL,
status TEXT NOT NULL,
messages TEXT NOT NULL DEFAULT '[]',
context TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_`+postgresSessionTable+`_updated_at ON `+postgresSessionTable+` (updated_at);
`))
return err
}

func (s *PostgresSessionStore) Save(ctx context.Context, session *Session) error {
if s == nil || s.db == nil {
return fmt.Errorf("postgres session store is not initialized")
}
if session == nil {
return fmt.Errorf("session is required")
}
if err := s.EnsureSchema(ctx); err != nil {
return err
}

messagesJSON, err := json.Marshal(session.Messages)
if err != nil {
return err
}
if len(messagesJSON) == 0 {
messagesJSON = []byte("[]")
}

contextJSON, err := json.Marshal(session.Context)
if err != nil {
return err
}
if len(contextJSON) == 0 {
contextJSON = []byte("{}")
}

createdAt := session.CreatedAt.UTC()
if createdAt.IsZero() {
createdAt = time.Now().UTC()
session.CreatedAt = createdAt
}
updatedAt := session.UpdatedAt.UTC()
if updatedAt.IsZero() {
updatedAt = createdAt
session.UpdatedAt = updatedAt
}

_, err = s.db.ExecContext(ctx, s.q(`
INSERT INTO `+postgresSessionTable+` (
id, agent_id, user_id, status, messages, context, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (id) DO UPDATE SET
agent_id = EXCLUDED.agent_id,
user_id = EXCLUDED.user_id,
status = EXCLUDED.status,
messages = EXCLUDED.messages,
context = EXCLUDED.context,
updated_at = EXCLUDED.updated_at
`),
session.ID,
session.AgentID,
session.UserID,
session.Status,
string(messagesJSON),
string(contextJSON),
createdAt,
updatedAt,
)
return err
}

func (s *PostgresSessionStore) Get(ctx context.Context, id string) (*Session, error) {
if s == nil || s.db == nil {
return nil, fmt.Errorf("postgres session store is not initialized")
}
if err := s.EnsureSchema(ctx); err != nil {
return nil, err
}

var session Session
var messagesRaw string
var contextRaw string

err := s.db.QueryRowContext(ctx, s.q(`
SELECT id, agent_id, user_id, status, messages, context, created_at, updated_at
FROM `+postgresSessionTable+`
WHERE id = $1
`), id).Scan(
&session.ID,
&session.AgentID,
&session.UserID,
&session.Status,
&messagesRaw,
&contextRaw,
&session.CreatedAt,
&session.UpdatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}

if messagesJSON := normalizeVariantJSON(messagesRaw); len(messagesJSON) > 0 {
if err := json.Unmarshal(messagesJSON, &session.Messages); err != nil {
return nil, err
}
}
if contextJSON := normalizeVariantJSON(contextRaw); len(contextJSON) > 0 {
if err := json.Unmarshal(contextJSON, &session.Context); err != nil {
return nil, err
}
}

session.CreatedAt = session.CreatedAt.UTC()
session.UpdatedAt = session.UpdatedAt.UTC()
return &session, nil
}

func (s *PostgresSessionStore) q(query string) string {
if s != nil && s.rewriteSQL != nil {
return s.rewriteSQL(query)
}
return query
}

var _ SessionStore = (*PostgresSessionStore)(nil)
109 changes: 109 additions & 0 deletions internal/agents/session_store_postgres_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package agents

import (
"context"
"database/sql"
"regexp"
"testing"
"time"

_ "modernc.org/sqlite"
)

var sessionStoreDollarPlaceholderRe = regexp.MustCompile(`\$\d+`)

func sessionStoreSQLiteRewrite(query string) string {
return sessionStoreDollarPlaceholderRe.ReplaceAllString(query, "?")
}

func newTestPostgresSessionStore(t *testing.T) *PostgresSessionStore {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
store := &PostgresSessionStore{
db: db,
rewriteSQL: sessionStoreSQLiteRewrite,
}
if err := store.EnsureSchema(context.Background()); err != nil {
_ = db.Close()
t.Fatalf("EnsureSchema() error = %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return store
}

func TestPostgresSessionStoreSaveAndGet(t *testing.T) {
store := newTestPostgresSessionStore(t)
now := time.Now().UTC().Truncate(time.Second)
session := &Session{
ID: "session-1",
AgentID: "agent-1",
UserID: "user-1",
Status: "active",
Messages: []Message{
{Role: "user", Content: "investigate bucket"},
},
Context: SessionContext{
FindingIDs: []string{"finding-1"},
Metadata: map[string]interface{}{"tenant": "tenant-a"},
},
CreatedAt: now,
UpdatedAt: now.Add(time.Minute),
}

if err := store.Save(context.Background(), session); err != nil {
t.Fatalf("Save() error = %v", err)
}

got, err := store.Get(context.Background(), session.ID)
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if got == nil {
t.Fatal("expected persisted session")
}
if got.UserID != session.UserID {
t.Fatalf("UserID = %q, want %q", got.UserID, session.UserID)
}
if len(got.Messages) != 1 || got.Messages[0].Content != "investigate bucket" {
t.Fatalf("unexpected messages: %#v", got.Messages)
}
if len(got.Context.FindingIDs) != 1 || got.Context.FindingIDs[0] != "finding-1" {
t.Fatalf("unexpected context finding ids: %#v", got.Context.FindingIDs)
}
}

func TestPostgresSessionStoreSaveUpdatesExistingRow(t *testing.T) {
store := newTestPostgresSessionStore(t)
session := &Session{
ID: "session-2",
AgentID: "agent-1",
UserID: "user-1",
Status: "active",
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
if err := store.Save(context.Background(), session); err != nil {
t.Fatalf("initial Save() error = %v", err)
}

session.Status = "completed"
session.Messages = []Message{{Role: "assistant", Content: "done"}}
session.UpdatedAt = session.UpdatedAt.Add(2 * time.Minute)
if err := store.Save(context.Background(), session); err != nil {
t.Fatalf("update Save() error = %v", err)
}

got, err := store.Get(context.Background(), session.ID)
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if got.Status != "completed" {
t.Fatalf("Status = %q, want completed", got.Status)
}
if len(got.Messages) != 1 || got.Messages[0].Content != "done" {
t.Fatalf("unexpected updated messages: %#v", got.Messages)
}
}
41 changes: 38 additions & 3 deletions internal/agents/session_store_snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,44 @@ func (s *SnowflakeSessionStore) Get(ctx context.Context, id string) (*Session, e

row := s.db.QueryRowContext(ctx, query, id)

session, err := scanSnowflakeSession(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
return session, nil
}

func (s *SnowflakeSessionStore) ListAll(ctx context.Context) ([]*Session, error) {
// #nosec G202 -- s.tableRef is validated by snowflake.SafeTableRef before interpolation.
query := `
SELECT id, agent_id, user_id, status, messages, context, created_at, updated_at
FROM ` + s.tableRef + `
ORDER BY updated_at DESC
`

rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()

sessions := make([]*Session, 0)
for rows.Next() {
session, err := scanSnowflakeSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, session)
}
return sessions, rows.Err()
}

func scanSnowflakeSession(row interface {
Scan(dest ...any) error
}) (*Session, error) {
var session Session
var messagesRaw any
var contextRaw any
Expand All @@ -131,9 +169,6 @@ func (s *SnowflakeSessionStore) Get(ctx context.Context, id string) (*Session, e
&updatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}

Expand Down
12 changes: 6 additions & 6 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,28 +628,28 @@ func (s *Server) adminHealth(w http.ResponseWriter, r *http.Request) {
"timestamp": time.Now().UTC(),
}

// Snowflake status
if s.app.Snowflake != nil {
// Warehouse status
if s.app.Warehouse != nil && s.app.Warehouse.DB() != nil {
ctx, cancel := context.WithTimeout(r.Context(), s.healthCheckTimeout())
start := time.Now()
err := s.app.Snowflake.Ping(ctx)
err := s.app.Warehouse.DB().PingContext(ctx)
cancel()
latency := time.Since(start).Milliseconds()

if err != nil {
health["snowflake"] = map[string]interface{}{
health["warehouse"] = map[string]interface{}{
"status": "unhealthy",
"error": err.Error(),
"latency_ms": latency,
}
} else {
health["snowflake"] = map[string]interface{}{
health["warehouse"] = map[string]interface{}{
"status": "healthy",
"latency_ms": latency,
}
}
} else {
health["snowflake"] = map[string]interface{}{"status": "not_configured"}
health["warehouse"] = map[string]interface{}{"status": "not_configured"}
}

// Findings stats
Expand Down
Loading