diff --git a/docs/CONFIG_ENV_VARS.md b/docs/CONFIG_ENV_VARS.md index 13d0ed227..df9b3a899 100644 --- a/docs/CONFIG_ENV_VARS.md +++ b/docs/CONFIG_ENV_VARS.md @@ -182,7 +182,7 @@ Total variables: **369** | `JIRA_CLOSE_TRANSITIONS` | `getEnv` | `"Done,Closed,Resolve Issue"` | `JiraCloseTransitions` | `-` | | `JIRA_EMAIL` | `getEnv` | `""` | `JiraEmail` | `-` | | `JIRA_PROJECT` | `getEnv` | `"SEC"` | `JiraProject` | `-` | -| `JOB_DATABASE_URL` | `getEnv` | `""` | `JobDatabaseURL` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` | +| `JOB_DATABASE_URL` | `getEnv` | `""` | `-` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` | | `JOB_MAX_ATTEMPTS` | `getEnvInt` | `3` | `JobMaxAttempts` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` | | `JOB_NATS_CONSUMER` | `getEnv` | `"job-worker"` | `JobNATSConsumer` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` | | `JOB_NATS_STREAM` | `getEnv` | `"CEREBRO_JOBS"` | `JobNATSStream` | `when JOB_DATABASE_URL is configured, NATS settings must be present and worker timing/count controls must be positive` | @@ -352,7 +352,7 @@ Total variables: **369** | `VAULT_TOKEN` | `getEnv` | `""` | `VaultToken` | `-` | | `VULNDB_STATE_FILE` | `getEnv` | `filepath.Join(".cerebro", "vulndb.db")` | `VulnDBStateFile` | `-` | | `WAREHOUSE_BACKEND` | `getEnv` | `defaultWarehouseBackend` | `WarehouseBackend` | `backend-specific connection settings must be present when an alternative warehouse backend is selected`, `must be one of snowflake, sqlite, postgres` | -| `WAREHOUSE_POSTGRES_DSN` | `getEnv` | `""` | `WarehousePostgresDSN` | `backend-specific connection settings must be present when an alternative warehouse backend is selected` | +| `WAREHOUSE_POSTGRES_DSN` | `getEnv` | `""` | `-` | `backend-specific connection settings must be present when an alternative warehouse backend is selected` | | `WAREHOUSE_SQLITE_PATH` | `getEnv` | `defaultWarehouseSQLitePath` | `WarehouseSQLitePath` | `backend-specific connection settings must be present when an alternative warehouse backend is selected` | | `WEBHOOK_URLS` | `getEnv` | `""` | `WebhookURLs` | `-` | | `WIZ_API_URL` | `getEnv` | `""` | `WizAPIURL` | `-` | diff --git a/internal/agents/session_store_postgres.go b/internal/agents/session_store_postgres.go new file mode 100644 index 000000000..73d5158e9 --- /dev/null +++ b/internal/agents/session_store_postgres.go @@ -0,0 +1,161 @@ +package agents + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" +) + +const postgresSessionTable = "cerebro_agent_sessions" + +type PostgresSessionStore struct { + db *sql.DB + rewriteSQL func(string) string +} + +func NewPostgresSessionStore(db *sql.DB) *PostgresSessionStore { + return &PostgresSessionStore{db: db} +} + +func (s *PostgresSessionStore) EnsureSchema(ctx context.Context) error { + if s == nil || s.db == nil { + return fmt.Errorf("postgres session store is not initialized") + } + _, err := s.db.ExecContext(ctx, s.q(` +CREATE TABLE IF NOT EXISTS `+postgresSessionTable+` ( + id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + user_id TEXT NOT NULL, + status TEXT NOT NULL, + messages TEXT NOT NULL DEFAULT '[]', + context TEXT NOT NULL DEFAULT '{}', + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_`+postgresSessionTable+`_updated_at ON `+postgresSessionTable+` (updated_at); +`)) + return err +} + +func (s *PostgresSessionStore) Save(ctx context.Context, session *Session) error { + if s == nil || s.db == nil { + return fmt.Errorf("postgres session store is not initialized") + } + if session == nil { + return fmt.Errorf("session is required") + } + if err := s.EnsureSchema(ctx); err != nil { + return err + } + + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + return err + } + if len(messagesJSON) == 0 { + messagesJSON = []byte("[]") + } + + contextJSON, err := json.Marshal(session.Context) + if err != nil { + return err + } + if len(contextJSON) == 0 { + contextJSON = []byte("{}") + } + + createdAt := session.CreatedAt.UTC() + if createdAt.IsZero() { + createdAt = time.Now().UTC() + session.CreatedAt = createdAt + } + updatedAt := session.UpdatedAt.UTC() + if updatedAt.IsZero() { + updatedAt = createdAt + session.UpdatedAt = updatedAt + } + + _, err = s.db.ExecContext(ctx, s.q(` +INSERT INTO `+postgresSessionTable+` ( + id, agent_id, user_id, status, messages, context, created_at, updated_at +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +ON CONFLICT (id) DO UPDATE SET + agent_id = EXCLUDED.agent_id, + user_id = EXCLUDED.user_id, + status = EXCLUDED.status, + messages = EXCLUDED.messages, + context = EXCLUDED.context, + updated_at = EXCLUDED.updated_at +`), + session.ID, + session.AgentID, + session.UserID, + session.Status, + string(messagesJSON), + string(contextJSON), + createdAt, + updatedAt, + ) + return err +} + +func (s *PostgresSessionStore) Get(ctx context.Context, id string) (*Session, error) { + if s == nil || s.db == nil { + return nil, fmt.Errorf("postgres session store is not initialized") + } + if err := s.EnsureSchema(ctx); err != nil { + return nil, err + } + + var session Session + var messagesRaw string + var contextRaw string + + err := s.db.QueryRowContext(ctx, s.q(` +SELECT id, agent_id, user_id, status, messages, context, created_at, updated_at +FROM `+postgresSessionTable+` +WHERE id = $1 +`), id).Scan( + &session.ID, + &session.AgentID, + &session.UserID, + &session.Status, + &messagesRaw, + &contextRaw, + &session.CreatedAt, + &session.UpdatedAt, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + + if messagesJSON := normalizeVariantJSON(messagesRaw); len(messagesJSON) > 0 { + if err := json.Unmarshal(messagesJSON, &session.Messages); err != nil { + return nil, err + } + } + if contextJSON := normalizeVariantJSON(contextRaw); len(contextJSON) > 0 { + if err := json.Unmarshal(contextJSON, &session.Context); err != nil { + return nil, err + } + } + + session.CreatedAt = session.CreatedAt.UTC() + session.UpdatedAt = session.UpdatedAt.UTC() + return &session, nil +} + +func (s *PostgresSessionStore) q(query string) string { + if s != nil && s.rewriteSQL != nil { + return s.rewriteSQL(query) + } + return query +} + +var _ SessionStore = (*PostgresSessionStore)(nil) diff --git a/internal/agents/session_store_postgres_test.go b/internal/agents/session_store_postgres_test.go new file mode 100644 index 000000000..ddd5c2a67 --- /dev/null +++ b/internal/agents/session_store_postgres_test.go @@ -0,0 +1,109 @@ +package agents + +import ( + "context" + "database/sql" + "regexp" + "testing" + "time" + + _ "modernc.org/sqlite" +) + +var sessionStoreDollarPlaceholderRe = regexp.MustCompile(`\$\d+`) + +func sessionStoreSQLiteRewrite(query string) string { + return sessionStoreDollarPlaceholderRe.ReplaceAllString(query, "?") +} + +func newTestPostgresSessionStore(t *testing.T) *PostgresSessionStore { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + store := &PostgresSessionStore{ + db: db, + rewriteSQL: sessionStoreSQLiteRewrite, + } + if err := store.EnsureSchema(context.Background()); err != nil { + _ = db.Close() + t.Fatalf("EnsureSchema() error = %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return store +} + +func TestPostgresSessionStoreSaveAndGet(t *testing.T) { + store := newTestPostgresSessionStore(t) + now := time.Now().UTC().Truncate(time.Second) + session := &Session{ + ID: "session-1", + AgentID: "agent-1", + UserID: "user-1", + Status: "active", + Messages: []Message{ + {Role: "user", Content: "investigate bucket"}, + }, + Context: SessionContext{ + FindingIDs: []string{"finding-1"}, + Metadata: map[string]interface{}{"tenant": "tenant-a"}, + }, + CreatedAt: now, + UpdatedAt: now.Add(time.Minute), + } + + if err := store.Save(context.Background(), session); err != nil { + t.Fatalf("Save() error = %v", err) + } + + got, err := store.Get(context.Background(), session.ID) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got == nil { + t.Fatal("expected persisted session") + } + if got.UserID != session.UserID { + t.Fatalf("UserID = %q, want %q", got.UserID, session.UserID) + } + if len(got.Messages) != 1 || got.Messages[0].Content != "investigate bucket" { + t.Fatalf("unexpected messages: %#v", got.Messages) + } + if len(got.Context.FindingIDs) != 1 || got.Context.FindingIDs[0] != "finding-1" { + t.Fatalf("unexpected context finding ids: %#v", got.Context.FindingIDs) + } +} + +func TestPostgresSessionStoreSaveUpdatesExistingRow(t *testing.T) { + store := newTestPostgresSessionStore(t) + session := &Session{ + ID: "session-2", + AgentID: "agent-1", + UserID: "user-1", + Status: "active", + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + if err := store.Save(context.Background(), session); err != nil { + t.Fatalf("initial Save() error = %v", err) + } + + session.Status = "completed" + session.Messages = []Message{{Role: "assistant", Content: "done"}} + session.UpdatedAt = session.UpdatedAt.Add(2 * time.Minute) + if err := store.Save(context.Background(), session); err != nil { + t.Fatalf("update Save() error = %v", err) + } + + got, err := store.Get(context.Background(), session.ID) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != "completed" { + t.Fatalf("Status = %q, want completed", got.Status) + } + if len(got.Messages) != 1 || got.Messages[0].Content != "done" { + t.Fatalf("unexpected updated messages: %#v", got.Messages) + } +} diff --git a/internal/agents/session_store_snowflake.go b/internal/agents/session_store_snowflake.go index 99d9a7cb9..7d0c7bd7c 100644 --- a/internal/agents/session_store_snowflake.go +++ b/internal/agents/session_store_snowflake.go @@ -114,6 +114,44 @@ func (s *SnowflakeSessionStore) Get(ctx context.Context, id string) (*Session, e row := s.db.QueryRowContext(ctx, query, id) + session, err := scanSnowflakeSession(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return session, nil +} + +func (s *SnowflakeSessionStore) ListAll(ctx context.Context) ([]*Session, error) { + // #nosec G202 -- s.tableRef is validated by snowflake.SafeTableRef before interpolation. + query := ` + SELECT id, agent_id, user_id, status, messages, context, created_at, updated_at + FROM ` + s.tableRef + ` + ORDER BY updated_at DESC + ` + + rows, err := s.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + sessions := make([]*Session, 0) + for rows.Next() { + session, err := scanSnowflakeSession(rows) + if err != nil { + return nil, err + } + sessions = append(sessions, session) + } + return sessions, rows.Err() +} + +func scanSnowflakeSession(row interface { + Scan(dest ...any) error +}) (*Session, error) { var session Session var messagesRaw any var contextRaw any @@ -131,9 +169,6 @@ func (s *SnowflakeSessionStore) Get(ctx context.Context, id string) (*Session, e &updatedAt, ) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } return nil, err } diff --git a/internal/api/server.go b/internal/api/server.go index 37e9cdbb5..f51132f22 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -628,28 +628,28 @@ func (s *Server) adminHealth(w http.ResponseWriter, r *http.Request) { "timestamp": time.Now().UTC(), } - // Snowflake status - if s.app.Snowflake != nil { + // Warehouse status + if s.app.Warehouse != nil && s.app.Warehouse.DB() != nil { ctx, cancel := context.WithTimeout(r.Context(), s.healthCheckTimeout()) start := time.Now() - err := s.app.Snowflake.Ping(ctx) + err := s.app.Warehouse.DB().PingContext(ctx) cancel() latency := time.Since(start).Milliseconds() if err != nil { - health["snowflake"] = map[string]interface{}{ + health["warehouse"] = map[string]interface{}{ "status": "unhealthy", "error": err.Error(), "latency_ms": latency, } } else { - health["snowflake"] = map[string]interface{}{ + health["warehouse"] = map[string]interface{}{ "status": "healthy", "latency_ms": latency, } } } else { - health["snowflake"] = map[string]interface{}{"status": "not_configured"} + health["warehouse"] = map[string]interface{}{"status": "not_configured"} } // Findings stats diff --git a/internal/api/server_dependencies.go b/internal/api/server_dependencies.go index 4c69e14fc..f5a21d38a 100644 --- a/internal/api/server_dependencies.go +++ b/internal/api/server_dependencies.go @@ -98,10 +98,19 @@ type serverDependencies struct { Notifications *notifications.Manager Scheduler *scheduler.Scheduler - AuditRepo *snowflake.AuditRepository - PolicyHistoryRepo *snowflake.PolicyHistoryRepository - RiskEngineStateRepo *snowflake.RiskEngineStateRepository - ScanWatermarks *scanner.WatermarkStore + AuditRepo interface { + Log(ctx context.Context, entry *snowflake.AuditEntry) error + List(ctx context.Context, resourceType, resourceID string, limit int) ([]*snowflake.AuditEntry, error) + } + PolicyHistoryRepo interface { + Upsert(ctx context.Context, record *snowflake.PolicyHistoryRecord) error + List(ctx context.Context, policyID string, limit int) ([]*snowflake.PolicyHistoryRecord, error) + } + RiskEngineStateRepo interface { + SaveSnapshot(ctx context.Context, graphID string, snapshot []byte) error + LoadSnapshot(ctx context.Context, graphID string) ([]byte, error) + } + ScanWatermarks *scanner.WatermarkStore RBAC *auth.RBAC ThreatIntel *threatintel.ThreatIntelService @@ -145,7 +154,7 @@ type serverDependencies struct { type graphRuntimeAdapter struct { deps *serverDependencies - fallback graphRuntimeService + delegate graphRuntimeService logger *slog.Logger originalGraph *graph.Graph originalBuilder *builders.Builder @@ -200,7 +209,7 @@ func newServerDependenciesFromApp(application *app.App) serverDependencies { } deps.graphRuntime = &graphRuntimeAdapter{ deps: &deps, - fallback: application, + delegate: application, logger: application.Logger, originalGraph: application.SecurityGraph, originalBuilder: application.SecurityGraphBuilder, @@ -438,7 +447,7 @@ func (r *graphRuntimeAdapter) useLocalGraph() bool { if r == nil || r.deps == nil { return false } - if r.fallback != nil { + if r.delegate != nil { return r.originalBuilder == nil && r.deps.SecurityGraphBuilder != nil && (r.deps.SecurityGraph != nil || r.deps.SecurityGraphBuilder != nil) } return r.deps.SecurityGraph != nil || r.deps.SecurityGraphBuilder != nil @@ -454,8 +463,8 @@ func (r *graphRuntimeAdapter) CurrentSecurityGraph() *graph.Graph { if r.useLocalGraph() && r.deps != nil { return r.deps.SecurityGraph } - if r.fallback != nil { - return r.fallback.CurrentSecurityGraph() + if r.delegate != nil { + return r.delegate.CurrentSecurityGraph() } if r.deps != nil { return r.deps.SecurityGraph @@ -473,7 +482,7 @@ func (r *graphRuntimeAdapter) CurrentSecurityGraphStore() graph.GraphStore { if r.useLocalGraph() && r.deps != nil && r.deps.SecurityGraph != nil { return r.deps.SecurityGraph } - if scoped, ok := r.fallback.(interface { + if scoped, ok := r.delegate.(interface { CurrentSecurityGraphStore() graph.GraphStore }); ok { return scoped.CurrentSecurityGraphStore() @@ -491,7 +500,7 @@ func (r *graphRuntimeAdapter) CurrentSecurityGraphForTenant(tenantID string) *gr } return r.deps.SecurityGraph.SubgraphForTenant(tenantID) } - if scoped, ok := r.fallback.(interface { + if scoped, ok := r.delegate.(interface { CurrentSecurityGraphForTenant(string) *graph.Graph }); ok { return scoped.CurrentSecurityGraphForTenant(tenantID) @@ -517,7 +526,7 @@ func (r *graphRuntimeAdapter) CurrentSecurityGraphStoreForTenant(tenantID string } return r.deps.SecurityGraph.SubgraphForTenant(tenantID) } - if scoped, ok := r.fallback.(interface { + if scoped, ok := r.delegate.(interface { CurrentSecurityGraphStoreForTenant(string) graph.GraphStore }); ok { return scoped.CurrentSecurityGraphStoreForTenant(tenantID) @@ -532,8 +541,8 @@ func (r *graphRuntimeAdapter) GraphBuildSnapshot() app.GraphBuildSnapshot { if r == nil { return app.GraphBuildSnapshot{} } - if !r.useLocalGraph() && r.fallback != nil { - return r.fallback.GraphBuildSnapshot() + if !r.useLocalGraph() && r.delegate != nil { + return r.delegate.GraphBuildSnapshot() } r.snapshotMu.RLock() snapshot := r.snapshot @@ -545,17 +554,17 @@ func (r *graphRuntimeAdapter) GraphBuildSnapshot() app.GraphBuildSnapshot { } func (r *graphRuntimeAdapter) CurrentRetentionStatus() app.RetentionStatus { - if r == nil || r.fallback == nil { + if r == nil || r.delegate == nil { return app.RetentionStatus{} } - return r.fallback.CurrentRetentionStatus() + return r.delegate.CurrentRetentionStatus() } func (r *graphRuntimeAdapter) GraphFreshnessStatusSnapshot(now time.Time) app.GraphFreshnessStatus { - if r == nil || r.fallback == nil { + if r == nil || r.delegate == nil { return app.GraphFreshnessStatus{} } - return r.fallback.GraphFreshnessStatusSnapshot(now) + return r.delegate.GraphFreshnessStatusSnapshot(now) } func (r *graphRuntimeAdapter) GraphHealthSnapshot(now time.Time) app.GraphHealthSnapshot { @@ -567,7 +576,7 @@ func (r *graphRuntimeAdapter) GraphHealthSnapshot(now time.Time) app.GraphHealth } else { now = now.UTC() } - if provider, ok := r.fallback.(interface { + if provider, ok := r.delegate.(interface { GraphHealthSnapshot(time.Time) app.GraphHealthSnapshot }); ok { snapshot := provider.GraphHealthSnapshot(now) @@ -611,10 +620,10 @@ func (r *graphRuntimeAdapter) RebuildSecurityGraph(ctx context.Context) error { if r == nil { return errors.New("security graph runtime not configured") } - if !r.useLocalGraph() && r.fallback != nil { - err := r.fallback.RebuildSecurityGraph(ctx) + if !r.useLocalGraph() && r.delegate != nil { + err := r.delegate.RebuildSecurityGraph(ctx) if r.deps != nil { - r.deps.SecurityGraph = r.fallback.CurrentSecurityGraph() + r.deps.SecurityGraph = r.delegate.CurrentSecurityGraph() } return err } @@ -650,7 +659,7 @@ func (r *graphRuntimeAdapter) CanApplySecurityGraphChanges() bool { if r.useLocalGraph() { return r.localBuilder() != nil } - if capable, ok := r.fallback.(graphChangeApplyCapability); ok { + if capable, ok := r.delegate.(graphChangeApplyCapability); ok { return capable.CanApplySecurityGraphChanges() } return r.originalBuilder != nil @@ -660,10 +669,10 @@ func (r *graphRuntimeAdapter) TryApplySecurityGraphChanges(ctx context.Context, if r == nil { return graph.GraphMutationSummary{}, false, errors.New("security graph runtime not configured") } - if !r.useLocalGraph() && r.fallback != nil { - summary, applied, err := r.fallback.TryApplySecurityGraphChanges(ctx, trigger) + if !r.useLocalGraph() && r.delegate != nil { + summary, applied, err := r.delegate.TryApplySecurityGraphChanges(ctx, trigger) if r.deps != nil { - r.deps.SecurityGraph = r.fallback.CurrentSecurityGraph() + r.deps.SecurityGraph = r.delegate.CurrentSecurityGraph() } return summary, applied, err } @@ -679,27 +688,13 @@ func (r *graphRuntimeAdapter) TryApplySecurityGraphChanges(ctx context.Context, summary, err := builder.ApplyChanges(ctx, time.Time{}) if err != nil { if r.logger != nil { - r.logger.Warn("incremental graph apply failed, falling back to full rebuild", + r.logger.Warn("incremental graph apply failed", "trigger", trigger, "error", err, ) } - r.setSnapshot(app.GraphBuildBuilding, time.Time{}, nil) - if buildErr := builder.Build(ctx); buildErr != nil { - r.setSnapshot(app.GraphBuildFailed, time.Now().UTC(), buildErr) - return graph.GraphMutationSummary{}, true, buildErr - } - summary = builder.LastMutation() - graphValue := builder.Graph() - if r.deps != nil { - r.deps.SecurityGraph = graphValue - } - builtAt := time.Now().UTC() - if graphValue != nil { - builtAt = graphValue.Metadata().BuiltAt - } - r.setSnapshot(app.GraphBuildSuccess, builtAt, nil) - return summary, true, nil + r.setSnapshot(app.GraphBuildFailed, time.Now().UTC(), err) + return graph.GraphMutationSummary{}, true, err } if r.deps != nil { @@ -742,7 +737,7 @@ func (r *graphRuntimeAdapter) detachedLocalGraph() *graph.Graph { if r == nil || r.deps == nil || r.deps.SecurityGraph == nil || r.deps.SecurityGraphBuilder == nil { return nil } - if r.fallback != nil && r.originalBuilder != nil { + if r.delegate != nil && r.originalBuilder != nil { return nil } return r.deps.SecurityGraph diff --git a/internal/api/server_dependencies_test.go b/internal/api/server_dependencies_test.go index 64209e7ef..0abb1de09 100644 --- a/internal/api/server_dependencies_test.go +++ b/internal/api/server_dependencies_test.go @@ -193,7 +193,7 @@ func TestGraphRuntimeAdapterCanApplySecurityGraphChangesAfterFallbackRefresh(t * deps := &serverDependencies{} runtime := &graphRuntimeAdapter{ deps: deps, - fallback: fallback, + delegate: fallback, } if !runtime.CanApplySecurityGraphChanges() { @@ -261,7 +261,7 @@ func TestGraphRuntimeAdapterGraphHealthSnapshotRecalculatesMemoryEstimate(t *tes SecurityGraph: localGraph, SecurityGraphBuilder: &builders.Builder{}, }, - fallback: stubGraphRuntime{ + delegate: stubGraphRuntime{ graph: providerGraph, healthSnapshot: app.GraphHealthSnapshot{ MemoryUsageEstimateBytes: app.EstimateGraphMemoryUsageBytes(providerGraph.NodeCount(), providerGraph.EdgeCount()), @@ -289,7 +289,7 @@ func TestGraphRuntimeAdapterGraphHealthSnapshotEmptyLocalGraphIsNotHot(t *testin SecurityGraph: graph.New(), SecurityGraphBuilder: &builders.Builder{}, }, - fallback: stubGraphRuntime{}, + delegate: stubGraphRuntime{}, } snapshot := runtime.GraphHealthSnapshot(now) @@ -333,7 +333,7 @@ func TestGraphRuntimeAdapterGraphHealthSnapshotFallbackUsesLocalBuilderLastMutat SecurityGraph: localGraph, SecurityGraphBuilder: builder, }, - fallback: &mutatingFallbackGraphRuntime{current: graph.New()}, + delegate: &mutatingFallbackGraphRuntime{current: graph.New()}, } snapshot := runtime.GraphHealthSnapshot(now) diff --git a/internal/api/server_fault_matrix_test.go b/internal/api/server_fault_matrix_test.go index b730a12d6..8baf961bb 100644 --- a/internal/api/server_fault_matrix_test.go +++ b/internal/api/server_fault_matrix_test.go @@ -64,7 +64,7 @@ func TestAuditEndpoint_ReturnsDegradedResponseWithoutAuditRepo(t *testing.T) { if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } - if !strings.Contains(w.Body.String(), "snowflake not configured") { + if !strings.Contains(w.Body.String(), "audit log persistence not configured") { t.Fatalf("expected degraded audit response, got %s", w.Body.String()) } } diff --git a/internal/api/server_handlers_data.go b/internal/api/server_handlers_data.go index 9fe5f196f..c4e272528 100644 --- a/internal/api/server_handlers_data.go +++ b/internal/api/server_handlers_data.go @@ -19,7 +19,7 @@ import ( func (s *Server) syncStatus(w http.ResponseWriter, r *http.Request) { if s.app.Warehouse == nil { - s.error(w, http.StatusServiceUnavailable, "snowflake not configured") + s.error(w, http.StatusServiceUnavailable, "warehouse not configured") return } diff --git a/internal/api/server_handlers_graph_intelligence_test.go b/internal/api/server_handlers_graph_intelligence_test.go index d4bf79379..9e429d718 100644 --- a/internal/api/server_handlers_graph_intelligence_test.go +++ b/internal/api/server_handlers_graph_intelligence_test.go @@ -614,7 +614,7 @@ func TestGraphIntelligenceQualityEndpoint_InvalidParams(t *testing.T) { func TestGraphIntelligenceAgentActionEffectivenessEndpoint(t *testing.T) { s := newTestServer(t) g := s.app.SecurityGraph - now := time.Date(2026, 3, 22, 18, 0, 0, 0, time.UTC) + now := time.Now().UTC().Add(-2 * time.Hour).Truncate(time.Second) g.AddNode(&graph.Node{ ID: "thread:evaluation:run-1:conv-1", @@ -959,6 +959,7 @@ func TestPlatformIntelligenceEvaluationTemporalAnalysisReportDefinition(t *testi func TestGraphIntelligencePlaybookEffectivenessEndpoint(t *testing.T) { s := newTestServer(t) + startedAt := time.Now().UTC().Add(-90 * time.Minute).Truncate(time.Second) addPlaybookEffectivenessEndpointFixture(t, s.app.SecurityGraph, playbookEffectivenessEndpointFixture{ RunID: "run-a1", PlaybookID: "pb-remediate", @@ -966,15 +967,15 @@ func TestGraphIntelligencePlaybookEffectivenessEndpoint(t *testing.T) { TenantID: "tenant-acme", TargetID: "service:payments", TargetKind: graph.NodeKindService, - StartedAt: time.Date(2026, 3, 23, 15, 0, 0, 0, time.UTC), + StartedAt: startedAt, Stages: []playbookEffectivenessEndpointStage{ - {ID: "approve", Name: "Approve Fix", Order: 1, Status: "completed", ApprovalRequired: true, ApprovalStatus: "approved", ObservedAt: time.Date(2026, 3, 23, 15, 10, 0, 0, time.UTC)}, + {ID: "approve", Name: "Approve Fix", Order: 1, Status: "completed", ApprovalRequired: true, ApprovalStatus: "approved", ObservedAt: startedAt.Add(10 * time.Minute)}, }, Outcome: &playbookEffectivenessEndpointOutcome{ Verdict: "positive", Status: "completed", RollbackState: "stable", - ObservedAt: time.Date(2026, 3, 23, 15, 40, 0, 0, time.UTC), + ObservedAt: startedAt.Add(40 * time.Minute), }, }) @@ -1032,11 +1033,12 @@ func TestPlatformIntelligencePlaybookEffectivenessReportDefinition(t *testing.T) func TestGraphIntelligenceUnifiedExecutionTimelineEndpoint(t *testing.T) { s := newTestServer(t) + baseAt := time.Now().UTC().Add(-24 * time.Hour).Truncate(time.Second) addEvaluationTemporalAnalysisEndpointFixture(t, s.app.SecurityGraph, evaluationTemporalAnalysisEndpointFixture{ RunID: "run-1", Conversation: "conv-1", ServiceID: "service:payments:conv-1", - BaseAt: time.Date(2026, 3, 23, 16, 0, 0, 0, time.UTC), + BaseAt: baseAt, }) tagEvaluationTemporalAnalysisEndpointTenant(t, s.app.SecurityGraph, "run-1", "conv-1", "tenant-acme") @@ -1047,15 +1049,15 @@ func TestGraphIntelligenceUnifiedExecutionTimelineEndpoint(t *testing.T) { TenantID: "tenant-acme", TargetID: "database:orders", TargetKind: graph.NodeKind("database"), - StartedAt: time.Date(2026, 3, 23, 17, 0, 0, 0, time.UTC), + StartedAt: baseAt.Add(time.Hour), Stages: []playbookEffectivenessEndpointStage{ - {ID: "repair", Name: "Repair", Order: 1, Status: "completed", ObservedAt: time.Date(2026, 3, 23, 17, 10, 0, 0, time.UTC)}, + {ID: "repair", Name: "Repair", Order: 1, Status: "completed", ObservedAt: baseAt.Add(70 * time.Minute)}, }, Outcome: &playbookEffectivenessEndpointOutcome{ Verdict: "positive", Status: "completed", RollbackState: "stable", - ObservedAt: time.Date(2026, 3, 23, 17, 30, 0, 0, time.UTC), + ObservedAt: baseAt.Add(90 * time.Minute), }, }) s.app.SecurityGraph.AddNode(&graph.Node{ @@ -1074,8 +1076,8 @@ func TestGraphIntelligenceUnifiedExecutionTimelineEndpoint(t *testing.T) { "tenant_id": "tenant-acme", "target_ids": []string{"database:orders"}, "source_system": "platform_playbook", - "observed_at": time.Date(2026, 3, 23, 17, 20, 0, 0, time.UTC).Format(time.RFC3339), - "valid_from": time.Date(2026, 3, 23, 17, 20, 0, 0, time.UTC).Format(time.RFC3339), + "observed_at": baseAt.Add(80 * time.Minute).Format(time.RFC3339), + "valid_from": baseAt.Add(80 * time.Minute).Format(time.RFC3339), }, }) diff --git a/internal/api/server_handlers_identity_attack_webhooks.go b/internal/api/server_handlers_identity_attack_webhooks.go index fb2f76424..e9a5b7e20 100644 --- a/internal/api/server_handlers_identity_attack_webhooks.go +++ b/internal/api/server_handlers_identity_attack_webhooks.go @@ -504,7 +504,7 @@ func (s *Server) listAuditLogs(w http.ResponseWriter, r *http.Request) { s.json(w, http.StatusOK, map[string]interface{}{ "logs": []interface{}{}, "count": 0, - "message": "snowflake not configured", + "message": "audit log persistence not configured", "pagination": PaginationResponse{Limit: limit, Offset: offset, HasMore: false}, }) return diff --git a/internal/api/server_handlers_sync.go b/internal/api/server_handlers_sync.go index 0552a117f..794a25f82 100644 --- a/internal/api/server_handlers_sync.go +++ b/internal/api/server_handlers_sync.go @@ -18,8 +18,8 @@ import ( "github.com/aws/aws-sdk-go-v2/service/organizations" orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" "golang.org/x/sync/errgroup" ) @@ -37,7 +37,7 @@ func (s *Server) backfillRelationshipIDs(w http.ResponseWriter, r *http.Request) stats, err := s.syncHandlers.BackfillRelationshipIDs(r.Context(), req.BatchSize) if err != nil { - if errors.Is(err, errSyncSnowflakeUnavailable) { + if errors.Is(err, errSyncWarehouseUnavailable) { s.error(w, http.StatusServiceUnavailable, err.Error()) return } @@ -57,7 +57,7 @@ type azureSyncRequest struct { Validate bool `json:"validate"` } -var runAzureSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req azureSyncRequest) ([]nativesync.SyncResult, error) { +var runAzureSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req azureSyncRequest) ([]nativesync.SyncResult, error) { opts := []nativesync.AzureEngineOption{} switch len(req.Subscriptions) { case 1: @@ -118,7 +118,7 @@ func (s *Server) syncAzure(w http.ResponseWriter, r *http.Request) { result, err := s.syncHandlers.SyncAzure(r.Context(), req) if err != nil { - if errors.Is(err, errSyncSnowflakeUnavailable) { + if errors.Is(err, errSyncWarehouseUnavailable) { s.error(w, http.StatusServiceUnavailable, err.Error()) return } @@ -146,7 +146,7 @@ type k8sSyncRequest struct { Validate bool `json:"validate"` } -var runK8sSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req k8sSyncRequest) ([]nativesync.SyncResult, error) { +var runK8sSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req k8sSyncRequest) ([]nativesync.SyncResult, error) { opts := []nativesync.K8sEngineOption{} if req.Kubeconfig != "" { opts = append(opts, nativesync.WithK8sKubeconfig(req.Kubeconfig)) @@ -194,7 +194,7 @@ func (s *Server) syncK8s(w http.ResponseWriter, r *http.Request) { result, err := s.syncHandlers.SyncK8s(r.Context(), req) if err != nil { - if errors.Is(err, errSyncSnowflakeUnavailable) { + if errors.Is(err, errSyncWarehouseUnavailable) { s.error(w, http.StatusServiceUnavailable, err.Error()) return } @@ -232,7 +232,7 @@ type awsSyncOutcome struct { RelationshipsSkippedReason string } -var runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { +var runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { loadOptions := make([]func(*config.LoadOptions) error, 0, 2) if req.Profile != "" { loadOptions = append(loadOptions, config.WithSharedConfigProfile(req.Profile)) @@ -318,7 +318,7 @@ func (s *Server) syncAWS(w http.ResponseWriter, r *http.Request) { result, err := s.syncHandlers.SyncAWS(r.Context(), req) if err != nil { - if errors.Is(err, errSyncSnowflakeUnavailable) { + if errors.Is(err, errSyncWarehouseUnavailable) { s.error(w, http.StatusServiceUnavailable, err.Error()) return } @@ -364,7 +364,7 @@ type awsOrgSyncOutcome struct { AccountErrors []string } -var runAWSOrgSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsOrgSyncRequest) (*awsOrgSyncOutcome, error) { +var runAWSOrgSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsOrgSyncRequest) (*awsOrgSyncOutcome, error) { loadOptions := make([]func(*config.LoadOptions) error, 0, 2) if req.Profile != "" { loadOptions = append(loadOptions, config.WithSharedConfigProfile(req.Profile)) @@ -494,7 +494,7 @@ func (s *Server) syncAWSOrg(w http.ResponseWriter, r *http.Request) { result, err := s.syncHandlers.SyncAWSOrg(r.Context(), req) if err != nil { - if errors.Is(err, errSyncSnowflakeUnavailable) { + if errors.Is(err, errSyncWarehouseUnavailable) { s.error(w, http.StatusServiceUnavailable, err.Error()) return } @@ -660,7 +660,7 @@ type gcpSyncOutcome struct { RelationshipsSkippedReason string } -var runGCPSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req gcpSyncRequest) (*gcpSyncOutcome, error) { +var runGCPSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req gcpSyncRequest) (*gcpSyncOutcome, error) { if req.Project == "" { return nil, fmt.Errorf("project is required") } @@ -722,7 +722,7 @@ func (s *Server) syncGCP(w http.ResponseWriter, r *http.Request) { result, err := s.syncHandlers.SyncGCP(r.Context(), req) if err != nil { - if errors.Is(err, errSyncSnowflakeUnavailable) { + if errors.Is(err, errSyncWarehouseUnavailable) { s.error(w, http.StatusServiceUnavailable, err.Error()) return } @@ -767,7 +767,7 @@ type gcpAssetSyncRequest struct { Validate bool `json:"validate"` } -var runGCPAssetSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req gcpAssetSyncRequest) ([]nativesync.SyncResult, error) { +var runGCPAssetSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req gcpAssetSyncRequest) ([]nativesync.SyncResult, error) { organization := strings.TrimSpace(req.Organization) if len(req.Projects) == 0 && organization == "" { return nil, fmt.Errorf("projects or organization are required") @@ -823,7 +823,7 @@ func (s *Server) syncGCPAsset(w http.ResponseWriter, r *http.Request) { result, err := s.syncHandlers.SyncGCPAsset(r.Context(), req) if err != nil { - if errors.Is(err, errSyncSnowflakeUnavailable) { + if errors.Is(err, errSyncWarehouseUnavailable) { s.error(w, http.StatusServiceUnavailable, err.Error()) return } diff --git a/internal/api/server_handlers_sync_test.go b/internal/api/server_handlers_sync_test.go index 48798c677..40dd8194c 100644 --- a/internal/api/server_handlers_sync_test.go +++ b/internal/api/server_handlers_sync_test.go @@ -14,6 +14,7 @@ import ( "github.com/writer/cerebro/internal/graph/builders" "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" ) type syncGraphSource struct { @@ -105,7 +106,7 @@ func TestSyncAzure_UsesRequestOptions(t *testing.T) { t.Cleanup(func() { runAzureSyncWithOptions = originalRun }) called := false - runAzureSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req azureSyncRequest) ([]nativesync.SyncResult, error) { + runAzureSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req azureSyncRequest) ([]nativesync.SyncResult, error) { called = true if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") @@ -167,7 +168,7 @@ func TestSyncAzure_NormalizesSubscriptionsCaseInsensitively(t *testing.T) { originalRun := runAzureSyncWithOptions t.Cleanup(func() { runAzureSyncWithOptions = originalRun }) - runAzureSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req azureSyncRequest) ([]nativesync.SyncResult, error) { + runAzureSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req azureSyncRequest) ([]nativesync.SyncResult, error) { if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") } @@ -222,7 +223,7 @@ func TestSyncK8s_UsesRequestOptions(t *testing.T) { t.Cleanup(func() { runK8sSyncWithOptions = originalRun }) called := false - runK8sSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req k8sSyncRequest) ([]nativesync.SyncResult, error) { + runK8sSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req k8sSyncRequest) ([]nativesync.SyncResult, error) { called = true if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") @@ -297,7 +298,7 @@ func TestSyncAWS_UsesRequestOptions(t *testing.T) { t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) called := false - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { called = true if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") @@ -381,7 +382,7 @@ func TestSyncAWS_AppliesIncrementalGraphChangesAfterSync(t *testing.T) { originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -418,7 +419,7 @@ func TestSyncAWS_GraphUpdateFailureIsSanitized(t *testing.T) { originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -482,7 +483,7 @@ func TestSyncAWS_GraphUpdateBusyReturnsBusyStatus(t *testing.T) { originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -529,7 +530,7 @@ func TestSyncAWS_GraphUpdateNoopSummaryUsesEmptyTablesArray(t *testing.T) { originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -563,7 +564,7 @@ func TestSyncAWS_GraphUpdateNoopSummaryUsesEmptyTablesArray(t *testing.T) { } } -func TestSyncAWS_FullRebuildFallbackReportsAppliedStatus(t *testing.T) { +func TestSyncAWS_GraphUpdateFailureReportsFailedStatus(t *testing.T) { s := newTestServer(t) s.app.Snowflake = &snowflake.Client{} @@ -574,7 +575,7 @@ func TestSyncAWS_FullRebuildFallbackReportsAppliedStatus(t *testing.T) { originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -592,15 +593,17 @@ func TestSyncAWS_FullRebuildFallbackReportsAppliedStatus(t *testing.T) { if !ok { t.Fatalf("expected graph_update payload, got %#v", body["graph_update"]) } - if graphUpdate["status"] != "applied" { - t.Fatalf("expected graph update status applied after full rebuild fallback, got %#v", graphUpdate) + if graphUpdate["status"] != "failed" { + t.Fatalf("expected graph update status failed, got %#v", graphUpdate) } - summary, ok := graphUpdate["summary"].(map[string]any) - if !ok { - t.Fatalf("expected graph update summary, got %#v", graphUpdate["summary"]) + if graphUpdate["error"] != "graph update failed" { + t.Fatalf("expected sanitized graph update error, got %#v", graphUpdate["error"]) + } + if graphUpdate["error_code"] != "GRAPH_UPDATE_FAILED" { + t.Fatalf("expected graph update error code, got %#v", graphUpdate["error_code"]) } - if summary["mode"] != graph.GraphMutationModeFullRebuild { - t.Fatalf("expected full rebuild summary mode, got %#v", summary["mode"]) + if _, ok := graphUpdate["summary"]; ok { + t.Fatalf("expected failed graph update to omit summary, got %#v", graphUpdate["summary"]) } } @@ -630,7 +633,7 @@ func TestSyncAWS_AppliesGraphUpdateUsingRuntimeWithoutLocalBuilder(t *testing.T) originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -674,7 +677,7 @@ func TestSyncAWS_SkipsGraphUpdateWithoutRuntimeOrLocalBuilder(t *testing.T) { originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -706,7 +709,7 @@ func TestSyncAWS_SkipsGraphUpdateWhenRuntimeAdapterHasNoApplyCapability(t *testi originalRun := runAWSSyncWithOptions t.Cleanup(func() { runAWSSyncWithOptions = originalRun }) - runAWSSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsSyncRequest) (*awsSyncOutcome, error) { + runAWSSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsSyncRequest) (*awsSyncOutcome, error) { return &awsSyncOutcome{ Results: []nativesync.SyncResult{{Table: "aws_s3_buckets", Synced: 1}}, }, nil @@ -753,7 +756,7 @@ func TestSyncAWSOrg_UsesRequestOptions(t *testing.T) { t.Cleanup(func() { runAWSOrgSyncWithOptions = originalRun }) called := false - runAWSOrgSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req awsOrgSyncRequest) (*awsOrgSyncOutcome, error) { + runAWSOrgSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req awsOrgSyncRequest) (*awsOrgSyncOutcome, error) { called = true if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") @@ -863,7 +866,7 @@ func TestSyncGCP_UsesRequestOptions(t *testing.T) { t.Cleanup(func() { runGCPSyncWithOptions = originalRun }) called := false - runGCPSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req gcpSyncRequest) (*gcpSyncOutcome, error) { + runGCPSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req gcpSyncRequest) (*gcpSyncOutcome, error) { called = true if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") @@ -963,7 +966,7 @@ func TestSyncGCPAsset_UsesRequestOptions(t *testing.T) { t.Cleanup(func() { runGCPAssetSyncWithOptions = originalRun }) called := false - runGCPAssetSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req gcpAssetSyncRequest) ([]nativesync.SyncResult, error) { + runGCPAssetSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req gcpAssetSyncRequest) ([]nativesync.SyncResult, error) { called = true if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") @@ -1008,7 +1011,7 @@ func TestSyncGCPAsset_UsesOrganizationScope(t *testing.T) { t.Cleanup(func() { runGCPAssetSyncWithOptions = originalRun }) called := false - runGCPAssetSyncWithOptions = func(ctx context.Context, client *snowflake.Client, req gcpAssetSyncRequest) ([]nativesync.SyncResult, error) { + runGCPAssetSyncWithOptions = func(ctx context.Context, client warehouse.SyncWarehouse, req gcpAssetSyncRequest) ([]nativesync.SyncResult, error) { called = true if client != s.app.Snowflake { t.Fatalf("expected server snowflake client to be passed through") diff --git a/internal/api/server_services_sync.go b/internal/api/server_services_sync.go index a8242bf74..41cde772e 100644 --- a/internal/api/server_services_sync.go +++ b/internal/api/server_services_sync.go @@ -9,11 +9,11 @@ import ( "github.com/writer/cerebro/internal/app" "github.com/writer/cerebro/internal/graph" - "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" ) -var errSyncSnowflakeUnavailable = errors.New("snowflake not configured") +var errSyncWarehouseUnavailable = errors.New("warehouse not configured") type syncHandlerService interface { BackfillRelationshipIDs(ctx context.Context, batchSize int) (syncBackfillResult, error) @@ -66,7 +66,7 @@ func newSyncHandlerService(deps *serverDependencies) syncHandlerService { } func (s serverSyncHandlerService) BackfillRelationshipIDs(ctx context.Context, batchSize int) (syncBackfillResult, error) { - client, err := s.snowflake() + client, err := s.warehouse() if err != nil { return syncBackfillResult{}, err } @@ -84,7 +84,7 @@ func (s serverSyncHandlerService) BackfillRelationshipIDs(ctx context.Context, b } func (s serverSyncHandlerService) SyncAzure(ctx context.Context, req azureSyncRequest) (syncRunResult, error) { - client, err := s.snowflake() + client, err := s.warehouse() if err != nil { return syncRunResult{}, err } @@ -99,7 +99,7 @@ func (s serverSyncHandlerService) SyncAzure(ctx context.Context, req azureSyncRe } func (s serverSyncHandlerService) SyncK8s(ctx context.Context, req k8sSyncRequest) (syncRunResult, error) { - client, err := s.snowflake() + client, err := s.warehouse() if err != nil { return syncRunResult{}, err } @@ -114,7 +114,7 @@ func (s serverSyncHandlerService) SyncK8s(ctx context.Context, req k8sSyncReques } func (s serverSyncHandlerService) SyncAWS(ctx context.Context, req awsSyncRequest) (awsSyncRunResult, error) { - client, err := s.snowflake() + client, err := s.warehouse() if err != nil { return awsSyncRunResult{}, err } @@ -134,7 +134,7 @@ func (s serverSyncHandlerService) SyncAWS(ctx context.Context, req awsSyncReques } func (s serverSyncHandlerService) SyncAWSOrg(ctx context.Context, req awsOrgSyncRequest) (awsOrgSyncRunResult, error) { - client, err := s.snowflake() + client, err := s.warehouse() if err != nil { return awsOrgSyncRunResult{}, err } @@ -153,7 +153,7 @@ func (s serverSyncHandlerService) SyncAWSOrg(ctx context.Context, req awsOrgSync } func (s serverSyncHandlerService) SyncGCP(ctx context.Context, req gcpSyncRequest) (gcpSyncRunResult, error) { - client, err := s.snowflake() + client, err := s.warehouse() if err != nil { return gcpSyncRunResult{}, err } @@ -173,7 +173,7 @@ func (s serverSyncHandlerService) SyncGCP(ctx context.Context, req gcpSyncReques } func (s serverSyncHandlerService) SyncGCPAsset(ctx context.Context, req gcpAssetSyncRequest) (syncRunResult, error) { - client, err := s.snowflake() + client, err := s.warehouse() if err != nil { return syncRunResult{}, err } @@ -187,11 +187,17 @@ func (s serverSyncHandlerService) SyncGCPAsset(ctx context.Context, req gcpAsset }, nil } -func (s serverSyncHandlerService) snowflake() (*snowflake.Client, error) { - if s.deps == nil || s.deps.Snowflake == nil { - return nil, errSyncSnowflakeUnavailable +func (s serverSyncHandlerService) warehouse() (warehouse.SyncWarehouse, error) { + if s.deps == nil { + return nil, errSyncWarehouseUnavailable } - return s.deps.Snowflake, nil + if s.deps.Warehouse != nil { + return s.deps.Warehouse, nil + } + if s.deps.Snowflake != nil { + return s.deps.Snowflake, nil + } + return nil, errSyncWarehouseUnavailable } func (s serverSyncHandlerService) logger() *slog.Logger { diff --git a/internal/app/app.go b/internal/app/app.go index 794023d69..f6b567b84 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -36,6 +36,7 @@ package app import ( "context" + "database/sql" "errors" "fmt" "log/slog" @@ -80,6 +81,21 @@ type retentionCleaner interface { CleanupAccessReviewData(ctx context.Context, olderThan time.Time) (reviewsDeleted, itemsDeleted int64, err error) } +type auditRepository interface { + Log(ctx context.Context, entry *snowflake.AuditEntry) error + List(ctx context.Context, resourceType, resourceID string, limit int) ([]*snowflake.AuditEntry, error) +} + +type policyHistoryRepository interface { + Upsert(ctx context.Context, record *snowflake.PolicyHistoryRecord) error + List(ctx context.Context, policyID string, limit int) ([]*snowflake.PolicyHistoryRecord, error) +} + +type riskEngineStateRepository interface { + SaveSnapshot(ctx context.Context, graphID string, snapshot []byte) error + LoadSnapshot(ctx context.Context, graphID string) ([]byte, error) +} + // App is the main application container that holds references to all initialized // services. Create a new App using the New() function which handles all service // initialization and wiring based on environment configuration. @@ -91,15 +107,17 @@ type App struct { Logger *slog.Logger // Core services - Snowflake *snowflake.Client - Warehouse warehouse.DataWarehouse - Policy *policy.Engine - Findings findings.FindingStore - Scanner *scanner.Scanner - DSPM *dspm.Scanner - Cache *cache.PolicyCache - ExecutionStore executionstore.Store - GraphSnapshots *graph.GraphPersistenceStore + Snowflake *snowflake.Client + LegacySnowflake *snowflake.Client + Warehouse warehouse.DataWarehouse + Policy *policy.Engine + Findings findings.FindingStore + Scanner *scanner.Scanner + DSPM *dspm.Scanner + Cache *cache.PolicyCache + ExecutionStore executionstore.Store + GraphSnapshots *graph.GraphPersistenceStore + appStateDB *sql.DB // Feature services Agents *agents.AgentRegistry @@ -116,15 +134,13 @@ type App struct { Notifications *notifications.Manager Scheduler *scheduler.Scheduler - // Repositories (for Snowflake persistence) - FindingsRepo *snowflake.FindingRepository - TicketsRepo *snowflake.TicketRepository - AuditRepo *snowflake.AuditRepository - PolicyHistoryRepo *snowflake.PolicyHistoryRepository - RiskEngineStateRepo *snowflake.RiskEngineStateRepository + // Durable app-state repositories. + AuditRepo auditRepository + PolicyHistoryRepo policyHistoryRepository + RiskEngineStateRepo riskEngineStateRepository RetentionRepo retentionCleaner - // Snowflake-backed stores (when available) + // Legacy Snowflake-backed findings store for deployments without Postgres app-state. SnowflakeFindings *findings.SnowflakeStore // Incremental scanning @@ -167,12 +183,6 @@ type App struct { graphWriterLeaseTransitionWG sync.WaitGroup tenantShardMu sync.Mutex tenantSecurityGraphShards *tenantGraphShardManager - passiveSnapshotStoreMu sync.RWMutex - passiveSnapshotStoreOwner *graph.GraphPersistenceStore - passiveSnapshotStoreSource string - passiveSnapshotStoreID string - passiveSnapshotStoreStatusID string - passiveSnapshotStore *graph.SnapshotGraphStore eventCorrelationRefreshQueue *eventCorrelationRefreshQueue eventCorrelationRefreshCancel context.CancelFunc eventCorrelationRefreshWG sync.WaitGroup diff --git a/internal/app/app_cerebro_tools_temporal_test.go b/internal/app/app_cerebro_tools_temporal_test.go index 4ed9e5127..c8fd4af9d 100644 --- a/internal/app/app_cerebro_tools_temporal_test.go +++ b/internal/app/app_cerebro_tools_temporal_test.go @@ -190,7 +190,7 @@ func TestCerebroTemporalAliasTools(t *testing.T) { } } -func TestCerebroEntityHistoryToolUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestCerebroEntityHistoryToolUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { base := time.Date(2026, 3, 10, 9, 0, 0, 0, time.UTC) g := graph.New() g.AddNode(&graph.Node{ @@ -219,9 +219,9 @@ func TestCerebroEntityHistoryToolUsesPersistedSnapshotWhenLiveGraphUnavailable(t } application := &App{ - GraphSnapshots: mustPersistToolGraph(t, g), - Config: &Config{}, + Config: &Config{}, } + setConfiguredSnapshotGraphFromGraph(t, application, g) tool := findCerebroTool(application.AgentSDKTools(), "cerebro.entity_history") if tool == nil { t.Fatal("expected cerebro.entity_history tool") @@ -346,8 +346,8 @@ func TestCerebroGraphChangelogTool(t *testing.T) { if got := toSnapshot["captured_at"]; got != latest.CreatedAt.Format(time.RFC3339) { t.Fatalf("expected newest changelog entry, got to=%#v", toSnapshot) } - if got := toSnapshot["current"]; got != true { - t.Fatalf("expected newest persisted snapshot to be current, got %#v", got) + if got := toSnapshot["current"]; got != nil { + t.Fatalf("expected no current snapshot marker without live/configured graph, got %#v", got) } summary := entry["summary"].(map[string]any) if summary["nodes_modified"] != float64(1) || summary["nodes_added"] != float64(0) { @@ -370,8 +370,8 @@ func TestCerebroGraphChangelogTool(t *testing.T) { t.Fatalf("expected filtered detail nodes_added=1, got %#v", detailSummary) } detailTo := detailPayload["to"].(map[string]any) - if got := detailTo["current"]; got != true { - t.Fatalf("expected detail target snapshot to be current, got %#v", got) + if got := detailTo["current"]; got != nil { + t.Fatalf("expected no detail current snapshot marker without live/configured graph, got %#v", got) } diffDir := filepath.Join(dir, "diffs") @@ -384,7 +384,7 @@ func TestCerebroGraphChangelogTool(t *testing.T) { } } -func TestCerebroGraphChangelogToolMarksPersistedSnapshotCurrentWithoutBuiltAt(t *testing.T) { +func TestCerebroGraphChangelogToolOmitsCurrentMarkerWithoutLiveGraph(t *testing.T) { dir := t.TempDir() t.Setenv("GRAPH_SNAPSHOT_PATH", dir) @@ -446,8 +446,8 @@ func TestCerebroGraphChangelogToolMarksPersistedSnapshotCurrentWithoutBuiltAt(t if got := toSnapshot["captured_at"]; got != latest.CreatedAt.Format(time.RFC3339) { t.Fatalf("expected latest persisted snapshot in changelog, got %#v", toSnapshot) } - if got := toSnapshot["current"]; got != true { - t.Fatalf("expected latest persisted snapshot to stay current without built_at, got %#v", got) + if got := toSnapshot["current"]; got != nil { + t.Fatalf("expected latest persisted snapshot to omit current marker without live graph, got %#v", got) } } diff --git a/internal/app/app_cerebro_tools_test.go b/internal/app/app_cerebro_tools_test.go index 059f6eecf..df48f8c4f 100644 --- a/internal/app/app_cerebro_tools_test.go +++ b/internal/app/app_cerebro_tools_test.go @@ -211,7 +211,7 @@ func TestCerebroBlastRadiusTool(t *testing.T) { } } -func TestCerebroAnalysisToolsUsePersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestCerebroAnalysisToolsUseConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { base := time.Date(2026, 3, 12, 10, 0, 0, 0, time.UTC) g := graph.New() g.AddNode(&graph.Node{ID: "user:alice", Kind: graph.NodeKindUser, Name: "Alice"}) @@ -263,9 +263,9 @@ func TestCerebroAnalysisToolsUsePersistedSnapshotWhenLiveGraphUnavailable(t *tes g.BuildIndex() application := &App{ - GraphSnapshots: mustPersistToolGraph(t, g), - Identity: identity.NewService(identity.WithGraphResolver(func(context.Context) *graph.Graph { return nil })), + Identity: identity.NewService(identity.WithGraphResolver(func(context.Context) *graph.Graph { return nil })), } + setConfiguredSnapshotGraphFromGraph(t, application, g) tests := []struct { name string tool string @@ -1328,22 +1328,11 @@ func TestCerebroEvaluatePolicyTool(t *testing.T) { } } -func TestCerebroEvaluatePolicyToolUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { - store, err := graph.NewGraphPersistenceStore(graph.GraphPersistenceOptions{ - LocalPath: filepath.Join(t.TempDir(), "graph-snapshots"), - MaxSnapshots: 4, - }) - if err != nil { - t.Fatalf("NewGraphPersistenceStore() error = %v", err) - } - if _, err := store.SaveGraph(graph.New()); err != nil { - t.Fatalf("SaveGraph() error = %v", err) - } - +func TestCerebroEvaluatePolicyToolUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { application := &App{ - Policy: policy.NewEngine(), - GraphSnapshots: store, + Policy: policy.NewEngine(), } + setConfiguredSnapshotGraphFromGraph(t, application, graph.New()) tool := findCerebroTool(application.AgentSDKTools(), "evaluate_policy") if tool == nil { t.Fatal("expected evaluate_policy tool") @@ -1356,7 +1345,7 @@ func TestCerebroEvaluatePolicyToolUsesPersistedSnapshotWhenLiveGraphUnavailable( "proposed_change":{ "id":"chg-1", "source":"tool", - "reason":"test snapshot fallback", + "reason":"test configured graph base", "nodes":[{"action":"add","node":{"id":"service:payments","kind":"service","name":"payments"}}] } }`)) @@ -1615,7 +1604,7 @@ func TestCerebroAccessReviewTool(t *testing.T) { } } -func TestCerebroAccessReviewToolUsesTenantScopedPersistedSnapshot(t *testing.T) { +func TestCerebroAccessReviewToolUsesTenantScopedConfiguredStore(t *testing.T) { g := graph.New() g.AddNode(&graph.Node{ID: "user:alice", Kind: graph.NodeKindUser, Name: "Alice"}) g.AddNode(&graph.Node{ID: "bucket:tenant-a", Kind: graph.NodeKindBucket, Name: "Tenant A Bucket", TenantID: "tenant-a", Risk: graph.RiskHigh}) @@ -1623,7 +1612,8 @@ func TestCerebroAccessReviewToolUsesTenantScopedPersistedSnapshot(t *testing.T) g.AddEdge(&graph.Edge{ID: "alice-tenant-a", Source: "user:alice", Target: "bucket:tenant-a", Kind: graph.EdgeKindCanRead, Effect: graph.EdgeEffectAllow}) g.AddEdge(&graph.Edge{ID: "alice-tenant-b", Source: "user:alice", Target: "bucket:tenant-b", Kind: graph.EdgeKindCanRead, Effect: graph.EdgeEffectAllow}) - application := &App{GraphSnapshots: mustPersistToolGraph(t, g)} + application := &App{} + setConfiguredSnapshotGraphFromGraph(t, application, g) tool := findCerebroTool(application.cerebroTools(), "cerebro.access_review") if tool == nil { t.Fatal("expected access_review tool") @@ -1736,7 +1726,7 @@ func TestCerebroAutonomousCredentialResponseTool_AwaitingApproval(t *testing.T) } } -func TestCerebroAutonomousCredentialResponseTool_UsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestCerebroAutonomousCredentialResponseTool_UsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { dir := t.TempDir() store, err := executionstore.NewSQLiteStore(filepath.Join(dir, "executions.db")) if err != nil { @@ -1747,8 +1737,8 @@ func TestCerebroAutonomousCredentialResponseTool_UsesPersistedSnapshotWhenLiveGr application := &App{ Config: &Config{ExecutionStoreFile: filepath.Join(dir, "executions.db")}, ExecutionStore: store, - GraphSnapshots: mustPersistToolGraph(t, autonomousCredentialWorkflowGraph()), } + setConfiguredGraphFromGraph(t, application, autonomousCredentialWorkflowGraph()) tool := findCerebroTool(application.cerebroTools(), "cerebro.autonomous_credential_response") if tool == nil { t.Fatal("expected autonomous credential response tool") @@ -1788,10 +1778,10 @@ func TestCerebroAutonomousCredentialResponseTool_UsesPersistedSnapshotWhenLiveGr current := application.CurrentSecurityGraph() if current == nil { - t.Fatal("expected persisted snapshot base to hydrate a live graph during mutation") + t.Fatal("expected configured graph base to hydrate a live graph during mutation") } if _, ok := current.GetNode("secret:public-repo:1"); !ok { - t.Fatal("expected original persisted secret node to remain present") + t.Fatal("expected original configured secret node to remain present") } if _, ok := current.GetNode(run.ObservationID); !ok { t.Fatalf("expected observation node %q", run.ObservationID) diff --git a/internal/app/app_config.go b/internal/app/app_config.go index 2775a024f..2f9996fb3 100644 --- a/internal/app/app_config.go +++ b/internal/app/app_config.go @@ -557,8 +557,12 @@ func LoadConfig() *Config { snowflakeAccount := getEnv("SNOWFLAKE_ACCOUNT", "") snowflakeUser := getEnv("SNOWFLAKE_USER", "") snowflakePrivateKey := normalizePrivateKey(getEnv("SNOWFLAKE_PRIVATE_KEY", "")) + warehousePostgresDSN := getEnv("WAREHOUSE_POSTGRES_DSN", "") + jobDatabaseURL := getEnv("JOB_DATABASE_URL", "") defaultWarehouseBackend := "sqlite" - if strings.TrimSpace(snowflakeAccount) != "" || strings.TrimSpace(snowflakeUser) != "" || strings.TrimSpace(snowflakePrivateKey) != "" { + if strings.TrimSpace(warehousePostgresDSN) != "" { + defaultWarehouseBackend = "postgres" + } else if strings.TrimSpace(snowflakeAccount) != "" || strings.TrimSpace(snowflakeUser) != "" || strings.TrimSpace(snowflakePrivateKey) != "" { defaultWarehouseBackend = "snowflake" } defaultWarehouseSQLitePath := filepath.Join(filepath.Dir(findings.DefaultFilePath()), "warehouse.db") @@ -589,7 +593,7 @@ func LoadConfig() *Config { CredentialVaultKVVersion: bootstrapConfigInt("CEREBRO_CREDENTIAL_VAULT_KV_VERSION", 2), WarehouseBackend: strings.ToLower(strings.TrimSpace(getEnv("WAREHOUSE_BACKEND", defaultWarehouseBackend))), WarehouseSQLitePath: getEnv("WAREHOUSE_SQLITE_PATH", defaultWarehouseSQLitePath), - WarehousePostgresDSN: getEnv("WAREHOUSE_POSTGRES_DSN", ""), + WarehousePostgresDSN: warehousePostgresDSN, SnowflakeAccount: snowflakeAccount, SnowflakeUser: snowflakeUser, SnowflakePrivateKey: snowflakePrivateKey, @@ -860,7 +864,7 @@ func LoadConfig() *Config { FindingAttestationLogURL: getEnv("FINDING_ATTESTATION_LOG_URL", ""), FindingAttestationTimeout: getEnvDuration("FINDING_ATTESTATION_TIMEOUT", 3*time.Second), FindingAttestationAttestReobserved: getEnvBool("FINDING_ATTESTATION_ATTEST_REOBSERVED", false), - JobDatabaseURL: getEnv("JOB_DATABASE_URL", ""), + JobDatabaseURL: jobDatabaseURL, JobNATSStream: getEnv("JOB_NATS_STREAM", "CEREBRO_JOBS"), JobNATSSubject: getEnv("JOB_NATS_SUBJECT", "cerebro.jobs"), JobNATSConsumer: getEnv("JOB_NATS_CONSUMER", "job-worker"), diff --git a/internal/app/app_config_validation.go b/internal/app/app_config_validation.go index 43478de37..8f0f596e2 100644 --- a/internal/app/app_config_validation.go +++ b/internal/app/app_config_validation.go @@ -386,8 +386,8 @@ func (c *Config) Validate() error { } if strings.EqualFold(strings.TrimSpace(c.WarehouseBackend), "postgres") { - if strings.TrimSpace(c.WarehousePostgresDSN) == "" { - problems = addConfigProblem(problems, "WAREHOUSE_POSTGRES_DSN is required when WAREHOUSE_BACKEND=postgres") + if strings.TrimSpace(c.WarehousePostgresDSN) == "" && strings.TrimSpace(c.JobDatabaseURL) == "" { + problems = addConfigProblem(problems, "WAREHOUSE_POSTGRES_DSN or JOB_DATABASE_URL is required when WAREHOUSE_BACKEND=postgres") } } diff --git a/internal/app/app_dspm_graph_test.go b/internal/app/app_dspm_graph_test.go index 5e1e1674e..4628fa8fe 100644 --- a/internal/app/app_dspm_graph_test.go +++ b/internal/app/app_dspm_graph_test.go @@ -85,7 +85,7 @@ func TestEnrichSecurityGraphWithDSPMResult_UsesCopyOnWriteForLiveGraph(t *testin } } -func TestEnrichSecurityGraphWithDSPMResult_UsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestEnrichSecurityGraphWithDSPMResult_UsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { logger := testutil.Logger() nodeID := "arn:aws:s3:::customer-card-bucket" @@ -101,10 +101,10 @@ func TestEnrichSecurityGraphWithDSPMResult_UsesPersistedSnapshotWhenLiveGraphUna base.BuildIndex() app := &App{ - Logger: logger, - DSPM: dspm.NewScanner(nil, logger, dspm.DefaultScannerConfig()), - GraphSnapshots: mustPersistToolGraph(t, base), + Logger: logger, + DSPM: dspm.NewScanner(nil, logger, dspm.DefaultScannerConfig()), } + setConfiguredGraphFromGraph(t, app, base) app.enrichSecurityGraphWithDSPMResult(&dspm.ScanTarget{ Provider: "aws", @@ -123,7 +123,7 @@ func TestEnrichSecurityGraphWithDSPMResult_UsesPersistedSnapshotWhenLiveGraphUna current := app.CurrentSecurityGraph() if current == nil { - t.Fatal("expected persisted snapshot base to hydrate a live graph during DSPM enrichment") + t.Fatal("expected configured graph base to hydrate a live graph during DSPM enrichment") } if !current.IsIndexBuilt() { t.Fatal("expected enriched live graph index to be rebuilt") @@ -142,7 +142,7 @@ func TestEnrichSecurityGraphWithDSPMResult_UsesPersistedSnapshotWhenLiveGraphUna t.Fatalf("expected original base node %q to exist", nodeID) } if _, exists := baseNode.Properties["dspm_scanned"]; exists { - t.Fatal("expected persisted base graph value to remain unchanged in memory") + t.Fatal("expected original base graph value to remain unchanged in memory") } } diff --git a/internal/app/app_event_remediation_test.go b/internal/app/app_event_remediation_test.go index ad43c35b0..3af2d467c 100644 --- a/internal/app/app_event_remediation_test.go +++ b/internal/app/app_event_remediation_test.go @@ -129,7 +129,7 @@ func TestStartEventRemediation_PropagationCanGateExecution(t *testing.T) { } } -func TestStartEventRemediation_PropagationUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestStartEventRemediation_PropagationUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) engine := remediation.NewEngine(logger) if err := engine.AddRule(remediation.Rule{ @@ -145,15 +145,6 @@ func TestStartEventRemediation_PropagationUsesPersistedSnapshotWhenLiveGraphUnav t.Fatalf("failed to add remediation rule: %v", err) } - dir := t.TempDir() - store, err := graph.NewGraphPersistenceStore(graph.GraphPersistenceOptions{ - LocalPath: dir, - MaxSnapshots: 4, - }) - if err != nil { - t.Fatalf("new graph persistence store: %v", err) - } - g := graph.New() g.AddNode(&graph.Node{ ID: "customer-1", @@ -165,9 +156,6 @@ func TestStartEventRemediation_PropagationUsesPersistedSnapshotWhenLiveGraphUnav BuiltAt: time.Date(2026, 3, 19, 15, 30, 0, 0, time.UTC), NodeCount: 1, }) - if _, err := store.SaveGraph(g); err != nil { - t.Fatalf("save graph snapshot: %v", err) - } hooks := webhooks.NewServiceForTesting() capture := &captureNotifier{} @@ -180,8 +168,8 @@ func TestStartEventRemediation_PropagationUsesPersistedSnapshotWhenLiveGraphUnav Remediation: engine, Notifications: notifier, RemediationExecutor: remediation.NewExecutor(engine, nil, notifier, nil, hooks), - GraphSnapshots: store, } + setConfiguredSnapshotGraphFromGraph(t, app, g) app.startEventRemediation(context.Background()) if err := hooks.EmitWithErrors(context.Background(), webhooks.EventSignalCreated, map[string]any{ @@ -196,17 +184,17 @@ func TestStartEventRemediation_PropagationUsesPersistedSnapshotWhenLiveGraphUnav if execution.RuleID == "signal-snapshot-graph-action" { found = true if execution.Status != remediation.ExecutionPending { - t.Fatalf("expected execution to remain %q when persisted propagation gates execution, got %q", remediation.ExecutionPending, execution.Status) + t.Fatalf("expected execution to remain %q when configured propagation gates execution, got %q", remediation.ExecutionPending, execution.Status) } break } } if !found { - t.Fatal("expected snapshot-backed graph-action remediation execution") + t.Fatal("expected configured graph-action remediation execution") } if len(capture.events) == 0 { - t.Fatal("expected persisted propagation gating to emit review notification") + t.Fatal("expected configured propagation gating to emit review notification") } if capture.events[0].Type != notifications.EventReviewRequired { t.Fatalf("expected review required notification, got %q", capture.events[0].Type) diff --git a/internal/app/app_event_routing_test.go b/internal/app/app_event_routing_test.go index 3cd4e0a64..2f8b545f7 100644 --- a/internal/app/app_event_routing_test.go +++ b/internal/app/app_event_routing_test.go @@ -6,22 +6,20 @@ import ( "github.com/writer/cerebro/internal/graph" ) -func TestCurrentOrStoredEventRoutingGraphUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { - persisted := graph.New() - persisted.AddNode(&graph.Node{ID: "service:payments", Kind: graph.NodeKindService, Name: "payments"}) - persisted.BuildIndex() +func TestCurrentOrStoredEventRoutingGraphUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { + configured := graph.New() + configured.AddNode(&graph.Node{ID: "service:payments", Kind: graph.NodeKindService, Name: "payments"}) + configured.BuildIndex() - application := &App{ - Config: &Config{}, - GraphSnapshots: mustPersistToolGraph(t, persisted), - } + application := &App{Config: &Config{}} + setConfiguredSnapshotGraphFromGraph(t, application, configured) resolved := application.currentOrStoredEventRoutingGraph() if resolved == nil { - t.Fatal("expected persisted snapshot graph for event routing") + t.Fatal("expected configured graph for event routing") } if _, ok := resolved.GetNode("service:payments"); !ok { - t.Fatal("expected persisted snapshot node to be available for event routing") + t.Fatal("expected configured graph node to be available for event routing") } } @@ -35,10 +33,10 @@ func TestCurrentOrStoredEventRoutingGraphPrefersLiveGraph(t *testing.T) { persisted.BuildIndex() application := &App{ - Config: &Config{}, - SecurityGraph: live, - GraphSnapshots: mustPersistToolGraph(t, persisted), + Config: &Config{}, + SecurityGraph: live, } + setConfiguredSnapshotGraphFromGraph(t, application, persisted) resolved := application.currentOrStoredEventRoutingGraph() if resolved != live { @@ -48,6 +46,6 @@ func TestCurrentOrStoredEventRoutingGraphPrefersLiveGraph(t *testing.T) { t.Fatal("expected live graph node to be available for event routing") } if _, ok := resolved.GetNode("service:stored"); ok { - t.Fatal("expected persisted snapshot graph to be ignored when live graph is present") + t.Fatal("expected configured graph to be ignored when live graph is present") } } diff --git a/internal/app/app_graph_backend.go b/internal/app/app_graph_backend.go index cdf06b343..0b5223d59 100644 --- a/internal/app/app_graph_backend.go +++ b/internal/app/app_graph_backend.go @@ -57,5 +57,7 @@ func (a *App) probeConfiguredSecurityGraphStore(ctx context.Context, store graph if err != nil { return false, fmt.Errorf("probe configured graph store edges: %w", err) } - return nodes > 0 || edges > 0, nil + _ = nodes + _ = edges + return true, nil } diff --git a/internal/app/app_graph_helpers_test.go b/internal/app/app_graph_helpers_test.go new file mode 100644 index 000000000..c9a63e648 --- /dev/null +++ b/internal/app/app_graph_helpers_test.go @@ -0,0 +1,47 @@ +package app + +import ( + "context" + "testing" + + "github.com/writer/cerebro/internal/graph" +) + +func mustSnapshotGraphStore(t *testing.T, g *graph.Graph) *graph.SnapshotGraphStore { + t.Helper() + snapshot, err := g.Snapshot(context.Background()) + if err != nil { + t.Fatalf("Snapshot() error = %v", err) + } + return graph.NewSnapshotGraphStore(snapshot) +} + +func mustConfiguredGraphStore(t *testing.T, g *graph.Graph) *graph.Graph { + t.Helper() + store := g.Clone() + store.BuildIndex() + return store +} + +func setConfiguredGraphStore(t *testing.T, application *App, store graph.GraphStore) { + t.Helper() + if application == nil { + t.Fatal("expected app") + } + application.configuredSecurityGraphStore = store + application.configuredSecurityGraphReady = true +} + +func setConfiguredGraphFromGraph(t *testing.T, application *App, g *graph.Graph) graph.GraphStore { + t.Helper() + store := mustConfiguredGraphStore(t, g) + setConfiguredGraphStore(t, application, store) + return store +} + +func setConfiguredSnapshotGraphFromGraph(t *testing.T, application *App, g *graph.Graph) *graph.SnapshotGraphStore { + t.Helper() + store := mustSnapshotGraphStore(t, g) + setConfiguredGraphStore(t, application, store) + return store +} diff --git a/internal/app/app_graph_status.go b/internal/app/app_graph_status.go index 46a445552..ca2380212 100644 --- a/internal/app/app_graph_status.go +++ b/internal/app/app_graph_status.go @@ -56,7 +56,9 @@ func (a *App) setGraphBuildState(state GraphBuildState, builtAt time.Time, err e func (a *App) CurrentSecurityGraph() *graph.Graph { if current := a.currentLiveSecurityGraph(); current != nil { - return current + if current.NodeCount() > 0 || current.EdgeCount() > 0 { + return current + } } if a == nil { return nil @@ -64,14 +66,7 @@ func (a *App) CurrentSecurityGraph() *graph.Graph { if view, err := a.currentConfiguredSecurityGraphView(context.Background()); err == nil && view != nil { return view } - view, err := a.storedSecurityGraphViewWithSnapshotLoader(func(store *graph.GraphPersistenceStore) (*graph.Snapshot, error) { - snapshot, _, _, err := store.PeekLatestSnapshot() - return snapshot, err - }) - if err != nil { - return nil - } - return view + return a.currentLiveSecurityGraph() } func (a *App) CurrentSecurityGraphForTenant(tenantID string) *graph.Graph { diff --git a/internal/app/app_graph_store.go b/internal/app/app_graph_store.go index 05b3f66e7..2f633bda7 100644 --- a/internal/app/app_graph_store.go +++ b/internal/app/app_graph_store.go @@ -41,7 +41,7 @@ func (a *App) CurrentSecurityGraphStore() graph.GraphStore { if a == nil { return nil } - layers := make([]graphStoreLayer, 0, 3) + layers := make([]graphStoreLayer, 0, 2) layers = append(layers, graphStoreLayer{ writable: true, resolve: func(ctx context.Context) (graph.GraphStore, error) { @@ -54,11 +54,6 @@ func (a *App) CurrentSecurityGraphStore() graph.GraphStore { return resolveCurrentGraphStore(ctx, a.currentLiveSecurityGraph()) }, }) - layers = append(layers, graphStoreLayer{ - resolve: func(ctx context.Context) (graph.GraphStore, error) { - return a.currentPassiveSnapshotStore(ctx) - }, - }) return tieredGraphStore{layers: layers} } @@ -95,15 +90,6 @@ func (a *App) CurrentSecurityGraphStoreForTenant(tenantID string) graph.GraphSto return a.currentWarmTenantGraphStore(ctx, tenantID) }, }, - { - resolve: func(ctx context.Context) (graph.GraphStore, error) { - store, err := a.currentPassiveSnapshotGraphStore(ctx) - if err != nil { - return nil, err - } - return passiveResolver.Resolve(ctx, store) - }, - }, }, } } @@ -435,54 +421,6 @@ func (r *tenantGraphStoreResolver) storeResult(current *graph.Graph, version uin r.unavailable = unavailable } -func (a *App) currentPassiveSnapshotStore(ctx context.Context) (graph.GraphStore, error) { - if err := graphStoreContextErr(ctx); err != nil { - return nil, err - } - if a == nil || a.GraphSnapshots == nil { - return nil, graph.ErrStoreUnavailable - } - snapshotStore := a.GraphSnapshots - if cached := a.cachedPassiveSnapshotStore(snapshotStore); cached != nil { - return cached, nil - } - a.passiveSnapshotStoreMu.Lock() - defer a.passiveSnapshotStoreMu.Unlock() - if cached := a.cachedPassiveSnapshotStoreLocked(snapshotStore); cached != nil { - return cached, nil - } - snapshot, record, source, err := snapshotStore.PeekLatestSnapshot() - if err != nil { - if isNoSnapshotsGraphStoreErr(err) { - return nil, graph.ErrStoreUnavailable - } - return nil, err - } - if snapshot == nil { - return nil, graph.ErrStoreUnavailable - } - store := graph.NewSnapshotGraphStore(snapshot) - status := snapshotStore.Status() - a.passiveSnapshotStoreOwner = snapshotStore - a.passiveSnapshotStoreSource = strings.TrimSpace(source) - a.passiveSnapshotStoreID = passiveSnapshotStoreCacheID(a.passiveSnapshotStoreSource, record, status) - a.passiveSnapshotStoreStatusID = passiveSnapshotStoreCacheID(a.passiveSnapshotStoreSource, nil, status) - a.passiveSnapshotStore = store - return store, nil -} - -func (a *App) currentPassiveSnapshotGraphStore(ctx context.Context) (*graph.SnapshotGraphStore, error) { - store, err := a.currentPassiveSnapshotStore(ctx) - if err != nil { - return nil, err - } - snapshotStore, ok := store.(*graph.SnapshotGraphStore) - if !ok || snapshotStore == nil { - return nil, graph.ErrStoreUnavailable - } - return snapshotStore, nil -} - func (a *App) currentConfiguredSnapshotGraphStore(ctx context.Context) (*graph.SnapshotGraphStore, error) { snapshot, err := a.currentConfiguredSecurityGraphSnapshot(ctx) if err != nil { @@ -575,43 +513,3 @@ func graphStoreContextErr(ctx context.Context) error { } return ctx.Err() } - -func isNoSnapshotsGraphStoreErr(err error) bool { - if err == nil { - return false - } - return strings.Contains(strings.ToLower(err.Error()), "no snapshots found") -} - -func (a *App) cachedPassiveSnapshotStore(store *graph.GraphPersistenceStore) *graph.SnapshotGraphStore { - if a == nil || store == nil { - return nil - } - a.passiveSnapshotStoreMu.RLock() - defer a.passiveSnapshotStoreMu.RUnlock() - return a.cachedPassiveSnapshotStoreLocked(store) -} - -func (a *App) cachedPassiveSnapshotStoreLocked(store *graph.GraphPersistenceStore) *graph.SnapshotGraphStore { - if a == nil || store == nil || a.passiveSnapshotStore == nil { - return nil - } - if a.passiveSnapshotStoreOwner != store { - return nil - } - cacheID := passiveSnapshotStoreCacheID(a.passiveSnapshotStoreSource, nil, store.Status()) - if cacheID != "" && cacheID != a.passiveSnapshotStoreID && cacheID != a.passiveSnapshotStoreStatusID { - return nil - } - return a.passiveSnapshotStore -} - -func passiveSnapshotStoreCacheID(source string, record *graph.GraphSnapshotRecord, status graph.GraphPersistenceStatus) string { - if record != nil { - if id := strings.TrimSpace(record.ID); id != "" { - return id - } - } - _ = source - return strings.TrimSpace(status.LastPersistedSnapshot) -} diff --git a/internal/app/app_graph_store_test.go b/internal/app/app_graph_store_test.go index 51dd4429f..4e45eb9bd 100644 --- a/internal/app/app_graph_store_test.go +++ b/internal/app/app_graph_store_test.go @@ -7,7 +7,6 @@ import ( "reflect" "testing" "time" - "unsafe" "github.com/writer/cerebro/internal/graph" ) @@ -53,26 +52,21 @@ func TestCurrentSecurityGraphStoreTracksLiveGraphSwap(t *testing.T) { } } -func TestCurrentSecurityGraphStoreUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { - persisted := graph.New() - persisted.AddNode(&graph.Node{ID: "service:persisted", Kind: graph.NodeKindService}) - persisted.BuildIndex() +func TestCurrentSecurityGraphStoreUsesConfiguredBackendWhenLiveGraphUnavailable(t *testing.T) { + configured := graph.New() + configured.AddNode(&graph.Node{ID: "service:configured", Kind: graph.NodeKindService}) + configured.BuildIndex() - store := mustPersistToolGraph(t, persisted) - application := &App{GraphSnapshots: store} + application := &App{} + setConfiguredGraphFromGraph(t, application, configured) graphStore := application.CurrentSecurityGraphStore() if graphStore == nil { t.Fatal("expected graph store wrapper") } - if _, ok, err := graphStore.LookupNode(context.Background(), "service:persisted"); err != nil || !ok { - t.Fatalf("LookupNode(persisted) = (%v, %v), want present; err=%v", ok, err, err) - } - - status := store.Status() - if status.LastRecoveredSnapshot != "" || status.LastRecoverySource != "" { - t.Fatalf("expected passive snapshot read to avoid recovery bookkeeping, got %+v", status) + if _, ok, err := graphStore.LookupNode(context.Background(), "service:configured"); err != nil || !ok { + t.Fatalf("LookupNode(configured) = (%v, %v), want present; err=%v", ok, err, err) } } @@ -102,85 +96,6 @@ func TestCurrentSecurityGraphStorePrefersConfiguredBackendWhenReady(t *testing.T } } -func TestCurrentPassiveSnapshotStoreReusesCachedStoreUntilSnapshotChanges(t *testing.T) { - first := graph.New() - first.AddNode(&graph.Node{ID: "service:first", Kind: graph.NodeKindService}) - first.BuildIndex() - - snapshots := mustPersistToolGraph(t, first) - application := &App{GraphSnapshots: snapshots} - - firstStore, err := application.currentPassiveSnapshotStore(context.Background()) - if err != nil { - t.Fatalf("currentPassiveSnapshotStore() first error = %v", err) - } - if _, ok, err := firstStore.LookupNode(context.Background(), "service:first"); err != nil || !ok { - t.Fatalf("LookupNode(service:first) = (%v, %v), want present; err=%v", ok, err, err) - } - - secondStore, err := application.currentPassiveSnapshotStore(context.Background()) - if err != nil { - t.Fatalf("currentPassiveSnapshotStore() second error = %v", err) - } - if firstStore != secondStore { - t.Fatal("expected passive snapshot store to be reused while snapshot is unchanged") - } - - next := graph.New() - next.AddNode(&graph.Node{ID: "service:next", Kind: graph.NodeKindService}) - next.BuildIndex() - if _, err := snapshots.SaveGraph(next); err != nil { - t.Fatalf("SaveGraph(next) error = %v", err) - } - - thirdStore, err := application.currentPassiveSnapshotStore(context.Background()) - if err != nil { - t.Fatalf("currentPassiveSnapshotStore() third error = %v", err) - } - if thirdStore == secondStore { - t.Fatal("expected passive snapshot store to refresh after snapshot changes") - } - if _, ok, err := thirdStore.LookupNode(context.Background(), "service:next"); err != nil || !ok { - t.Fatalf("LookupNode(service:next) = (%v, %v), want present; err=%v", ok, err, err) - } -} - -func TestCurrentPassiveSnapshotStoreReusesCachedStoreWhileStatusCatchesUp(t *testing.T) { - source := graph.New() - source.AddNode(&graph.Node{ID: "service:cached", Kind: graph.NodeKindService}) - source.BuildIndex() - - snapshots := mustPersistToolGraph(t, source) - application := &App{GraphSnapshots: snapshots} - - caughtUpStatus := snapshots.Status() - laggingStatus := caughtUpStatus - laggingStatus.LastPersistedSnapshot = "stale-snapshot-id" - setGraphPersistenceStoreStatus(t, snapshots, laggingStatus) - - firstStore, err := application.currentPassiveSnapshotStore(context.Background()) - if err != nil { - t.Fatalf("currentPassiveSnapshotStore() first error = %v", err) - } - secondStore, err := application.currentPassiveSnapshotStore(context.Background()) - if err != nil { - t.Fatalf("currentPassiveSnapshotStore() second error = %v", err) - } - if firstStore != secondStore { - t.Fatal("expected passive snapshot cache to survive lagging status identifier") - } - - setGraphPersistenceStoreStatus(t, snapshots, caughtUpStatus) - - thirdStore, err := application.currentPassiveSnapshotStore(context.Background()) - if err != nil { - t.Fatalf("currentPassiveSnapshotStore() third error = %v", err) - } - if firstStore != thirdStore { - t.Fatal("expected passive snapshot cache to survive status catch-up") - } -} - func TestTenantGraphStoreResolverCachesScopedStoreUntilSourceChanges(t *testing.T) { resolver := &tenantGraphStoreResolver{tenantID: "tenant-a"} source := buildTenantShardTestGraph(time.Date(2026, time.March, 17, 23, 0, 0, 0, time.UTC)) @@ -373,20 +288,10 @@ func TestCurrentSecurityGraphStoreForTenantScopesAndTracksLiveGraphSwap(t *testi } } -func setGraphPersistenceStoreStatus(t *testing.T, store *graph.GraphPersistenceStore, status graph.GraphPersistenceStatus) { - t.Helper() - value := reflect.ValueOf(store) - if !value.IsValid() || value.IsNil() { - t.Fatal("expected graph persistence store") - } - field := value.Elem().FieldByName("status") - reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(status)) -} - -func TestCurrentSecurityGraphStoreForTenantUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestCurrentSecurityGraphStoreForTenantUsesConfiguredBackendWhenLiveGraphUnavailable(t *testing.T) { source := buildTenantShardTestGraph(time.Date(2026, time.March, 17, 23, 0, 0, 0, time.UTC)) - store := mustPersistToolGraph(t, source) - application := &App{GraphSnapshots: store} + application := &App{} + setConfiguredSnapshotGraphFromGraph(t, application, source) graphStore := application.CurrentSecurityGraphStoreForTenant("tenant-a") if graphStore == nil { @@ -401,11 +306,6 @@ func TestCurrentSecurityGraphStoreForTenantUsesPersistedSnapshotWhenLiveGraphUna } else if ok { t.Fatal("expected tenant store to exclude foreign-tenant nodes") } - - status := store.Status() - if status.LastRecoveredSnapshot != "" || status.LastRecoverySource != "" { - t.Fatalf("expected passive snapshot read to avoid recovery bookkeeping, got %+v", status) - } } func TestCurrentSecurityGraphStoreForTenantScopesConfiguredSnapshotReads(t *testing.T) { @@ -437,10 +337,39 @@ func TestCurrentSecurityGraphStoreForTenantScopesConfiguredSnapshotReads(t *test } } +func TestCurrentSecurityGraphStoreDoesNotFallbackToPersistedSnapshots(t *testing.T) { + persisted := graph.New() + persisted.AddNode(&graph.Node{ID: "service:persisted", Kind: graph.NodeKindService}) + persisted.BuildIndex() + + application := &App{GraphSnapshots: mustPersistToolGraph(t, persisted)} + + graphStore := application.CurrentSecurityGraphStore() + if graphStore == nil { + t.Fatal("expected graph store wrapper") + } + if _, _, err := graphStore.LookupNode(context.Background(), "service:persisted"); !errors.Is(err, graph.ErrStoreUnavailable) { + t.Fatalf("LookupNode() error = %v, want ErrStoreUnavailable", err) + } +} + +func TestCurrentSecurityGraphStoreForTenantDoesNotFallbackToPersistedSnapshots(t *testing.T) { + source := buildTenantShardTestGraph(time.Date(2026, time.March, 17, 23, 0, 0, 0, time.UTC)) + application := &App{GraphSnapshots: mustPersistToolGraph(t, source)} + + graphStore := application.CurrentSecurityGraphStoreForTenant("tenant-a") + if graphStore == nil { + t.Fatal("expected tenant-scoped graph store wrapper") + } + if _, _, err := graphStore.LookupNode(context.Background(), "service:tenant-a"); !errors.Is(err, graph.ErrStoreUnavailable) { + t.Fatalf("LookupNode() error = %v, want ErrStoreUnavailable", err) + } +} + func TestCurrentSecurityGraphStoreForTenantReturnsUnavailableWhenTenantMissingFromSnapshot(t *testing.T) { source := buildTenantShardTestGraph(time.Date(2026, time.March, 17, 23, 0, 0, 0, time.UTC)) - store := mustPersistToolGraph(t, source) - application := &App{GraphSnapshots: store} + application := &App{} + setConfiguredSnapshotGraphFromGraph(t, application, source) graphStore := application.CurrentSecurityGraphStoreForTenant("tenant-missing") if graphStore == nil { @@ -551,13 +480,13 @@ func TestCurrentSecurityGraphStoreForTenantRejectsWrites(t *testing.T) { } } -func TestCurrentSecurityGraphStoreTreatsPersistedSnapshotAsReadOnly(t *testing.T) { - persisted := graph.New() - persisted.AddNode(&graph.Node{ID: "service:persisted", Kind: graph.NodeKindService}) - persisted.BuildIndex() +func TestCurrentSecurityGraphStoreTreatsConfiguredSnapshotAsReadOnly(t *testing.T) { + configured := graph.New() + configured.AddNode(&graph.Node{ID: "service:configured", Kind: graph.NodeKindService}) + configured.BuildIndex() - store := mustPersistToolGraph(t, persisted) - application := &App{GraphSnapshots: store} + application := &App{} + setConfiguredSnapshotGraphFromGraph(t, application, configured) graphStore := application.CurrentSecurityGraphStore() if graphStore == nil { @@ -572,12 +501,7 @@ func TestCurrentSecurityGraphStoreTreatsPersistedSnapshotAsReadOnly(t *testing.T if _, ok, err := graphStore.LookupNode(context.Background(), "service:new"); err != nil { t.Fatalf("LookupNode(service:new) error = %v", err) } else if ok { - t.Fatal("expected persisted snapshot fallback to remain unchanged after rejected write") - } - - status := store.Status() - if status.LastRecoveredSnapshot != "" || status.LastRecoverySource != "" { - t.Fatalf("expected passive snapshot fallback to avoid recovery bookkeeping, got %+v", status) + t.Fatal("expected configured snapshot store to remain unchanged after rejected write") } } diff --git a/internal/app/app_graph_updates.go b/internal/app/app_graph_updates.go index d0a0ea8fc..92523d5f3 100644 --- a/internal/app/app_graph_updates.go +++ b/internal/app/app_graph_updates.go @@ -3,14 +3,12 @@ package app import ( "context" "fmt" - "strings" "time" "github.com/writer/cerebro/internal/graph" ) -// ApplySecurityGraphChanges applies CDC-backed graph mutations and falls back to -// a copy-on-write full rebuild only when incremental mutation fails. +// ApplySecurityGraphChanges applies CDC-backed graph mutations. func (a *App) ApplySecurityGraphChanges(ctx context.Context, trigger string) (graph.GraphMutationSummary, error) { if a == nil || a.SecurityGraphBuilder == nil { return graph.GraphMutationSummary{}, errGraphNotInitialized() @@ -41,39 +39,17 @@ func (a *App) applySecurityGraphChangesLocked(ctx context.Context, trigger strin if err := a.requireGraphWriterLease("apply security graph changes"); err != nil { return graph.GraphMutationSummary{}, err } - start := time.Now() if err := a.prepareSecurityGraphBuilderForIncrementalApply(ctx); err != nil { return graph.GraphMutationSummary{}, err } summary, err := a.SecurityGraphBuilder.ApplyChanges(ctx, time.Time{}) if err != nil { - a.Logger.Warn("incremental graph apply failed, falling back to full rebuild", + a.Logger.Warn("incremental graph apply failed", "trigger", trigger, "error", err, ) - a.setGraphBuildState(GraphBuildBuilding, time.Time{}, nil) - if buildErr := a.SecurityGraphBuilder.Build(ctx); buildErr != nil { - a.setGraphBuildState(GraphBuildFailed, time.Now().UTC(), buildErr) - return graph.GraphMutationSummary{}, buildErr - } - - securityGraph := a.SecurityGraphBuilder.Graph() - meta, activateErr := a.activateBuiltSecurityGraph(ctx, securityGraph) - if activateErr != nil { - return graph.GraphMutationSummary{}, activateErr - } - - summary = a.SecurityGraphBuilder.LastMutation() - duration := time.Since(start) - a.Logger.Info("security graph rebuilt after incremental apply failure", - "trigger", trigger, - "nodes", meta.NodeCount, - "edges", meta.EdgeCount, - "duration", duration, - ) - a.emitGraphRebuiltEvent(ctx, meta, duration) - a.emitGraphMutationEvent(ctx, summary, trigger) - return summary, nil + a.setGraphBuildState(GraphBuildFailed, time.Now().UTC(), err) + return graph.GraphMutationSummary{}, err } if summary.EventsProcessed == 0 { @@ -254,15 +230,5 @@ func (a *App) currentIncrementalBuilderSnapshot(ctx context.Context) (*graph.Sna } else if snapshot != nil { return snapshot, nil } - if a.GraphSnapshots == nil { - return nil, nil - } - snapshot, _, _, err := a.GraphSnapshots.PeekLatestSnapshot() - if err != nil { - if isNoSnapshotsGraphStoreErr(err) || strings.Contains(strings.ToLower(err.Error()), "no snapshots found") { - return nil, nil - } - return nil, err - } - return snapshot, nil + return nil, nil } diff --git a/internal/app/app_graph_view.go b/internal/app/app_graph_view.go index 360e1a27a..4f8847cc5 100644 --- a/internal/app/app_graph_view.go +++ b/internal/app/app_graph_view.go @@ -10,17 +10,22 @@ import ( ) func (a *App) currentOrStoredSecurityGraphView() (*graph.Graph, error) { - return a.currentOrStoredSecurityGraphViewWithSnapshotLoader(func(store *graph.GraphPersistenceStore) (*graph.Snapshot, error) { - snapshot, _, _, err := store.LoadLatestSnapshot() - return snapshot, err - }) + if a == nil { + return nil, nil + } + if current := a.currentLiveSecurityGraph(); current != nil && (current.NodeCount() > 0 || current.EdgeCount() > 0) { + return current, nil + } + if view, err := a.currentConfiguredSecurityGraphView(context.Background()); err != nil { + return nil, err + } else if view != nil { + return view, nil + } + return a.currentLiveSecurityGraph(), nil } func (a *App) currentOrStoredPassiveSecurityGraphView() (*graph.Graph, error) { - return a.currentOrStoredSecurityGraphViewWithSnapshotLoader(func(store *graph.GraphPersistenceStore) (*graph.Snapshot, error) { - snapshot, _, _, err := store.PeekLatestSnapshot() - return snapshot, err - }) + return a.currentOrStoredSecurityGraphView() } func (a *App) currentOrStoredPassiveGraphSnapshotRecord() (*graph.GraphSnapshotRecord, error) { @@ -40,61 +45,7 @@ func (a *App) currentOrStoredPassiveGraphSnapshotRecord() (*graph.GraphSnapshotR if current := graph.CurrentGraphSnapshotRecord(a.CurrentSecurityGraph()); current != nil { return current, nil } - store := a.platformGraphSnapshotStoreForTool() - if store == nil { - return nil, nil - } - snapshot, record, _, err := store.PeekLatestSnapshot() - if err != nil { - if strings.Contains(strings.ToLower(err.Error()), "no snapshots found") { - return nil, nil - } - return nil, err - } - if record != nil { - current := *record - current.Current = true - return ¤t, nil - } - if snapshot == nil { - return nil, nil - } - return graph.CurrentGraphSnapshotRecord(graph.GraphViewFromSnapshot(snapshot)), nil -} - -func (a *App) currentOrStoredSecurityGraphViewWithSnapshotLoader(loadSnapshot func(store *graph.GraphPersistenceStore) (*graph.Snapshot, error)) (*graph.Graph, error) { - if a == nil { - return nil, nil - } - if current := a.currentLiveSecurityGraph(); current != nil { - return current, nil - } - if view, err := a.currentConfiguredSecurityGraphView(context.Background()); err != nil { - return nil, err - } else if view != nil { - return view, nil - } - return a.storedSecurityGraphViewWithSnapshotLoader(loadSnapshot) -} - -func (a *App) storedSecurityGraphViewWithSnapshotLoader(loadSnapshot func(store *graph.GraphPersistenceStore) (*graph.Snapshot, error)) (*graph.Graph, error) { - if a == nil { - return nil, nil - } - if a.GraphSnapshots == nil || loadSnapshot == nil { - return nil, nil - } - snapshot, err := loadSnapshot(a.GraphSnapshots) - if err != nil { - if strings.Contains(strings.ToLower(err.Error()), "no snapshots found") { - return nil, nil - } - return nil, err - } - if snapshot == nil { - return nil, nil - } - return graph.GraphViewFromSnapshot(snapshot), nil + return nil, nil } func (a *App) currentOrStoredSecurityGraphViewForTenant(tenantID string) (*graph.Graph, error) { diff --git a/internal/app/app_init_agents.go b/internal/app/app_init_agents.go index 525b6beed..e64a0a932 100644 --- a/internal/app/app_init_agents.go +++ b/internal/app/app_init_agents.go @@ -13,7 +13,10 @@ func (a *App) initAgents(ctx context.Context) { ctx = context.Background() } a.Agents = agents.NewAgentRegistry() - if a.Snowflake != nil { + switch { + case a.appStateDB != nil: + a.Agents.SetSessionStore(agents.NewPostgresSessionStore(a.appStateDB)) + case a.Snowflake != nil: store, err := agents.NewSnowflakeSessionStore(a.Snowflake) if err != nil { a.Logger.Warn("failed to initialize persistent agent session store, using in-memory store", "error", err) diff --git a/internal/app/app_init_core.go b/internal/app/app_init_core.go index ce51fae9b..865c4a232 100644 --- a/internal/app/app_init_core.go +++ b/internal/app/app_init_core.go @@ -26,6 +26,13 @@ import ( "github.com/writer/cerebro/internal/webhooks" ) +var ( + newSnowflakeClient = snowflake.NewClient + pingSnowflake = func(ctx context.Context, client *snowflake.Client) error { + return client.Ping(ctx) + } +) + func (a *App) initWarehouse(ctx context.Context) error { switch strings.ToLower(strings.TrimSpace(a.Config.WarehouseBackend)) { case "", "snowflake": @@ -39,34 +46,66 @@ func (a *App) initWarehouse(ctx context.Context) error { } } +func hasSnowflakeCredentials(cfg *Config) bool { + return cfg != nil && + strings.TrimSpace(cfg.SnowflakePrivateKey) != "" && + strings.TrimSpace(cfg.SnowflakeAccount) != "" && + strings.TrimSpace(cfg.SnowflakeUser) != "" +} + +func snowflakeClientConfig(cfg *Config) snowflake.ClientConfig { + return snowflake.ClientConfig{ + Account: cfg.SnowflakeAccount, + User: cfg.SnowflakeUser, + PrivateKey: cfg.SnowflakePrivateKey, + Database: cfg.SnowflakeDatabase, + Schema: cfg.SnowflakeSchema, + Warehouse: cfg.SnowflakeWarehouse, + Role: cfg.SnowflakeRole, + } +} + +func openConfiguredSnowflakeClient(ctx context.Context, cfg *Config) (*snowflake.Client, error) { + client, err := newSnowflakeClient(snowflakeClientConfig(cfg)) + if err != nil { + return nil, err + } + if err := pingSnowflake(ctx, client); err != nil { + _ = client.Close() + return nil, err + } + return client, nil +} + func (a *App) initSnowflake(ctx context.Context) error { // Require key-pair auth - if a.Config.SnowflakePrivateKey == "" || a.Config.SnowflakeAccount == "" || a.Config.SnowflakeUser == "" { + if !hasSnowflakeCredentials(a.Config) { return fmt.Errorf("snowflake not configured: set SNOWFLAKE_PRIVATE_KEY, SNOWFLAKE_ACCOUNT, and SNOWFLAKE_USER") } - client, err := snowflake.NewClient(snowflake.ClientConfig{ - Account: a.Config.SnowflakeAccount, - User: a.Config.SnowflakeUser, - PrivateKey: a.Config.SnowflakePrivateKey, - Database: a.Config.SnowflakeDatabase, - Schema: a.Config.SnowflakeSchema, - Warehouse: a.Config.SnowflakeWarehouse, - Role: a.Config.SnowflakeRole, - }) + client, err := openConfiguredSnowflakeClient(ctx, a.Config) if err != nil { return err } - if err := client.Ping(ctx); err != nil { - return err - } - a.Snowflake = client + a.LegacySnowflake = nil a.Warehouse = client return nil } +func (a *App) initLegacySnowflake(ctx context.Context) error { + if a == nil || a.Config == nil || a.Snowflake != nil || !hasSnowflakeCredentials(a.Config) { + return nil + } + client, err := openConfiguredSnowflakeClient(ctx, a.Config) + if err != nil { + return err + } + a.LegacySnowflake = client + return nil +} + func (a *App) initSQLiteWarehouse(_ context.Context) error { store, err := warehouse.NewSQLiteWarehouse(warehouse.SQLiteWarehouseConfig{ Path: strings.TrimSpace(a.Config.WarehouseSQLitePath), @@ -84,7 +123,7 @@ func (a *App) initSQLiteWarehouse(_ context.Context) error { func (a *App) initPostgresWarehouse(_ context.Context) error { store, err := warehouse.NewPostgresWarehouse(warehouse.PostgresWarehouseConfig{ - DSN: strings.TrimSpace(a.Config.WarehousePostgresDSN), + DSN: a.warehousePostgresDSN(), AppSchema: "cerebro", }) if err != nil { @@ -95,6 +134,13 @@ func (a *App) initPostgresWarehouse(_ context.Context) error { return nil } +func (a *App) warehousePostgresDSN() string { + if a == nil || a.Config == nil { + return "" + } + return strings.TrimSpace(a.Config.WarehousePostgresDSN) +} + func (a *App) initPolicy() error { a.Policy = policy.NewEngine() if err := a.Policy.LoadPolicies(a.Config.PoliciesPath); err != nil { @@ -114,6 +160,19 @@ func (a *App) initPolicy() error { } func (a *App) initFindings() { + if a.appStateDB != nil { + store := findings.NewPostgresStore(a.appStateDB) + store.SetSemanticDedup(a.Config.FindingsSemanticDedupEnabled) + if err := store.Load(context.Background()); err != nil { + a.Logger.Warn("failed to load postgres findings store", "error", err) + } + a.Findings = store + a.SnowflakeFindings = nil + a.configureFindingAttestation() + a.Logger.Info("using postgres findings store") + return + } + warehouseDB := (*sql.DB)(nil) if a.Warehouse != nil { warehouseDB = a.Warehouse.DB() @@ -230,6 +289,10 @@ func (a *App) configureFindingAttestation() { store.SetAttestor(attestor, a.Config.FindingAttestationAttestReobserved) configured = true } + if store, ok := a.Findings.(*findings.PostgresStore); ok { + store.SetAttestor(attestor, a.Config.FindingAttestationAttestReobserved) + configured = true + } if store, ok := a.Findings.(*findings.FileStore); ok { store.SetAttestor(attestor, a.Config.FindingAttestationAttestReobserved) configured = true diff --git a/internal/app/app_init_core_test.go b/internal/app/app_init_core_test.go index dba6a88d9..e42e11665 100644 --- a/internal/app/app_init_core_test.go +++ b/internal/app/app_init_core_test.go @@ -12,9 +12,12 @@ import ( "testing" "time" + _ "modernc.org/sqlite" + "github.com/writer/cerebro/internal/findings" "github.com/writer/cerebro/internal/graph" "github.com/writer/cerebro/internal/identity" + "github.com/writer/cerebro/internal/snowflake" "github.com/writer/cerebro/internal/warehouse" ) @@ -59,6 +62,33 @@ func TestInitFindings_FallsBackToSQLiteWhenWarehouseHasNoDB(t *testing.T) { } } +func TestInitFindings_UsesPostgresStoreWhenAppStateDatabaseConfigured(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + if err := findings.NewPostgresStore(db).EnsureSchema(context.Background()); err != nil { + t.Fatalf("EnsureSchema() error = %v", err) + } + + a := &App{ + Config: &Config{}, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + a.appStateDB = db + + a.initFindings() + + if _, ok := a.Findings.(*findings.PostgresStore); !ok { + t.Fatalf("expected postgres findings store, got %T", a.Findings) + } + if a.SnowflakeFindings != nil { + t.Fatal("expected legacy snowflake findings store to stay nil when app-state postgres is configured") + } +} + func TestNewInMemoryFindingsStore_UsesConfiguredBounds(t *testing.T) { var logs bytes.Buffer a := &App{ @@ -128,6 +158,107 @@ func TestInitWarehouse_UsesSQLiteBackend(t *testing.T) { } } +func TestWarehousePostgresDSN_DoesNotFallBackToJobDatabaseURL(t *testing.T) { + a := &App{ + Config: &Config{ + JobDatabaseURL: "postgres://jobs", + }, + } + + if got := a.warehousePostgresDSN(); got != "" { + t.Fatalf("expected warehouse DSN to ignore JOB_DATABASE_URL, got %q", got) + } +} + +func TestInitLegacySnowflake_UsesSeparateClientForNonSnowflakeWarehouse(t *testing.T) { + originalNewSnowflakeClient := newSnowflakeClient + originalPingSnowflake := pingSnowflake + newSnowflakeClient = func(snowflake.ClientConfig) (*snowflake.Client, error) { + return new(snowflake.Client), nil + } + pingSnowflake = func(context.Context, *snowflake.Client) error { return nil } + t.Cleanup(func() { + newSnowflakeClient = originalNewSnowflakeClient + pingSnowflake = originalPingSnowflake + }) + + a := &App{ + Config: &Config{ + WarehouseBackend: "postgres", + SnowflakeAccount: "acct", + SnowflakeUser: "user", + SnowflakePrivateKey: "key", + WarehousePostgresDSN: "postgres://warehouse", + }, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + + if err := a.initLegacySnowflake(context.Background()); err != nil { + t.Fatalf("initLegacySnowflake() error = %v", err) + } + + if a.LegacySnowflake == nil { + t.Fatal("expected legacy snowflake client to be initialized") + } + if a.Snowflake != nil { + t.Fatalf("expected active snowflake warehouse client to stay nil, got %T", a.Snowflake) + } + if a.Warehouse != nil { + t.Fatalf("expected warehouse selection to remain unchanged, got %T", a.Warehouse) + } +} + +func TestAppStateMigrationSnowflakePrefersLegacyClient(t *testing.T) { + active := new(snowflake.Client) + legacy := new(snowflake.Client) + a := &App{Snowflake: active, LegacySnowflake: legacy} + + if got := a.appStateMigrationSnowflake(); got != legacy { + t.Fatalf("expected legacy snowflake migration source, got %p want %p", got, legacy) + } +} + +func TestRotateSnowflakeClientPreservesWarehouseWhenUsingLegacySource(t *testing.T) { + originalNewSnowflakeClient := newSnowflakeClient + originalPingSnowflake := pingSnowflake + newSnowflakeClient = func(snowflake.ClientConfig) (*snowflake.Client, error) { + return new(snowflake.Client), nil + } + pingSnowflake = func(context.Context, *snowflake.Client) error { return nil } + t.Cleanup(func() { + newSnowflakeClient = originalNewSnowflakeClient + pingSnowflake = originalPingSnowflake + }) + + existingWarehouse := &warehouse.MemoryWarehouse{} + oldLegacy := new(snowflake.Client) + a := &App{ + Config: &Config{ + WarehouseBackend: "postgres", + SnowflakeAccount: "acct", + SnowflakeUser: "user", + SnowflakePrivateKey: "key", + }, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + Warehouse: existingWarehouse, + LegacySnowflake: oldLegacy, + } + + if err := a.rotateSnowflakeClient(context.Background(), a.Config); err != nil { + t.Fatalf("rotateSnowflakeClient() error = %v", err) + } + + if a.Warehouse != existingWarehouse { + t.Fatalf("expected warehouse selection to stay unchanged, got %T", a.Warehouse) + } + if a.Snowflake != nil { + t.Fatalf("expected active snowflake client to remain nil, got %T", a.Snowflake) + } + if a.LegacySnowflake == nil || a.LegacySnowflake == oldLegacy { + t.Fatal("expected legacy snowflake client to rotate independently") + } +} + func TestInitFindings_UsesSQLiteStoreForSQLiteWarehouse(t *testing.T) { a := &App{ Config: &Config{ @@ -215,7 +346,7 @@ func TestInitIdentityGraphResolverUsesTenantReadScope(t *testing.T) { } } -func TestInitIdentityGraphResolverUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestInitIdentityGraphResolverUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { a := &App{ Config: &Config{ GraphTenantShardIdleTTL: 10 * time.Minute, @@ -260,7 +391,7 @@ func TestInitIdentityGraphResolverUsesPersistedSnapshotWhenLiveGraphUnavailable( }) g.AddEdge(&graph.Edge{ID: "alice-tenant-a", Source: "user:alice", Target: "bucket:tenant-a", Kind: graph.EdgeKindCanRead, Effect: graph.EdgeEffectAllow}) g.AddEdge(&graph.Edge{ID: "alice-tenant-b", Source: "user:alice", Target: "bucket:tenant-b", Kind: graph.EdgeKindCanRead, Effect: graph.EdgeEffectAllow}) - a.GraphSnapshots = mustPersistToolGraph(t, g) + setConfiguredSnapshotGraphFromGraph(t, a, g) a.initIdentity() if a.currentLiveSecurityGraph() != nil { diff --git a/internal/app/app_init_phases.go b/internal/app/app_init_phases.go index 9b6686580..17ed5436b 100644 --- a/internal/app/app_init_phases.go +++ b/internal/app/app_init_phases.go @@ -65,6 +65,12 @@ func (a *App) initPhase1(ctx context.Context) error { func (a *App) initPhase2a(ctx context.Context) error { a.initExecutionStore() a.initGraphPersistenceStore() + if err := runInitErrorStep("app_state_db", func() error { return a.initAppStateDB(ctx) }); err != nil { + return err + } + if err := runInitErrorStep("legacy_snowflake", func() error { return a.initLegacySnowflake(ctx) }); err != nil { + a.Logger.Warn("legacy snowflake initialization failed", "error", err) + } if err := runInitErrorStep("graph_store_backend", func() error { return a.initConfiguredSecurityGraphStore(ctx) }); err != nil { return err } @@ -95,6 +101,9 @@ func (a *App) initPhase2a(ctx context.Context) error { }); err != nil { return fmt.Errorf("phase 2a init failed: %w", err) } + if err := runInitErrorStep("app_state_migration", func() error { return a.migrateAppState(ctx) }); err != nil { + return err + } return nil } diff --git a/internal/app/app_scan_scheduler.go b/internal/app/app_scan_scheduler.go index 11a44e608..d00664f0f 100644 --- a/internal/app/app_scan_scheduler.go +++ b/internal/app/app_scan_scheduler.go @@ -488,7 +488,7 @@ func (a *App) runScheduledScan(ctx context.Context, tables []string) error { } sqlToxicRiskSets := make(map[string][]map[string]bool) - if a.Warehouse != nil { + if scanner.SupportsRelationshipToxicDetection(a.Warehouse) { var toxicCursor *scanner.ToxicScanCursor if a.ScanWatermarks != nil { if wm := a.ScanWatermarks.GetWatermark("_toxic_relationships"); wm != nil { @@ -560,10 +560,9 @@ func (a *App) runScheduledScan(ctx context.Context, tables []string) error { } } - // Sync to Snowflake if available - if a.SnowflakeFindings != nil { - if err := a.SnowflakeFindings.Sync(ctx); err != nil { - a.Logger.Warn("failed to sync findings to snowflake", "error", err) + if syncer, ok := a.Findings.(interface{ Sync(context.Context) error }); ok { + if err := syncer.Sync(ctx); err != nil { + a.Logger.Warn("failed to sync findings", "error", err) } } diff --git a/internal/app/app_scan_scheduler_test.go b/internal/app/app_scan_scheduler_test.go index 51e0d1b84..65928f09b 100644 --- a/internal/app/app_scan_scheduler_test.go +++ b/internal/app/app_scan_scheduler_test.go @@ -14,21 +14,20 @@ import ( "github.com/writer/cerebro/internal/scanner" ) -func TestCurrentOrStoredScheduledScanGraphView_UsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestCurrentOrStoredScheduledScanGraphView_UsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { g := orgTopologyTestGraph(time.Now().UTC()) - app := &App{ - GraphSnapshots: mustPersistToolGraph(t, g), - } + app := &App{} + setConfiguredSnapshotGraphFromGraph(t, app, g) got := app.currentOrStoredScheduledScanGraphView(context.Background(), ScanTuning{}) if got == nil { - t.Fatal("expected persisted snapshot graph view") + t.Fatal("expected configured graph view") } if got.NodeCount() != g.NodeCount() { t.Fatalf("expected %d nodes, got %d", g.NodeCount(), got.NodeCount()) } if _, ok := got.GetNode("svc:core"); !ok { - t.Fatal("expected persisted graph view to include svc:core") + t.Fatal("expected configured graph view to include svc:core") } } @@ -38,10 +37,10 @@ func TestCurrentOrStoredScheduledScanGraphView_PreservesLiveGraphWaitWhenPresent live.BuildIndex() app := &App{ - SecurityGraph: live, - GraphSnapshots: mustPersistToolGraph(t, orgTopologyTestGraph(time.Now().UTC())), - graphReady: make(chan struct{}), + SecurityGraph: live, + graphReady: make(chan struct{}), } + setConfiguredSnapshotGraphFromGraph(t, app, orgTopologyTestGraph(time.Now().UTC())) got := app.currentOrStoredScheduledScanGraphView(context.Background(), ScanTuning{ GraphWaitTimeout: 5 * time.Millisecond, @@ -51,7 +50,7 @@ func TestCurrentOrStoredScheduledScanGraphView_PreservesLiveGraphWaitWhenPresent } } -func TestRunScheduledGraphAnalyses_UsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestRunScheduledGraphAnalyses_UsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) engine := policy.NewEngine() addOrgTestPolicy(t, engine, &policy.Policy{ @@ -68,25 +67,25 @@ func TestRunScheduledGraphAnalyses_UsesPersistedSnapshotWhenLiveGraphUnavailable findingStore := findings.NewStore() app := &App{ - Logger: logger, - Policy: engine, - Scanner: scanner.NewScanner(engine, scanner.ScanConfig{}, logger), - Findings: findingStore, - GraphSnapshots: mustPersistToolGraph(t, orgTopologyTestGraph(time.Now().UTC())), + Logger: logger, + Policy: engine, + Scanner: scanner.NewScanner(engine, scanner.ScanConfig{}, logger), + Findings: findingStore, } + setConfiguredSnapshotGraphFromGraph(t, app, orgTopologyTestGraph(time.Now().UTC())) summary := app.runScheduledGraphAnalyses(context.Background(), ScanTuning{}, nil) if summary.orgTopologyErrorCount != 0 { t.Fatalf("expected no org topology errors, got %d", summary.orgTopologyErrorCount) } if summary.orgTopologyFindingCount == 0 { - t.Fatal("expected org topology findings from persisted snapshot") + t.Fatal("expected org topology findings from configured graph") } stored := findingStore.List(findings.FindingFilter{}) if !slices.ContainsFunc(stored, func(f *findings.Finding) bool { return f != nil && f.PolicyID == "org-bus-factor-critical" }) { - t.Fatalf("expected stored finding for persisted snapshot org topology policy, got %v", stored) + t.Fatalf("expected stored finding for configured graph org topology policy, got %v", stored) } } diff --git a/internal/app/app_secrets.go b/internal/app/app_secrets.go index c7aca54da..a54afc2fe 100644 --- a/internal/app/app_secrets.go +++ b/internal/app/app_secrets.go @@ -241,70 +241,71 @@ func (a *App) rotateSnowflakeClient(ctx context.Context, cfg *Config) error { ctx = context.Background() } - if strings.TrimSpace(cfg.SnowflakePrivateKey) == "" || - strings.TrimSpace(cfg.SnowflakeAccount) == "" || - strings.TrimSpace(cfg.SnowflakeUser) == "" { + if !hasSnowflakeCredentials(cfg) { return fmt.Errorf("snowflake rotation requires SNOWFLAKE_PRIVATE_KEY, SNOWFLAKE_ACCOUNT, and SNOWFLAKE_USER") } - newClient, err := snowflake.NewClient(snowflake.ClientConfig{ - Account: cfg.SnowflakeAccount, - User: cfg.SnowflakeUser, - PrivateKey: cfg.SnowflakePrivateKey, - Database: cfg.SnowflakeDatabase, - Schema: cfg.SnowflakeSchema, - Warehouse: cfg.SnowflakeWarehouse, - Role: cfg.SnowflakeRole, - }) + newClient, err := openConfiguredSnowflakeClient(ctx, cfg) if err != nil { return err } - if err := newClient.Ping(ctx); err != nil { - _ = newClient.Close() - return err - } oldClient := a.Snowflake - a.Snowflake = newClient - a.Warehouse = newClient - a.initRepositories() - - if a.ScanWatermarks != nil { - a.ScanWatermarks.SetDB(newClient.DB()) - if err := a.ScanWatermarks.LoadWatermarks(ctx); err != nil && a.Logger != nil { - a.Logger.Warn("failed to reload scan watermarks after snowflake rotation", "error", err) + oldLegacyClient := a.LegacySnowflake + if strings.EqualFold(strings.TrimSpace(cfg.WarehouseBackend), "snowflake") { + a.Snowflake = newClient + a.LegacySnowflake = nil + a.Warehouse = newClient + a.initRepositories() + + if a.ScanWatermarks != nil { + a.ScanWatermarks.SetDB(newClient.DB()) + if err := a.ScanWatermarks.LoadWatermarks(ctx); err != nil && a.Logger != nil { + a.Logger.Warn("failed to reload scan watermarks after snowflake rotation", "error", err) + } } - } - if a.SnowflakeFindings != nil { - a.SnowflakeFindings.SetConnection(newClient.DB(), newClient.Database(), newClient.Schema()) - if err := a.SnowflakeFindings.Load(ctx); err != nil && a.Logger != nil { - a.Logger.Warn("failed to reload findings after snowflake rotation", "error", err) + if a.SnowflakeFindings != nil { + a.SnowflakeFindings.SetConnection(newClient.DB(), newClient.Database(), newClient.Schema()) + if err := a.SnowflakeFindings.Load(ctx); err != nil && a.Logger != nil { + a.Logger.Warn("failed to reload findings after snowflake rotation", "error", err) + } } - } - if a.Agents != nil { - store, err := agents.NewSnowflakeSessionStore(newClient) - if err != nil { - if a.Logger != nil { - a.Logger.Warn("failed to rotate agent session store", "error", err) + if a.Agents != nil { + if a.appStateDB != nil { + a.Agents.SetSessionStore(agents.NewPostgresSessionStore(a.appStateDB)) + } else { + store, err := agents.NewSnowflakeSessionStore(newClient) + if err != nil { + if a.Logger != nil { + a.Logger.Warn("failed to rotate agent session store", "error", err) + } + } else { + a.Agents.SetSessionStore(store) + } } - } else { - a.Agents.SetSessionStore(store) } - } - if a.graphCancel != nil { - a.graphCancel() - a.graphCancel = nil + if a.graphCancel != nil { + a.graphCancel() + a.graphCancel = nil + } + a.initSecurityGraph(ctx) + } else { + a.LegacySnowflake = newClient } - a.initSecurityGraph(ctx) - if oldClient != nil { + if oldClient != nil && oldClient != newClient { if err := oldClient.Close(); err != nil && a.Logger != nil { a.Logger.Warn("failed to close previous snowflake client after rotation", "error", err) } } + if oldLegacyClient != nil && oldLegacyClient != newClient && oldLegacyClient != oldClient { + if err := oldLegacyClient.Close(); err != nil && a.Logger != nil { + a.Logger.Warn("failed to close previous legacy snowflake client after rotation", "error", err) + } + } return nil } diff --git a/internal/app/app_security_services.go b/internal/app/app_security_services.go index 17232ded8..79e295b35 100644 --- a/internal/app/app_security_services.go +++ b/internal/app/app_security_services.go @@ -109,11 +109,14 @@ func (a *App) initHealth() { a.Health = health.NewRegistry() // Register health checks for all services - a.Health.Register("snowflake", health.PingCheck("snowflake", func(ctx context.Context) error { - if a.Snowflake == nil { + a.Health.Register("warehouse", health.PingCheck("warehouse", func(ctx context.Context) error { + if a.Warehouse == nil { return fmt.Errorf("not configured") } - return a.Snowflake.Ping(ctx) + if db := a.Warehouse.DB(); db != nil { + return db.PingContext(ctx) + } + return nil })) a.Health.Register("policy_engine", health.PingCheck("policy_engine", func(ctx context.Context) error { @@ -528,27 +531,8 @@ func (a *App) initSecurityGraph(ctx context.Context) { securityGraph := a.SecurityGraphBuilder.Graph() a.configureGraphRuntimeBehavior(securityGraph) a.publishSecurityGraphRuntimeView(securityGraph) - if a.GraphSnapshots != nil { - recovered, record, recoverySource, err := a.GraphSnapshots.LoadLatestSnapshot() - if err != nil { - a.Logger.Warn("failed to recover persisted security graph snapshot", "error", err) - } else if recovered != nil { - recoveredGraph := graph.RestoreFromSnapshot(recovered) - a.configureGraphRuntimeBehavior(recoveredGraph) - a.publishSecurityGraphRuntimeView(recoveredGraph) - if record != nil && record.BuiltAt != nil { - a.setGraphBuildState(GraphBuildSuccess, record.BuiltAt.UTC(), nil) - } - a.Logger.Info("recovered persisted security graph snapshot", - "source", recoverySource, - "snapshot_id", recordID(record), - "nodes", recoveredGraph.NodeCount(), - "edges", recoveredGraph.EdgeCount(), - ) - } - } if !a.graphWriterLeaseAllowsWrites() { - a.Logger.Info("security graph initialized from local snapshots while waiting for graph writer lease", + a.Logger.Info("security graph waiting for graph writer lease", "lease", a.Config.GraphWriterLeaseName, "holder", a.GraphWriterLeaseStatusSnapshot().LeaseHolderID, ) @@ -770,13 +754,6 @@ func (a *App) activateBuiltSecurityGraph(ctx context.Context, securityGraph *gra return meta, nil } -func recordID(record *graph.GraphSnapshotRecord) string { - if record == nil { - return "" - } - return record.ID -} - func (a *App) rematerializeEventCorrelations(securityGraph *graph.Graph, reason string) { if securityGraph == nil { return diff --git a/internal/app/app_security_services_test.go b/internal/app/app_security_services_test.go index ccd2a8c63..b6f0ba5b5 100644 --- a/internal/app/app_security_services_test.go +++ b/internal/app/app_security_services_test.go @@ -119,7 +119,7 @@ func TestGraphOntologySLOHealthCheckWithoutGraph(t *testing.T) { } } -func TestGraphOntologySLOHealthCheckUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestGraphOntologySLOHealthCheckUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { g := graph.New() now := time.Date(2026, 3, 9, 10, 0, 0, 0, time.UTC) g.AddNode(&graph.Node{ @@ -141,12 +141,12 @@ func TestGraphOntologySLOHealthCheckUsesPersistedSnapshotWhenLiveGraphUnavailabl GraphOntologySchemaValidWarnPct: 98, GraphOntologySchemaValidCriticalPct: 92, }, - GraphSnapshots: mustPersistToolGraph(t, g), } + setConfiguredSnapshotGraphFromGraph(t, application, g) result := application.graphOntologySLOHealthCheck()(context.Background()) if result.Status != health.StatusUnhealthy { - t.Fatalf("expected unhealthy status from persisted snapshot fallback activity, got %s (%s)", result.Status, result.Message) + t.Fatalf("expected unhealthy status from configured graph fallback activity, got %s (%s)", result.Status, result.Message) } if !strings.Contains(result.Message, "fallback_activity_percent") { t.Fatalf("expected fallback issue in message, got %q", result.Message) @@ -332,16 +332,15 @@ func TestGraphBuildSnapshotIncludesNodeCountWithoutHoldingBuildLock(t *testing.T } } -func TestGraphBuildSnapshotUsesPersistedSnapshotNodeCountWhenLiveGraphUnavailable(t *testing.T) { - persisted := graph.New() - persisted.AddNode(&graph.Node{ID: "service:payments", Kind: graph.NodeKindService, Name: "payments"}) - persisted.AddNode(&graph.Node{ID: "service:billing", Kind: graph.NodeKindService, Name: "billing"}) - store := mustPersistToolGraph(t, persisted) +func TestGraphBuildSnapshotUsesConfiguredStoreNodeCountWhenLiveGraphUnavailable(t *testing.T) { + configured := graph.New() + configured.AddNode(&graph.Node{ID: "service:payments", Kind: graph.NodeKindService, Name: "payments"}) + configured.AddNode(&graph.Node{ID: "service:billing", Kind: graph.NodeKindService, Name: "billing"}) application := &App{ - Config: &Config{}, - GraphSnapshots: store, + Config: &Config{}, } + setConfiguredSnapshotGraphFromGraph(t, application, configured) application.setGraphBuildState(GraphBuildSuccess, time.Now().UTC(), nil) snapshot := application.GraphBuildSnapshot() @@ -349,17 +348,14 @@ func TestGraphBuildSnapshotUsesPersistedSnapshotNodeCountWhenLiveGraphUnavailabl t.Fatalf("expected graph build state success, got %#v", snapshot) } if snapshot.NodeCount != 2 { - t.Fatalf("expected persisted graph node count 2, got %d", snapshot.NodeCount) - } - if status := store.Status(); status.LastRecoveredAt != nil { - t.Fatalf("expected build snapshot read to avoid recovery bookkeeping, got %#v", status) + t.Fatalf("expected configured graph node count 2, got %d", snapshot.NodeCount) } } -func TestGraphFreshnessStatusSnapshotUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestGraphFreshnessStatusSnapshotUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC) - persisted := graph.New() - persisted.AddNode(&graph.Node{ + configured := graph.New() + configured.AddNode(&graph.Node{ ID: "service:payments", Kind: graph.NodeKindService, Name: "payments", @@ -368,18 +364,16 @@ func TestGraphFreshnessStatusSnapshotUsesPersistedSnapshotWhenLiveGraphUnavailab "observed_at": now.Add(-12 * time.Hour).Format(time.RFC3339), }, }) - store := mustPersistToolGraph(t, persisted) - application := &App{ Config: &Config{ GraphFreshnessDefaultSLA: 6 * time.Hour, }, - GraphSnapshots: store, } + setConfiguredSnapshotGraphFromGraph(t, application, configured) status := application.GraphFreshnessStatusSnapshot(now) if status.Healthy { - t.Fatalf("expected persisted snapshot freshness breach, got %#v", status) + t.Fatalf("expected configured graph freshness breach, got %#v", status) } if len(status.Breaches) != 1 { t.Fatalf("expected one freshness breach, got %#v", status.Breaches) @@ -390,15 +384,12 @@ func TestGraphFreshnessStatusSnapshotUsesPersistedSnapshotWhenLiveGraphUnavailab if len(status.Breakdown.Providers) != 1 { t.Fatalf("expected one provider freshness scope, got %#v", status.Breakdown.Providers) } - if persistence := store.Status(); persistence.LastRecoveredAt != nil { - t.Fatalf("expected freshness status read to avoid recovery bookkeeping, got %#v", persistence) - } } -func TestInitHealthRegistersGraphFreshnessCheckUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestInitHealthRegistersGraphFreshnessCheckUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC) - persisted := graph.New() - persisted.AddNode(&graph.Node{ + configured := graph.New() + configured.AddNode(&graph.Node{ ID: "service:payments", Kind: graph.NodeKindService, Name: "payments", @@ -407,16 +398,14 @@ func TestInitHealthRegistersGraphFreshnessCheckUsesPersistedSnapshotWhenLiveGrap "observed_at": now.Add(-12 * time.Hour).Format(time.RFC3339), }, }) - store := mustPersistToolGraph(t, persisted) - application := &App{ Config: &Config{ GraphFreshnessDefaultSLA: 6 * time.Hour, }, - Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), - Warehouse: &warehouse.MemoryWarehouse{}, - GraphSnapshots: store, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + Warehouse: &warehouse.MemoryWarehouse{}, } + setConfiguredSnapshotGraphFromGraph(t, application, configured) application.initHealth() results := application.Health.RunAll(context.Background()) @@ -425,13 +414,10 @@ func TestInitHealthRegistersGraphFreshnessCheckUsesPersistedSnapshotWhenLiveGrap t.Fatal("expected graph_freshness health check to be registered") } if check.Status != health.StatusUnhealthy { - t.Fatalf("expected unhealthy persisted graph freshness check, got %s (%s)", check.Status, check.Message) + t.Fatalf("expected unhealthy configured graph freshness check, got %s (%s)", check.Status, check.Message) } if !strings.Contains(check.Message, "aws") { - t.Fatalf("expected persisted provider breach in message, got %q", check.Message) - } - if persistence := store.Status(); persistence.LastRecoveredAt != nil { - t.Fatalf("expected graph freshness health check to avoid recovery bookkeeping, got %#v", persistence) + t.Fatalf("expected configured provider breach in message, got %q", check.Message) } } @@ -486,24 +472,22 @@ func TestMutateSecurityGraphSwapsCloneAfterMutationCompletes(t *testing.T) { } } -func TestMutateSecurityGraphUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestMutateSecurityGraphUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { base := graph.New() base.AddNode(&graph.Node{ID: "service:payments", Kind: graph.NodeKindService, Name: "payments"}) base.AddNode(&graph.Node{ID: "bucket:prod", Kind: graph.NodeKindBucket, Name: "prod"}) base.AddEdge(&graph.Edge{ID: "payments-prod", Source: "service:payments", Target: "bucket:prod", Kind: graph.EdgeKindOwns, Effect: graph.EdgeEffectAllow}) base.BuildIndex() - application := &App{ - Config: &Config{}, - GraphSnapshots: mustPersistToolGraph(t, base), - } + application := &App{Config: &Config{}} + setConfiguredGraphFromGraph(t, application, base) mutated, err := application.MutateSecurityGraph(context.Background(), func(candidate *graph.Graph) error { if _, ok := candidate.GetNode("service:payments"); !ok { - return fmt.Errorf("persisted base node missing") + return fmt.Errorf("configured base node missing") } if _, ok := candidate.GetNode("bucket:prod"); !ok { - return fmt.Errorf("persisted base resource missing") + return fmt.Errorf("configured base resource missing") } candidate.AddNode(&graph.Node{ID: "service:billing", Kind: graph.NodeKindService, Name: "billing"}) return nil @@ -512,13 +496,13 @@ func TestMutateSecurityGraphUsesPersistedSnapshotWhenLiveGraphUnavailable(t *tes t.Fatalf("MutateSecurityGraph() error = %v", err) } if _, ok := mutated.GetNode("service:payments"); !ok { - t.Fatal("expected persisted base node to be preserved") + t.Fatal("expected configured base node to be preserved") } if _, ok := mutated.GetNode("bucket:prod"); !ok { - t.Fatal("expected persisted base resource to be preserved") + t.Fatal("expected configured base resource to be preserved") } if _, ok := mutated.GetNode("service:billing"); !ok { - t.Fatal("expected new node to be added on top of persisted base") + t.Fatal("expected new node to be added on top of configured base") } if got := application.CurrentSecurityGraph(); got != mutated { t.Fatal("expected mutated graph to become the live graph") diff --git a/internal/app/app_service_groups.go b/internal/app/app_service_groups.go index b94660290..735cc42b1 100644 --- a/internal/app/app_service_groups.go +++ b/internal/app/app_service_groups.go @@ -74,13 +74,11 @@ type SecurityServices struct { // StorageServices groups data repositories and durable stores. type StorageServices struct { - FindingsRepo *snowflake.FindingRepository - TicketsRepo *snowflake.TicketRepository - AuditRepo *snowflake.AuditRepository - PolicyHistoryRepo *snowflake.PolicyHistoryRepository - RiskEngineStateRepo *snowflake.RiskEngineStateRepository + Findings findings.FindingStore + AuditRepo auditRepository + PolicyHistoryRepo policyHistoryRepository + RiskEngineStateRepo riskEngineStateRepository RetentionRepo retentionCleaner - SnowflakeFindings *findings.SnowflakeStore } func (a *App) CoreServices() CoreServices { @@ -132,12 +130,10 @@ func (a *App) SecurityServices() SecurityServices { func (a *App) StorageServices() StorageServices { return StorageServices{ - FindingsRepo: a.FindingsRepo, - TicketsRepo: a.TicketsRepo, + Findings: a.Findings, AuditRepo: a.AuditRepo, PolicyHistoryRepo: a.PolicyHistoryRepo, RiskEngineStateRepo: a.RiskEngineStateRepo, RetentionRepo: a.RetentionRepo, - SnowflakeFindings: a.SnowflakeFindings, } } diff --git a/internal/app/app_service_groups_test.go b/internal/app/app_service_groups_test.go index 0274fa109..ab4c94e87 100644 --- a/internal/app/app_service_groups_test.go +++ b/internal/app/app_service_groups_test.go @@ -38,9 +38,8 @@ func TestAppServiceGroupAccessors(t *testing.T) { schedulerSvc := &scheduler.Scheduler{} rbac := auth.NewRBAC() securityGraph := graph.New() - findingsRepo := &snowflake.FindingRepository{} + auditRepo := &snowflake.AuditRepository{} riskEngineStateRepo := &snowflake.RiskEngineStateRepository{} - snowflakeStore := &findings.SnowflakeStore{} retention := noopRetentionCleaner{} application := &App{ @@ -51,9 +50,8 @@ func TestAppServiceGroupAccessors(t *testing.T) { Scheduler: schedulerSvc, RBAC: rbac, SecurityGraph: securityGraph, - FindingsRepo: findingsRepo, + AuditRepo: auditRepo, RiskEngineStateRepo: riskEngineStateRepo, - SnowflakeFindings: snowflakeStore, RetentionRepo: retention, } @@ -85,15 +83,15 @@ func TestAppServiceGroupAccessors(t *testing.T) { } storage := application.StorageServices() - if storage.FindingsRepo != findingsRepo { - t.Fatal("storage services should expose findings repository") + if storage.Findings != store { + t.Fatal("storage services should expose findings store") + } + if storage.AuditRepo != auditRepo { + t.Fatal("storage services should expose audit repository") } if storage.RiskEngineStateRepo != riskEngineStateRepo { t.Fatal("storage services should expose risk engine state repository") } - if storage.SnowflakeFindings != snowflakeStore { - t.Fatal("storage services should expose Snowflake findings store") - } if storage.RetentionRepo != retention { t.Fatal("storage services should expose retention cleaner") } diff --git a/internal/app/app_startup_test.go b/internal/app/app_startup_test.go index 9e4542079..e968bc326 100644 --- a/internal/app/app_startup_test.go +++ b/internal/app/app_startup_test.go @@ -142,21 +142,21 @@ func TestWaitForGraph_ContextCanceled(t *testing.T) { } } -func TestWaitForReadableSecurityGraphUsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { - persisted := graph.New() - persisted.AddNode(&graph.Node{ID: "service:payments", Kind: graph.NodeKindService, Name: "payments"}) +func TestWaitForReadableSecurityGraphUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { + configured := graph.New() + configured.AddNode(&graph.Node{ID: "service:payments", Kind: graph.NodeKindService, Name: "payments"}) a := &App{ - GraphSnapshots: mustPersistToolGraph(t, persisted), - Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } + setConfiguredSnapshotGraphFromGraph(t, a, configured) resolved := a.WaitForReadableSecurityGraph(context.Background()) if resolved == nil { - t.Fatal("expected persisted snapshot graph") + t.Fatal("expected configured graph") } if _, ok := resolved.GetNode("service:payments"); !ok { - t.Fatal("expected persisted snapshot node in readable graph") + t.Fatal("expected configured graph node in readable graph") } } diff --git a/internal/app/app_state_postgres.go b/internal/app/app_state_postgres.go new file mode 100644 index 000000000..76f32be61 --- /dev/null +++ b/internal/app/app_state_postgres.go @@ -0,0 +1,210 @@ +package app + +import ( + "context" + "database/sql" + "fmt" + "strings" + + _ "github.com/jackc/pgx/v5/stdlib" + + "github.com/writer/cerebro/internal/agents" + "github.com/writer/cerebro/internal/appstate" + "github.com/writer/cerebro/internal/findings" + "github.com/writer/cerebro/internal/snowflake" +) + +const appStateRiskEngineGraphID = "security-graph" + +func (a *App) appStateMigrationSnowflake() *snowflake.Client { + if a == nil { + return nil + } + if a.LegacySnowflake != nil { + return a.LegacySnowflake + } + return a.Snowflake +} + +func (a *App) appStateDatabaseURL() string { + if a == nil || a.Config == nil { + return "" + } + if dsn := strings.TrimSpace(a.Config.JobDatabaseURL); dsn != "" { + return dsn + } + if strings.EqualFold(strings.TrimSpace(a.Config.WarehouseBackend), "postgres") { + return strings.TrimSpace(a.Config.WarehousePostgresDSN) + } + return "" +} + +func (a *App) initAppStateDB(ctx context.Context) error { + dsn := a.appStateDatabaseURL() + if dsn == "" { + return nil + } + + db, err := sql.Open("pgx", dsn) + if err != nil { + return fmt.Errorf("open app-state database: %w", err) + } + db.SetMaxOpenConns(4) + db.SetMaxIdleConns(4) + if err := db.PingContext(ctx); err != nil { + _ = db.Close() + return fmt.Errorf("ping app-state database: %w", err) + } + + ensure := []func(context.Context) error{ + findings.NewPostgresStore(db).EnsureSchema, + agents.NewPostgresSessionStore(db).EnsureSchema, + appstate.NewAuditRepository(db).EnsureSchema, + appstate.NewPolicyHistoryRepository(db).EnsureSchema, + appstate.NewRiskEngineStateRepository(db).EnsureSchema, + } + for _, ensureFn := range ensure { + if err := ensureFn(ctx); err != nil { + _ = db.Close() + return err + } + } + + a.appStateDB = db + return nil +} + +func (a *App) migrateAppState(ctx context.Context) error { + if a == nil || a.appStateDB == nil || a.appStateMigrationSnowflake() == nil { + return nil + } + if err := a.migrateFindings(ctx); err != nil { + return err + } + if err := a.migrateAgentSessions(ctx); err != nil { + return err + } + if err := a.migrateAuditLogs(ctx); err != nil { + return err + } + if err := a.migratePolicyHistory(ctx); err != nil { + return err + } + if err := a.migrateRiskEngineState(ctx); err != nil { + return err + } + return nil +} + +func (a *App) migrateFindings(ctx context.Context) error { + source := a.appStateMigrationSnowflake() + store, ok := a.Findings.(*findings.PostgresStore) + if !ok || source == nil { + return nil + } + records, err := snowflake.NewFindingRepository(source).ListAll(ctx) + if err != nil { + if isMissingSnowflakeTableErr(err) { + return nil + } + return fmt.Errorf("migrate findings from snowflake: %w", err) + } + return store.ImportRecords(ctx, records) +} + +func (a *App) migrateAgentSessions(ctx context.Context) error { + sourceClient := a.appStateMigrationSnowflake() + if a.appStateDB == nil || sourceClient == nil { + return nil + } + source, err := agents.NewSnowflakeSessionStore(sourceClient) + if err != nil { + return fmt.Errorf("initialize snowflake session store: %w", err) + } + sessions, err := source.ListAll(ctx) + if err != nil { + return fmt.Errorf("list snowflake agent sessions: %w", err) + } + destination := agents.NewPostgresSessionStore(a.appStateDB) + for _, session := range sessions { + if err := destination.Save(ctx, session); err != nil { + return fmt.Errorf("persist postgres agent session %s: %w", session.ID, err) + } + } + return nil +} + +func (a *App) migrateAuditLogs(ctx context.Context) error { + source := a.appStateMigrationSnowflake() + if a.AuditRepo == nil || source == nil { + return nil + } + entries, err := snowflake.NewAuditRepository(source).ListAll(ctx) + if err != nil { + if isMissingSnowflakeTableErr(err) { + return nil + } + return fmt.Errorf("list snowflake audit logs: %w", err) + } + for _, entry := range entries { + if err := a.AuditRepo.Log(ctx, entry); err != nil { + return fmt.Errorf("persist audit log %s: %w", entry.ID, err) + } + } + return nil +} + +func (a *App) migratePolicyHistory(ctx context.Context) error { + source := a.appStateMigrationSnowflake() + if a.PolicyHistoryRepo == nil || source == nil { + return nil + } + records, err := snowflake.NewPolicyHistoryRepository(source).ListAll(ctx) + if err != nil { + if isMissingSnowflakeTableErr(err) { + return nil + } + return fmt.Errorf("list snowflake policy history: %w", err) + } + for _, record := range records { + if err := a.PolicyHistoryRepo.Upsert(ctx, record); err != nil { + return fmt.Errorf("persist policy history %s@%d: %w", record.PolicyID, record.Version, err) + } + } + return nil +} + +func (a *App) migrateRiskEngineState(ctx context.Context) error { + source := a.appStateMigrationSnowflake() + if a.RiskEngineStateRepo == nil || source == nil { + return nil + } + existing, err := a.RiskEngineStateRepo.LoadSnapshot(ctx, appStateRiskEngineGraphID) + if err != nil { + return fmt.Errorf("load postgres risk engine state: %w", err) + } + if len(existing) > 0 { + return nil + } + payload, err := snowflake.NewRiskEngineStateRepository(source).LoadSnapshot(ctx, appStateRiskEngineGraphID) + if err != nil { + return fmt.Errorf("load snowflake risk engine state: %w", err) + } + if len(payload) == 0 { + return nil + } + if err := a.RiskEngineStateRepo.SaveSnapshot(ctx, appStateRiskEngineGraphID, payload); err != nil { + return fmt.Errorf("persist postgres risk engine state: %w", err) + } + return nil +} + +func isMissingSnowflakeTableErr(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(err.Error()) + return strings.Contains(message, "does not exist") || + strings.Contains(message, "unknown table") || + strings.Contains(message, "not exist") +} diff --git a/internal/app/app_storage_graph.go b/internal/app/app_storage_graph.go index b66a1aaaf..e1aa04cf0 100644 --- a/internal/app/app_storage_graph.go +++ b/internal/app/app_storage_graph.go @@ -6,24 +6,28 @@ import ( "sort" "time" + "github.com/writer/cerebro/internal/appstate" "github.com/writer/cerebro/internal/policy" "github.com/writer/cerebro/internal/scanner" "github.com/writer/cerebro/internal/snowflake" ) func (a *App) initRepositories() { - a.FindingsRepo = nil - a.TicketsRepo = nil a.AuditRepo = nil a.PolicyHistoryRepo = nil a.RiskEngineStateRepo = nil a.RetentionRepo = nil + if a.appStateDB != nil { + a.AuditRepo = appstate.NewAuditRepository(a.appStateDB) + a.PolicyHistoryRepo = appstate.NewPolicyHistoryRepository(a.appStateDB) + a.RiskEngineStateRepo = appstate.NewRiskEngineStateRepository(a.appStateDB) + a.RetentionRepo = appstate.NewRetentionRepository(a.appStateDB) + return + } if a.Snowflake == nil { return } - a.FindingsRepo = snowflake.NewFindingRepository(a.Snowflake) - a.TicketsRepo = snowflake.NewTicketRepository(a.Snowflake) a.AuditRepo = snowflake.NewAuditRepository(a.Snowflake) a.PolicyHistoryRepo = snowflake.NewPolicyHistoryRepository(a.Snowflake) a.RiskEngineStateRepo = snowflake.NewRiskEngineStateRepository(a.Snowflake) @@ -275,6 +279,11 @@ func (a *App) Close() error { errs = append(errs, fmt.Errorf("snowflake: %w", err)) } } + if a.LegacySnowflake != nil && a.LegacySnowflake != a.Snowflake { + if err := a.LegacySnowflake.Close(); err != nil { + errs = append(errs, fmt.Errorf("legacy snowflake: %w", err)) + } + } if closer, ok := a.Warehouse.(interface{ Close() error }); ok { if a.Snowflake == nil || any(a.Warehouse) != any(a.Snowflake) { if err := closer.Close(); err != nil { @@ -315,6 +324,11 @@ func (a *App) Close() error { errs = append(errs, fmt.Errorf("findings store: %w", err)) } } + if a.appStateDB != nil { + if err := a.appStateDB.Close(); err != nil { + errs = append(errs, fmt.Errorf("app-state database: %w", err)) + } + } if a.Webhooks != nil { if err := a.Webhooks.Close(); err != nil { diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 55e98b466..10c928c7e 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -761,6 +761,20 @@ func TestLoadConfig_DefaultsToSnowflakeBackendWhenAuthPresent(t *testing.T) { } } +func TestLoadConfig_DoesNotDefaultWarehouseToPostgresWhenOnlyJobDatabaseConfigured(t *testing.T) { + t.Setenv("WAREHOUSE_BACKEND", "") + t.Setenv("WAREHOUSE_POSTGRES_DSN", "") + t.Setenv("JOB_DATABASE_URL", "postgres://jobs") + t.Setenv("SNOWFLAKE_ACCOUNT", "") + t.Setenv("SNOWFLAKE_USER", "") + t.Setenv("SNOWFLAKE_PRIVATE_KEY", "") + + cfg := LoadConfig() + if cfg.WarehouseBackend != "sqlite" { + t.Fatalf("expected warehouse backend to remain sqlite when only JOB_DATABASE_URL is set, got %q", cfg.WarehouseBackend) + } +} + func TestNew_APIAuthEnabledWithoutKeys(t *testing.T) { t.Setenv("API_AUTH_ENABLED", "true") t.Setenv("API_KEYS", "") diff --git a/internal/app/health_test.go b/internal/app/health_test.go index 073a8af60..e24c9172f 100644 --- a/internal/app/health_test.go +++ b/internal/app/health_test.go @@ -78,32 +78,32 @@ func TestGraphHealthSnapshotAggregatesLiveGraphPersistenceAndTiers(t *testing.T) } } -func TestGraphHealthSnapshotFallsBackToPersistedSnapshot(t *testing.T) { +func TestGraphHealthSnapshotUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { now := time.Date(2026, time.March, 20, 18, 0, 0, 0, time.UTC) - persisted := graph.New() - persisted.AddNode(&graph.Node{ID: "service:persisted", Kind: graph.NodeKindService}) - persisted.SetMetadata(graph.Metadata{ + configured := graph.New() + configured.AddNode(&graph.Node{ID: "service:configured", Kind: graph.NodeKindService}) + configured.SetMetadata(graph.Metadata{ BuiltAt: now.Add(-time.Hour), - NodeCount: persisted.NodeCount(), - EdgeCount: persisted.EdgeCount(), + NodeCount: configured.NodeCount(), + EdgeCount: configured.EdgeCount(), }) - store := mustPersistToolGraph(t, persisted) - application := &App{GraphSnapshots: store} + application := &App{} + setConfiguredGraphFromGraph(t, application, configured) snapshot := application.GraphHealthSnapshot(now) if snapshot.NodeCount != 1 || snapshot.EdgeCount != 0 { t.Fatalf("snapshot counts = (%d,%d), want (1,0)", snapshot.NodeCount, snapshot.EdgeCount) } - if snapshot.SnapshotCount != 1 { - t.Fatalf("SnapshotCount = %d, want 1", snapshot.SnapshotCount) + if snapshot.SnapshotCount != 0 { + t.Fatalf("SnapshotCount = %d, want 0", snapshot.SnapshotCount) } - if snapshot.TierDistribution.Hot != 0 || snapshot.TierDistribution.Warm != 0 || snapshot.TierDistribution.Cold != 1 { - t.Fatalf("TierDistribution = %+v, want hot=0 warm=0 cold=1", snapshot.TierDistribution) + if snapshot.TierDistribution.Hot != 0 || snapshot.TierDistribution.Warm != 0 || snapshot.TierDistribution.Cold != 0 { + t.Fatalf("TierDistribution = %+v, want hot=0 warm=0 cold=0", snapshot.TierDistribution) } if snapshot.LastMutationAt.IsZero() { - t.Fatal("expected last mutation timestamp from persisted snapshot") + t.Fatal("expected last mutation timestamp from configured store") } } diff --git a/internal/app/org_topology_policy_scan_test.go b/internal/app/org_topology_policy_scan_test.go index 2f99e8965..a817a9e5c 100644 --- a/internal/app/org_topology_policy_scan_test.go +++ b/internal/app/org_topology_policy_scan_test.go @@ -128,7 +128,7 @@ func TestScanOrgTopologyPolicies_EmptyWhenGraphUnavailable(t *testing.T) { } } -func TestScanOrgTopologyPolicies_UsesPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestScanOrgTopologyPolicies_UsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { engine := policy.NewEngine() addOrgTestPolicy(t, engine, &policy.Policy{ ID: "org-bus-factor-critical", @@ -142,31 +142,18 @@ func TestScanOrgTopologyPolicies_UsesPersistedSnapshotWhenLiveGraphUnavailable(t }, }) - store, err := graph.NewGraphPersistenceStore(graph.GraphPersistenceOptions{ - LocalPath: filepath.Join(t.TempDir(), "graph-snapshots"), - MaxSnapshots: 4, - }) - if err != nil { - t.Fatalf("NewGraphPersistenceStore() error = %v", err) - } - if _, err := store.SaveGraph(orgTopologyTestGraph(time.Now().UTC())); err != nil { - t.Fatalf("SaveGraph() error = %v", err) - } - - app := &App{ - Policy: engine, - GraphSnapshots: store, - } + app := &App{Policy: engine} + setConfiguredSnapshotGraphFromGraph(t, app, orgTopologyTestGraph(time.Now().UTC())) result := app.ScanOrgTopologyPolicies(context.Background()) if len(result.Errors) != 0 { t.Fatalf("expected no scan errors, got %v", result.Errors) } if result.Assets == 0 { - t.Fatal("expected synthesized org-topology assets from persisted snapshot") + t.Fatal("expected synthesized org-topology assets from configured graph") } if !slices.ContainsFunc(result.Findings, func(f policy.Finding) bool { return f.PolicyID == "org-bus-factor-critical" }) { - t.Fatalf("expected persisted snapshot finding, got %v", result.Findings) + t.Fatalf("expected configured graph finding, got %v", result.Findings) } } diff --git a/internal/app/tenant_shard_manager_test.go b/internal/app/tenant_shard_manager_test.go index 6f98090b2..58f5f26d4 100644 --- a/internal/app/tenant_shard_manager_test.go +++ b/internal/app/tenant_shard_manager_test.go @@ -156,18 +156,9 @@ func TestEnsureTenantSecurityGraphShardsDoesNotWaitOnSecurityGraphLock(t *testin application.securityGraphInitMu.Unlock() } -func TestCurrentSecurityGraphForTenantHydratesFromPersistedSnapshotWhenLiveGraphUnavailable(t *testing.T) { +func TestCurrentSecurityGraphForTenantUsesConfiguredStoreWhenLiveGraphUnavailable(t *testing.T) { basePath := filepath.Join(t.TempDir(), "graph-snapshots") - store, err := graph.NewGraphPersistenceStore(graph.GraphPersistenceOptions{LocalPath: basePath, MaxSnapshots: 4}) - if err != nil { - t.Fatalf("NewGraphPersistenceStore() error = %v", err) - } - live := buildTenantShardTestGraph(time.Date(2026, time.March, 17, 22, 0, 0, 0, time.UTC)) - if _, err := store.SaveGraph(live); err != nil { - t.Fatalf("SaveGraph() error = %v", err) - } - application := &App{ Config: &Config{ GraphSnapshotPath: basePath, @@ -175,12 +166,12 @@ func TestCurrentSecurityGraphForTenantHydratesFromPersistedSnapshotWhenLiveGraph GraphTenantWarmShardTTL: time.Hour, GraphTenantWarmShardMaxRetained: 1, }, - GraphSnapshots: store, } + setConfiguredSnapshotGraphFromGraph(t, application, live) scoped := application.CurrentSecurityGraphForTenant("tenant-a") if scoped == nil { - t.Fatal("expected tenant shard recovered from persisted snapshot") + t.Fatal("expected tenant shard resolved from configured graph") } if _, ok := scoped.GetNode("service:tenant-a"); !ok { t.Fatal("expected tenant shard to include tenant-a node") diff --git a/internal/appstate/postgres.go b/internal/appstate/postgres.go new file mode 100644 index 000000000..249abb99d --- /dev/null +++ b/internal/appstate/postgres.go @@ -0,0 +1,483 @@ +package appstate + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + + "github.com/writer/cerebro/internal/snowflake" +) + +const ( + auditTable = "cerebro_audit_log" + policyHistoryTable = "cerebro_policy_history" + riskEngineStateTable = "cerebro_risk_engine_state" + sessionTable = "cerebro_agent_sessions" +) + +type AuditRepository struct { + db *sql.DB + rewriteSQL func(string) string +} + +func NewAuditRepository(db *sql.DB) *AuditRepository { + return &AuditRepository{db: db} +} + +func (r *AuditRepository) EnsureSchema(ctx context.Context) error { + if r == nil || r.db == nil { + return fmt.Errorf("audit repository is not initialized") + } + _, err := r.db.ExecContext(ctx, r.q(` +CREATE TABLE IF NOT EXISTS `+auditTable+` ( + id TEXT PRIMARY KEY, + created_at TIMESTAMP NOT NULL, + action TEXT NOT NULL, + actor_id TEXT, + actor_type TEXT, + resource_type TEXT, + resource_id TEXT, + details TEXT NOT NULL DEFAULT '{}', + ip_address TEXT, + user_agent TEXT +); +CREATE INDEX IF NOT EXISTS idx_`+auditTable+`_resource ON `+auditTable+` (resource_type, resource_id, created_at); +CREATE INDEX IF NOT EXISTS idx_`+auditTable+`_created_at ON `+auditTable+` (created_at); +`)) + return err +} + +func (r *AuditRepository) Log(ctx context.Context, entry *snowflake.AuditEntry) error { + if r == nil || r.db == nil { + return fmt.Errorf("audit repository is not initialized") + } + if entry == nil { + return fmt.Errorf("audit entry is required") + } + if err := r.EnsureSchema(ctx); err != nil { + return err + } + if strings.TrimSpace(entry.ID) == "" { + entry.ID = uuid.NewString() + } + + detailsJSON, err := marshalJSONText(entry.Details, "{}") + if err != nil { + return err + } + + createdAt := entry.Timestamp.UTC() + if createdAt.IsZero() { + createdAt = time.Now().UTC() + entry.Timestamp = createdAt + } + + _, err = r.db.ExecContext(ctx, r.q(` +INSERT INTO `+auditTable+` ( + id, created_at, action, actor_id, actor_type, resource_type, resource_id, details, ip_address, user_agent +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) +ON CONFLICT (id) DO UPDATE SET + created_at = EXCLUDED.created_at, + action = EXCLUDED.action, + actor_id = EXCLUDED.actor_id, + actor_type = EXCLUDED.actor_type, + resource_type = EXCLUDED.resource_type, + resource_id = EXCLUDED.resource_id, + details = EXCLUDED.details, + ip_address = EXCLUDED.ip_address, + user_agent = EXCLUDED.user_agent +`), + entry.ID, + createdAt, + entry.Action, + entry.ActorID, + entry.ActorType, + entry.ResourceType, + entry.ResourceID, + detailsJSON, + entry.IPAddress, + entry.UserAgent, + ) + return err +} + +func (r *AuditRepository) List(ctx context.Context, resourceType, resourceID string, limit int) ([]*snowflake.AuditEntry, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("audit repository is not initialized") + } + if err := r.EnsureSchema(ctx); err != nil { + return nil, err + } + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + query := ` +SELECT id, created_at, action, actor_id, actor_type, resource_type, resource_id, details, ip_address, user_agent +FROM ` + auditTable + ` +WHERE 1=1` + args := make([]any, 0, 3) + if trimmed := strings.TrimSpace(resourceType); trimmed != "" { + args = append(args, trimmed) + query += fmt.Sprintf(" AND resource_type = %s", placeholder(len(args))) + } + if trimmed := strings.TrimSpace(resourceID); trimmed != "" { + args = append(args, trimmed) + query += fmt.Sprintf(" AND resource_id = %s", placeholder(len(args))) + } + args = append(args, limit) + query += fmt.Sprintf(" ORDER BY created_at DESC LIMIT %s", placeholder(len(args))) + + rows, err := r.db.QueryContext(ctx, r.q(query), args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + entries := make([]*snowflake.AuditEntry, 0, limit) + for rows.Next() { + entry := &snowflake.AuditEntry{} + var detailsRaw string + if err := rows.Scan( + &entry.ID, + &entry.Timestamp, + &entry.Action, + &entry.ActorID, + &entry.ActorType, + &entry.ResourceType, + &entry.ResourceID, + &detailsRaw, + &entry.IPAddress, + &entry.UserAgent, + ); err != nil { + return nil, err + } + if strings.TrimSpace(detailsRaw) != "" { + if err := json.Unmarshal([]byte(detailsRaw), &entry.Details); err != nil { + return nil, err + } + } + entries = append(entries, entry) + } + return entries, rows.Err() +} + +func (r *AuditRepository) q(query string) string { + if r != nil && r.rewriteSQL != nil { + return r.rewriteSQL(query) + } + return query +} + +type PolicyHistoryRepository struct { + db *sql.DB + rewriteSQL func(string) string +} + +func NewPolicyHistoryRepository(db *sql.DB) *PolicyHistoryRepository { + return &PolicyHistoryRepository{db: db} +} + +func (r *PolicyHistoryRepository) EnsureSchema(ctx context.Context) error { + if r == nil || r.db == nil { + return fmt.Errorf("policy history repository is not initialized") + } + _, err := r.db.ExecContext(ctx, r.q(` +CREATE TABLE IF NOT EXISTS `+policyHistoryTable+` ( + policy_id TEXT NOT NULL, + version INTEGER NOT NULL, + content TEXT NOT NULL DEFAULT '{}', + change_type TEXT, + pinned_version INTEGER, + effective_from TIMESTAMP NOT NULL, + effective_to TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (policy_id, version) +); +CREATE INDEX IF NOT EXISTS idx_`+policyHistoryTable+`_policy ON `+policyHistoryTable+` (policy_id, version DESC); +`)) + return err +} + +func (r *PolicyHistoryRepository) Upsert(ctx context.Context, record *snowflake.PolicyHistoryRecord) error { + if r == nil || r.db == nil { + return fmt.Errorf("policy history repository is not initialized") + } + if record == nil { + return fmt.Errorf("policy history record is required") + } + if strings.TrimSpace(record.PolicyID) == "" { + return fmt.Errorf("policy id is required") + } + if record.Version <= 0 { + return fmt.Errorf("policy version must be positive") + } + if err := r.EnsureSchema(ctx); err != nil { + return err + } + + content := string(record.Content) + if strings.TrimSpace(content) == "" { + content = "{}" + } + effectiveFrom := record.EffectiveFrom.UTC() + if effectiveFrom.IsZero() { + effectiveFrom = time.Now().UTC() + } + + var pinnedVersion any + if record.PinnedVersion != nil { + pinnedVersion = *record.PinnedVersion + } + var effectiveTo any + if record.EffectiveTo != nil { + effectiveTo = record.EffectiveTo.UTC() + } + + _, err := r.db.ExecContext(ctx, r.q(` +INSERT INTO `+policyHistoryTable+` ( + policy_id, version, content, change_type, pinned_version, effective_from, effective_to +) VALUES ($1, $2, $3, $4, $5, $6, $7) +ON CONFLICT (policy_id, version) DO UPDATE SET + content = EXCLUDED.content, + change_type = EXCLUDED.change_type, + pinned_version = EXCLUDED.pinned_version, + effective_from = EXCLUDED.effective_from, + effective_to = EXCLUDED.effective_to +`), + record.PolicyID, + record.Version, + content, + record.ChangeType, + pinnedVersion, + effectiveFrom, + effectiveTo, + ) + return err +} + +func (r *PolicyHistoryRepository) List(ctx context.Context, policyID string, limit int) ([]*snowflake.PolicyHistoryRecord, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("policy history repository is not initialized") + } + policyID = strings.TrimSpace(policyID) + if policyID == "" { + return nil, fmt.Errorf("policy id is required") + } + if err := r.EnsureSchema(ctx); err != nil { + return nil, err + } + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + rows, err := r.db.QueryContext(ctx, r.q(` +SELECT policy_id, version, content, change_type, pinned_version, effective_from, effective_to +FROM `+policyHistoryTable+` +WHERE policy_id = $1 +ORDER BY version DESC +LIMIT $2 +`), policyID, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + records := make([]*snowflake.PolicyHistoryRecord, 0, limit) + for rows.Next() { + record := &snowflake.PolicyHistoryRecord{} + var content string + var changeType sql.NullString + var pinned sql.NullInt64 + var effectiveTo sql.NullTime + if err := rows.Scan( + &record.PolicyID, + &record.Version, + &content, + &changeType, + &pinned, + &record.EffectiveFrom, + &effectiveTo, + ); err != nil { + return nil, err + } + record.Content = json.RawMessage(content) + if changeType.Valid { + record.ChangeType = changeType.String + } + if pinned.Valid { + pinnedValue := int(pinned.Int64) + record.PinnedVersion = &pinnedValue + } + if effectiveTo.Valid { + ts := effectiveTo.Time.UTC() + record.EffectiveTo = &ts + } + records = append(records, record) + } + return records, rows.Err() +} + +func (r *PolicyHistoryRepository) q(query string) string { + if r != nil && r.rewriteSQL != nil { + return r.rewriteSQL(query) + } + return query +} + +type RiskEngineStateRepository struct { + db *sql.DB + rewriteSQL func(string) string +} + +func NewRiskEngineStateRepository(db *sql.DB) *RiskEngineStateRepository { + return &RiskEngineStateRepository{db: db} +} + +func (r *RiskEngineStateRepository) EnsureSchema(ctx context.Context) error { + if r == nil || r.db == nil { + return fmt.Errorf("risk engine state repository is not initialized") + } + _, err := r.db.ExecContext(ctx, r.q(` +CREATE TABLE IF NOT EXISTS `+riskEngineStateTable+` ( + graph_id TEXT PRIMARY KEY, + snapshot TEXT NOT NULL DEFAULT '{}', + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +`)) + return err +} + +func (r *RiskEngineStateRepository) SaveSnapshot(ctx context.Context, graphID string, snapshot []byte) error { + if r == nil || r.db == nil { + return fmt.Errorf("risk engine state repository is not initialized") + } + graphID = strings.TrimSpace(graphID) + if graphID == "" { + return fmt.Errorf("graph id is required") + } + if len(snapshot) == 0 { + snapshot = []byte("{}") + } + if err := r.EnsureSchema(ctx); err != nil { + return err + } + + _, err := r.db.ExecContext(ctx, r.q(` +INSERT INTO `+riskEngineStateTable+` (graph_id, snapshot, updated_at) +VALUES ($1, $2, $3) +ON CONFLICT (graph_id) DO UPDATE SET + snapshot = EXCLUDED.snapshot, + updated_at = EXCLUDED.updated_at +`), graphID, string(snapshot), time.Now().UTC()) + return err +} + +func (r *RiskEngineStateRepository) LoadSnapshot(ctx context.Context, graphID string) ([]byte, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("risk engine state repository is not initialized") + } + graphID = strings.TrimSpace(graphID) + if graphID == "" { + return nil, fmt.Errorf("graph id is required") + } + if err := r.EnsureSchema(ctx); err != nil { + return nil, err + } + + var payload string + err := r.db.QueryRowContext(ctx, r.q(` +SELECT snapshot FROM `+riskEngineStateTable+` +WHERE graph_id = $1 +`), graphID).Scan(&payload) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + if strings.TrimSpace(payload) == "" || strings.TrimSpace(payload) == "null" { + return nil, nil + } + return []byte(payload), nil +} + +func (r *RiskEngineStateRepository) q(query string) string { + if r != nil && r.rewriteSQL != nil { + return r.rewriteSQL(query) + } + return query +} + +type RetentionRepository struct { + db *sql.DB +} + +func NewRetentionRepository(db *sql.DB) *RetentionRepository { + return &RetentionRepository{db: db} +} + +func (r *RetentionRepository) CleanupAuditLogs(ctx context.Context, olderThan time.Time) (int64, error) { + return r.deleteBefore(ctx, auditTable, "created_at", olderThan) +} + +func (r *RetentionRepository) CleanupAgentData(ctx context.Context, olderThan time.Time) (sessionsDeleted, messagesDeleted int64, err error) { + sessionsDeleted, err = r.deleteBefore(ctx, sessionTable, "updated_at", olderThan) + if err != nil { + return 0, 0, err + } + return sessionsDeleted, 0, nil +} + +func (r *RetentionRepository) CleanupGraphData(context.Context, time.Time) (int64, int64, int64, error) { + return 0, 0, 0, nil +} + +func (r *RetentionRepository) CleanupAccessReviewData(context.Context, time.Time) (int64, int64, error) { + return 0, 0, nil +} + +func (r *RetentionRepository) deleteBefore(ctx context.Context, table, timeColumn string, olderThan time.Time) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("retention repository is not initialized") + } + if olderThan.IsZero() { + return 0, fmt.Errorf("retention cutoff is required") + } + result, err := r.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE %s < $1`, table, timeColumn), olderThan.UTC()) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, nil + } + return rowsAffected, nil +} + +func marshalJSONText(value any, fallback string) (string, error) { + encoded, err := json.Marshal(value) + if err != nil { + return "", err + } + if trimmed := strings.TrimSpace(string(encoded)); trimmed != "" && trimmed != "null" { + return trimmed, nil + } + return fallback, nil +} + +func placeholder(idx int) string { + return fmt.Sprintf("$%d", idx) +} diff --git a/internal/appstate/postgres_test.go b/internal/appstate/postgres_test.go new file mode 100644 index 000000000..5944527b5 --- /dev/null +++ b/internal/appstate/postgres_test.go @@ -0,0 +1,135 @@ +package appstate + +import ( + "context" + "database/sql" + "encoding/json" + "regexp" + "testing" + "time" + + _ "modernc.org/sqlite" + + "github.com/writer/cerebro/internal/snowflake" +) + +var appStateDollarPlaceholderRe = regexp.MustCompile(`\$\d+`) + +func appStateSQLiteRewrite(query string) string { + return appStateDollarPlaceholderRe.ReplaceAllString(query, "?") +} + +func newTestAuditRepository(t *testing.T) *AuditRepository { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + repo := NewAuditRepository(db) + repo.rewriteSQL = appStateSQLiteRewrite + if err := repo.EnsureSchema(context.Background()); err != nil { + _ = db.Close() + t.Fatalf("EnsureSchema() error = %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return repo +} + +func TestAuditRepositoryLogAndList(t *testing.T) { + repo := newTestAuditRepository(t) + entry := &snowflake.AuditEntry{ + ID: "audit-1", + Timestamp: time.Now().UTC().Truncate(time.Second), + Action: "policy.evaluate", + ActorID: "user-1", + ActorType: "user", + ResourceType: "policy", + ResourceID: "policy-1", + Details: map[string]interface{}{"decision": "allow"}, + IPAddress: "127.0.0.1", + UserAgent: "test", + } + if err := repo.Log(context.Background(), entry); err != nil { + t.Fatalf("Log() error = %v", err) + } + + entries, err := repo.List(context.Background(), "policy", "policy-1", 10) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(entries) != 1 { + t.Fatalf("len(entries) = %d, want 1", len(entries)) + } + if entries[0].Action != entry.Action { + t.Fatalf("Action = %q, want %q", entries[0].Action, entry.Action) + } + if got := entries[0].Details["decision"]; got != "allow" { + t.Fatalf("decision detail = %#v, want allow", got) + } +} + +func TestPolicyHistoryRepositoryUpsertAndList(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + repo := NewPolicyHistoryRepository(db) + repo.rewriteSQL = appStateSQLiteRewrite + if err := repo.EnsureSchema(context.Background()); err != nil { + _ = db.Close() + t.Fatalf("EnsureSchema() error = %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + content, err := json.Marshal(map[string]any{"id": "policy-1"}) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + record := &snowflake.PolicyHistoryRecord{ + PolicyID: "policy-1", + Version: 2, + Content: content, + ChangeType: "updated", + EffectiveFrom: time.Now().UTC(), + } + if err := repo.Upsert(context.Background(), record); err != nil { + t.Fatalf("Upsert() error = %v", err) + } + + records, err := repo.List(context.Background(), "policy-1", 10) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("len(records) = %d, want 1", len(records)) + } + if records[0].Version != 2 { + t.Fatalf("Version = %d, want 2", records[0].Version) + } +} + +func TestRiskEngineStateRepositorySaveAndLoad(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + repo := NewRiskEngineStateRepository(db) + repo.rewriteSQL = appStateSQLiteRewrite + if err := repo.EnsureSchema(context.Background()); err != nil { + _ = db.Close() + t.Fatalf("EnsureSchema() error = %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + payload := []byte(`{"score":42}`) + if err := repo.SaveSnapshot(context.Background(), "security-graph", payload); err != nil { + t.Fatalf("SaveSnapshot() error = %v", err) + } + got, err := repo.LoadSnapshot(context.Background(), "security-graph") + if err != nil { + t.Fatalf("LoadSnapshot() error = %v", err) + } + if string(got) != string(payload) { + t.Fatalf("payload = %s, want %s", got, payload) + } +} diff --git a/internal/cli/aws_org.go b/internal/cli/aws_org.go index 44f53ce40..8d1b21038 100644 --- a/internal/cli/aws_org.go +++ b/internal/cli/aws_org.go @@ -250,15 +250,15 @@ func runAWSOrgSyncDirect(ctx context.Context, start time.Time) error { accountCfg = assumedCfg } - sfClient, err := createSnowflakeClient() + sfClient, err := createSyncWarehouse() if err != nil { mu.Lock() - errs = append(errs, fmt.Errorf("account %s: create snowflake client: %w", account.ID, err)) + errs = append(errs, fmt.Errorf("account %s: create warehouse client: %w", account.ID, err)) mu.Unlock() - Warning("Failed to create Snowflake client for account %s: %v", account.ID, err) + Warning("Failed to create warehouse client for account %s: %v", account.ID, err) return nil } - defer func() { _ = sfClient.Close() }() + defer closeSyncWarehouse(sfClient) syncer := nativesync.NewSyncEngine(sfClient, slog.Default(), opts...) accountResults, syncErr := syncer.SyncAllWithConfig(ctx, accountCfg) @@ -307,11 +307,11 @@ func runAWSOrgValidation(ctx context.Context, start time.Time, cfg aws.Config, r Info("Filtering AWS tables: %s", strings.Join(tableFilter, ", ")) } - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) opts := buildAWSEngineOptions(region, tableFilter) syncer := nativesync.NewSyncEngine(client, slog.Default(), opts...) diff --git a/internal/cli/serve.go b/internal/cli/serve.go index a6bd908a7..26db98490 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -99,12 +99,11 @@ func runServe(cmd *cobra.Command, args []string) error { return nil }, func() error { - // Sync any dirty findings to Snowflake before shutdown - if application.SnowflakeFindings != nil { - application.Logger.Info("syncing findings to snowflake before shutdown") + if syncer, ok := application.Findings.(interface{ Sync(context.Context) error }); ok { + application.Logger.Info("syncing findings before shutdown") syncCtx, syncCancel := context.WithTimeout(context.Background(), 10*time.Second) defer syncCancel() - return application.SnowflakeFindings.Sync(syncCtx) + return syncer.Sync(syncCtx) } return nil }, diff --git a/internal/cli/status.go b/internal/cli/status.go index 344f7d723..344b07cea 100644 --- a/internal/cli/status.go +++ b/internal/cli/status.go @@ -17,7 +17,7 @@ var statusCmd = &cobra.Command{ Use: "status", Short: "Show system status and health", Long: `Display the current status of Cerebro, including: -- Snowflake connection status +- Warehouse connection status - Loaded policies count - Findings summary - Registered agents and providers`, @@ -78,20 +78,20 @@ func runStatusDirect(cmd *cobra.Command, args []string) error { "timestamp": time.Now().UTC(), } - // Snowflake status - sfStatus := map[string]interface{}{"configured": false} - if application.Snowflake != nil { - sfStatus["configured"] = true + // Warehouse status + warehouseStatus := map[string]interface{}{"configured": false} + if application.Warehouse != nil && application.Warehouse.DB() != nil { + warehouseStatus["configured"] = true start := time.Now() - if err := application.Snowflake.Ping(ctx); err != nil { - sfStatus["status"] = "unhealthy" - sfStatus["error"] = err.Error() + if err := application.Warehouse.DB().PingContext(ctx); err != nil { + warehouseStatus["status"] = "unhealthy" + warehouseStatus["error"] = err.Error() } else { - sfStatus["status"] = "healthy" - sfStatus["latency_ms"] = time.Since(start).Milliseconds() + warehouseStatus["status"] = "healthy" + warehouseStatus["latency_ms"] = time.Since(start).Milliseconds() } } - status["snowflake"] = sfStatus + status["warehouse"] = warehouseStatus // Policies policies := application.Policy.ListPolicies() @@ -134,8 +134,8 @@ func renderStatus(status map[string]interface{}) error { fmt.Println(strings.Repeat("=", 50)) fmt.Println() - fmt.Println(bold("Snowflake")) - sf := statusSection(status, "snowflake") + fmt.Println(bold("Warehouse")) + sf := statusSection(status, "warehouse") sfStatus := strings.ToLower(strings.TrimSpace(statusString(sf["status"]))) configured := statusBool(sf["configured"]) || (sfStatus != "" && sfStatus != "not_configured") if configured { diff --git a/internal/cli/sync.go b/internal/cli/sync.go index 1915807b6..67c5d0e93 100644 --- a/internal/cli/sync.go +++ b/internal/cli/sync.go @@ -22,12 +22,13 @@ import ( "github.com/writer/cerebro/internal/scanner" "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" ) var syncCmd = &cobra.Command{ Use: "sync", - Short: "Sync cloud assets to Snowflake", - Long: `Sync cloud assets from AWS, GCP, Azure, or Kubernetes to Snowflake using Cerebro's native scanners. + Short: "Sync cloud assets to the warehouse", + Long: `Sync cloud assets from AWS, GCP, Azure, or Kubernetes to the configured warehouse using Cerebro's native scanners. Examples: cerebro sync # Sync AWS (default) @@ -41,7 +42,7 @@ Examples: var syncBackfillRelationshipsCmd = &cobra.Command{ Use: "backfill-relationships", - Short: "Normalize relationship IDs in Snowflake", + Short: "Normalize relationship IDs in the warehouse", Long: `Normalize existing relationship IDs to remove JSON/map wrappers. This command re-computes relationship IDs from normalized source/target IDs and @@ -681,11 +682,11 @@ func runBackfillRelationshipsDirect(cmd *cobra.Command, args []string) error { ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) extractor := nativesync.NewRelationshipExtractor(client, slog.Default()) stats, err := extractor.BackfillNormalizedRelationshipIDs(ctx, syncBackfillBatchSize) @@ -960,21 +961,50 @@ func isScannableTable(table string) bool { return true } -func createSnowflakeClient() (*snowflake.Client, error) { - cfg := snowflake.DSNConfigFromEnv() - if missing := cfg.MissingFields(); len(missing) > 0 { - return nil, fmt.Errorf("snowflake not configured: set %s", strings.Join(missing, ", ")) +func createSyncWarehouse() (warehouse.DataWarehouse, error) { + cfg := app.LoadConfig() + backend := strings.ToLower(strings.TrimSpace(cfg.WarehouseBackend)) + switch backend { + case "", "snowflake": + snowflakeCfg := snowflake.DSNConfigFromEnv() + if missing := snowflakeCfg.MissingFields(); len(missing) > 0 { + return nil, fmt.Errorf("warehouse not configured: set %s", strings.Join(missing, ", ")) + } + + return snowflake.NewClient(snowflake.ClientConfig{ + Account: snowflakeCfg.Account, + User: snowflakeCfg.User, + PrivateKey: snowflakeCfg.PrivateKey, + Database: snowflakeCfg.Database, + Schema: snowflakeCfg.Schema, + Warehouse: snowflakeCfg.Warehouse, + Role: snowflakeCfg.Role, + }) + case "postgres": + dsn := strings.TrimSpace(cfg.WarehousePostgresDSN) + if dsn == "" { + dsn = strings.TrimSpace(cfg.JobDatabaseURL) + } + return warehouse.NewPostgresWarehouse(warehouse.PostgresWarehouseConfig{ + DSN: dsn, + AppSchema: "cerebro", + }) + case "sqlite": + return warehouse.NewSQLiteWarehouse(warehouse.SQLiteWarehouseConfig{ + Path: strings.TrimSpace(cfg.WarehouseSQLitePath), + Database: "sqlite", + Schema: "RAW", + AppSchema: "CEREBRO", + }) + default: + return nil, fmt.Errorf("unsupported warehouse backend %q", cfg.WarehouseBackend) } +} - return snowflake.NewClient(snowflake.ClientConfig{ - Account: cfg.Account, - User: cfg.User, - PrivateKey: cfg.PrivateKey, - Database: cfg.Database, - Schema: cfg.Schema, - Warehouse: cfg.Warehouse, - Role: cfg.Role, - }) +func closeSyncWarehouse(store warehouse.DataWarehouse) { + if closer, ok := any(store).(interface{ Close() error }); ok { + _ = closer.Close() + } } func runPostSyncScan(ctx context.Context, tableFilter []string) error { @@ -986,13 +1016,13 @@ func runPostSyncScan(ctx context.Context, tableFilter []string) error { } defer func() { _ = application.Close() }() - if application.Snowflake == nil { - return fmt.Errorf("snowflake not configured: set SNOWFLAKE_PRIVATE_KEY, SNOWFLAKE_ACCOUNT, and SNOWFLAKE_USER") + if application.Warehouse == nil { + return fmt.Errorf("warehouse not configured") } availableTables := application.AvailableTables - if application.Snowflake != nil { - if refreshed, err := application.Snowflake.ListAvailableTables(ctx); err == nil { + if application.Warehouse != nil { + if refreshed, err := application.Warehouse.ListAvailableTables(ctx); err == nil { application.AvailableTables = refreshed availableTables = refreshed } else { @@ -1022,7 +1052,7 @@ func runPostSyncScan(ctx context.Context, tableFilter []string) error { tables, skipped := filterAvailableTables(tables, availableTables) if skipped > 0 { - Info("Skipped %d tables not present in Snowflake", skipped) + Info("Skipped %d tables not present in the warehouse", skipped) } fmt.Println("Scanning synced assets...") @@ -1077,7 +1107,7 @@ func runPostSyncScan(ctx context.Context, tableFilter []string) error { filter.Offset = offset } assets, attempts, err := scanner.WithRetryValue(tableCtx, tuning.RetryOptions, func() ([]map[string]interface{}, error) { - return application.Snowflake.GetAssets(tableCtx, table, filter) + return application.Warehouse.GetAssets(tableCtx, table, filter) }) tableProfile.RetryAttempts += retryCount(attempts) if err != nil { @@ -1169,14 +1199,14 @@ func runPostSyncScan(ctx context.Context, tableFilter []string) error { sqlToxicRiskSets := make(map[string][]map[string]bool) relationshipCount := 0 - if application.Snowflake != nil { + if scanner.SupportsRelationshipToxicDetection(application.Warehouse) { var toxicCursor *scanner.ToxicScanCursor if application.ScanWatermarks != nil { if wm := application.ScanWatermarks.GetWatermark("_toxic_relationships"); wm != nil { toxicCursor = &scanner.ToxicScanCursor{SinceTime: wm.LastScanTime, SinceID: wm.LastScanID} } } - toxicResult, err := scanner.DetectRelationshipToxicCombinations(ctx, application.Snowflake, toxicCursor) + toxicResult, err := scanner.DetectRelationshipToxicCombinations(ctx, application.Warehouse, toxicCursor) if err != nil { Warning("Failed to detect toxic combinations from relationships: %v", err) } else { diff --git a/internal/cli/sync_aws.go b/internal/cli/sync_aws.go index 63c85e61e..e6c68bcc9 100644 --- a/internal/cli/sync_aws.go +++ b/internal/cli/sync_aws.go @@ -255,10 +255,10 @@ func runMultiAccountAWSSyncDirect(ctx context.Context, start time.Time, profiles region = "us-east-1" } - sfClient, err := createSnowflakeClient() + sfClient, err := createSyncWarehouse() if err != nil { - Warning("Failed to create Snowflake client for profile %s: %v", profile, err) - syncErrs = append(syncErrs, fmt.Errorf("profile %s: create snowflake client: %w", profile, err)) + Warning("Failed to create warehouse client for profile %s: %v", profile, err) + syncErrs = append(syncErrs, fmt.Errorf("profile %s: create warehouse client: %w", profile, err)) continue } @@ -278,7 +278,7 @@ func runMultiAccountAWSSyncDirect(ctx context.Context, start time.Time, profiles syncer := nativesync.NewSyncEngine(sfClient, slog.Default(), opts...) results, err := syncer.SyncAllWithConfig(ctx, awsCfg) - _ = sfClient.Close() + closeSyncWarehouse(sfClient) totalResults = append(totalResults, results...) if err != nil { @@ -445,11 +445,11 @@ func runNativeSyncDirect(ctx context.Context, start time.Time) error { Info("Filtering AWS tables: %s", strings.Join(tableFilter, ", ")) } - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) opts := []nativesync.EngineOption{} if syncConcurrency > 0 { diff --git a/internal/cli/sync_azure.go b/internal/cli/sync_azure.go index fd02c0578..ad30a4769 100644 --- a/internal/cli/sync_azure.go +++ b/internal/cli/sync_azure.go @@ -90,11 +90,11 @@ func runAzureSyncDirect(ctx context.Context, start time.Time, tableFilter []stri if err != nil { return err } - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) opts := []nativesync.AzureEngineOption{} switch len(explicitSubscriptions) { diff --git a/internal/cli/sync_gcp.go b/internal/cli/sync_gcp.go index c791977e7..365cd31af 100644 --- a/internal/cli/sync_gcp.go +++ b/internal/cli/sync_gcp.go @@ -290,11 +290,11 @@ func runGCPSyncDirect( runSecuritySync bool, tableFilterSet map[string]struct{}, ) error { - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) if runNativeSync { if err := preflightGCPProjectAccessFn(ctx, gcpProjectPreflightSpec{ @@ -474,11 +474,11 @@ func runGCPMultiProjectSync(ctx context.Context, start time.Time, projects []str return fmt.Errorf("validation for GCP security-only table filters is not supported; include at least one native table") } - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) if syncValidate { if len(projects) == 0 { @@ -725,11 +725,11 @@ func runGCPAssetAPISyncDirect( runNativeSync bool, runSecuritySync bool, ) error { - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) var syncErrs []error diff --git a/internal/cli/sync_gcp_filter_test.go b/internal/cli/sync_gcp_filter_test.go index dac38c543..271a9578e 100644 --- a/internal/cli/sync_gcp_filter_test.go +++ b/internal/cli/sync_gcp_filter_test.go @@ -243,8 +243,8 @@ func TestRunGCPOrgSync_SecurityOnlySkipsProjectDiscovery(t *testing.T) { if called { t.Fatalf("expected organization project discovery to be skipped") } - if err == nil || !strings.Contains(err.Error(), "snowflake not configured") { - t.Fatalf("expected snowflake configuration error, got %v", err) + if err == nil { + t.Fatalf("expected security-only org sync to require additional runtime configuration") } } diff --git a/internal/cli/sync_k8s.go b/internal/cli/sync_k8s.go index 6fc79d18d..ccd4f6c44 100644 --- a/internal/cli/sync_k8s.go +++ b/internal/cli/sync_k8s.go @@ -72,11 +72,11 @@ func runK8sSync(ctx context.Context, start time.Time) error { var runK8sSyncDirectFn = runK8sSyncDirect func runK8sSyncDirect(ctx context.Context, start time.Time, tableFilter []string) error { - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) opts := []nativesync.K8sEngineOption{} if syncK8sKubeconfig != "" { diff --git a/internal/cli/sync_schedule.go b/internal/cli/sync_schedule.go index e092ade29..b333cb326 100644 --- a/internal/cli/sync_schedule.go +++ b/internal/cli/sync_schedule.go @@ -22,8 +22,8 @@ import ( apiclient "github.com/writer/cerebro/internal/client" "github.com/writer/cerebro/internal/jobs" providerregistry "github.com/writer/cerebro/internal/providers" - "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" ) var syncScheduleCmd = &cobra.Command{ @@ -170,11 +170,11 @@ type SyncSchedule struct { func runScheduleList(cmd *cobra.Command, args []string) error { ctx := context.Background() - client, err := createSnowflakeClientForSchedule() + client, err := createScheduleStore() if err != nil { - return fmt.Errorf("failed to connect to Snowflake: %w", err) + return fmt.Errorf("failed to connect to warehouse: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) schedules, err := listSchedules(ctx, client) if err != nil { @@ -247,11 +247,11 @@ func runScheduleCreate(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid provider %q; valid providers: %s", scheduleProvider, strings.Join(validProviders, ", ")) } - client, err := createSnowflakeClientForSchedule() + client, err := createScheduleStore() if err != nil { - return fmt.Errorf("failed to connect to Snowflake: %w", err) + return fmt.Errorf("failed to connect to warehouse: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) // Check if schedule already exists existing, _ := getSchedule(ctx, client, scheduleName) @@ -288,11 +288,11 @@ func runScheduleDelete(cmd *cobra.Command, args []string) error { ctx := context.Background() name := args[0] - client, err := createSnowflakeClientForSchedule() + client, err := createScheduleStore() if err != nil { - return fmt.Errorf("failed to connect to Snowflake: %w", err) + return fmt.Errorf("failed to connect to warehouse: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) // Check if schedule exists existing, err := getSchedule(ctx, client, name) @@ -315,11 +315,11 @@ func runScheduleShow(cmd *cobra.Command, args []string) error { ctx := context.Background() name := args[0] - client, err := createSnowflakeClientForSchedule() + client, err := createScheduleStore() if err != nil { - return fmt.Errorf("failed to connect to Snowflake: %w", err) + return fmt.Errorf("failed to connect to warehouse: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) schedule, err := getSchedule(ctx, client, name) if err != nil { @@ -359,11 +359,11 @@ func runScheduleDaemon(cmd *cobra.Command, args []string) error { ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() - client, err := createSnowflakeClientForSchedule() + client, err := createScheduleStore() if err != nil { - return fmt.Errorf("failed to connect to Snowflake: %w", err) + return fmt.Errorf("failed to connect to warehouse: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) Info("Starting sync schedule daemon...") @@ -456,7 +456,7 @@ shutdown: return nil } -func runScheduledSync(client *snowflake.Client, schedule *SyncSchedule) { +func runScheduledSync(client warehouse.DataWarehouse, schedule *SyncSchedule) { start := scheduleNowFn() persistCtx := context.Background() scheduleKey := strings.ToLower(strings.TrimSpace(schedule.Name)) @@ -560,7 +560,7 @@ func runScheduledSync(client *snowflake.Client, schedule *SyncSchedule) { slog.Default().Info("scheduled_sync_audit", attrs...) } -func executeScheduledSync(ctx context.Context, client *snowflake.Client, schedule *SyncSchedule) error { +func executeScheduledSync(ctx context.Context, client warehouse.SyncWarehouse, schedule *SyncSchedule) error { provider := strings.ToLower(strings.TrimSpace(schedule.Provider)) if isNativeScheduleProvider(provider) && nativeSyncWorkerConfigured() { return enqueueScheduledNativeSyncFn(ctx, schedule) @@ -713,7 +713,7 @@ func parseBoundedPositiveIntDirective(raw, name string, min, max int) (int, erro return value, nil } -func executeProviderSync(ctx context.Context, _ *snowflake.Client, schedule *SyncSchedule) error { +func executeProviderSync(ctx context.Context, _ warehouse.SyncWarehouse, schedule *SyncSchedule) error { providerName := strings.ToLower(strings.TrimSpace(schedule.Provider)) Info("[%s] Executing provider sync for %s...", schedule.Name, providerName) @@ -1139,8 +1139,8 @@ func schedulesEqual(a, b []SyncSchedule) bool { return true } -func ensureScheduleTable(ctx context.Context, client *snowflake.Client) error { - query := `CREATE TABLE IF NOT EXISTS sync_schedules ( +func ensureScheduleTable(ctx context.Context, client warehouse.SyncWarehouse) error { + query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS sync_schedules ( name VARCHAR PRIMARY KEY, cron VARCHAR NOT NULL, provider VARCHAR NOT NULL, @@ -1148,78 +1148,48 @@ func ensureScheduleTable(ctx context.Context, client *snowflake.Client) error { enabled BOOLEAN DEFAULT TRUE, scan_after BOOLEAN DEFAULT FALSE, retry INTEGER DEFAULT 3, - created_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(), - updated_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(), - last_run TIMESTAMP_NTZ, + created_at %s DEFAULT CURRENT_TIMESTAMP(), + updated_at %s DEFAULT CURRENT_TIMESTAMP(), + last_run %s, last_status VARCHAR, - next_run TIMESTAMP_NTZ - )` + next_run %s + )`, warehouse.LocalTimestampColumnType(client), warehouse.LocalTimestampColumnType(client), warehouse.LocalTimestampColumnType(client), warehouse.LocalTimestampColumnType(client)) _, err := client.Exec(ctx, query) return err } -func listSchedules(ctx context.Context, client *snowflake.Client) ([]SyncSchedule, error) { +func listSchedules(ctx context.Context, client warehouse.SyncWarehouse) ([]SyncSchedule, error) { if err := ensureScheduleTable(ctx, client); err != nil { return nil, err } - query := `SELECT name, cron, provider, COALESCE(table_filter, ''), enabled, - scan_after, retry, created_at, updated_at, - COALESCE(last_run, '1970-01-01'::TIMESTAMP_NTZ), - COALESCE(last_status, ''), - COALESCE(next_run, '1970-01-01'::TIMESTAMP_NTZ) - FROM sync_schedules ORDER BY name` + query := `SELECT name, cron, provider, COALESCE(table_filter, '') AS table_filter, enabled, + scan_after, retry, created_at, updated_at, last_run, COALESCE(last_status, '') AS last_status, next_run + FROM sync_schedules ORDER BY name` result, err := client.Query(ctx, query) if err != nil { return nil, err } - var schedules []SyncSchedule + schedules := make([]SyncSchedule, 0, len(result.Rows)) for _, row := range result.Rows { - s := SyncSchedule{ - Name: getString(row, "NAME"), - Cron: getString(row, "CRON"), - Provider: getString(row, "PROVIDER"), - Table: getString(row, "COALESCE(TABLE_FILTER, '')"), - Enabled: getBool(row, "ENABLED"), - ScanAfter: getBool(row, "SCAN_AFTER"), - Retry: getInt(row, "RETRY"), - CreatedAt: getTime(row, "CREATED_AT"), - UpdatedAt: getTime(row, "UPDATED_AT"), - LastRun: getTime(row, "COALESCE(LAST_RUN, '1970-01-01'::TIMESTAMP_NTZ)"), - LastStatus: getString(row, "COALESCE(LAST_STATUS, '')"), - NextRun: getTime(row, "COALESCE(NEXT_RUN, '1970-01-01'::TIMESTAMP_NTZ)"), - } - // Reset zero times - if s.LastRun.Year() == 1970 { - s.LastRun = time.Time{} - } - if s.NextRun.Year() == 1970 { - s.NextRun = time.Time{} - } - schedules = append(schedules, s) - } - - // Sort by name + schedules = append(schedules, decodeScheduleRow(row)) + } sort.Slice(schedules, func(i, j int) bool { return schedules[i].Name < schedules[j].Name }) - return schedules, nil } -func getSchedule(ctx context.Context, client *snowflake.Client, name string) (*SyncSchedule, error) { +func getSchedule(ctx context.Context, client warehouse.SyncWarehouse, name string) (*SyncSchedule, error) { if err := ensureScheduleTable(ctx, client); err != nil { return nil, err } - query := `SELECT name, cron, provider, COALESCE(table_filter, ''), enabled, - scan_after, retry, created_at, updated_at, - COALESCE(last_run, '1970-01-01'::TIMESTAMP_NTZ), - COALESCE(last_status, ''), - COALESCE(next_run, '1970-01-01'::TIMESTAMP_NTZ) - FROM sync_schedules WHERE name = ?` + query := `SELECT name, cron, provider, COALESCE(table_filter, '') AS table_filter, enabled, + scan_after, retry, created_at, updated_at, last_run, COALESCE(last_status, '') AS last_status, next_run + FROM sync_schedules WHERE name = ` + warehouse.Placeholder(client, 1) result, err := client.Query(ctx, query, name) if err != nil { @@ -1230,138 +1200,207 @@ func getSchedule(ctx context.Context, client *snowflake.Client, name string) (*S return nil, nil } - row := result.Rows[0] - s := &SyncSchedule{ - Name: getString(row, "NAME"), - Cron: getString(row, "CRON"), - Provider: getString(row, "PROVIDER"), - Table: getString(row, "COALESCE(TABLE_FILTER, '')"), - Enabled: getBool(row, "ENABLED"), - ScanAfter: getBool(row, "SCAN_AFTER"), - Retry: getInt(row, "RETRY"), - CreatedAt: getTime(row, "CREATED_AT"), - UpdatedAt: getTime(row, "UPDATED_AT"), - LastRun: getTime(row, "COALESCE(LAST_RUN, '1970-01-01'::TIMESTAMP_NTZ)"), - LastStatus: getString(row, "COALESCE(LAST_STATUS, '')"), - NextRun: getTime(row, "COALESCE(NEXT_RUN, '1970-01-01'::TIMESTAMP_NTZ)"), - } - if s.LastRun.Year() == 1970 { - s.LastRun = time.Time{} - } - if s.NextRun.Year() == 1970 { - s.NextRun = time.Time{} - } - - return s, nil + schedule := decodeScheduleRow(result.Rows[0]) + return &schedule, nil } -func saveSchedule(ctx context.Context, client *snowflake.Client, schedule *SyncSchedule) error { +func saveSchedule(ctx context.Context, client warehouse.SyncWarehouse, schedule *SyncSchedule) error { if err := ensureScheduleTable(ctx, client); err != nil { return err } - query := `MERGE INTO sync_schedules t - USING (SELECT ? as name) s - ON t.name = s.name - WHEN MATCHED THEN UPDATE SET - cron = ?, - provider = ?, - table_filter = ?, - enabled = ?, - scan_after = ?, - retry = ?, - updated_at = ?, - last_run = ?, - last_status = ?, - next_run = ? - WHEN NOT MATCHED THEN INSERT - (name, cron, provider, table_filter, enabled, scan_after, retry, created_at, updated_at, last_run, last_status, next_run) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` - - var lastRun, nextRun interface{} - if !schedule.LastRun.IsZero() { - lastRun = schedule.LastRun - } - if !schedule.NextRun.IsZero() { - nextRun = schedule.NextRun + lastRun := nullableTime(schedule.LastRun) + nextRun := nullableTime(schedule.NextRun) + if warehouse.Dialect(client) == warehouse.SQLDialectSnowflake { + query := `MERGE INTO sync_schedules t + USING (SELECT ? as name) s + ON t.name = s.name + WHEN MATCHED THEN UPDATE SET + cron = ?, + provider = ?, + table_filter = ?, + enabled = ?, + scan_after = ?, + retry = ?, + updated_at = ?, + last_run = ?, + last_status = ?, + next_run = ? + WHEN NOT MATCHED THEN INSERT + (name, cron, provider, table_filter, enabled, scan_after, retry, created_at, updated_at, last_run, last_status, next_run) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + _, err := client.Exec(ctx, query, + schedule.Name, + schedule.Cron, schedule.Provider, schedule.Table, schedule.Enabled, schedule.ScanAfter, schedule.Retry, + schedule.UpdatedAt, lastRun, schedule.LastStatus, nextRun, + schedule.Name, schedule.Cron, schedule.Provider, schedule.Table, schedule.Enabled, schedule.ScanAfter, schedule.Retry, + schedule.CreatedAt, schedule.UpdatedAt, lastRun, schedule.LastStatus, nextRun, + ) + return err } + query := fmt.Sprintf( + "INSERT INTO sync_schedules (name, cron, provider, table_filter, enabled, scan_after, retry, created_at, updated_at, last_run, last_status, next_run) VALUES (%s) "+ + "ON CONFLICT (name) DO UPDATE SET cron = EXCLUDED.cron, provider = EXCLUDED.provider, table_filter = EXCLUDED.table_filter, enabled = EXCLUDED.enabled, scan_after = EXCLUDED.scan_after, retry = EXCLUDED.retry, updated_at = EXCLUDED.updated_at, last_run = EXCLUDED.last_run, last_status = EXCLUDED.last_status, next_run = EXCLUDED.next_run", + strings.Join(warehouse.Placeholders(client, 1, 12), ", "), + ) + _, err := client.Exec(ctx, query, schedule.Name, - schedule.Cron, schedule.Provider, schedule.Table, schedule.Enabled, schedule.ScanAfter, schedule.Retry, - schedule.UpdatedAt, lastRun, schedule.LastStatus, nextRun, - schedule.Name, schedule.Cron, schedule.Provider, schedule.Table, schedule.Enabled, schedule.ScanAfter, schedule.Retry, - schedule.CreatedAt, schedule.UpdatedAt, lastRun, schedule.LastStatus, nextRun, + schedule.Cron, + schedule.Provider, + schedule.Table, + schedule.Enabled, + schedule.ScanAfter, + schedule.Retry, + schedule.CreatedAt, + schedule.UpdatedAt, + lastRun, + schedule.LastStatus, + nextRun, ) return err } -func deleteSchedule(ctx context.Context, client *snowflake.Client, name string) error { - query := `DELETE FROM sync_schedules WHERE name = ?` +func deleteSchedule(ctx context.Context, client warehouse.SyncWarehouse, name string) error { + query := `DELETE FROM sync_schedules WHERE name = ` + warehouse.Placeholder(client, 1) _, err := client.Exec(ctx, query, name) return err } -func createSnowflakeClientForSchedule() (*snowflake.Client, error) { - cfg := snowflake.DSNConfigFromEnv() - if missing := cfg.MissingFields(); len(missing) > 0 { - return nil, fmt.Errorf("snowflake not configured: set %s", strings.Join(missing, ", ")) +func createScheduleStore() (warehouse.DataWarehouse, error) { + return createSyncWarehouse() +} + +func decodeScheduleRow(row map[string]interface{}) SyncSchedule { + return SyncSchedule{ + Name: getString(row, "name"), + Cron: getString(row, "cron"), + Provider: getString(row, "provider"), + Table: getString(row, "table_filter"), + Enabled: getBool(row, "enabled"), + ScanAfter: getBool(row, "scan_after"), + Retry: getInt(row, "retry"), + CreatedAt: getTime(row, "created_at"), + UpdatedAt: getTime(row, "updated_at"), + LastRun: getTime(row, "last_run"), + LastStatus: getString(row, "last_status"), + NextRun: getTime(row, "next_run"), } +} - return snowflake.NewClient(snowflake.ClientConfig{ - Account: cfg.Account, - User: cfg.User, - PrivateKey: cfg.PrivateKey, - Database: cfg.Database, - Schema: cfg.Schema, - Warehouse: cfg.Warehouse, - Role: cfg.Role, - }) +func nullableTime(value time.Time) interface{} { + if value.IsZero() { + return nil + } + return value.UTC() } -// Helper functions for extracting values from query results +func scheduleRowValue(row map[string]interface{}, key string) interface{} { + for candidate, value := range row { + if strings.EqualFold(candidate, key) { + return value + } + } + return nil +} func getString(row map[string]interface{}, key string) string { - if v, ok := row[key]; ok { - if s, ok := v.(string); ok { - return s + switch value := scheduleRowValue(row, key).(type) { + case string: + return value + case []byte: + return string(value) + default: + if value == nil { + return "" } + return fmt.Sprintf("%v", value) } - return "" } func getBool(row map[string]interface{}, key string) bool { - if v, ok := row[key]; ok { - if b, ok := v.(bool); ok { - return b - } + switch value := scheduleRowValue(row, key).(type) { + case bool: + return value + case int64: + return value != 0 + case int: + return value != 0 + case string: + return strings.EqualFold(strings.TrimSpace(value), "true") || strings.TrimSpace(value) == "1" + default: + return false } - return false } func getInt(row map[string]interface{}, key string) int { - if v, ok := row[key]; ok { - switch n := v.(type) { - case int: - return n - case int64: - return int(n) - case float64: - return int(n) - } + switch value := scheduleRowValue(row, key).(type) { + case int: + return value + case int64: + return int(value) + case int32: + return int(value) + case float64: + return int(value) + case []byte: + parsed, _ := strconv.Atoi(string(value)) + return parsed + case string: + parsed, _ := strconv.Atoi(value) + return parsed + default: + return 0 } - return 0 } func getTime(row map[string]interface{}, key string) time.Time { - if v, ok := row[key]; ok { - if t, ok := v.(time.Time); ok { - return t + switch value := scheduleRowValue(row, key).(type) { + case time.Time: + return value.UTC() + case *time.Time: + if value != nil { + return value.UTC() } - if s, ok := v.(string); ok { - if t, err := time.Parse(time.RFC3339, s); err == nil { - return t - } + case string: + if value == "" { + return time.Time{} + } + if ts, err := time.Parse(time.RFC3339Nano, value); err == nil { + return ts.UTC() + } + if ts, err := time.Parse(time.RFC3339, value); err == nil { + return ts.UTC() + } + if ts, err := time.Parse("2006-01-02 15:04:05.999999999-07:00", value); err == nil { + return ts.UTC() + } + if ts, err := time.Parse("2006-01-02 15:04:05.999999999", value); err == nil { + return ts.UTC() + } + if ts, err := time.Parse("2006-01-02 15:04:05", value); err == nil { + return ts.UTC() + } + case []byte: + text := string(value) + if text == "" { + return time.Time{} + } + if ts, err := time.Parse(time.RFC3339Nano, text); err == nil { + return ts.UTC() + } + if ts, err := time.Parse(time.RFC3339, text); err == nil { + return ts.UTC() + } + if ts, err := time.Parse("2006-01-02 15:04:05.999999999-07:00", text); err == nil { + return ts.UTC() + } + if ts, err := time.Parse("2006-01-02 15:04:05.999999999", text); err == nil { + return ts.UTC() + } + if ts, err := time.Parse("2006-01-02 15:04:05", text); err == nil { + return ts.UTC() } } return time.Time{} diff --git a/internal/cli/sync_schedule_aws.go b/internal/cli/sync_schedule_aws.go index 39e01a0b3..2d7903eae 100644 --- a/internal/cli/sync_schedule_aws.go +++ b/internal/cli/sync_schedule_aws.go @@ -16,12 +16,12 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/writer/cerebro/internal/metrics" - "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" "golang.org/x/sync/errgroup" ) -func executeAWSSync(ctx context.Context, client *snowflake.Client, schedule *SyncSchedule) error { +func executeAWSSync(ctx context.Context, client warehouse.SyncWarehouse, schedule *SyncSchedule) error { Info("[%s] Executing AWS sync...", schedule.Name) spec := parseScheduledSyncSpec(schedule.Table) @@ -72,7 +72,7 @@ func executeAWSSync(ctx context.Context, client *snowflake.Client, schedule *Syn return runScheduledAWSNativeSyncFn(ctx, client, awsCfg, spec.TableFilter) } -func runScheduledAWSNativeSync(ctx context.Context, client *snowflake.Client, awsCfg aws.Config, tableFilter []string) error { +func runScheduledAWSNativeSync(ctx context.Context, client warehouse.SyncWarehouse, awsCfg aws.Config, tableFilter []string) error { var opts []nativesync.EngineOption if len(tableFilter) > 0 { opts = append(opts, nativesync.WithTableFilter(tableFilter)) @@ -83,7 +83,7 @@ func runScheduledAWSNativeSync(ctx context.Context, client *snowflake.Client, aw return err } -func runScheduledAWSOrgSync(ctx context.Context, client *snowflake.Client, awsCfg aws.Config, spec scheduledSyncSpec) error { +func runScheduledAWSOrgSync(ctx context.Context, client warehouse.SyncWarehouse, awsCfg aws.Config, spec scheduledSyncSpec) error { orgCfg := awsCfg.Copy() if strings.TrimSpace(orgCfg.Region) == "" { orgCfg.Region = "us-east-1" diff --git a/internal/cli/sync_schedule_azure.go b/internal/cli/sync_schedule_azure.go index 302cc9d3c..907c5adf5 100644 --- a/internal/cli/sync_schedule_azure.go +++ b/internal/cli/sync_schedule_azure.go @@ -7,11 +7,11 @@ import ( "strconv" "strings" - "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" ) -func executeAzureSync(ctx context.Context, client *snowflake.Client, schedule *SyncSchedule) error { +func executeAzureSync(ctx context.Context, client warehouse.SyncWarehouse, schedule *SyncSchedule) error { spec := parseScheduledSyncSpec(schedule.Table) subscriptions := uniqueNonEmpty(append(append([]string{}, spec.AzureSubscriptions...), spec.AzureSubscription)) if len(subscriptions) == 0 { diff --git a/internal/cli/sync_schedule_gcp.go b/internal/cli/sync_schedule_gcp.go index e309eb38e..aad6df432 100644 --- a/internal/cli/sync_schedule_gcp.go +++ b/internal/cli/sync_schedule_gcp.go @@ -16,8 +16,8 @@ import ( securitycenter "cloud.google.com/go/securitycenter/apiv1" "cloud.google.com/go/securitycenter/apiv1/securitycenterpb" "github.com/writer/cerebro/internal/metrics" - "github.com/writer/cerebro/internal/snowflake" nativesync "github.com/writer/cerebro/internal/sync" + "github.com/writer/cerebro/internal/warehouse" "golang.org/x/oauth2/google" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -40,7 +40,7 @@ type gcpProjectPreflightSpec struct { ClientOptions []option.ClientOption } -func executeGCPSync(ctx context.Context, client *snowflake.Client, schedule *SyncSchedule) error { +func executeGCPSync(ctx context.Context, client warehouse.SyncWarehouse, schedule *SyncSchedule) error { spec := parseScheduledSyncSpec(schedule.Table) authConfig, err := applyScheduledGCPAuthFn(spec) if err != nil { @@ -576,7 +576,7 @@ func detectGCPCredentialsType(raw []byte, source string) (option.CredentialsType } } -func runScheduledGCPNativeSync(ctx context.Context, client *snowflake.Client, projectID string, tableFilter []string) error { +func runScheduledGCPNativeSync(ctx context.Context, client warehouse.SyncWarehouse, projectID string, tableFilter []string) error { opts := []nativesync.GCPEngineOption{nativesync.WithGCPProject(projectID)} if len(tableFilter) > 0 { opts = append(opts, nativesync.WithGCPTableFilter(tableFilter)) @@ -586,7 +586,7 @@ func runScheduledGCPNativeSync(ctx context.Context, client *snowflake.Client, pr return err } -func runScheduledGCPSecuritySync(ctx context.Context, client *snowflake.Client, projectID, orgID string, tableFilter []string) error { +func runScheduledGCPSecuritySync(ctx context.Context, client warehouse.SyncWarehouse, projectID, orgID string, tableFilter []string) error { secOpts := []nativesync.GCPSecurityOption{} if len(tableFilter) > 0 { secOpts = append(secOpts, nativesync.WithGCPSecurityTableFilter(tableFilter)) diff --git a/internal/cli/sync_schedule_test.go b/internal/cli/sync_schedule_test.go index bd061f594..eb54d5ebd 100644 --- a/internal/cli/sync_schedule_test.go +++ b/internal/cli/sync_schedule_test.go @@ -16,7 +16,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/writer/cerebro/internal/app" providerregistry "github.com/writer/cerebro/internal/providers" - "github.com/writer/cerebro/internal/snowflake" + "github.com/writer/cerebro/internal/warehouse" "google.golang.org/api/option" ) @@ -222,19 +222,19 @@ func TestExecuteScheduledSync_RoutesByProvider(t *testing.T) { }) called := "" - executeAWSSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeAWSSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { called = "aws" return nil } - executeGCPSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeGCPSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { called = "gcp" return nil } - executeAzureSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeAzureSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { called = "azure" return nil } - executeProviderSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeProviderSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { called = "provider" return nil } @@ -283,7 +283,7 @@ func TestExecuteScheduledSync_UsesWorkerForNativeProviders(t *testing.T) { }) directCalled := false - executeAWSSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeAWSSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { directCalled = true return nil } @@ -308,7 +308,7 @@ func TestExecuteScheduledSync_UsesWorkerForNativeProviders(t *testing.T) { } providerCalled := 0 - executeProviderSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeProviderSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { providerCalled++ return nil } @@ -531,7 +531,7 @@ func TestExecuteAWSSync_UsesScheduledAuthDirectives(t *testing.T) { preflightCalled = true return nil } - runScheduledAWSNativeSyncFn = func(_ context.Context, _ *snowflake.Client, _ aws.Config, tableFilter []string) error { + runScheduledAWSNativeSyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _ aws.Config, tableFilter []string) error { runCalled = true if len(tableFilter) != 1 || tableFilter[0] != "aws_iam_roles" { return fmt.Errorf("unexpected aws table filter: %v", tableFilter) @@ -581,13 +581,13 @@ func TestExecuteAWSSync_UsesAWSOrgDirectives(t *testing.T) { preflightScheduledAWSAuthFn = func(context.Context, *SyncSchedule, scheduledSyncSpec, aws.Config) error { return nil } - runScheduledAWSNativeSyncFn = func(_ context.Context, _ *snowflake.Client, _ aws.Config, _ []string) error { + runScheduledAWSNativeSyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _ aws.Config, _ []string) error { t.Fatal("did not expect single-account scheduled sync to run") return nil } orgRunCalled := false - runScheduledAWSOrgSyncFn = func(_ context.Context, _ *snowflake.Client, _ aws.Config, spec scheduledSyncSpec) error { + runScheduledAWSOrgSyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _ aws.Config, spec scheduledSyncSpec) error { orgRunCalled = true if len(spec.AWSOrgIncludeAccounts) != 1 || spec.AWSOrgIncludeAccounts[0] != "111111111111" { t.Fatalf("unexpected include accounts: %v", spec.AWSOrgIncludeAccounts) @@ -686,14 +686,14 @@ func TestRunScheduledSync_RetryAndStatus(t *testing.T) { t.Run("succeeds after retry", func(t *testing.T) { attempts := 0 saves := 0 - executeScheduledSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeScheduledSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { attempts++ if attempts < 2 { return errors.New("temporary failure") } return nil } - saveScheduleFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + saveScheduleFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { saves++ return nil } @@ -719,11 +719,11 @@ func TestRunScheduledSync_RetryAndStatus(t *testing.T) { t.Run("fails after all retries", func(t *testing.T) { attempts := 0 - executeScheduledSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeScheduledSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { attempts++ return errors.New("hard failure") } - saveScheduleFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { return nil } + saveScheduleFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { return nil } scheduleSleepFn = func(time.Duration) {} now := time.Date(2026, 2, 24, 12, 0, 0, 0, time.UTC) scheduleNowFn = func() time.Time { @@ -751,11 +751,11 @@ func TestRunScheduledSync_RejectsInvalidTimeoutDirective(t *testing.T) { }) executeCalls := 0 - executeScheduledSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeScheduledSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { executeCalls++ return nil } - saveScheduleFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { return nil } + saveScheduleFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { return nil } schedule := &SyncSchedule{Name: "invalid-timeout", Provider: "aws", Retry: 1, Table: "sync_timeout_seconds=5"} runScheduledSync(nil, schedule) @@ -783,7 +783,7 @@ func TestRunScheduledSync_SkipsOverlappingRuns(t *testing.T) { release := make(chan struct{}) finished := make(chan struct{}) - executeScheduledSyncFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { + executeScheduledSyncFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { select { case <-started: default: @@ -792,7 +792,7 @@ func TestRunScheduledSync_SkipsOverlappingRuns(t *testing.T) { <-release return nil } - saveScheduleFn = func(context.Context, *snowflake.Client, *SyncSchedule) error { return nil } + saveScheduleFn = func(context.Context, warehouse.SyncWarehouse, *SyncSchedule) error { return nil } scheduleSleepFn = func(time.Duration) {} first := &SyncSchedule{Name: "overlap-test", Provider: "aws", Retry: 1} @@ -837,8 +837,8 @@ func TestExecuteGCPSync_InvalidProjectTimeoutDirective(t *testing.T) { preflightGCPProjectAccessFn = func(context.Context, gcpProjectPreflightSpec) error { return nil } - runScheduledGCPNativeSyncFn = func(context.Context, *snowflake.Client, string, []string) error { return nil } - runScheduledGCPSecuritySyncFn = func(context.Context, *snowflake.Client, string, string, []string) error { return nil } + runScheduledGCPNativeSyncFn = func(context.Context, warehouse.SyncWarehouse, string, []string) error { return nil } + runScheduledGCPSecuritySyncFn = func(context.Context, warehouse.SyncWarehouse, string, string, []string) error { return nil } err := executeGCPSync(context.Background(), nil, &SyncSchedule{ Name: "invalid-project-timeout", @@ -875,11 +875,11 @@ func TestExecuteGCPSync_SkipsSecurityWhenNativeProjectTimesOut(t *testing.T) { preflightGCPProjectAccessFn = func(context.Context, gcpProjectPreflightSpec) error { return nil } - runScheduledGCPNativeSyncFn = func(context.Context, *snowflake.Client, string, []string) error { + runScheduledGCPNativeSyncFn = func(context.Context, warehouse.SyncWarehouse, string, []string) error { return context.DeadlineExceeded } securityCalls := 0 - runScheduledGCPSecuritySyncFn = func(context.Context, *snowflake.Client, string, string, []string) error { + runScheduledGCPSecuritySyncFn = func(context.Context, warehouse.SyncWarehouse, string, string, []string) error { securityCalls++ return nil } @@ -925,11 +925,11 @@ func TestExecuteGCPSync_PreflightFailureSkipsNativeAndSecurity(t *testing.T) { nativeCalls := 0 securityCalls := 0 - runScheduledGCPNativeSyncFn = func(context.Context, *snowflake.Client, string, []string) error { + runScheduledGCPNativeSyncFn = func(context.Context, warehouse.SyncWarehouse, string, []string) error { nativeCalls++ return nil } - runScheduledGCPSecuritySyncFn = func(context.Context, *snowflake.Client, string, string, []string) error { + runScheduledGCPSecuritySyncFn = func(context.Context, warehouse.SyncWarehouse, string, string, []string) error { securityCalls++ return nil } @@ -1122,10 +1122,10 @@ func TestExecuteGCPSync_AppliesScheduledAuthDirectives(t *testing.T) { preflightGCPProjectAccessFn = func(context.Context, gcpProjectPreflightSpec) error { return nil } - runScheduledGCPNativeSyncFn = func(context.Context, *snowflake.Client, string, []string) error { + runScheduledGCPNativeSyncFn = func(context.Context, warehouse.SyncWarehouse, string, []string) error { return nil } - runScheduledGCPSecuritySyncFn = func(context.Context, *snowflake.Client, string, string, []string) error { + runScheduledGCPSecuritySyncFn = func(context.Context, warehouse.SyncWarehouse, string, string, []string) error { return fmt.Errorf("security sync should not run") } @@ -1450,11 +1450,11 @@ func TestExecuteGCPSync_FilterRouting(t *testing.T) { securityCalls := 0 var securityFilters []string - runScheduledGCPNativeSyncFn = func(context.Context, *snowflake.Client, string, []string) error { + runScheduledGCPNativeSyncFn = func(context.Context, warehouse.SyncWarehouse, string, []string) error { nativeCalls++ return nil } - runScheduledGCPSecuritySyncFn = func(_ context.Context, _ *snowflake.Client, projectID, orgID string, tableFilter []string) error { + runScheduledGCPSecuritySyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, projectID, orgID string, tableFilter []string) error { securityCalls++ if projectID != "proj-1" { return fmt.Errorf("unexpected project id %q", projectID) @@ -1497,11 +1497,11 @@ func TestExecuteGCPSync_FilterRouting(t *testing.T) { securityCalls := 0 listCalls := 0 - runScheduledGCPNativeSyncFn = func(context.Context, *snowflake.Client, string, []string) error { + runScheduledGCPNativeSyncFn = func(context.Context, warehouse.SyncWarehouse, string, []string) error { nativeCalls++ return nil } - runScheduledGCPSecuritySyncFn = func(_ context.Context, _ *snowflake.Client, projectID, orgID string, tableFilter []string) error { + runScheduledGCPSecuritySyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, projectID, orgID string, tableFilter []string) error { securityCalls++ if projectID != "" { return fmt.Errorf("expected empty project id, got %q", projectID) @@ -1551,12 +1551,12 @@ func TestExecuteGCPSync_FilterRouting(t *testing.T) { var nativeFilters []string var securityFilters []string - runScheduledGCPNativeSyncFn = func(_ context.Context, _ *snowflake.Client, _ string, tableFilter []string) error { + runScheduledGCPNativeSyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _ string, tableFilter []string) error { nativeCalls++ nativeFilters = append([]string(nil), tableFilter...) return nil } - runScheduledGCPSecuritySyncFn = func(_ context.Context, _ *snowflake.Client, _ string, _ string, tableFilter []string) error { + runScheduledGCPSecuritySyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _ string, _ string, tableFilter []string) error { securityCalls++ securityFilters = append([]string(nil), tableFilter...) return nil @@ -1618,11 +1618,11 @@ func TestExecuteGCPSync_AppliesWIFAuth(t *testing.T) { t.Setenv("AWS_SECRET_ACCESS_KEY", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") var observedGAC string - runScheduledGCPNativeSyncFn = func(_ context.Context, _ *snowflake.Client, _ string, _ []string) error { + runScheduledGCPNativeSyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _ string, _ []string) error { observedGAC = os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") return nil } - runScheduledGCPSecuritySyncFn = func(_ context.Context, _ *snowflake.Client, _, _ string, _ []string) error { + runScheduledGCPSecuritySyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _, _ string, _ []string) error { return nil } @@ -1677,7 +1677,7 @@ func TestExecuteGCPSync_WIFCredsContent(t *testing.T) { t.Setenv("AWS_SECRET_ACCESS_KEY", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") var capturedPayload map[string]interface{} - runScheduledGCPNativeSyncFn = func(_ context.Context, _ *snowflake.Client, _ string, _ []string) error { + runScheduledGCPNativeSyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _ string, _ []string) error { gac := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") data, err := os.ReadFile(gac) if err != nil { @@ -1685,7 +1685,7 @@ func TestExecuteGCPSync_WIFCredsContent(t *testing.T) { } return json.Unmarshal(data, &capturedPayload) } - runScheduledGCPSecuritySyncFn = func(_ context.Context, _ *snowflake.Client, _, _ string, _ []string) error { + runScheduledGCPSecuritySyncFn = func(_ context.Context, _ warehouse.SyncWarehouse, _, _ string, _ []string) error { return nil } @@ -1756,11 +1756,11 @@ func TestExecuteGCPSync_OrgDiscoveryUsesScheduledAuthContext(t *testing.T) { } nativeCalls := 0 - runScheduledGCPNativeSyncFn = func(context.Context, *snowflake.Client, string, []string) error { + runScheduledGCPNativeSyncFn = func(context.Context, warehouse.SyncWarehouse, string, []string) error { nativeCalls++ return nil } - runScheduledGCPSecuritySyncFn = func(context.Context, *snowflake.Client, string, string, []string) error { + runScheduledGCPSecuritySyncFn = func(context.Context, warehouse.SyncWarehouse, string, string, []string) error { return nil } diff --git a/internal/cli/worker.go b/internal/cli/worker.go index 480e30114..d472357d4 100644 --- a/internal/cli/worker.go +++ b/internal/cli/worker.go @@ -237,11 +237,11 @@ func newNativeSyncJobHandler(application *app.App) jobs.JobHandler { } func runNativeSyncForJob(ctx context.Context, provider string, schedule *SyncSchedule) error { - client, err := createSnowflakeClient() + client, err := createSyncWarehouse() if err != nil { - return fmt.Errorf("create snowflake client: %w", err) + return fmt.Errorf("create warehouse client: %w", err) } - defer func() { _ = client.Close() }() + defer closeSyncWarehouse(client) switch provider { case "aws": diff --git a/internal/findings/postgres_store.go b/internal/findings/postgres_store.go new file mode 100644 index 000000000..a553bd2d1 --- /dev/null +++ b/internal/findings/postgres_store.go @@ -0,0 +1,611 @@ +package findings + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/writer/cerebro/internal/policy" + "github.com/writer/cerebro/internal/snowflake" +) + +const postgresFindingsTable = "cerebro_findings" + +type PostgresStore struct { + db *sql.DB + cache map[string]*Finding + semanticIndex map[string]string + dirty map[string]bool + attestor FindingAttestor + attestReobserved bool + semanticDedup bool + rewriteSQL func(string) string + mu sync.RWMutex + syncedAt time.Time +} + +func NewPostgresStore(db *sql.DB) *PostgresStore { + return &PostgresStore{ + db: db, + cache: make(map[string]*Finding), + semanticIndex: make(map[string]string), + dirty: make(map[string]bool), + semanticDedup: DefaultSemanticDedupEnabled, + } +} + +func (s *PostgresStore) EnsureSchema(ctx context.Context) error { + if s == nil || s.db == nil { + return fmt.Errorf("postgres findings store is not initialized") + } + _, err := s.db.ExecContext(ctx, s.q(` +CREATE TABLE IF NOT EXISTS `+postgresFindingsTable+` ( + id TEXT PRIMARY KEY, + policy_id TEXT NOT NULL, + policy_name TEXT NOT NULL, + severity TEXT NOT NULL, + status TEXT NOT NULL, + resource_id TEXT, + resource_type TEXT, + resource_data TEXT, + description TEXT, + remediation TEXT, + metadata TEXT NOT NULL DEFAULT '{}', + first_seen TIMESTAMP NOT NULL, + last_seen TIMESTAMP NOT NULL, + resolved_at TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_`+postgresFindingsTable+`_status ON `+postgresFindingsTable+` (status); +CREATE INDEX IF NOT EXISTS idx_`+postgresFindingsTable+`_severity ON `+postgresFindingsTable+` (severity); +CREATE INDEX IF NOT EXISTS idx_`+postgresFindingsTable+`_policy_id ON `+postgresFindingsTable+` (policy_id); +`)) + return err +} + +func (s *PostgresStore) SetAttestor(attestor FindingAttestor, attestReobserved bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.attestor = attestor + s.attestReobserved = attestReobserved +} + +func (s *PostgresStore) SetSemanticDedup(enabled bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.semanticDedup = enabled + s.rebuildSemanticIndexLocked() +} + +func (s *PostgresStore) Load(ctx context.Context) error { + if err := s.EnsureSchema(ctx); err != nil { + return err + } + + cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour) + rows, err := s.db.QueryContext(ctx, s.q(` +SELECT id, policy_id, policy_name, severity, status, + resource_id, resource_type, resource_data, description, + remediation, metadata, first_seen, last_seen, resolved_at +FROM `+postgresFindingsTable+` +WHERE UPPER(status) != 'RESOLVED' OR resolved_at > $1 +ORDER BY last_seen DESC +LIMIT 10000 +`), cutoff) + if err != nil { + return fmt.Errorf("load findings: %w", err) + } + defer func() { _ = rows.Close() }() + + s.mu.Lock() + defer s.mu.Unlock() + + s.cache = make(map[string]*Finding) + s.semanticIndex = make(map[string]string) + s.dirty = make(map[string]bool) + + for rows.Next() { + var finding Finding + var resourceData sql.NullString + var remediation sql.NullString + var metadataData string + var resolvedAt sql.NullTime + + if err := rows.Scan( + &finding.ID, + &finding.PolicyID, + &finding.PolicyName, + &finding.Severity, + &finding.Status, + &finding.ResourceID, + &finding.ResourceType, + &resourceData, + &finding.Description, + &remediation, + &metadataData, + &finding.FirstSeen, + &finding.LastSeen, + &resolvedAt, + ); err != nil { + return err + } + + if resolvedAt.Valid { + ts := resolvedAt.Time.UTC() + finding.ResolvedAt = &ts + } + if resourceData.Valid && strings.TrimSpace(resourceData.String) != "" { + if err := parseResourceData(&finding, []byte(resourceData.String)); err != nil { + return fmt.Errorf("parse resource data for finding %s: %w", finding.ID, err) + } + } + if remediation.Valid { + finding.Remediation = remediation.String + } + applyFindingMetadata(&finding, []byte(metadataData)) + finding.Status = normalizeStatus(finding.Status) + EnrichFinding(&finding) + + s.cache[finding.ID] = &finding + s.indexSemanticFindingLocked(&finding) + } + + s.syncedAt = time.Now().UTC() + return rows.Err() +} + +func (s *PostgresStore) ImportRecords(ctx context.Context, records []*snowflake.FindingRecord) error { + if len(records) == 0 { + return nil + } + + s.mu.Lock() + for _, record := range records { + if record == nil || strings.TrimSpace(record.ID) == "" { + continue + } + oldKey := "" + if existing, ok := s.cache[record.ID]; ok { + oldKey = existing.SemanticKey + } + finding := &Finding{ + ID: record.ID, + PolicyID: record.PolicyID, + PolicyName: record.PolicyName, + Severity: record.Severity, + Status: normalizeStatus(record.Status), + ResourceID: record.ResourceID, + ResourceType: record.ResourceType, + Resource: record.ResourceData, + Description: record.Description, + Remediation: record.Remediation, + FirstSeen: record.FirstSeen.UTC(), + LastSeen: record.LastSeen.UTC(), + } + if record.ResolvedAt != nil { + ts := record.ResolvedAt.UTC() + finding.ResolvedAt = &ts + } + if len(record.Metadata) > 0 { + applyFindingMetadata(finding, record.Metadata) + } + if len(record.ResourceData) > 0 { + resourceJSON, err := json.Marshal(record.ResourceData) + if err == nil { + finding.resourceJSONRaw = cloneBytes(resourceJSON) + } + } + EnrichFinding(finding) + s.cache[finding.ID] = finding + s.syncSemanticIndexLocked(finding, oldKey) + s.dirty[finding.ID] = true + } + s.mu.Unlock() + + return s.Sync(ctx) +} + +func (s *PostgresStore) Upsert(ctx context.Context, pf policy.Finding) *Finding { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + semanticKey := semanticKeyForPolicyFinding(pf) + + if existing, ok := s.cache[pf.ID]; ok { + oldKey := existing.SemanticKey + previousStatus := applyPolicyFindingUpdate(existing, pf, now) + applySemanticObservation(existing, pf, semanticKey) + s.syncSemanticIndexLocked(existing, oldKey) + EnrichFinding(existing) + eventType := upsertAttestationEvent(true, previousStatus, s.attestReobserved) + if eventType != "" { + _ = attestFindingEvent(ctx, s.attestor, existing, eventType, now) + } + s.dirty[existing.ID] = true + return existing + } + if match := s.findSemanticMatchLocked(semanticKey); match != nil { + oldKey := match.SemanticKey + previousStatus := applyPolicyFindingUpdate(match, pf, now) + applySemanticObservation(match, pf, semanticKey) + s.syncSemanticIndexLocked(match, oldKey) + EnrichFinding(match) + eventType := upsertAttestationEvent(true, previousStatus, s.attestReobserved) + if eventType != "" { + _ = attestFindingEvent(ctx, s.attestor, match, eventType, now) + } + s.dirty[match.ID] = true + return match + } + + finding := newFindingFromPolicyFinding(pf, now) + applySemanticObservation(finding, pf, semanticKey) + EnrichFinding(finding) + _ = attestFindingEvent(ctx, s.attestor, finding, upsertAttestationEvent(false, "", s.attestReobserved), now) + + s.cache[pf.ID] = finding + s.indexSemanticFindingLocked(finding) + s.dirty[pf.ID] = true + return finding +} + +func (s *PostgresStore) Get(id string) (*Finding, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + finding, ok := s.cache[id] + return finding, ok +} + +func (s *PostgresStore) Update(id string, mutate func(*Finding) error) error { + s.mu.Lock() + defer s.mu.Unlock() + + finding, ok := s.cache[id] + if !ok { + return ErrIssueNotFound + } + oldKey := finding.SemanticKey + if err := mutate(finding); err != nil { + return err + } + invalidateResourceJSONCache(finding) + finding.Status = normalizeStatus(finding.Status) + refreshFindingSemanticState(finding) + s.syncSemanticIndexLocked(finding, oldKey) + EnrichFinding(finding) + s.dirty[id] = true + return nil +} + +func (s *PostgresStore) List(filter FindingFilter) []*Finding { + s.mu.RLock() + defer s.mu.RUnlock() + + statusFilter := normalizeStatus(filter.Status) + result := make([]*Finding, 0) + for _, finding := range s.cache { + if filter.Severity != "" && finding.Severity != filter.Severity { + continue + } + if statusFilter != "" && normalizeStatus(finding.Status) != statusFilter { + continue + } + if filter.PolicyID != "" && finding.PolicyID != filter.PolicyID { + continue + } + if filter.TenantID != "" && !strings.EqualFold(strings.TrimSpace(finding.TenantID), strings.TrimSpace(filter.TenantID)) { + continue + } + if filter.SignalType != "" && !strings.EqualFold(finding.SignalType, filter.SignalType) { + continue + } + if filter.Domain != "" && !strings.EqualFold(finding.Domain, filter.Domain) { + continue + } + result = append(result, finding) + } + + if filter.Offset > 0 || filter.Limit > 0 { + if filter.Offset >= len(result) { + return []*Finding{} + } + end := len(result) + if filter.Limit > 0 && filter.Offset+filter.Limit < end { + end = filter.Offset + filter.Limit + } + result = result[filter.Offset:end] + } + + return result +} + +func (s *PostgresStore) Count(filter FindingFilter) int { + s.mu.RLock() + defer s.mu.RUnlock() + + statusFilter := normalizeStatus(filter.Status) + count := 0 + for _, finding := range s.cache { + if filter.Severity != "" && finding.Severity != filter.Severity { + continue + } + if statusFilter != "" && normalizeStatus(finding.Status) != statusFilter { + continue + } + if filter.PolicyID != "" && finding.PolicyID != filter.PolicyID { + continue + } + if filter.TenantID != "" && !strings.EqualFold(strings.TrimSpace(finding.TenantID), strings.TrimSpace(filter.TenantID)) { + continue + } + if filter.SignalType != "" && !strings.EqualFold(finding.SignalType, filter.SignalType) { + continue + } + if filter.Domain != "" && !strings.EqualFold(finding.Domain, filter.Domain) { + continue + } + count++ + } + return count +} + +func (s *PostgresStore) Resolve(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + finding, ok := s.cache[id] + if !ok { + return false + } + now := time.Now() + finding.Status = "RESOLVED" + finding.ResolvedAt = &now + finding.SnoozedUntil = nil + finding.StatusChangedAt = &now + finding.UpdatedAt = now + s.dirty[id] = true + return true +} + +func (s *PostgresStore) Suppress(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + finding, ok := s.cache[id] + if !ok { + return false + } + now := time.Now() + finding.Status = "SUPPRESSED" + finding.SnoozedUntil = nil + finding.StatusChangedAt = &now + finding.UpdatedAt = now + s.dirty[id] = true + return true +} + +func (s *PostgresStore) Stats() Stats { + s.mu.RLock() + defer s.mu.RUnlock() + + stats := Stats{ + BySeverity: make(map[string]int), + ByStatus: make(map[string]int), + ByPolicy: make(map[string]int), + BySignalType: make(map[string]int), + ByDomain: make(map[string]int), + } + for _, finding := range s.cache { + stats.Total++ + stats.BySeverity[finding.Severity]++ + stats.ByStatus[normalizeStatus(finding.Status)]++ + stats.ByPolicy[finding.PolicyID]++ + signalType := strings.ToLower(strings.TrimSpace(finding.SignalType)) + if signalType == "" { + signalType = SignalTypeSecurity + } + stats.BySignalType[signalType]++ + domain := strings.ToLower(strings.TrimSpace(finding.Domain)) + if domain == "" { + domain = DomainInfra + } + stats.ByDomain[domain]++ + } + return stats +} + +func (s *PostgresStore) Sync(ctx context.Context) error { + if err := s.EnsureSchema(ctx); err != nil { + return err + } + + s.mu.Lock() + dirtyIDs := make([]string, 0, len(s.dirty)) + for id := range s.dirty { + dirtyIDs = append(dirtyIDs, id) + } + s.mu.Unlock() + + if len(dirtyIDs) == 0 { + return nil + } + + findings := make([]*Finding, 0, len(dirtyIDs)) + for _, id := range dirtyIDs { + s.mu.RLock() + finding, ok := s.cache[id] + s.mu.RUnlock() + if ok { + findings = append(findings, finding) + } + } + if len(findings) == 0 { + return nil + } + + const batchSize = 100 + for start := 0; start < len(findings); start += batchSize { + end := start + batchSize + if end > len(findings) { + end = len(findings) + } + batch := findings[start:end] + if err := s.syncBatch(ctx, batch); err != nil { + return err + } + + s.mu.Lock() + for _, finding := range batch { + delete(s.dirty, finding.ID) + } + s.mu.Unlock() + } + + s.syncedAt = time.Now().UTC() + return nil +} + +func (s *PostgresStore) syncBatch(ctx context.Context, batch []*Finding) error { + args := make([]any, 0, len(batch)*14) + values := make([]string, 0, len(batch)) + for _, finding := range batch { + resourceJSON, err := resourceJSONForSync(finding) + if err != nil { + return fmt.Errorf("marshal resource data for finding %s: %w", finding.ID, err) + } + metadataJSON, err := buildFindingMetadata(finding) + if err != nil { + return err + } + if len(metadataJSON) == 0 { + metadataJSON = []byte("{}") + } + + var resolvedAt any + if finding.ResolvedAt != nil { + resolvedAt = finding.ResolvedAt.UTC() + } + + values = append(values, placeholderTuple(len(args)+1, 14)) + args = append(args, + finding.ID, + finding.PolicyID, + finding.PolicyName, + finding.Severity, + normalizeStatus(finding.Status), + finding.ResourceID, + finding.ResourceType, + string(resourceJSON), + finding.Description, + finding.Remediation, + string(metadataJSON), + finding.FirstSeen.UTC(), + finding.LastSeen.UTC(), + resolvedAt, + ) + } + + _, err := s.db.ExecContext(ctx, s.q(` +INSERT INTO `+postgresFindingsTable+` ( + id, policy_id, policy_name, severity, status, + resource_id, resource_type, resource_data, description, + remediation, metadata, first_seen, last_seen, resolved_at +) VALUES `+strings.Join(values, ",")+` +ON CONFLICT (id) DO UPDATE SET + policy_id = EXCLUDED.policy_id, + policy_name = EXCLUDED.policy_name, + severity = EXCLUDED.severity, + status = EXCLUDED.status, + resource_id = EXCLUDED.resource_id, + resource_type = EXCLUDED.resource_type, + resource_data = EXCLUDED.resource_data, + description = EXCLUDED.description, + remediation = EXCLUDED.remediation, + metadata = EXCLUDED.metadata, + first_seen = EXCLUDED.first_seen, + last_seen = EXCLUDED.last_seen, + resolved_at = EXCLUDED.resolved_at +`), args...) + if err != nil { + return fmt.Errorf("sync findings batch: %w", err) + } + return nil +} + +func (s *PostgresStore) SyncedAt() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.syncedAt +} + +func (s *PostgresStore) DirtyCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.dirty) +} + +func (s *PostgresStore) q(query string) string { + if s != nil && s.rewriteSQL != nil { + return s.rewriteSQL(query) + } + return query +} + +func (s *PostgresStore) findSemanticMatchLocked(semanticKey string) *Finding { + if !findingNeedsSemanticMatch(s.semanticDedup, semanticKey) { + return nil + } + id, ok := s.semanticIndex[semanticKey] + if !ok { + return nil + } + return s.cache[id] +} + +func (s *PostgresStore) syncSemanticIndexLocked(finding *Finding, oldKey string) { + if !s.semanticDedup { + return + } + ensureFindingSemanticState(finding) + oldKey = strings.TrimSpace(oldKey) + if oldKey != "" && oldKey != finding.SemanticKey && s.semanticIndex[oldKey] == finding.ID { + delete(s.semanticIndex, oldKey) + } + if strings.TrimSpace(finding.SemanticKey) != "" { + s.semanticIndex[finding.SemanticKey] = finding.ID + } +} + +func (s *PostgresStore) indexSemanticFindingLocked(finding *Finding) { + if !s.semanticDedup { + return + } + ensureFindingSemanticState(finding) + if strings.TrimSpace(finding.SemanticKey) != "" { + s.semanticIndex[finding.SemanticKey] = finding.ID + } +} + +func (s *PostgresStore) rebuildSemanticIndexLocked() { + s.semanticIndex = make(map[string]string, len(s.cache)) + if !s.semanticDedup { + return + } + for _, finding := range s.cache { + s.indexSemanticFindingLocked(finding) + } +} + +func placeholderTuple(start, count int) string { + parts := make([]string, count) + for idx := 0; idx < count; idx++ { + parts[idx] = fmt.Sprintf("$%d", start+idx) + } + return "(" + strings.Join(parts, ", ") + ")" +} + +var _ FindingStore = (*PostgresStore)(nil) diff --git a/internal/findings/postgres_store_test.go b/internal/findings/postgres_store_test.go new file mode 100644 index 000000000..915029051 --- /dev/null +++ b/internal/findings/postgres_store_test.go @@ -0,0 +1,139 @@ +package findings + +import ( + "context" + "database/sql" + "encoding/json" + "regexp" + "testing" + "time" + + _ "modernc.org/sqlite" + + "github.com/writer/cerebro/internal/policy" + "github.com/writer/cerebro/internal/snowflake" +) + +var postgresDollarPlaceholderRe = regexp.MustCompile(`\$\d+`) + +func postgresSQLiteRewrite(query string) string { + return postgresDollarPlaceholderRe.ReplaceAllString(query, "?") +} + +func newTestPostgresFindingsStore(t *testing.T) (*PostgresStore, *sql.DB) { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + store := &PostgresStore{ + db: db, + cache: make(map[string]*Finding), + semanticIndex: make(map[string]string), + dirty: make(map[string]bool), + semanticDedup: DefaultSemanticDedupEnabled, + rewriteSQL: postgresSQLiteRewrite, + } + if err := store.EnsureSchema(context.Background()); err != nil { + _ = db.Close() + t.Fatalf("ensure schema: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return store, db +} + +func TestPostgresStoreImportRecordsPersistsAndLoads(t *testing.T) { + store, db := newTestPostgresFindingsStore(t) + + resolvedAt := time.Now().UTC().Add(-time.Hour) + if err := store.ImportRecords(context.Background(), []*snowflake.FindingRecord{ + { + ID: "finding-1", + PolicyID: "policy-1", + PolicyName: "Public bucket", + Severity: "high", + Status: "OPEN", + ResourceID: "bucket-1", + ResourceType: "s3_bucket", + ResourceData: map[string]interface{}{"name": "bucket-1"}, + Description: "bucket is public", + Metadata: json.RawMessage(`{"tenant_id":"tenant-a","signal_type":"security","domain":"infra"}`), + FirstSeen: time.Now().UTC().Add(-2 * time.Hour), + LastSeen: time.Now().UTC(), + ResolvedAt: &resolvedAt, + }, + }); err != nil { + t.Fatalf("ImportRecords() error = %v", err) + } + + reloaded := &PostgresStore{ + db: db, + cache: make(map[string]*Finding), + semanticIndex: make(map[string]string), + dirty: make(map[string]bool), + semanticDedup: DefaultSemanticDedupEnabled, + rewriteSQL: postgresSQLiteRewrite, + } + if err := reloaded.Load(context.Background()); err != nil { + t.Fatalf("Load() error = %v", err) + } + + finding, ok := reloaded.Get("finding-1") + if !ok { + t.Fatal("expected migrated finding to be available after reload") + } + if finding.TenantID != "tenant-a" { + t.Fatalf("TenantID = %q, want tenant-a", finding.TenantID) + } + if finding.SignalType != SignalTypeSecurity { + t.Fatalf("SignalType = %q, want %q", finding.SignalType, SignalTypeSecurity) + } + if finding.Domain != DomainInfra { + t.Fatalf("Domain = %q, want %q", finding.Domain, DomainInfra) + } +} + +func TestPostgresStoreResolveSyncPersistsStatus(t *testing.T) { + store, db := newTestPostgresFindingsStore(t) + pf := policy.Finding{ + ID: "finding-2", + PolicyID: "policy-2", + PolicyName: "Encrypt storage", + Severity: "medium", + ResourceID: "bucket-2", + ResourceType: "s3_bucket", + Description: "bucket is not encrypted", + } + + if store.Upsert(context.Background(), pf) == nil { + t.Fatal("expected Upsert() to create a finding") + } + if err := store.Sync(context.Background()); err != nil { + t.Fatalf("initial Sync() error = %v", err) + } + if !store.Resolve("finding-2") { + t.Fatal("expected Resolve() to update existing finding") + } + if err := store.Sync(context.Background()); err != nil { + t.Fatalf("resolve Sync() error = %v", err) + } + + reloaded := &PostgresStore{ + db: db, + cache: make(map[string]*Finding), + semanticIndex: make(map[string]string), + dirty: make(map[string]bool), + semanticDedup: DefaultSemanticDedupEnabled, + rewriteSQL: postgresSQLiteRewrite, + } + if err := reloaded.Load(context.Background()); err != nil { + t.Fatalf("Load() error = %v", err) + } + finding, ok := reloaded.Get("finding-2") + if !ok { + t.Fatal("expected resolved finding to remain loadable") + } + if got := normalizeStatus(finding.Status); got != "RESOLVED" { + t.Fatalf("Status = %q, want RESOLVED", got) + } +} diff --git a/internal/graph/store_neptune.go b/internal/graph/store_neptune.go index fb49927a2..f9137418e 100644 --- a/internal/graph/store_neptune.go +++ b/internal/graph/store_neptune.go @@ -81,7 +81,7 @@ func (e *neptuneDataExecutor) ExecuteOpenCypher(ctx context.Context, query strin if output == nil { return nil, nil } - return output.Results, nil + return neptuneDecodeExecuteResults(output.Results) }) } @@ -120,6 +120,21 @@ func defaultNeptuneRetryOptions() neptuneRetryOptions { } } +func neptuneDecodeExecuteResults(results any) (any, error) { + if results == nil { + return nil, nil + } + unmarshaler, ok := results.(document.Unmarshaler) + if !ok { + return results, nil + } + var decoded any + if err := unmarshaler.UnmarshalSmithyDocument(&decoded); err != nil { + return nil, fmt.Errorf("unmarshal neptune results: %w", err) + } + return neptuneNormalizeValue(decoded), nil +} + func normalizeNeptuneRetryOptions(opts neptuneRetryOptions) neptuneRetryOptions { defaults := defaultNeptuneRetryOptions() if opts.Attempts <= 0 { diff --git a/internal/graph/store_neptune_pool.go b/internal/graph/store_neptune_pool.go index df6290643..a7f64ce0e 100644 --- a/internal/graph/store_neptune_pool.go +++ b/internal/graph/store_neptune_pool.go @@ -160,7 +160,7 @@ func (e *pooledNeptuneDataExecutor) ExecuteOpenCypher(ctx context.Context, query if output == nil { return nil, nil } - return output.Results, nil + return neptuneDecodeExecuteResults(output.Results) } func (e *pooledNeptuneDataExecutor) ExecuteOpenCypherExplain(ctx context.Context, query string, mode NeptuneExplainMode, params map[string]any) ([]byte, error) { diff --git a/internal/graph/store_neptune_pool_test.go b/internal/graph/store_neptune_pool_test.go index 0b89e27c8..9381676de 100644 --- a/internal/graph/store_neptune_pool_test.go +++ b/internal/graph/store_neptune_pool_test.go @@ -212,6 +212,34 @@ func TestPooledNeptuneDataExecutorRecyclesConnectionsAfterMaxUses(t *testing.T) } } +func TestPooledNeptuneGraphStoreCountNodesDecodesSmithyDocumentResults(t *testing.T) { + client, cleanup := newNeptuneDocumentTestClient(t, `{"results":[{"total":2}]}`) + defer cleanup() + + exec, err := NewPooledNeptuneDataExecutor(func() (neptuneDataClient, error) { + return client, nil + }, NeptuneDataExecutorPoolConfig{ + Size: 1, + }) + if err != nil { + t.Fatalf("NewPooledNeptuneDataExecutor() error = %v", err) + } + defer func() { + if err := exec.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + }() + + store := NewNeptuneGraphStore(exec) + got, err := store.CountNodes(context.Background()) + if err != nil { + t.Fatalf("CountNodes() error = %v", err) + } + if got != 2 { + t.Fatalf("CountNodes() = %d, want 2", got) + } +} + func TestPooledNeptuneDataExecutorDrainWaitsForInflightQueriesAndRejectsNewWork(t *testing.T) { started := make(chan struct{}, 1) release := make(chan struct{}) diff --git a/internal/graph/store_neptune_test.go b/internal/graph/store_neptune_test.go index c53349a63..d9539beff 100644 --- a/internal/graph/store_neptune_test.go +++ b/internal/graph/store_neptune_test.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "net" + "net/http" + "net/http/httptest" "reflect" "regexp" "sort" @@ -12,6 +14,7 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/neptunedata" "github.com/aws/smithy-go" ) @@ -102,6 +105,23 @@ func (f *fakeNeptuneDataClient) ExecuteOpenCypherExplainQuery(_ context.Context, return output, nil } +func newNeptuneDocumentTestClient(t *testing.T, responseBody string) (neptuneDataClient, func()) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(responseBody)) + })) + client := neptunedata.NewFromConfig(aws.Config{ + Region: "us-east-1", + Credentials: aws.AnonymousCredentials{}, + HTTPClient: server.Client(), + }, func(options *neptunedata.Options) { + options.BaseEndpoint = aws.String(server.URL) + }) + return client, server.Close +} + type timeoutNetError struct{} func (timeoutNetError) Error() string { return "i/o timeout" } @@ -201,6 +221,20 @@ func TestNeptuneDataExecutorExecuteOpenCypherStopsAtMaxAttempts(t *testing.T) { } } +func TestNeptuneGraphStoreCountNodesDecodesSmithyDocumentResults(t *testing.T) { + client, cleanup := newNeptuneDocumentTestClient(t, `{"results":[{"total":2}]}`) + defer cleanup() + + store := NewNeptuneGraphStore(NewNeptuneDataExecutor(client)) + got, err := store.CountNodes(context.Background()) + if err != nil { + t.Fatalf("CountNodes() error = %v", err) + } + if got != 2 { + t.Fatalf("CountNodes() = %d, want 2", got) + } +} + func TestNeptuneDataExecutorExecuteOpenCypherExplainRetriesTransientErrors(t *testing.T) { client := &fakeNeptuneDataClient{ explainErrors: []error{ diff --git a/internal/remediation/executor.go b/internal/remediation/executor.go index 45e22933f..46f058377 100644 --- a/internal/remediation/executor.go +++ b/internal/remediation/executor.go @@ -44,6 +44,7 @@ var ( _ NotificationSender = (*notifications.Manager)(nil) _ FindingsWriter = (*findings.Store)(nil) _ FindingsWriter = (*findings.SQLiteStore)(nil) + _ FindingsWriter = (*findings.PostgresStore)(nil) _ FindingsWriter = (*findings.SnowflakeStore)(nil) _ EventPublisher = (*webhooks.Service)(nil) ) diff --git a/internal/scanner/relationship_toxic.go b/internal/scanner/relationship_toxic.go index d5668abe6..d523ac2e2 100644 --- a/internal/scanner/relationship_toxic.go +++ b/internal/scanner/relationship_toxic.go @@ -36,6 +36,10 @@ type ToxicScanCursor struct { SinceID string // keyset tiebreak: last resource_id at SinceTime } +func SupportsRelationshipToxicDetection(w warehouse.SchemaWarehouse) bool { + return w != nil && strings.EqualFold(strings.TrimSpace(w.Database()), "snowflake") +} + func DetectRelationshipToxicCombinations(ctx context.Context, sf warehouse.QueryWarehouse, cursor *ToxicScanCursor) (*ToxicDetectionResult, error) { if sf == nil { return &ToxicDetectionResult{}, nil diff --git a/internal/scanner/relationship_toxic_test.go b/internal/scanner/relationship_toxic_test.go index f9cfb011c..14220a9ee 100644 --- a/internal/scanner/relationship_toxic_test.go +++ b/internal/scanner/relationship_toxic_test.go @@ -5,10 +5,24 @@ import ( "strings" "testing" "time" + + "github.com/writer/cerebro/internal/warehouse" ) // --- toxicSinceFilter tests --- +func TestSupportsRelationshipToxicDetection(t *testing.T) { + if SupportsRelationshipToxicDetection(nil) { + t.Fatal("expected nil warehouse to be unsupported") + } + if SupportsRelationshipToxicDetection(&warehouse.MemoryWarehouse{DatabaseValue: "postgres"}) { + t.Fatal("expected postgres warehouse to skip relationship toxic detection") + } + if !SupportsRelationshipToxicDetection(&warehouse.MemoryWarehouse{DatabaseValue: "snowflake"}) { + t.Fatal("expected snowflake warehouse to support relationship toxic detection") + } +} + func TestToxicSinceFilter_NilCursor(t *testing.T) { if got := toxicSinceFilter("s", nil); got != "" { t.Errorf("expected empty, got %q", got) diff --git a/internal/snowflake/repository.go b/internal/snowflake/repository.go index 02f461938..4014a96e3 100644 --- a/internal/snowflake/repository.go +++ b/internal/snowflake/repository.go @@ -159,6 +159,63 @@ func (r *FindingRepository) List(ctx context.Context, filter FindingFilter) ([]* return findings, nil } +func (r *FindingRepository) ListAll(ctx context.Context) ([]*FindingRecord, error) { + findingsTable, err := SafeQualifiedTableRef(r.schema, "findings") + if err != nil { + return nil, fmt.Errorf("invalid findings table reference: %w", err) + } + + // #nosec G202 -- findingsTable is validated via SafeQualifiedTableRef. + rows, err := r.client.db.QueryContext(ctx, ` + SELECT id, policy_id, policy_name, severity, status, + resource_id, resource_type, resource_data, description, + remediation, metadata, first_seen, last_seen, resolved_at + FROM `+findingsTable+` + ORDER BY last_seen DESC + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + records := make([]*FindingRecord, 0) + for rows.Next() { + record := &FindingRecord{} + var resourceData []byte + var metadataData []byte + var remediation sql.NullString + if err := rows.Scan( + &record.ID, + &record.PolicyID, + &record.PolicyName, + &record.Severity, + &record.Status, + &record.ResourceID, + &record.ResourceType, + &resourceData, + &record.Description, + &remediation, + &metadataData, + &record.FirstSeen, + &record.LastSeen, + &record.ResolvedAt, + ); err != nil { + return nil, err + } + if remediation.Valid { + record.Remediation = remediation.String + } + if len(metadataData) > 0 { + record.Metadata = metadataData + } + if len(resourceData) > 0 { + _ = json.Unmarshal(resourceData, &record.ResourceData) + } + records = append(records, record) + } + return records, rows.Err() +} + func (r *FindingRepository) UpdateStatus(ctx context.Context, id, status string) error { normalized := strings.ToUpper(status) findingsTable, err := SafeQualifiedTableRef(r.schema, "findings") @@ -294,6 +351,7 @@ func NewAuditRepository(client *Client) *AuditRepository { type AuditEntry struct { ID string `json:"id"` + Timestamp time.Time `json:"timestamp,omitempty"` Action string `json:"action"` ActorID string `json:"actor_id"` ActorType string `json:"actor_type"` @@ -308,6 +366,9 @@ func (r *AuditRepository) Log(ctx context.Context, entry *AuditEntry) error { if entry.ID == "" { entry.ID = uuid.New().String() } + if entry.Timestamp.IsZero() { + entry.Timestamp = time.Now().UTC() + } detailsJSON, _ := json.Marshal(entry.Details) auditTable, err := SafeQualifiedTableRef(r.schema, "audit_log") @@ -318,13 +379,13 @@ func (r *AuditRepository) Log(ctx context.Context, entry *AuditEntry) error { // #nosec G202 -- auditTable is validated via SafeQualifiedTableRef. query := ` INSERT INTO ` + auditTable + ` ( - id, action, actor_id, actor_type, resource_type, + id, timestamp, action, actor_id, actor_type, resource_type, resource_id, details, ip_address, user_agent - ) VALUES (?, ?, ?, ?, ?, ?, PARSE_JSON(?), ?, ?) + ) VALUES (?, ?, ?, ?, ?, ?, ?, PARSE_JSON(?), ?, ?) ` _, err = r.client.db.ExecContext(ctx, query, - entry.ID, entry.Action, entry.ActorID, entry.ActorType, entry.ResourceType, + entry.ID, entry.Timestamp.UTC(), entry.Action, entry.ActorID, entry.ActorType, entry.ResourceType, entry.ResourceID, string(detailsJSON), entry.IPAddress, entry.UserAgent, ) return err @@ -372,12 +433,58 @@ func (r *AuditRepository) List(ctx context.Context, resourceType, resourceID str &e.ResourceType, &e.ResourceID, &e.IPAddress, &ts); err != nil { continue } - _ = ts // timestamp available for future use + e.Timestamp = ts.UTC() entries = append(entries, &e) } return entries, nil } +func (r *AuditRepository) ListAll(ctx context.Context) ([]*AuditEntry, error) { + auditTable, err := SafeQualifiedTableRef(r.schema, "audit_log") + if err != nil { + return nil, fmt.Errorf("invalid audit_log table reference: %w", err) + } + + // #nosec G202 -- auditTable is validated via SafeQualifiedTableRef. + rows, err := r.client.db.QueryContext(ctx, ` + SELECT id, timestamp, action, actor_id, actor_type, resource_type, resource_id, details, ip_address, user_agent + FROM `+auditTable+` + ORDER BY timestamp DESC + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + entries := make([]*AuditEntry, 0) + for rows.Next() { + entry := &AuditEntry{} + var detailsRaw any + if err := rows.Scan( + &entry.ID, + &entry.Timestamp, + &entry.Action, + &entry.ActorID, + &entry.ActorType, + &entry.ResourceType, + &entry.ResourceID, + &detailsRaw, + &entry.IPAddress, + &entry.UserAgent, + ); err != nil { + return nil, err + } + if detailsJSON := normalizeVariantJSONForState(detailsRaw); len(detailsJSON) > 0 { + if err := json.Unmarshal(detailsJSON, &entry.Details); err != nil { + return nil, err + } + } + entry.Timestamp = entry.Timestamp.UTC() + entries = append(entries, entry) + } + return entries, rows.Err() +} + // PolicyHistoryRepository handles policy version history persistence. type PolicyHistoryRepository struct { client *Client @@ -528,3 +635,55 @@ func (r *PolicyHistoryRepository) List(ctx context.Context, policyID string, lim return result, rows.Err() } + +func (r *PolicyHistoryRepository) ListAll(ctx context.Context) ([]*PolicyHistoryRecord, error) { + policyHistoryTable, err := SafeQualifiedTableRef(r.schema, "policy_history") + if err != nil { + return nil, fmt.Errorf("invalid policy_history table reference: %w", err) + } + + // #nosec G202 -- policyHistoryTable is validated via SafeQualifiedTableRef. + rows, err := r.client.db.QueryContext(ctx, ` + SELECT policy_id, version, content, change_type, pinned_version, effective_from, effective_to + FROM `+policyHistoryTable+` + ORDER BY policy_id ASC, version DESC + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + records := make([]*PolicyHistoryRecord, 0) + for rows.Next() { + record := &PolicyHistoryRecord{} + var content []byte + var changeType sql.NullString + var pinned sql.NullInt64 + var effectiveTo sql.NullTime + if err := rows.Scan( + &record.PolicyID, + &record.Version, + &content, + &changeType, + &pinned, + &record.EffectiveFrom, + &effectiveTo, + ); err != nil { + return nil, err + } + record.Content = content + if changeType.Valid { + record.ChangeType = changeType.String + } + if pinned.Valid { + pinnedVal := int(pinned.Int64) + record.PinnedVersion = &pinnedVal + } + if effectiveTo.Valid { + ts := effectiveTo.Time + record.EffectiveTo = &ts + } + records = append(records, record) + } + return records, rows.Err() +} diff --git a/internal/snowflake/tableops/tableops.go b/internal/snowflake/tableops/tableops.go index 0659bbb53..a371160fe 100644 --- a/internal/snowflake/tableops/tableops.go +++ b/internal/snowflake/tableops/tableops.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/writer/cerebro/internal/snowflake" + "github.com/writer/cerebro/internal/warehouse" ) const DefaultInsertBatchSize = 200 @@ -43,11 +44,11 @@ func EnsureVariantTable(ctx context.Context, client QueryExecClient, table strin colDefs := make([]string, 0, len(filtered)+3) colDefs = append(colDefs, "_CQ_ID VARCHAR PRIMARY KEY", - "_CQ_SYNC_TIME TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP()", + fmt.Sprintf("_CQ_SYNC_TIME %s DEFAULT CURRENT_TIMESTAMP()", warehouse.TimestampColumnType(client)), ) colDefs = append(colDefs, "_CQ_HASH VARCHAR") for _, col := range filtered { - colDefs = append(colDefs, fmt.Sprintf("%s VARIANT", col)) + colDefs = append(colDefs, fmt.Sprintf("%s %s", col, warehouse.JSONColumnType(client))) } createQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n\t%s\n)", table, strings.Join(colDefs, ",\n\t")) @@ -72,7 +73,7 @@ func EnsureVariantTable(ctx context.Context, client QueryExecClient, table strin desired = append(desired, filtered...) for _, col := range columnsMissingFromSchema(existingCols, desired) { - columnType := "VARIANT" + columnType := warehouse.JSONColumnType(client) if col == "_CQ_HASH" { columnType = "VARCHAR" } @@ -128,7 +129,7 @@ func InsertVariantRowsBatch(ctx context.Context, client ExecClient, table string } batch := rows[start:end] - selects := make([]string, 0, len(batch)) + values := make([]string, 0, len(batch)) args := make([]interface{}, 0, len(batch)*len(allColumns)) for _, row := range batch { @@ -143,24 +144,24 @@ func InsertVariantRowsBatch(ctx context.Context, client ExecClient, table string } hash := stringValue(rowUpper["_CQ_HASH"]) - selectParts := make([]string, 0, len(allColumns)) - selectParts = append(selectParts, "?", "?") + valueParts := make([]string, 0, len(allColumns)) + valueParts = append(valueParts, warehouse.Placeholder(client, len(args)+1), warehouse.Placeholder(client, len(args)+2)) args = append(args, id, hash) for _, col := range columns { jsonVal, _ := json.Marshal(rowUpper[col]) - selectParts = append(selectParts, "PARSE_JSON(?)") + valueParts = append(valueParts, warehouse.JSONPlaceholder(client, len(args)+1)) args = append(args, string(jsonVal)) } - selects = append(selects, "SELECT "+strings.Join(selectParts, ", ")) + values = append(values, "("+strings.Join(valueParts, ", ")+")") } - if len(selects) == 0 { + if len(values) == 0 { continue } - query := fmt.Sprintf("INSERT INTO %s (%s) %s", table, strings.Join(allColumns, ", "), strings.Join(selects, " UNION ALL ")) + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", table, strings.Join(allColumns, ", "), strings.Join(values, ", ")) if _, err := client.Exec(ctx, query, args...); err != nil { return fmt.Errorf("insert rows: %w", err) } @@ -205,6 +206,7 @@ func MergeVariantRowsBatch(ctx context.Context, client ExecClient, table string, return err } + dialect := warehouse.Dialect(client) for start := 0; start < len(rows); start += batchSize { end := start + batchSize if end > len(rows) { @@ -212,79 +214,140 @@ func MergeVariantRowsBatch(ctx context.Context, client ExecClient, table string, } batch := rows[start:end] - selects := make([]string, 0, len(batch)) - args := make([]interface{}, 0, len(batch)*len(allColumns)) - firstSelect := true + var ( + query string + args []interface{} + ) + if dialect == warehouse.SQLDialectSnowflake { + query, args = buildSnowflakeMergeQuery(batch, columns, allColumns, client, table) + } else { + query, args = buildUpsertQuery(batch, columns, allColumns, client, table) + } + if strings.TrimSpace(query) == "" { + continue + } + if _, err := client.Exec(ctx, query, args...); err != nil { + return fmt.Errorf("merge rows: %w", err) + } + } - for _, row := range batch { - rowUpper := make(map[string]interface{}, len(row)) - for key, value := range row { - rowUpper[strings.ToUpper(key)] = value - } + return nil +} - id := strings.TrimSpace(stringValue(rowUpper["_CQ_ID"])) - if id == "" { - continue - } - hash := stringValue(rowUpper["_CQ_HASH"]) +func buildSnowflakeMergeQuery(batch []map[string]interface{}, columns, allColumns []string, client ExecClient, table string) (string, []interface{}) { + selects := make([]string, 0, len(batch)) + args := make([]interface{}, 0, len(batch)*len(allColumns)) + firstSelect := true - parts := make([]string, 0, len(allColumns)) - args = append(args, id, hash) + for _, row := range batch { + rowUpper := make(map[string]interface{}, len(row)) + for key, value := range row { + rowUpper[strings.ToUpper(key)] = value + } - // The first emitted SELECT needs column aliases for UNION ALL. - if firstSelect { - parts = append(parts, "? AS _CQ_ID", "? AS _CQ_HASH") - for _, col := range columns { - jsonVal, _ := json.Marshal(rowUpper[col]) - parts = append(parts, fmt.Sprintf("PARSE_JSON(?) AS %s", col)) - args = append(args, string(jsonVal)) - } - firstSelect = false - } else { - parts = append(parts, "?", "?") - for _, col := range columns { - jsonVal, _ := json.Marshal(rowUpper[col]) - parts = append(parts, "PARSE_JSON(?)") - args = append(args, string(jsonVal)) - } + id := strings.TrimSpace(stringValue(rowUpper["_CQ_ID"])) + if id == "" { + continue + } + hash := stringValue(rowUpper["_CQ_HASH"]) + + parts := make([]string, 0, len(allColumns)) + args = append(args, id, hash) + + if firstSelect { + parts = append(parts, "? AS _CQ_ID", "? AS _CQ_HASH") + for _, col := range columns { + jsonVal, _ := json.Marshal(rowUpper[col]) + parts = append(parts, fmt.Sprintf("PARSE_JSON(?) AS %s", col)) + args = append(args, string(jsonVal)) } + firstSelect = false + } else { + parts = append(parts, "?", "?") + for _, col := range columns { + jsonVal, _ := json.Marshal(rowUpper[col]) + parts = append(parts, "PARSE_JSON(?)") + args = append(args, string(jsonVal)) + } + } + + selects = append(selects, "SELECT "+strings.Join(parts, ", ")) + } - selects = append(selects, "SELECT "+strings.Join(parts, ", ")) + if len(selects) == 0 { + return "", nil + } + + usingClause := strings.Join(selects, " UNION ALL ") + updateParts := make([]string, 0, len(columns)+1) + updateParts = append(updateParts, "t._CQ_HASH = s._CQ_HASH") + for _, col := range columns { + updateParts = append(updateParts, fmt.Sprintf("t.%s = s.%s", col, col)) + } + + insertCols := strings.Join(allColumns, ", ") + insertVals := make([]string, 0, len(allColumns)) + for _, col := range allColumns { + insertVals = append(insertVals, "s."+col) + } + + query := fmt.Sprintf( + "MERGE INTO %s t USING (%s) s ON t._CQ_ID = s._CQ_ID "+ + "WHEN MATCHED THEN UPDATE SET %s "+ + "WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)", + table, usingClause, strings.Join(updateParts, ", "), + insertCols, strings.Join(insertVals, ", "), + ) + return query, args +} + +func buildUpsertQuery(batch []map[string]interface{}, columns, allColumns []string, client ExecClient, table string) (string, []interface{}) { + values := make([]string, 0, len(batch)) + args := make([]interface{}, 0, len(batch)*len(allColumns)) + + for _, row := range batch { + rowUpper := make(map[string]interface{}, len(row)) + for key, value := range row { + rowUpper[strings.ToUpper(key)] = value } - if len(selects) == 0 { + id := strings.TrimSpace(stringValue(rowUpper["_CQ_ID"])) + if id == "" { continue } + hash := stringValue(rowUpper["_CQ_HASH"]) - usingClause := strings.Join(selects, " UNION ALL ") + valueParts := make([]string, 0, len(allColumns)) + valueParts = append(valueParts, warehouse.Placeholder(client, len(args)+1), warehouse.Placeholder(client, len(args)+2)) + args = append(args, id, hash) - // Build UPDATE SET clause - updateParts := make([]string, 0, len(columns)+1) - updateParts = append(updateParts, "t._CQ_HASH = s._CQ_HASH") for _, col := range columns { - updateParts = append(updateParts, fmt.Sprintf("t.%s = s.%s", col, col)) + jsonVal, _ := json.Marshal(rowUpper[col]) + valueParts = append(valueParts, warehouse.JSONPlaceholder(client, len(args)+1)) + args = append(args, string(jsonVal)) } - // Build INSERT column/value lists - insertCols := strings.Join(allColumns, ", ") - insertVals := make([]string, 0, len(allColumns)) - for _, col := range allColumns { - insertVals = append(insertVals, "s."+col) - } + values = append(values, "("+strings.Join(valueParts, ", ")+")") + } - query := fmt.Sprintf( - "MERGE INTO %s t USING (%s) s ON t._CQ_ID = s._CQ_ID "+ - "WHEN MATCHED THEN UPDATE SET %s "+ - "WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)", - table, usingClause, strings.Join(updateParts, ", "), - insertCols, strings.Join(insertVals, ", ")) + if len(values) == 0 { + return "", nil + } - if _, err := client.Exec(ctx, query, args...); err != nil { - return fmt.Errorf("merge rows: %w", err) - } + updateParts := make([]string, 0, len(columns)+1) + updateParts = append(updateParts, "_CQ_HASH = EXCLUDED._CQ_HASH") + for _, col := range columns { + updateParts = append(updateParts, fmt.Sprintf("%s = EXCLUDED.%s", col, col)) } - return nil + query := fmt.Sprintf( + "INSERT INTO %s (%s) VALUES %s ON CONFLICT (_CQ_ID) DO UPDATE SET %s", + table, + strings.Join(allColumns, ", "), + strings.Join(values, ", "), + strings.Join(updateParts, ", "), + ) + return query, args } func normalizeReserved(custom map[string]struct{}) map[string]struct{} { @@ -329,10 +392,19 @@ func validateColumns(columns []string) error { } func tableColumns(ctx context.Context, client QueryExecClient, table string) ([]string, error) { + if memory, ok := any(client).(*warehouse.MemoryWarehouse); ok && memory.DescribeColumnsFunc == nil { + // Preserve query-based behavior for tests that stub INFORMATION_SCHEMA reads + // via QueryFunc without wiring DescribeColumnsFunc explicitly. + } else if describer, ok := any(client).(interface { + DescribeColumns(context.Context, string) ([]string, error) + }); ok { + return describer.DescribeColumns(ctx, table) + } + query := ` SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_NAME = ? + WHERE TABLE_NAME = ` + warehouse.Placeholder(client, 1) + ` AND TABLE_SCHEMA = CURRENT_SCHEMA() ` diff --git a/internal/sync/engine.go b/internal/sync/engine.go index 67e900402..177c780ce 100644 --- a/internal/sync/engine.go +++ b/internal/sync/engine.go @@ -645,7 +645,7 @@ func (e *SyncEngine) deleteRowsByID(ctx context.Context, table string, ids map[s end = len(keys) } batch := keys[start:end] - placeholders := strings.TrimRight(strings.Repeat("?,", len(batch)), ",") + placeholders := strings.Join(warehouse.Placeholders(e.sf, 1, len(batch)), ",") args := make([]interface{}, len(batch)) for i, id := range batch { args[i] = id @@ -702,12 +702,12 @@ func (e *SyncEngine) scopeWhereClause(region string, hasRegion bool, hasAccount args := make([]interface{}, 0, 2) if hasAccount && e.accountID != "" { - clauses = append(clauses, "ACCOUNT_ID = ?") + clauses = append(clauses, "ACCOUNT_ID = "+warehouse.Placeholder(e.sf, len(args)+1)) args = append(args, e.accountID) } if hasRegion && !globalScope { - clauses = append(clauses, "REGION = ?") + clauses = append(clauses, "REGION = "+warehouse.Placeholder(e.sf, len(args)+1)) args = append(args, region) } @@ -732,8 +732,8 @@ func (e *SyncEngine) persistChangeHistory(ctx context.Context, results []SyncRes region VARCHAR, account_id VARCHAR, provider VARCHAR, - timestamp TIMESTAMP_TZ, - _cq_sync_time TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP() + timestamp ` + warehouse.TimestampColumnType(e.sf) + `, + _cq_sync_time ` + warehouse.TimestampColumnType(e.sf) + ` DEFAULT CURRENT_TIMESTAMP() )` if _, err := e.sf.Exec(ctx, createQuery); err != nil { @@ -744,7 +744,7 @@ func (e *SyncEngine) persistChangeHistory(ctx context.Context, results []SyncRes "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS region VARCHAR", "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS account_id VARCHAR", "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS provider VARCHAR", - "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS timestamp TIMESTAMP_TZ", + "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS timestamp " + warehouse.TimestampColumnType(e.sf), } { if _, err := e.sf.Exec(ctx, query); err != nil { e.logger.Debug("failed to ensure change history column", "query", query, "error", err) @@ -778,8 +778,10 @@ func (e *SyncEngine) persistChangeHistory(ctx context.Context, results []SyncRes func (e *SyncEngine) insertChangeRecord(ctx context.Context, table, resourceID, op, region string, ts time.Time) { id := fmt.Sprintf("%s-%s-%s-%d", table, resourceID, op, ts.UnixNano()) - query := `INSERT INTO _sync_change_history (id, table_name, resource_id, operation, region, account_id, provider, timestamp) - SELECT ?, ?, ?, ?, ?, ?, ?, ?` + query := fmt.Sprintf( + "INSERT INTO _sync_change_history (id, table_name, resource_id, operation, region, account_id, provider, timestamp) VALUES (%s)", + strings.Join(warehouse.Placeholders(e.sf, 1, 8), ", "), + ) if _, err := e.sf.Exec(ctx, query, id, table, resourceID, op, region, e.accountID, "aws", ts); err != nil { e.logger.Debug("failed to insert change record", "error", err) @@ -794,8 +796,8 @@ func (e *SyncEngine) ensureBackfillQueueTable(ctx context.Context) error { region VARCHAR, account_id VARCHAR, reason VARCHAR, - requested_at TIMESTAMP_TZ, - _cq_sync_time TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP() + requested_at ` + warehouse.TimestampColumnType(e.sf) + `, + _cq_sync_time ` + warehouse.TimestampColumnType(e.sf) + ` DEFAULT CURRENT_TIMESTAMP() )` if _, err := e.sf.Exec(ctx, createQuery); err != nil { return fmt.Errorf("create backfill queue: %w", err) @@ -843,7 +845,7 @@ func (e *SyncEngine) loadBackfillRequests(ctx context.Context) map[string]string } result, err := e.sf.Query(ctx, - "SELECT table_name, region, reason FROM _sync_backfill_queue WHERE provider = ? AND account_id = ?", + "SELECT table_name, region, reason FROM _sync_backfill_queue WHERE provider = "+warehouse.Placeholder(e.sf, 1)+" AND account_id = "+warehouse.Placeholder(e.sf, 2), "aws", e.accountID, ) @@ -868,17 +870,31 @@ func (e *SyncEngine) recordBackfillRequest(ctx context.Context, table, region, r } id := backfillQueueID(e.accountID, table, region) - mergeQuery := `MERGE INTO _sync_backfill_queue t - USING (SELECT ? AS id, ? AS provider, ? AS table_name, ? AS region, ? AS account_id, ? AS reason, CURRENT_TIMESTAMP() AS requested_at) s - ON t.id = s.id - WHEN MATCHED THEN UPDATE SET - reason = s.reason, - requested_at = s.requested_at, - _cq_sync_time = CURRENT_TIMESTAMP() - WHEN NOT MATCHED THEN INSERT (id, provider, table_name, region, account_id, reason, requested_at) - VALUES (s.id, s.provider, s.table_name, s.region, s.account_id, s.reason, s.requested_at)` - - if _, err := e.sf.Exec(ctx, mergeQuery, id, "aws", table, region, e.accountID, reason); err != nil { + var ( + upsertQuery string + upsertArgs []interface{} + ) + if warehouse.Dialect(e.sf) == warehouse.SQLDialectSnowflake { + upsertQuery = `MERGE INTO _sync_backfill_queue t + USING (SELECT ? AS id, ? AS provider, ? AS table_name, ? AS region, ? AS account_id, ? AS reason, CURRENT_TIMESTAMP() AS requested_at) s + ON t.id = s.id + WHEN MATCHED THEN UPDATE SET + reason = s.reason, + requested_at = s.requested_at, + _cq_sync_time = CURRENT_TIMESTAMP() + WHEN NOT MATCHED THEN INSERT (id, provider, table_name, region, account_id, reason, requested_at) + VALUES (s.id, s.provider, s.table_name, s.region, s.account_id, s.reason, s.requested_at)` + upsertArgs = []interface{}{id, "aws", table, region, e.accountID, reason} + } else { + upsertQuery = fmt.Sprintf( + "INSERT INTO _sync_backfill_queue (id, provider, table_name, region, account_id, reason, requested_at) VALUES (%s) "+ + "ON CONFLICT (id) DO UPDATE SET reason = EXCLUDED.reason, requested_at = EXCLUDED.requested_at, _cq_sync_time = CURRENT_TIMESTAMP()", + strings.Join(warehouse.Placeholders(e.sf, 1, 7), ", "), + ) + upsertArgs = []interface{}{id, "aws", table, region, e.accountID, reason, time.Now().UTC()} + } + + if _, err := e.sf.Exec(ctx, upsertQuery, upsertArgs...); err != nil { return fmt.Errorf("upsert backfill request: %w", err) } @@ -897,7 +913,7 @@ func (e *SyncEngine) clearBackfillRequest(ctx context.Context, table, region str } id := backfillQueueID(e.accountID, table, region) - _, err := e.sf.Exec(ctx, "DELETE FROM _sync_backfill_queue WHERE id = ?", id) + _, err := e.sf.Exec(ctx, "DELETE FROM _sync_backfill_queue WHERE id = "+warehouse.Placeholder(e.sf, 1), id) if err != nil { return fmt.Errorf("delete backfill request: %w", err) } diff --git a/internal/sync/gcp_asset_inventory.go b/internal/sync/gcp_asset_inventory.go index b8bbbc592..925f9dcee 100644 --- a/internal/sync/gcp_asset_inventory.go +++ b/internal/sync/gcp_asset_inventory.go @@ -15,6 +15,7 @@ import ( "github.com/writer/cerebro/internal/metrics" "github.com/writer/cerebro/internal/snowflake" + "github.com/writer/cerebro/internal/snowflake/tableops" "github.com/writer/cerebro/internal/warehouse" "golang.org/x/sync/errgroup" ) @@ -467,29 +468,9 @@ func (e *GCPAssetInventoryEngine) getColumnsForAssetType() []string { } func (e *GCPAssetInventoryEngine) ensureTable(ctx context.Context, table string, columns []string) error { - if err := snowflake.ValidateTableName(table); err != nil { - return fmt.Errorf("invalid table name: %w", err) - } - - for _, col := range columns { - if err := snowflake.ValidateColumnName(col); err != nil { - return fmt.Errorf("invalid column name %q: %w", col, err) - } - } - colDefs := make([]string, len(columns)) - for i, col := range columns { - colDefs[i] = fmt.Sprintf("%s VARIANT", strings.ToUpper(col)) - } - - createQuery := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( - _CQ_ID VARCHAR PRIMARY KEY, - _CQ_SYNC_TIME TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP(), - _CQ_HASH VARCHAR, - %s - )`, table, strings.Join(colDefs, ", ")) - - _, err := e.sf.Exec(ctx, createQuery) - return err + return tableops.EnsureVariantTable(ctx, e.sf, table, columns, tableops.EnsureVariantTableOptions{ + AddMissingColumns: true, + }) } func gcpProjectIDFromScope(scope string) string { diff --git a/internal/sync/gcp_security.go b/internal/sync/gcp_security.go index 5676f0ee8..d599879f6 100644 --- a/internal/sync/gcp_security.go +++ b/internal/sync/gcp_security.go @@ -702,10 +702,10 @@ func (s *GCPSecuritySync) syncSCCFindings(ctx context.Context) error { // upsertVulnerabilities saves vulnerability data to Snowflake func (s *GCPSecuritySync) upsertVulnerabilities(ctx context.Context, vulns []map[string]interface{}) error { // Create table if not exists - createSQL := ` + createSQL := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS GCP_CONTAINER_VULNERABILITIES ( _CQ_ID VARCHAR PRIMARY KEY, - _CQ_SYNC_TIME TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP(), + _CQ_SYNC_TIME %s DEFAULT CURRENT_TIMESTAMP(), PROJECT_ID VARCHAR, NAME VARCHAR, RESOURCE_URI VARCHAR, @@ -722,25 +722,25 @@ func (s *GCPSecuritySync) upsertVulnerabilities(ctx context.Context, vulns []map SHORT_DESCRIPTION VARCHAR, CVE_ID VARCHAR, PACKAGE_ISSUE VARCHAR - )` + )`, warehouse.TimestampColumnType(s.sf)) - if _, err := s.sf.Query(ctx, createSQL); err != nil { + if _, err := s.sf.Exec(ctx, createSQL); err != nil { return fmt.Errorf("failed to create vulnerabilities table: %w", err) } // Delete existing and insert new - deleteSQL := "DELETE FROM GCP_CONTAINER_VULNERABILITIES WHERE PROJECT_ID = ?" + deleteSQL := "DELETE FROM GCP_CONTAINER_VULNERABILITIES WHERE PROJECT_ID = " + warehouse.Placeholder(s.sf, 1) if _, err := s.sf.Exec(ctx, deleteSQL, s.projectID); err != nil { return fmt.Errorf("delete existing vulnerabilities: %w", err) } // Insert records - insertSQL := ` + insertSQL := fmt.Sprintf(` INSERT INTO GCP_CONTAINER_VULNERABILITIES (_CQ_ID, PROJECT_ID, NAME, RESOURCE_URI, NOTE_NAME, KIND, CREATE_TIME, UPDATE_TIME, SEVERITY, CVSS_SCORE, CVSS_V3_SCORE, EFFECTIVE_SEVERITY, FIX_AVAILABLE, LONG_DESCRIPTION, SHORT_DESCRIPTION, CVE_ID, PACKAGE_ISSUE) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + VALUES (%s)`, strings.Join(warehouse.Placeholders(s.sf, 1, 17), ", ")) insertErrs := make([]error, 0) for _, v := range vulns { if _, err := s.sf.Exec(ctx, insertSQL, @@ -774,15 +774,15 @@ func (s *GCPSecuritySync) upsertVulnerabilities(ctx context.Context, vulns []map // upsertDockerImages saves docker image data to Snowflake func (s *GCPSecuritySync) upsertDockerImages(ctx context.Context, images []map[string]interface{}) error { - createSQL := ` + createSQL := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS GCP_ARTIFACT_REGISTRY_IMAGES ( _CQ_ID VARCHAR PRIMARY KEY, - _CQ_SYNC_TIME TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP(), + _CQ_SYNC_TIME %s DEFAULT CURRENT_TIMESTAMP(), PROJECT_ID VARCHAR, NAME VARCHAR, URI VARCHAR, TAGS VARCHAR, - IMAGE_SIZE NUMBER, + IMAGE_SIZE %s, UPLOAD_TIME VARCHAR, MEDIA_TYPE VARCHAR, BUILD_TIME VARCHAR, @@ -798,9 +798,9 @@ func (s *GCPSecuritySync) upsertDockerImages(ctx context.Context, images []map[s HAS_CLOUD_KEYS BOOLEAN, HAS_HIGH_PRIVILEGE_CLOUD_KEYS BOOLEAN, HAS_CROSS_ACCOUNT_CLOUD_KEYS BOOLEAN - )` + )`, warehouse.TimestampColumnType(s.sf), warehouse.IntegerColumnType(s.sf)) - if _, err := s.sf.Query(ctx, createSQL); err != nil { + if _, err := s.sf.Exec(ctx, createSQL); err != nil { return fmt.Errorf("failed to create images table: %w", err) } @@ -821,15 +821,15 @@ func (s *GCPSecuritySync) upsertDockerImages(ctx context.Context, images []map[s } } - deleteSQL := "DELETE FROM GCP_ARTIFACT_REGISTRY_IMAGES WHERE PROJECT_ID = ?" + deleteSQL := "DELETE FROM GCP_ARTIFACT_REGISTRY_IMAGES WHERE PROJECT_ID = " + warehouse.Placeholder(s.sf, 1) if _, err := s.sf.Exec(ctx, deleteSQL, s.projectID); err != nil { return fmt.Errorf("delete existing images: %w", err) } - insertSQL := ` + insertSQL := fmt.Sprintf(` INSERT INTO GCP_ARTIFACT_REGISTRY_IMAGES (_CQ_ID, PROJECT_ID, NAME, URI, TAGS, IMAGE_SIZE, UPLOAD_TIME, MEDIA_TYPE, BUILD_TIME, UPDATE_TIME, REPOSITORY, REGISTRY_TYPE, SCANNED, SCAN_STATUS, VULNERABILITIES, HAS_VULNERABILITIES, HAS_OPENSSL_VULNERABILITY, SECRETS, HAS_CLOUD_KEYS, HAS_HIGH_PRIVILEGE_CLOUD_KEYS, HAS_CROSS_ACCOUNT_CLOUD_KEYS) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + VALUES (%s)`, strings.Join(warehouse.Placeholders(s.sf, 1, 21), ", ")) insertErrs := make([]error, 0) for _, img := range images { if _, err := s.sf.Exec(ctx, insertSQL, @@ -867,10 +867,10 @@ func (s *GCPSecuritySync) upsertDockerImages(ctx context.Context, images []map[s // upsertSCCFindings saves SCC findings to Snowflake func (s *GCPSecuritySync) upsertSCCFindings(ctx context.Context, findings []map[string]interface{}) error { - createSQL := ` + createSQL := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS GCP_SCC_FINDINGS ( _CQ_ID VARCHAR PRIMARY KEY, - _CQ_SYNC_TIME TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP(), + _CQ_SYNC_TIME %s DEFAULT CURRENT_TIMESTAMP(), PROJECT_ID VARCHAR, NAME VARCHAR, PARENT VARCHAR, @@ -886,22 +886,22 @@ func (s *GCPSecuritySync) upsertSCCFindings(ctx context.Context, findings []map[ DESCRIPTION VARCHAR, INDICATOR VARCHAR, VULNERABILITY VARCHAR - )` + )`, warehouse.TimestampColumnType(s.sf)) - if _, err := s.sf.Query(ctx, createSQL); err != nil { + if _, err := s.sf.Exec(ctx, createSQL); err != nil { return fmt.Errorf("failed to create SCC findings table: %w", err) } - deleteSQL := "DELETE FROM GCP_SCC_FINDINGS WHERE PROJECT_ID = ?" + deleteSQL := "DELETE FROM GCP_SCC_FINDINGS WHERE PROJECT_ID = " + warehouse.Placeholder(s.sf, 1) if _, err := s.sf.Exec(ctx, deleteSQL, s.projectID); err != nil { return fmt.Errorf("delete existing scc findings: %w", err) } - insertSQL := ` + insertSQL := fmt.Sprintf(` INSERT INTO GCP_SCC_FINDINGS (_CQ_ID, PROJECT_ID, NAME, PARENT, RESOURCE_NAME, STATE, CATEGORY, EXTERNAL_URI, SEVERITY, FINDING_CLASS, MUTE, CREATE_TIME, EVENT_TIME, DESCRIPTION, INDICATOR, VULNERABILITY) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + VALUES (%s)`, strings.Join(warehouse.Placeholders(s.sf, 1, 16), ", ")) insertErrs := make([]error, 0) for _, f := range findings { if _, err := s.sf.Exec(ctx, insertSQL, diff --git a/internal/sync/generation.go b/internal/sync/generation.go index a571d50dd..6a5305b8c 100644 --- a/internal/sync/generation.go +++ b/internal/sync/generation.go @@ -6,6 +6,8 @@ import ( "sort" "strings" "time" + + "github.com/writer/cerebro/internal/warehouse" ) const ( @@ -252,11 +254,11 @@ func (e *SyncEngine) ensureGenerationTrackingTable(ctx context.Context) error { region VARCHAR, status VARCHAR, backfill_pending BOOLEAN, - synced_rows NUMBER, + synced_rows %s, error_message VARCHAR, - sync_time TIMESTAMP_TZ, - _cq_sync_time TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP() - )`, syncTableGenerationsTable) + sync_time %s, + _cq_sync_time %s DEFAULT CURRENT_TIMESTAMP() + )`, syncTableGenerationsTable, warehouse.IntegerColumnType(e.sf), warehouse.TimestampColumnType(e.sf), warehouse.TimestampColumnType(e.sf)) if _, err := e.sf.Exec(ctx, createQuery); err != nil { return fmt.Errorf("create sync generation table: %w", err) } @@ -269,17 +271,17 @@ func (e *SyncEngine) ensureGenerationAlertsTable(ctx context.Context) error { provider VARCHAR, account_id VARCHAR, generation_id VARCHAR, - drift_seconds NUMBER, - threshold_seconds NUMBER, + drift_seconds %s, + threshold_seconds %s, min_table_name VARCHAR, min_region VARCHAR, - min_sync_time TIMESTAMP_TZ, + min_sync_time %s, max_table_name VARCHAR, max_region VARCHAR, - max_sync_time TIMESTAMP_TZ, + max_sync_time %s, message VARCHAR, - _cq_sync_time TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP() - )`, syncGenerationAlertsTable) + _cq_sync_time %s DEFAULT CURRENT_TIMESTAMP() + )`, syncGenerationAlertsTable, warehouse.IntegerColumnType(e.sf), warehouse.IntegerColumnType(e.sf), warehouse.TimestampColumnType(e.sf), warehouse.TimestampColumnType(e.sf), warehouse.TimestampColumnType(e.sf)) if _, err := e.sf.Exec(ctx, createQuery); err != nil { return fmt.Errorf("create sync generation alerts table: %w", err) } @@ -301,42 +303,69 @@ func (e *SyncEngine) recordGenerationResult(ctx context.Context, generationID st } recordID := generationRecordID("aws", e.accountID, generationID, result.Table, result.Region) - mergeQuery := fmt.Sprintf(`MERGE INTO %s t - USING ( - SELECT ? AS id, ? AS provider, ? AS account_id, ? AS generation_id, ? AS table_name, - ? AS region, ? AS status, ? AS backfill_pending, ? AS synced_rows, ? AS error_message, ? AS sync_time - ) s - ON t.id = s.id - WHEN MATCHED THEN UPDATE SET - provider = s.provider, - account_id = s.account_id, - generation_id = s.generation_id, - table_name = s.table_name, - region = s.region, - status = s.status, - backfill_pending = s.backfill_pending, - synced_rows = s.synced_rows, - error_message = s.error_message, - sync_time = s.sync_time, - _cq_sync_time = CURRENT_TIMESTAMP() - WHEN NOT MATCHED THEN INSERT - (id, provider, account_id, generation_id, table_name, region, status, backfill_pending, synced_rows, error_message, sync_time) - VALUES - (s.id, s.provider, s.account_id, s.generation_id, s.table_name, s.region, s.status, s.backfill_pending, s.synced_rows, s.error_message, s.sync_time)`, syncTableGenerationsTable) - - if _, err := e.sf.Exec(ctx, mergeQuery, - recordID, - "aws", - e.accountID, - generationID, - result.Table, - result.Region, - status, - result.BackfillPending, - result.Synced, - errorMessage, - syncTime, - ); err != nil { + var ( + query string + args []interface{} + ) + if warehouse.Dialect(e.sf) == warehouse.SQLDialectSnowflake { + query = fmt.Sprintf(`MERGE INTO %s t + USING ( + SELECT ? AS id, ? AS provider, ? AS account_id, ? AS generation_id, ? AS table_name, + ? AS region, ? AS status, ? AS backfill_pending, ? AS synced_rows, ? AS error_message, ? AS sync_time + ) s + ON t.id = s.id + WHEN MATCHED THEN UPDATE SET + provider = s.provider, + account_id = s.account_id, + generation_id = s.generation_id, + table_name = s.table_name, + region = s.region, + status = s.status, + backfill_pending = s.backfill_pending, + synced_rows = s.synced_rows, + error_message = s.error_message, + sync_time = s.sync_time, + _cq_sync_time = CURRENT_TIMESTAMP() + WHEN NOT MATCHED THEN INSERT + (id, provider, account_id, generation_id, table_name, region, status, backfill_pending, synced_rows, error_message, sync_time) + VALUES + (s.id, s.provider, s.account_id, s.generation_id, s.table_name, s.region, s.status, s.backfill_pending, s.synced_rows, s.error_message, s.sync_time)`, syncTableGenerationsTable) + args = []interface{}{ + recordID, + "aws", + e.accountID, + generationID, + result.Table, + result.Region, + status, + result.BackfillPending, + result.Synced, + errorMessage, + syncTime, + } + } else { + query = fmt.Sprintf( + "INSERT INTO %s (id, provider, account_id, generation_id, table_name, region, status, backfill_pending, synced_rows, error_message, sync_time) VALUES (%s) "+ + "ON CONFLICT (id) DO UPDATE SET provider = EXCLUDED.provider, account_id = EXCLUDED.account_id, generation_id = EXCLUDED.generation_id, table_name = EXCLUDED.table_name, region = EXCLUDED.region, status = EXCLUDED.status, backfill_pending = EXCLUDED.backfill_pending, synced_rows = EXCLUDED.synced_rows, error_message = EXCLUDED.error_message, sync_time = EXCLUDED.sync_time, _cq_sync_time = CURRENT_TIMESTAMP()", + syncTableGenerationsTable, + strings.Join(warehouse.Placeholders(e.sf, 1, 11), ", "), + ) + args = []interface{}{ + recordID, + "aws", + e.accountID, + generationID, + result.Table, + result.Region, + status, + result.BackfillPending, + result.Synced, + errorMessage, + syncTime, + } + } + + if _, err := e.sf.Exec(ctx, query, args...); err != nil { return fmt.Errorf("upsert sync generation result: %w", err) } @@ -362,9 +391,11 @@ func (e *SyncEngine) recordGenerationAlert(ctx context.Context, generationID str alert.Max.Region, ) - query := fmt.Sprintf(`INSERT INTO %s - (id, provider, account_id, generation_id, drift_seconds, threshold_seconds, min_table_name, min_region, min_sync_time, max_table_name, max_region, max_sync_time, message) - SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?`, syncGenerationAlertsTable) + query := fmt.Sprintf( + "INSERT INTO %s (id, provider, account_id, generation_id, drift_seconds, threshold_seconds, min_table_name, min_region, min_sync_time, max_table_name, max_region, max_sync_time, message) VALUES (%s)", + syncGenerationAlertsTable, + strings.Join(warehouse.Placeholders(e.sf, 1, 13), ", "), + ) if _, err := e.sf.Exec(ctx, query, recordID, diff --git a/internal/sync/permission_usage_state.go b/internal/sync/permission_usage_state.go index b286cd9f3..c59e87240 100644 --- a/internal/sync/permission_usage_state.go +++ b/internal/sync/permission_usage_state.go @@ -2,7 +2,6 @@ package sync import ( "context" - "database/sql" "fmt" "strings" "time" @@ -11,15 +10,7 @@ import ( ) const ( - permissionUsageStateTable = "cerebro_permission_usage_state" - permissionUsageStateSchema = ` - CREATE TABLE IF NOT EXISTS cerebro_permission_usage_state ( - state_key VARCHAR PRIMARY KEY, - last_cursor_time TIMESTAMP_NTZ, - last_cursor_id VARCHAR, - updated_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() - ) - ` + permissionUsageStateTable = "cerebro_permission_usage_state" permissionUsageCursorOverlap = 5 * time.Minute ) @@ -51,65 +42,88 @@ func (e *GCPSyncEngine) savePermissionUsageCursor(ctx context.Context, key strin } func loadPermissionUsageCursor(ctx context.Context, sf warehouse.SyncWarehouse, key string) (permissionUsageCursor, error) { - if sf == nil || sf.DB() == nil || key == "" { + if sf == nil || key == "" { return permissionUsageCursor{}, nil } if err := ensurePermissionUsageStateTable(ctx, sf); err != nil { return permissionUsageCursor{}, err } - row := sf.DB().QueryRowContext(ctx, - "SELECT last_cursor_time, last_cursor_id FROM "+permissionUsageStateTable+" WHERE state_key = ?", + result, err := sf.Query(ctx, + "SELECT last_cursor_time, last_cursor_id FROM "+permissionUsageStateTable+" WHERE state_key = "+warehouse.Placeholder(sf, 1), key, ) - - var t sql.NullTime - var id sql.NullString - if err := row.Scan(&t, &id); err != nil { - if err == sql.ErrNoRows { - return permissionUsageCursor{}, nil - } + if err != nil { return permissionUsageCursor{}, fmt.Errorf("read permission usage cursor %q: %w", key, err) } + if len(result.Rows) == 0 { + return permissionUsageCursor{}, nil + } cursor := permissionUsageCursor{} - if t.Valid { - cursor.Time = t.Time.UTC() - } - if id.Valid { - cursor.ID = id.String + if ts, ok := parseAnyTime(queryRow(result.Rows[0], "last_cursor_time")); ok { + cursor.Time = ts.UTC() } + cursor.ID = strings.TrimSpace(queryRowString(result.Rows[0], "last_cursor_id")) return cursor, nil } func savePermissionUsageCursor(ctx context.Context, sf warehouse.SyncWarehouse, key string, cursor permissionUsageCursor) error { - if sf == nil || sf.DB() == nil || key == "" || cursor.Time.IsZero() { + if sf == nil || key == "" || cursor.Time.IsZero() { return nil } if err := ensurePermissionUsageStateTable(ctx, sf); err != nil { return err } - _, err := sf.DB().ExecContext(ctx, ` - MERGE INTO `+permissionUsageStateTable+` t - USING (SELECT ? AS state_key, ? AS last_cursor_time, ? AS last_cursor_id) s - ON t.state_key = s.state_key - WHEN MATCHED THEN UPDATE SET - last_cursor_time = CASE - WHEN t.last_cursor_time IS NULL THEN s.last_cursor_time - WHEN s.last_cursor_time > t.last_cursor_time THEN s.last_cursor_time - ELSE t.last_cursor_time - END, - last_cursor_id = CASE - WHEN t.last_cursor_time IS NULL THEN s.last_cursor_id - WHEN s.last_cursor_time > t.last_cursor_time THEN s.last_cursor_id - WHEN s.last_cursor_time = t.last_cursor_time AND COALESCE(s.last_cursor_id, '') > COALESCE(t.last_cursor_id, '') THEN s.last_cursor_id - ELSE t.last_cursor_id - END, - updated_at = CURRENT_TIMESTAMP() - WHEN NOT MATCHED THEN INSERT (state_key, last_cursor_time, last_cursor_id, updated_at) - VALUES (s.state_key, s.last_cursor_time, s.last_cursor_id, CURRENT_TIMESTAMP()) - `, key, cursor.Time.UTC(), cursor.ID) + var ( + query string + args []interface{} + ) + if warehouse.Dialect(sf) == warehouse.SQLDialectSnowflake { + query = ` + MERGE INTO ` + permissionUsageStateTable + ` t + USING (SELECT ? AS state_key, ? AS last_cursor_time, ? AS last_cursor_id) s + ON t.state_key = s.state_key + WHEN MATCHED THEN UPDATE SET + last_cursor_time = CASE + WHEN t.last_cursor_time IS NULL THEN s.last_cursor_time + WHEN s.last_cursor_time > t.last_cursor_time THEN s.last_cursor_time + ELSE t.last_cursor_time + END, + last_cursor_id = CASE + WHEN t.last_cursor_time IS NULL THEN s.last_cursor_id + WHEN s.last_cursor_time > t.last_cursor_time THEN s.last_cursor_id + WHEN s.last_cursor_time = t.last_cursor_time AND COALESCE(s.last_cursor_id, '') > COALESCE(t.last_cursor_id, '') THEN s.last_cursor_id + ELSE t.last_cursor_id + END, + updated_at = CURRENT_TIMESTAMP() + WHEN NOT MATCHED THEN INSERT (state_key, last_cursor_time, last_cursor_id, updated_at) + VALUES (s.state_key, s.last_cursor_time, s.last_cursor_id, CURRENT_TIMESTAMP()) + ` + args = []interface{}{key, cursor.Time.UTC(), cursor.ID} + } else { + query = fmt.Sprintf(` + INSERT INTO %[1]s (state_key, last_cursor_time, last_cursor_id, updated_at) + VALUES (%[2]s) + ON CONFLICT (state_key) DO UPDATE SET + last_cursor_time = CASE + WHEN %[1]s.last_cursor_time IS NULL THEN EXCLUDED.last_cursor_time + WHEN EXCLUDED.last_cursor_time > %[1]s.last_cursor_time THEN EXCLUDED.last_cursor_time + ELSE %[1]s.last_cursor_time + END, + last_cursor_id = CASE + WHEN %[1]s.last_cursor_time IS NULL THEN EXCLUDED.last_cursor_id + WHEN EXCLUDED.last_cursor_time > %[1]s.last_cursor_time THEN EXCLUDED.last_cursor_id + WHEN EXCLUDED.last_cursor_time = %[1]s.last_cursor_time AND COALESCE(EXCLUDED.last_cursor_id, '') > COALESCE(%[1]s.last_cursor_id, '') THEN EXCLUDED.last_cursor_id + ELSE %[1]s.last_cursor_id + END, + updated_at = CURRENT_TIMESTAMP() + `, permissionUsageStateTable, strings.Join(warehouse.Placeholders(sf, 1, 4), ", ")) + args = []interface{}{key, cursor.Time.UTC(), cursor.ID, time.Now().UTC()} + } + + _, err := sf.Exec(ctx, query, args...) if err != nil { return fmt.Errorf("upsert permission usage cursor %q: %w", key, err) } @@ -117,13 +131,21 @@ func savePermissionUsageCursor(ctx context.Context, sf warehouse.SyncWarehouse, } func ensurePermissionUsageStateTable(ctx context.Context, sf warehouse.SyncWarehouse) error { - if sf == nil || sf.DB() == nil { + if sf == nil { return nil } - if _, err := sf.DB().ExecContext(ctx, permissionUsageStateSchema); err != nil { + schemaSQL := fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + state_key VARCHAR PRIMARY KEY, + last_cursor_time %s, + last_cursor_id VARCHAR, + updated_at %s DEFAULT CURRENT_TIMESTAMP() + ) + `, permissionUsageStateTable, warehouse.LocalTimestampColumnType(sf), warehouse.LocalTimestampColumnType(sf)) + if _, err := sf.Exec(ctx, schemaSQL); err != nil { return fmt.Errorf("ensure permission usage state table: %w", err) } - if _, err := sf.DB().ExecContext(ctx, `ALTER TABLE `+permissionUsageStateTable+` ADD COLUMN IF NOT EXISTS last_cursor_id VARCHAR`); err != nil { + if _, err := sf.Exec(ctx, `ALTER TABLE `+permissionUsageStateTable+` ADD COLUMN IF NOT EXISTS last_cursor_id VARCHAR`); err != nil { return fmt.Errorf("ensure permission usage state cursor id column: %w", err) } return nil diff --git a/internal/sync/relationships.go b/internal/sync/relationships.go index 8afd8352c..87d2cc72c 100644 --- a/internal/sync/relationships.go +++ b/internal/sync/relationships.go @@ -109,7 +109,11 @@ var relationshipSchemaName = func(sf warehouse.SyncWarehouse) string { var relationshipQueryBatch = func(ctx context.Context, sf warehouse.SyncWarehouse, query string, args ...interface{}) error { if sf == nil { - return fmt.Errorf("snowflake client is nil") + return fmt.Errorf("warehouse is nil") + } + if warehouse.Dialect(sf) != warehouse.SQLDialectSnowflake { + _, err := sf.Exec(ctx, query, args...) + return err } _, err := sf.Query(ctx, query, args...) return err @@ -281,22 +285,27 @@ type relationshipBackfillUpdate struct { func (r *RelationshipExtractor) applyRelationshipBackfillBatch(ctx context.Context, tableName string, updates []relationshipBackfillUpdate, deleteIDs []string) error { if len(updates) > 0 { - values := make([]string, 0, len(updates)) - args := make([]interface{}, 0, len(updates)*8) - for _, update := range updates { - values = append(values, "(?, ?, ?, ?, ?, ?, ?, ?)") - args = append(args, - update.NewID, - update.SourceID, - update.SourceType, - update.TargetID, - update.TargetType, - update.RelType, - update.Properties, - update.SyncTime, - ) - } - merge := fmt.Sprintf(`MERGE INTO %s AS t + var ( + query string + args []interface{} + ) + if warehouse.Dialect(r.sf) == warehouse.SQLDialectSnowflake { + values := make([]string, 0, len(updates)) + args = make([]interface{}, 0, len(updates)*8) + for _, update := range updates { + values = append(values, "(?, ?, ?, ?, ?, ?, ?, ?)") + args = append(args, + update.NewID, + update.SourceID, + update.SourceType, + update.TargetID, + update.TargetType, + update.RelType, + update.Properties, + update.SyncTime, + ) + } + query = fmt.Sprintf(`MERGE INTO %s AS t USING (SELECT column1 AS id, column2 AS source_id, column3 AS source_type, @@ -317,7 +326,41 @@ WHEN MATCHED THEN UPDATE SET SYNC_TIME = COALESCE(s.sync_time, t.SYNC_TIME) WHEN NOT MATCHED THEN INSERT (ID, SOURCE_ID, SOURCE_TYPE, TARGET_ID, TARGET_TYPE, REL_TYPE, PROPERTIES, SYNC_TIME) VALUES (s.id, s.source_id, s.source_type, s.target_id, s.target_type, s.rel_type, TRY_PARSE_JSON(s.properties), s.sync_time)`, tableName, strings.Join(values, ",")) - if _, err := r.sf.Exec(ctx, merge, args...); err != nil { + } else { + values := make([]string, 0, len(updates)) + args = make([]interface{}, 0, len(updates)*8) + for _, update := range updates { + row := []string{ + warehouse.Placeholder(r.sf, len(args)+1), + warehouse.Placeholder(r.sf, len(args)+2), + warehouse.Placeholder(r.sf, len(args)+3), + warehouse.Placeholder(r.sf, len(args)+4), + warehouse.Placeholder(r.sf, len(args)+5), + warehouse.Placeholder(r.sf, len(args)+6), + warehouse.JSONPlaceholder(r.sf, len(args)+7), + warehouse.Placeholder(r.sf, len(args)+8), + } + values = append(values, "("+strings.Join(row, ", ")+")") + args = append(args, + update.NewID, + update.SourceID, + update.SourceType, + update.TargetID, + update.TargetType, + update.RelType, + update.Properties, + update.SyncTime, + ) + } + query = fmt.Sprintf( + "INSERT INTO %s (ID, SOURCE_ID, SOURCE_TYPE, TARGET_ID, TARGET_TYPE, REL_TYPE, PROPERTIES, SYNC_TIME) VALUES %s "+ + "ON CONFLICT (ID) DO UPDATE SET SOURCE_ID = EXCLUDED.SOURCE_ID, SOURCE_TYPE = EXCLUDED.SOURCE_TYPE, TARGET_ID = EXCLUDED.TARGET_ID, TARGET_TYPE = EXCLUDED.TARGET_TYPE, REL_TYPE = EXCLUDED.REL_TYPE, PROPERTIES = EXCLUDED.PROPERTIES, SYNC_TIME = COALESCE(EXCLUDED.SYNC_TIME, %s.SYNC_TIME)", + tableName, + strings.Join(values, ", "), + tableName, + ) + } + if _, err := r.sf.Exec(ctx, query, args...); err != nil { return err } } @@ -342,8 +385,8 @@ VALUES (s.id, s.source_id, s.source_type, s.target_id, s.target_type, s.rel_type } placeholders := make([]string, 0, len(unique)) args := make([]interface{}, 0, len(unique)) - for _, id := range unique { - placeholders = append(placeholders, "?") + for i, id := range unique { + placeholders = append(placeholders, warehouse.Placeholder(r.sf, i+1)) args = append(args, id) } deleteQuery := fmt.Sprintf("DELETE FROM %s WHERE ID IN (%s)", tableName, strings.Join(placeholders, ",")) @@ -384,9 +427,9 @@ func (r *RelationshipExtractor) ensureTable(ctx context.Context) error { TARGET_ID VARCHAR NOT NULL, TARGET_TYPE VARCHAR NOT NULL, REL_TYPE VARCHAR NOT NULL, - PROPERTIES VARIANT, - SYNC_TIME TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP() - )`, schema) + PROPERTIES %s, + SYNC_TIME %s DEFAULT CURRENT_TIMESTAMP() + )`, schema, warehouse.JSONColumnType(r.sf), warehouse.TimestampColumnType(r.sf)) _, err := r.sf.Exec(ctx, query) return err } @@ -399,7 +442,7 @@ func (r *RelationshipExtractor) cleanupStaleRelationships(ctx context.Context, c if err := snowflake.ValidateTableName(schema); err != nil { return fmt.Errorf("invalid schema name: %w", err) } - query := fmt.Sprintf(`DELETE FROM %s.RESOURCE_RELATIONSHIPS WHERE SYNC_TIME < ?`, schema) + query := fmt.Sprintf(`DELETE FROM %s.RESOURCE_RELATIONSHIPS WHERE SYNC_TIME < %s`, schema, warehouse.Placeholder(r.sf, 1)) _, err := r.sf.Exec(ctx, query, cutoff.UTC()) return err } @@ -449,18 +492,30 @@ func (r *RelationshipExtractor) persistRelationships(ctx context.Context, rels [ props = "{}" } id := buildRelationshipID(sourceID, rel.RelType, targetID, props) - values = append(values, "(?, ?, ?, ?, ?, ?, ?, ?)") + values = append(values, fmt.Sprintf("(%s, %s, %s, %s, %s, %s, %s, %s)", + warehouse.Placeholder(r.sf, len(args)+1), + warehouse.Placeholder(r.sf, len(args)+2), + warehouse.Placeholder(r.sf, len(args)+3), + warehouse.Placeholder(r.sf, len(args)+4), + warehouse.Placeholder(r.sf, len(args)+5), + warehouse.Placeholder(r.sf, len(args)+6), + warehouse.JSONPlaceholder(r.sf, len(args)+7), + warehouse.Placeholder(r.sf, len(args)+8), + )) args = append(args, id, sourceID, rel.SourceType, targetID, rel.TargetType, rel.RelType, props, syncTime) } if len(values) == 0 { continue } - // Use simple INSERT with fully qualified table name - query := fmt.Sprintf(`INSERT INTO %s (ID, SOURCE_ID, SOURCE_TYPE, TARGET_ID, TARGET_TYPE, REL_TYPE, PROPERTIES, SYNC_TIME) - SELECT column1, column2, column3, column4, column5, column6, TRY_PARSE_JSON(column7), column8::TIMESTAMP_TZ - FROM VALUES %s`, - tableName, strings.Join(values, ", ")) + query := fmt.Sprintf( + "INSERT INTO %s (ID, SOURCE_ID, SOURCE_TYPE, TARGET_ID, TARGET_TYPE, REL_TYPE, PROPERTIES, SYNC_TIME) VALUES %s", + tableName, + strings.Join(values, ", "), + ) + if warehouse.Dialect(r.sf) != warehouse.SQLDialectSnowflake { + query += " ON CONFLICT (ID) DO UPDATE SET SOURCE_ID = EXCLUDED.SOURCE_ID, SOURCE_TYPE = EXCLUDED.SOURCE_TYPE, TARGET_ID = EXCLUDED.TARGET_ID, TARGET_TYPE = EXCLUDED.TARGET_TYPE, REL_TYPE = EXCLUDED.REL_TYPE, PROPERTIES = EXCLUDED.PROPERTIES, SYNC_TIME = EXCLUDED.SYNC_TIME" + } // Use Query instead of Exec - Exec has issues with Snowflake commit behavior err := relationshipQueryBatch(ctx, r.sf, query, args...) @@ -487,11 +542,30 @@ func (r *RelationshipExtractor) getTableColumnSet(ctx context.Context, table str return nil, fmt.Errorf("invalid table name %s: %w", table, err) } + if memory, ok := any(r.sf).(*warehouse.MemoryWarehouse); ok && memory.DescribeColumnsFunc == nil { + // Preserve query-driven column discovery in tests that only stub QueryFunc. + } else if describer, ok := any(r.sf).(interface { + DescribeColumns(context.Context, string) ([]string, error) + }); ok { + names, err := describer.DescribeColumns(ctx, table) + if err != nil { + return nil, err + } + columns := make(map[string]struct{}, len(names)) + for _, name := range names { + name = strings.ToUpper(strings.TrimSpace(name)) + if name != "" { + columns[name] = struct{}{} + } + } + return columns, nil + } + query := ` SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = CURRENT_SCHEMA() - AND TABLE_NAME = ? + AND TABLE_NAME = ` + warehouse.Placeholder(r.sf, 1) + ` ` result, err := r.sf.Query(ctx, query, strings.ToUpper(table)) diff --git a/internal/sync/relationships_cleanup_static_test.go b/internal/sync/relationships_cleanup_static_test.go index a20c2a298..c87fd2d5c 100644 --- a/internal/sync/relationships_cleanup_static_test.go +++ b/internal/sync/relationships_cleanup_static_test.go @@ -22,7 +22,6 @@ func TestRelationshipCleanupUsesRunSyncTimeGuardrails(t *testing.T) { checks := []string{ "cleanupStaleRelationships(ctx, runSyncTime)", - "column8::TIMESTAMP_TZ", "r.sf.Exec(ctx, query, cutoff.UTC())", } for _, check := range checks { @@ -30,4 +29,7 @@ func TestRelationshipCleanupUsesRunSyncTimeGuardrails(t *testing.T) { t.Fatalf("expected relationships.go to contain %q", check) } } + if !strings.Contains(text, "column8::TIMESTAMP_TZ") && !strings.Contains(text, "ON CONFLICT (ID) DO UPDATE") { + t.Fatal("expected relationships.go to contain a snowflake or postgres/sqlite batch upsert path") + } } diff --git a/internal/sync/scoped_table_ops.go b/internal/sync/scoped_table_ops.go index a800c5509..7661fc8a0 100644 --- a/internal/sync/scoped_table_ops.go +++ b/internal/sync/scoped_table_ops.go @@ -95,7 +95,7 @@ func getExistingHashesByScope(ctx context.Context, sf warehouse.SyncWarehouse, t return result, err } - whereClause, args := scopedWhereClause(scopeColumn, scopeValues) + whereClause, args := scopedWhereClauseForWarehouse(sf, scopeColumn, scopeValues) query := fmt.Sprintf("SELECT _CQ_ID, _CQ_HASH FROM %s%s", table, whereClause) rows, err := sf.Query(ctx, query, args...) if err != nil { @@ -122,7 +122,7 @@ func deleteRowsByIDByScope(ctx context.Context, sf warehouse.SyncWarehouse, tabl return nil } - scopeWhere, scopeArgs := scopedWhereClause(scopeColumn, scopeValues) + scopeWhere, scopeArgs := scopedWhereClauseForWarehouse(sf, scopeColumn, scopeValues) scopeCondition := strings.TrimPrefix(scopeWhere, " WHERE ") for start := 0; start < len(keys); start += insertBatchSize { @@ -132,7 +132,7 @@ func deleteRowsByIDByScope(ctx context.Context, sf warehouse.SyncWarehouse, tabl } batch := keys[start:end] - placeholders := strings.TrimRight(strings.Repeat("?,", len(batch)), ",") + placeholders := strings.Join(warehouse.Placeholders(sf, 1, len(batch)), ",") args := make([]interface{}, 0, len(batch)+len(scopeArgs)) for _, id := range batch { args = append(args, id) @@ -153,7 +153,7 @@ func deleteRowsByIDByScope(ctx context.Context, sf warehouse.SyncWarehouse, tabl } func deleteScopedRowsByScope(ctx context.Context, sf warehouse.SyncWarehouse, table, scopeColumn string, scopeValues []string) error { - whereClause, args := scopedWhereClause(scopeColumn, scopeValues) + whereClause, args := scopedWhereClauseForWarehouse(sf, scopeColumn, scopeValues) if whereClause == "" { if _, err := sf.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil { if _, err := sf.Exec(ctx, fmt.Sprintf("DELETE FROM %s", table)); err != nil { @@ -169,11 +169,19 @@ func deleteScopedRowsByScope(ctx context.Context, sf warehouse.SyncWarehouse, ta } func scopedWhereClause(column string, values []string) (string, []interface{}) { + return scopedWhereClauseWithTarget(nil, column, values) +} + +func scopedWhereClauseForWarehouse(sf warehouse.SyncWarehouse, column string, values []string) (string, []interface{}) { + return scopedWhereClauseWithTarget(sf, column, values) +} + +func scopedWhereClauseWithTarget(target any, column string, values []string) (string, []interface{}) { if column == "" || len(values) == 0 { return "", nil } - placeholders := strings.TrimRight(strings.Repeat("?,", len(values)), ",") + placeholders := strings.Join(warehouse.Placeholders(target, 1, len(values)), ",") args := make([]interface{}, len(values)) for i, value := range values { args[i] = value @@ -191,8 +199,8 @@ func persistProviderChangeHistory(ctx context.Context, sf warehouse.SyncWarehous region VARCHAR, account_id VARCHAR, provider VARCHAR, - timestamp TIMESTAMP_TZ, - _cq_sync_time TIMESTAMP_TZ DEFAULT CURRENT_TIMESTAMP() + timestamp ` + warehouse.TimestampColumnType(sf) + `, + _cq_sync_time ` + warehouse.TimestampColumnType(sf) + ` DEFAULT CURRENT_TIMESTAMP() )` if _, err := sf.Exec(ctx, createQuery); err != nil { @@ -204,7 +212,7 @@ func persistProviderChangeHistory(ctx context.Context, sf warehouse.SyncWarehous "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS region VARCHAR", "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS account_id VARCHAR", "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS provider VARCHAR", - "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS timestamp TIMESTAMP_TZ", + "ALTER TABLE _sync_change_history ADD COLUMN IF NOT EXISTS timestamp " + warehouse.TimestampColumnType(sf), } for _, query := range alterQueries { if _, err := sf.Exec(ctx, query); err != nil { @@ -233,8 +241,10 @@ func persistProviderChangeHistory(ctx context.Context, sf warehouse.SyncWarehous func insertProviderChangeRecord(ctx context.Context, sf warehouse.SyncWarehouse, logger *slog.Logger, provider, table, operation, region string, resourceIDs []string, syncTime time.Time) { for _, resourceID := range resourceIDs { id := fmt.Sprintf("%s-%s-%s-%d", table, operation, resourceID, syncTime.UnixNano()) - query := `INSERT INTO _sync_change_history (id, table_name, resource_id, operation, region, account_id, provider, timestamp) - SELECT ?, ?, ?, ?, ?, ?, ?, ?` + query := fmt.Sprintf( + "INSERT INTO _sync_change_history (id, table_name, resource_id, operation, region, account_id, provider, timestamp) VALUES (%s)", + strings.Join(warehouse.Placeholders(sf, 1, 8), ", "), + ) if _, err := sf.Exec(ctx, query, id, table, resourceID, operation, region, "", provider, syncTime); err != nil { logger.Debug("failed to insert change record", "provider", provider, "table", table, "error", err) } diff --git a/internal/sync/sync_time.go b/internal/sync/sync_time.go index 6e93b3b39..d0dff40bf 100644 --- a/internal/sync/sync_time.go +++ b/internal/sync/sync_time.go @@ -13,7 +13,7 @@ var queryLatestTableSyncTime = func(ctx context.Context, sf warehouse.SyncWareho query := fmt.Sprintf("SELECT MAX(_CQ_SYNC_TIME) AS SYNC_TIME FROM %s", table) args := []interface{}{} if hasRegion { - query += " WHERE REGION = ?" + query += " WHERE REGION = " + warehouse.Placeholder(sf, 1) args = append(args, region) } diff --git a/internal/sync/tables_gcp_iam_group_usage.go b/internal/sync/tables_gcp_iam_group_usage.go index 15ff14ab3..207fc249a 100644 --- a/internal/sync/tables_gcp_iam_group_usage.go +++ b/internal/sync/tables_gcp_iam_group_usage.go @@ -13,6 +13,7 @@ import ( "cloud.google.com/go/iam/apiv1/iampb" "cloud.google.com/go/logging" "cloud.google.com/go/logging/logadmin" + "github.com/writer/cerebro/internal/warehouse" "google.golang.org/api/iterator" auditpb "google.golang.org/genproto/googleapis/cloud/audit" ) @@ -619,7 +620,7 @@ func (e *GCPSyncEngine) loadWorkspaceGroupMemberships(ctx context.Context, group return result, false } - placeholders := strings.TrimSuffix(strings.Repeat("?,", len(groups)), ",") + placeholders := strings.Join(warehouse.Placeholders(e.sf, 1, len(groups)), ",") args := make([]interface{}, 0, len(groups)) for _, group := range groups { args = append(args, strings.ToLower(strings.TrimSpace(group))) @@ -667,7 +668,7 @@ func (e *GCPSyncEngine) loadExistingGCPGroupPermissionState(ctx context.Context, rows, err := e.sf.Query(ctx, ` SELECT permission, permission_last_used, unused_since, usage_status FROM `+gcpIAMGroupPermissionUsageTable+` - WHERE project_id = ? AND LOWER("group") = ? + WHERE project_id = `+warehouse.Placeholder(e.sf, 1)+` AND LOWER("group") = `+warehouse.Placeholder(e.sf, 2)+` `, projectID, strings.ToLower(group)) if err != nil { if strings.Contains(strings.ToLower(err.Error()), "does not exist") { diff --git a/internal/sync/tables_identitycenter_usage.go b/internal/sync/tables_identitycenter_usage.go index f6d0e5407..0ed722f20 100644 --- a/internal/sync/tables_identitycenter_usage.go +++ b/internal/sync/tables_identitycenter_usage.go @@ -15,6 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssoadmin" ssoadmintypes "github.com/aws/aws-sdk-go-v2/service/ssoadmin/types" "github.com/writer/cerebro/internal/graph" + "github.com/writer/cerebro/internal/warehouse" ) const ( @@ -405,15 +406,19 @@ func (e *SyncEngine) deleteIdentityCenterUsageRowsNotInInstance(ctx context.Cont if len(filtered) == 0 { query = fmt.Sprintf( - "DELETE FROM %s WHERE identity_center_instance_arn = ? AND account_id = ?", + "DELETE FROM %s WHERE identity_center_instance_arn = %s AND account_id = %s", awsIdentityCenterPermissionUsageTable, + warehouse.Placeholder(e.sf, 1), + warehouse.Placeholder(e.sf, 2), ) args = []interface{}{instanceArn, accountID} } else { - placeholders := strings.TrimRight(strings.Repeat("?,", len(filtered)), ",") + placeholders := strings.Join(warehouse.Placeholders(e.sf, 3, len(filtered)), ",") query = fmt.Sprintf( - "DELETE FROM %s WHERE identity_center_instance_arn = ? AND account_id = ? AND permission_set_arn NOT IN (%s)", + "DELETE FROM %s WHERE identity_center_instance_arn = %s AND account_id = %s AND permission_set_arn NOT IN (%s)", awsIdentityCenterPermissionUsageTable, + warehouse.Placeholder(e.sf, 1), + warehouse.Placeholder(e.sf, 2), placeholders, ) args = make([]interface{}, 0, len(filtered)+2) @@ -444,20 +449,22 @@ func (e *SyncEngine) deleteStaleIdentityCenterUsageRows(ctx context.Context, per if len(currentActions) == 0 { query = fmt.Sprintf( - "DELETE FROM %s WHERE permission_set_arn = ? AND sso_role_arn = ?", + "DELETE FROM %s WHERE permission_set_arn = %s AND sso_role_arn = %s", awsIdentityCenterPermissionUsageTable, + warehouse.Placeholder(e.sf, 1), + warehouse.Placeholder(e.sf, 2), ) args = []interface{}{permissionSetArn, roleArn} } else { - placeholders := strings.TrimRight(strings.Repeat("?,", len(currentActions)), ",") + placeholders := strings.Join(warehouse.Placeholders(e.sf, 3, len(currentActions)), ",") args = make([]interface{}, 0, len(currentActions)+2) args = append(args, permissionSetArn, roleArn) for _, action := range currentActions { args = append(args, strings.ToLower(action)) } query = fmt.Sprintf( - "DELETE FROM %s WHERE permission_set_arn = ? AND sso_role_arn = ? AND LOWER(action) NOT IN (%s)", - awsIdentityCenterPermissionUsageTable, placeholders, + "DELETE FROM %s WHERE permission_set_arn = %s AND sso_role_arn = %s AND LOWER(action) NOT IN (%s)", + awsIdentityCenterPermissionUsageTable, warehouse.Placeholder(e.sf, 1), warehouse.Placeholder(e.sf, 2), placeholders, ) } @@ -474,8 +481,8 @@ func (e *SyncEngine) deleteIdentityCenterUsageRowsByPermissionSet(ctx context.Co } query := fmt.Sprintf( - "DELETE FROM %s WHERE permission_set_arn = ? AND account_id = ?", - awsIdentityCenterPermissionUsageTable, + "DELETE FROM %s WHERE permission_set_arn = %s AND account_id = %s", + awsIdentityCenterPermissionUsageTable, warehouse.Placeholder(e.sf, 1), warehouse.Placeholder(e.sf, 2), ) if _, err := e.sf.Exec(ctx, query, permissionSetArn, accountID); err != nil { if !strings.Contains(strings.ToLower(err.Error()), "does not exist") { @@ -582,7 +589,7 @@ func (e *SyncEngine) loadExistingAWSPermissionActionState(ctx context.Context, p query := ` SELECT action, action_last_accessed, unused_since, usage_status FROM ` + awsIdentityCenterPermissionUsageTable + ` - WHERE permission_set_arn = ? AND sso_role_arn = ? + WHERE permission_set_arn = ` + warehouse.Placeholder(e.sf, 1) + ` AND sso_role_arn = ` + warehouse.Placeholder(e.sf, 2) + ` ` rows, err := e.sf.Query(ctx, query, permissionSetArn, roleArn) if err != nil { diff --git a/internal/warehouse/sql_dialect.go b/internal/warehouse/sql_dialect.go new file mode 100644 index 000000000..4408c082d --- /dev/null +++ b/internal/warehouse/sql_dialect.go @@ -0,0 +1,125 @@ +package warehouse + +import ( + "fmt" + "strings" + + "github.com/writer/cerebro/internal/snowflake" +) + +type SQLDialect string + +const ( + SQLDialectSnowflake SQLDialect = "snowflake" + SQLDialectPostgres SQLDialect = "postgres" + SQLDialectSQLite SQLDialect = "sqlite" +) + +func (w *PostgresWarehouse) SQLDialect() SQLDialect { + return SQLDialectPostgres +} + +func (w *SQLiteWarehouse) SQLDialect() SQLDialect { + return SQLDialectSQLite +} + +func (m *MemoryWarehouse) SQLDialect() SQLDialect { + if m == nil || strings.TrimSpace(string(m.DialectValue)) == "" { + return SQLDialectSnowflake + } + return m.DialectValue +} + +func Dialect(target any) SQLDialect { + switch typed := target.(type) { + case nil: + return SQLDialectSnowflake + case interface{ SQLDialect() SQLDialect }: + if dialect := typed.SQLDialect(); dialect != "" { + return dialect + } + case *snowflake.Client: + return SQLDialectSnowflake + } + return SQLDialectSnowflake +} + +func Placeholder(target any, position int) string { + if position < 1 { + position = 1 + } + if Dialect(target) == SQLDialectPostgres { + return fmt.Sprintf("$%d", position) + } + return "?" +} + +func Placeholders(target any, start, count int) []string { + if count <= 0 { + return nil + } + if start < 1 { + start = 1 + } + values := make([]string, 0, count) + for i := 0; i < count; i++ { + values = append(values, Placeholder(target, start+i)) + } + return values +} + +func JSONColumnType(target any) string { + switch Dialect(target) { + case SQLDialectPostgres: + return "JSONB" + case SQLDialectSQLite: + return "TEXT" + default: + return "VARIANT" + } +} + +func TimestampColumnType(target any) string { + switch Dialect(target) { + case SQLDialectPostgres: + return "TIMESTAMPTZ" + case SQLDialectSQLite: + return "TEXT" + default: + return "TIMESTAMP_TZ" + } +} + +func LocalTimestampColumnType(target any) string { + switch Dialect(target) { + case SQLDialectPostgres: + return "TIMESTAMPTZ" + case SQLDialectSQLite: + return "TEXT" + default: + return "TIMESTAMP_NTZ" + } +} + +func IntegerColumnType(target any) string { + switch Dialect(target) { + case SQLDialectPostgres: + return "BIGINT" + case SQLDialectSQLite: + return "INTEGER" + default: + return "NUMBER" + } +} + +func JSONPlaceholder(target any, position int) string { + placeholder := Placeholder(target, position) + switch Dialect(target) { + case SQLDialectPostgres: + return placeholder + "::jsonb" + case SQLDialectSnowflake: + return "PARSE_JSON(" + placeholder + ")" + default: + return placeholder + } +} diff --git a/internal/warehouse/warehouse.go b/internal/warehouse/warehouse.go index f5ba0d3b4..b649e79e0 100644 --- a/internal/warehouse/warehouse.go +++ b/internal/warehouse/warehouse.go @@ -78,6 +78,7 @@ type MemoryWarehouse struct { GetAssetsFunc func(ctx context.Context, table string, filter snowflake.AssetFilter) ([]map[string]interface{}, error) GetAssetByIDFunc func(ctx context.Context, table, id string) (map[string]interface{}, error) DBFunc func() *sql.DB + DialectValue SQLDialect DatabaseValue string SchemaValue string AppSchemaValue string diff --git a/scripts/generate_agent_sdk_packages/main.go b/scripts/generate_agent_sdk_packages/main.go index daddcb1a2..c5ce4a566 100644 --- a/scripts/generate_agent_sdk_packages/main.go +++ b/scripts/generate_agent_sdk_packages/main.go @@ -3,6 +3,7 @@ package main import ( "bytes" "fmt" + "go/format" "os" "path/filepath" "sort" @@ -244,6 +245,13 @@ func mustWriteTemplate(path, tmpl string, data templateData) { } func mustWrite(path, content string) { + if filepath.Ext(path) == ".go" { + formatted, err := format.Source([]byte(content)) + if err != nil { + fatalf("format %s: %v", path, err) + } + content = string(formatted) + } // #nosec G301 -- generated SDK/docs directories are checked into the repo and should remain readable. if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { fatalf("create dir for %s: %v", path, err)