diff --git a/cypher/models/pgsql/type.go b/cypher/models/pgsql/type.go index da00462..4fc96e4 100644 --- a/cypher/models/pgsql/type.go +++ b/cypher/models/pgsql/type.go @@ -3,6 +3,7 @@ package pgsql import ( "bytes" "encoding/json" + "strings" "reflect" @@ -55,6 +56,10 @@ func PropertiesToJSONB(properties *graph.Properties) (pgtype.JSONB, error) { return MapStringAnyToJSONB(properties.MapOrEmpty()) } +func DeletedPropertiesToString(properties *graph.Properties) string { + return "{" + strings.Join(properties.DeletedProperties(), ",") + "}" +} + func JSONBToProperties(jsonb pgtype.JSONB) (*graph.Properties, error) { propertiesMap := make(map[string]any) diff --git a/cypher/models/pgsql/type_test.go b/cypher/models/pgsql/type_test.go new file mode 100644 index 0000000..0b82bde --- /dev/null +++ b/cypher/models/pgsql/type_test.go @@ -0,0 +1,80 @@ +package pgsql_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/pgsql" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" +) + +func TestDeletedPropertiesToString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func() *graph.Properties + check func(t *testing.T, result string) + }{ + { + name: "no deleted properties returns empty braces", + setup: func() *graph.Properties { + return graph.NewProperties() + }, + check: func(t *testing.T, result string) { + assert.Equal(t, "{}", result) + }, + }, + { + name: "single deleted property is wrapped in braces", + setup: func() *graph.Properties { + props := graph.NewProperties() + props.Set("mykey", "myvalue") + props.Delete("mykey") + return props + }, + check: func(t *testing.T, result string) { + assert.Equal(t, "{mykey}", result) + }, + }, + { + name: "multiple deleted properties are all present in output", + setup: func() *graph.Properties { + props := graph.NewProperties() + props.Set("alpha", 1) + props.Set("beta", 2) + props.Delete("alpha") + props.Delete("beta") + return props + }, + check: func(t *testing.T, result string) { + // Map iteration order is non-deterministic; accept either ordering. + assert.True(t, result == "{alpha,beta}" || result == "{beta,alpha}", + "unexpected result: %s", result) + }, + }, + { + name: "non-deleted properties are not included", + setup: func() *graph.Properties { + props := graph.NewProperties() + props.Set("active", "yes") + props.Set("removed", "no") + props.Delete("removed") + return props + }, + check: func(t *testing.T, result string) { + assert.Equal(t, "{removed}", result) + assert.NotContains(t, result, "active") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + props := tc.setup() + result := pgsql.DeletedPropertiesToString(props) + tc.check(t, result) + }) + } +} diff --git a/drivers/neo4j/batch.go b/drivers/neo4j/batch.go index 3e8586c..420cb91 100644 --- a/drivers/neo4j/batch.go +++ b/drivers/neo4j/batch.go @@ -20,6 +20,7 @@ type batchTransaction struct { innerTx *neo4jTransaction nodeDeletionBuffer []graph.ID relationshipDeletionBuffer []graph.ID + nodeUpdateBuffer []*graph.Node nodeUpdateByBuffer []graph.NodeUpdate relationshipCreateBuffer []createRelationshipByIDs relationshipUpdateByBuffer []graph.RelationshipUpdate @@ -48,8 +49,20 @@ func (s *batchTransaction) Relationships() graph.RelationshipQuery { } func (s *batchTransaction) UpdateNodeBy(update graph.NodeUpdate) error { - if s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update); len(s.nodeUpdateByBuffer) >= s.batchWriteSize { - return s.flushNodeUpdates() + s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update) + + if len(s.nodeUpdateByBuffer) >= s.batchWriteSize { + return s.flushNodeUpdateByBuffer() + } + + return nil +} + +func (s *batchTransaction) UpdateNodes(nodes []*graph.Node) error { + s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, nodes...) + + if len(s.nodeUpdateBuffer) > s.batchWriteSize { + return s.flushNodeUpdateBuffer() } return nil @@ -73,7 +86,13 @@ func (s *batchTransaction) DeleteRelationships(ids []graph.ID) error { func (s *batchTransaction) Commit() error { if len(s.nodeUpdateByBuffer) > 0 { - if err := s.flushNodeUpdates(); err != nil { + if err := s.flushNodeUpdateByBuffer(); err != nil { + return err + } + } + + if len(s.nodeUpdateBuffer) > 0 { + if err := s.flushNodeUpdateBuffer(); err != nil { return err } } @@ -221,13 +240,20 @@ func (s *batchTransaction) flushRelationshipDeletions() error { return s.DeleteRelationships(buffer) } -func (s *batchTransaction) flushNodeUpdates() error { +func (s *batchTransaction) flushNodeUpdateByBuffer() error { buffer := s.nodeUpdateByBuffer s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0] return s.innerTx.updateNodesBy(buffer...) } +func (s *batchTransaction) flushNodeUpdateBuffer() error { + buffer := s.nodeUpdateBuffer + s.nodeUpdateBuffer = s.nodeUpdateBuffer[:0] + + return s.innerTx.updateNodeBatch(buffer) +} + func (s *batchTransaction) flushRelationshipUpdates() error { buffer := s.relationshipUpdateByBuffer s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0] diff --git a/drivers/neo4j/cypher.go b/drivers/neo4j/cypher.go index 378fc24..158ad24 100644 --- a/drivers/neo4j/cypher.go +++ b/drivers/neo4j/cypher.go @@ -4,9 +4,12 @@ import ( "bytes" "fmt" "log/slog" + "maps" + "slices" "sort" "strings" + "github.com/cespare/xxhash/v2" "github.com/specterops/dawgs/cypher/frontend" "github.com/specterops/dawgs/cypher/models/cypher/format" "github.com/specterops/dawgs/graph" @@ -228,7 +231,117 @@ func (s nodeUpdateByMap) add(update graph.NodeUpdate) { } } -func cypherBuildNodeUpdateQueryBatch(updates []graph.NodeUpdate) ([]string, []map[string]any) { +func nodeToNodeUpdateKey(digester *xxhash.Digest, node *graph.Node) uint64 { + digester.Reset() + + var ( + kindSet = map[string]struct{}{} + digestKinds = func() { + sortedKinds := make([]string, 0, len(kindSet)) + + for nextKind := range maps.Keys(kindSet) { + sortedKinds = append(sortedKinds, nextKind) + } + + slices.Sort(sortedKinds) + + for _, nextKind := range sortedKinds { + digester.WriteString(nextKind) + } + + clear(kindSet) + } + ) + + for _, addedKindStr := range node.AddedKinds.Strings() { + kindSet[addedKindStr] = struct{}{} + } + + digestKinds() + + for _, removedKindStr := range node.AddedKinds.Strings() { + kindSet[removedKindStr] = struct{}{} + } + + digestKinds() + + return digester.Sum64() +} + +type nodeUpdateBatch struct { + nodeKindsToAdd graph.Kinds + nodeKindsToRemove graph.Kinds + Parameters []map[string]any +} + +func cypherBuildNodeUpdateQueryBatch(updates []*graph.Node) ([]string, []map[string]any) { + var ( + queries []string + queryParameters []map[string]any + + output = strings.Builder{} + batchedUpdates = map[uint64]*nodeUpdateBatch{} + digester = xxhash.New() + ) + + for _, nodeToUpdate := range updates { + updateKey := nodeToNodeUpdateKey(digester, nodeToUpdate) + + if existingBatch, hasBatch := batchedUpdates[updateKey]; hasBatch { + existingBatch.Parameters = append(existingBatch.Parameters, map[string]any{ + "id": nodeToUpdate.ID, + "properties": nodeToUpdate.Properties, + }) + } else { + batchedUpdates[updateKey] = &nodeUpdateBatch{ + nodeKindsToAdd: nodeToUpdate.AddedKinds, + nodeKindsToRemove: nodeToUpdate.DeletedKinds, + Parameters: []map[string]any{{ + "id": nodeToUpdate.ID, + "properties": nodeToUpdate.Properties, + }}, + } + } + } + + for _, batch := range batchedUpdates { + output.WriteString("unwind $p as p update (n) where id(n) = p.id set n += p.properties") + + if len(batch.nodeKindsToAdd) > 0 { + for _, kindToAdd := range batch.nodeKindsToAdd { + output.WriteString(", n:") + output.WriteString(kindToAdd.String()) + } + } + + if len(batch.nodeKindsToRemove) > 0 { + output.WriteString(" remove ") + + for idx, kindToRemove := range batch.nodeKindsToRemove { + if idx > 0 { + output.WriteString(",") + } + + output.WriteString("n:") + output.WriteString(kindToRemove.String()) + } + } + + output.WriteString(";") + + // Write out the query to be run + queries = append(queries, output.String()) + queryParameters = append(queryParameters, map[string]any{ + "p": batch.Parameters, + }) + + output.Reset() + } + + return queries, queryParameters +} + +func cypherBuildNodeUpdateQueryByBatch(updates []graph.NodeUpdate) ([]string, []map[string]any) { var ( queries []string queryParameters []map[string]any diff --git a/drivers/neo4j/driver.go b/drivers/neo4j/driver.go index 7498d9d..731b647 100644 --- a/drivers/neo4j/driver.go +++ b/drivers/neo4j/driver.go @@ -43,7 +43,7 @@ func (s *driver) SetWriteFlushSize(size int) { s.writeFlushSize = size } -func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { +func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate, options ...graph.BatchOption) error { // Attempt to acquire a connection slot or wait for a bit until one becomes available if !s.limiter.Acquire(ctx) { return graph.ErrContextTimedOut @@ -51,13 +51,21 @@ func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDe defer s.limiter.Release() } + config := &graph.BatchConfig{ + BatchSize: s.batchWriteSize, + } + + for _, opt := range options { + opt(config) + } + var ( cfg = graph.TransactionConfig{ Timeout: s.defaultTransactionTimeout, } session = s.driver.NewSession(writeCfg()) - batch = newBatchOperation(ctx, session, cfg, s.writeFlushSize, s.batchWriteSize, s.graphQueryMemoryLimit) + batch = newBatchOperation(ctx, session, cfg, s.writeFlushSize, config.BatchSize, s.graphQueryMemoryLimit) ) defer session.Close() diff --git a/drivers/neo4j/transaction.go b/drivers/neo4j/transaction.go index dd2f1a9..c5307c3 100644 --- a/drivers/neo4j/transaction.go +++ b/drivers/neo4j/transaction.go @@ -93,10 +93,25 @@ func (s *neo4jTransaction) UpdateRelationshipBy(update graph.RelationshipUpdate) return s.updateRelationshipsBy(update) } +func (s *neo4jTransaction) updateNodeBatch(batch []*graph.Node) error { + var ( + numUpdates = len(batch) + statements, queryParameterMaps = cypherBuildNodeUpdateQueryBatch(batch) + ) + + for parameterIdx, stmt := range statements { + if result := s.Raw(stmt, queryParameterMaps[parameterIdx]); result.Error() != nil { + return fmt.Errorf("update nodes by error on statement (%s): %s", stmt, result.Error()) + } + } + + return s.logWrites(numUpdates) +} + func (s *neo4jTransaction) updateNodesBy(updates ...graph.NodeUpdate) error { var ( numUpdates = len(updates) - statements, queryParameterMaps = cypherBuildNodeUpdateQueryBatch(updates) + statements, queryParameterMaps = cypherBuildNodeUpdateQueryByBatch(updates) ) for parameterIdx, stmt := range statements { diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index a46a520..c37b110 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/drivers/pg/model" @@ -16,6 +17,10 @@ import ( "github.com/specterops/dawgs/graph" ) +const ( + LargeNodeUpdateThreshold = 1_000_000 +) + type Int2ArrayEncoder struct { buffer *bytes.Buffer } @@ -43,6 +48,7 @@ type batch struct { nodeDeletionBuffer []graph.ID relationshipDeletionBuffer []graph.ID nodeCreateBuffer []*graph.Node + nodeUpdateBuffer []*graph.Node nodeUpdateByBuffer []graph.NodeUpdate relationshipCreateBuffer []*graph.Relationship relationshipUpdateByBuffer []graph.RelationshipUpdate @@ -89,6 +95,121 @@ func (s *batch) UpdateNodeBy(update graph.NodeUpdate) error { return s.tryFlush(s.batchWriteSize) } +// largeUpdate performs a bulk node update using PostgreSQL's COPY FROM to stream +// nodes into a temporary staging table and then MERGE INTO the live node partition. +// This path is more efficient than a parameterised UPDATE for very large batches +// (see LargeUpdateThreshold). +func (s *batch) largeUpdate(nodes []*graph.Node) error { + tx, err := s.innerTransaction.conn.Begin(s.ctx) + if err != nil { + return err + } + + if _, err := tx.Exec(s.ctx, sql.FormatCreateNodeUpdateStagingTable(sql.NodeUpdateStagingTable)); err != nil { + return fmt.Errorf("creating node update staging table: %w", err) + } + + nodeRows := NewLargeNodeUpdateRows(len(nodes)) + if err := nodeRows.AppendAll(s.ctx, nodes, s.schemaManager, s.kindIDEncoder); err != nil { + return err + } + + // Stream the rows into the staging table via COPY FROM. + if _, err := tx.Conn().CopyFrom( + s.ctx, + pgx.Identifier{sql.NodeUpdateStagingTable}, + sql.NodeUpdateStagingColumns, + pgx.CopyFromRows(nodeRows.Rows()), + ); err != nil { + return fmt.Errorf("copying nodes into staging table: %w", err) + } + + graphTarget, err := s.innerTransaction.getTargetGraph() + if err != nil { + return err + } + + if _, err := tx.Exec(s.ctx, sql.FormatMergeNodeLargeUpdate(graphTarget, sql.NodeUpdateStagingTable)); err != nil { + return fmt.Errorf("merging node updates from staging table: %w", err) + } + + if err := tx.Commit(s.ctx); err != nil { + return err + } + + return nil +} + +// LargeNodeUpdateRows accumulates encoded node rows for bulk loading via COPY FROM. +// The column order matches sql.NodeUpdateStagingColumns. +type LargeNodeUpdateRows struct { + rows [][]any +} + +func NewLargeNodeUpdateRows(size int) *LargeNodeUpdateRows { + return &LargeNodeUpdateRows{ + rows: make([][]any, 0, size), + } +} + +func (s *LargeNodeUpdateRows) Rows() [][]any { + return s.rows +} + +func (s *LargeNodeUpdateRows) Append(ctx context.Context, node *graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + addedKindIDs, err := schemaManager.AssertKinds(ctx, node.Kinds) + if err != nil { + return fmt.Errorf("mapping added kinds for node %d: %w", node.ID, err) + } + + deletedKindIDs, err := schemaManager.AssertKinds(ctx, node.DeletedKinds) + if err != nil { + return fmt.Errorf("mapping deleted kinds for node %d: %w", node.ID, err) + } + + propertiesJSONB, err := pgsql.PropertiesToJSONB(node.Properties) + if err != nil { + return fmt.Errorf("encoding properties for node %d: %w", node.ID, err) + } + + s.rows = append(s.rows, []any{ + node.ID.Int64(), + kindIDEncoder.Encode(addedKindIDs), + kindIDEncoder.Encode(deletedKindIDs), + string(propertiesJSONB.Bytes), + pgsql.DeletedPropertiesToString(node.Properties), + }) + + return nil +} + +// AppendAll encodes every node in the slice and appends its row to the accumulator. +func (s *LargeNodeUpdateRows) AppendAll(ctx context.Context, nodes []*graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + for _, node := range nodes { + if err := s.Append(ctx, node, schemaManager, kindIDEncoder); err != nil { + return err + } + } + + return nil +} + +func (s *batch) UpdateNodes(nodes []*graph.Node) error { + if len(nodes) > LargeNodeUpdateThreshold { + return s.largeUpdate(nodes) + } + + for _, node := range nodes { + + s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, node) + if err := s.tryFlush(s.batchWriteSize); err != nil { + return err + } + } + + return nil +} + func (s *batch) flushNodeDeleteBuffer() error { if _, err := s.innerTransaction.conn.Exec(s.ctx, deleteNodeWithIDStatement, s.nodeDeletionBuffer); err != nil { return err @@ -229,6 +350,10 @@ func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { idFutureIndex++ } + + if err := rows.Err(); err != nil { + return err + } } } @@ -246,6 +371,101 @@ func (s *batch) tryFlushNodeUpdateByBuffer() error { return nil } +func (s *batch) flushNodeUpdateBatch(nodes []*graph.Node) error { + parameters := NewNodeUpdateParameters(len(nodes)) + + if err := parameters.AppendAll(s.ctx, nodes, s.schemaManager, s.kindIDEncoder); err != nil { + return err + } + + if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { + return err + } else { + query := sql.FormatNodesUpdate(graphTarget) + + if rows, err := s.innerTransaction.conn.Query(s.ctx, query, parameters.Format()...); err != nil { + return err + } else { + rows.Close() + + return rows.Err() + } + } +} + +func (s *batch) tryFlushNodeUpdateBuffer() error { + if err := s.flushNodeUpdateBatch(s.nodeUpdateBuffer); err != nil { + return err + } + + s.nodeUpdateBuffer = s.nodeUpdateBuffer[:0] + return nil +} + +type NodeUpdateParameters struct { + NodeIDs []graph.ID + KindSlices []string + DeletedKindSlices []string + Properties []pgtype.JSONB + DeletedProperties []string +} + +func NewNodeUpdateParameters(size int) *NodeUpdateParameters { + return &NodeUpdateParameters{ + NodeIDs: make([]graph.ID, 0, size), + KindSlices: make([]string, 0, size), + DeletedKindSlices: make([]string, 0, size), + Properties: make([]pgtype.JSONB, 0, size), + DeletedProperties: make([]string, 0, size), + } +} + +func (s *NodeUpdateParameters) Format() []any { + return []any{ + s.NodeIDs, + s.KindSlices, + s.DeletedKindSlices, + s.Properties, + s.DeletedProperties, + } +} + +func (s *NodeUpdateParameters) Append(ctx context.Context, node *graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + s.NodeIDs = append(s.NodeIDs, node.ID) + + if mappedKindIDs, err := schemaManager.AssertKinds(ctx, node.Kinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) + } else { + s.KindSlices = append(s.KindSlices, kindIDEncoder.Encode(mappedKindIDs)) + } + + if mappedKindIDs, err := schemaManager.AssertKinds(ctx, node.DeletedKinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) + } else { + s.DeletedKindSlices = append(s.DeletedKindSlices, kindIDEncoder.Encode(mappedKindIDs)) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(node.Properties); err != nil { + return err + } else { + s.Properties = append(s.Properties, propertiesJSONB) + } + + s.DeletedProperties = append(s.DeletedProperties, pgsql.DeletedPropertiesToString(node.Properties)) + + return nil +} + +func (s *NodeUpdateParameters) AppendAll(ctx context.Context, nodes []*graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + for _, node := range nodes { + if err := s.Append(ctx, node, schemaManager, kindIDEncoder); err != nil { + return err + } + } + + return nil +} + type NodeUpsertParameters struct { IDFutures []*sql.Future[graph.ID] KindIDSlices []string @@ -501,6 +721,12 @@ func (s *batch) tryFlush(batchWriteSize int) error { } } + if len(s.nodeUpdateBuffer) > batchWriteSize { + if err := s.tryFlushNodeUpdateBuffer(); err != nil { + return err + } + } + if len(s.relationshipUpdateByBuffer) > batchWriteSize { if err := s.tryFlushRelationshipUpdateByBuffer(); err != nil { return err diff --git a/drivers/pg/driver.go b/drivers/pg/driver.go index fcc05f2..bbee76b 100644 --- a/drivers/pg/driver.go +++ b/drivers/pg/driver.go @@ -64,8 +64,16 @@ func (s *Driver) SetWriteFlushSize(size int) { // THis is a no-op function since PostgreSQL does not require transaction rotation like Neo4j does } -func (s *Driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { - if cfg, err := renderConfig(batchWriteSize, readWriteTxOptions, nil); err != nil { +func (s *Driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate, options ...graph.BatchOption) error { + batchConfig := &graph.BatchConfig{ + BatchSize: batchWriteSize, + } + + for _, opt := range options { + opt(batchConfig) + } + + if cfg, err := renderConfig(batchConfig.BatchSize, readWriteTxOptions, nil); err != nil { return err } else if conn, err := s.pool.Acquire(ctx); err != nil { return err diff --git a/drivers/pg/query/format.go b/drivers/pg/query/format.go index db1ed24..bc181e9 100644 --- a/drivers/pg/query/format.go +++ b/drivers/pg/query/format.go @@ -101,6 +101,51 @@ func formatConflictMatcher(propertyNames []string, defaultOnConflict string) str return builder.String() } +func FormatNodesUpdate(graphTarget model.Graph) string { + return join( + "update ", graphTarget.Partitions.Node.Name, " as n ", + "set ", + " kind_ids = uniq(sort(kind_ids - u.deleted_kinds || u.added_kinds)), ", + " properties = n.properties - u.deleted_properties || u.properties ", + "from ", + " (select ", + " unnest($1::text[])::int8 as id, ", + " unnest($2::text[])::int2[] as added_kinds, ", + " unnest($3::text[])::int2[] as deleted_kinds, ", + " unnest($4::jsonb[]) as properties, ", + " unnest($5::text[])::text[] as deleted_properties) as u ", + "where n.id = u.id; ", + ) +} + +// NodeUpdateStagingTable is the name of the temporary staging table used by largeUpdate. +const NodeUpdateStagingTable = "node_update_staging" + +// NodeUpdateStagingColumns lists the columns (in order) written by a COPY FROM during largeUpdate. +var NodeUpdateStagingColumns = []string{"id", "added_kinds", "deleted_kinds", "properties", "deleted_props"} + +func FormatCreateNodeUpdateStagingTable(stagingTable string) string { + return join( + "create temp table if not exists ", stagingTable, " (", + "id bigint, ", + "added_kinds text, ", + "deleted_kinds text, ", + "properties text, ", + "deleted_props text", + ") on commit drop;", + ) +} + +func FormatMergeNodeLargeUpdate(graphTarget model.Graph, stagingTable string) string { + return join( + "merge into ", graphTarget.Partitions.Node.Name, " as n ", + "using ", stagingTable, " as u on n.id = u.id ", + "when matched then update set ", + "kind_ids = uniq(sort(n.kind_ids - u.deleted_kinds::int2[] || u.added_kinds::int2[])), ", + "properties = n.properties - u.deleted_props::text[] || u.properties::jsonb;", + ) +} + func FormatNodeUpsert(graphTarget model.Graph, identityProperties []string) string { return join( "insert into ", graphTarget.Partitions.Node.Name, " as n ", diff --git a/drivers/pg/query/format_test.go b/drivers/pg/query/format_test.go new file mode 100644 index 0000000..2092182 --- /dev/null +++ b/drivers/pg/query/format_test.go @@ -0,0 +1,93 @@ +package query_test + +import ( + "strings" + "testing" + + "github.com/specterops/dawgs/drivers/pg/model" + query "github.com/specterops/dawgs/drivers/pg/query" + "github.com/stretchr/testify/assert" +) + +func generateTestGraphTarget(nodePartitionName string) model.Graph { + return model.Graph{ + Partitions: model.GraphPartitions{ + Node: model.NewGraphPartition(nodePartitionName), + }, + } +} + +func TestFormatNodesUpdate(t *testing.T) { + t.Parallel() + + var ( + partitionName = "node_1" + expected = strings.Join([]string{ + "update node_1 as n ", + "set ", + " kind_ids = uniq(sort(kind_ids - u.deleted_kinds || u.added_kinds)), ", + " properties = n.properties - u.deleted_properties || u.properties ", + "from ", + " (select ", + " unnest($1::text[])::int8 as id, ", + " unnest($2::text[])::int2[] as added_kinds, ", + " unnest($3::text[])::int2[] as deleted_kinds, ", + " unnest($4::jsonb[]) as properties, ", + " unnest($5::text[])::text[] as deleted_properties) as u ", + "where n.id = u.id; ", + }, "") + result = query.FormatNodesUpdate(generateTestGraphTarget(partitionName)) + ) + + assert.Equal(t, expected, result) +} + +func TestFormatCreateNodeUpdateStagingTable(t *testing.T) { + t.Parallel() + + var ( + tableName = "my_staging_table" + expected = strings.Join([]string{ + "create temp table if not exists my_staging_table (", + "id bigint, ", + "added_kinds text, ", + "deleted_kinds text, ", + "properties text, ", + "deleted_props text", + ") on commit drop;", + }, "") + result = query.FormatCreateNodeUpdateStagingTable(tableName) + ) + + assert.Equal(t, expected, result) +} + +func TestFormatMergeNodeLargeUpdate(t *testing.T) { + t.Parallel() + + var ( + partitionName = "node_part_1" + stagingTable = "my_staging" + expected = strings.Join([]string{ + "merge into node_part_1 as n ", + "using my_staging as u on n.id = u.id ", + "when matched then update set ", + "kind_ids = uniq(sort(n.kind_ids - u.deleted_kinds::int2[] || u.added_kinds::int2[])), ", + "properties = n.properties - u.deleted_props::text[] || u.properties::jsonb;", + }, "") + result = query.FormatMergeNodeLargeUpdate(generateTestGraphTarget(partitionName), stagingTable) + ) + + assert.Equal(t, expected, result) +} + +func TestNodeUpdateStagingColumns(t *testing.T) { + t.Parallel() + + var ( + // Note: order is important + expected = []string{"id", "added_kinds", "deleted_kinds", "properties", "deleted_props"} + ) + + assert.Equal(t, expected, query.NodeUpdateStagingColumns) +} diff --git a/graph/graph.go b/graph/graph.go index 08b9850..5855409 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -283,6 +283,9 @@ type Batch interface { // exist, created. UpdateNodeBy(update NodeUpdate) error + // Updates nodes by ID + UpdateNodes(nodes []*Node) error + // TODO: Existing batch logic expects this to perform an upsert on conficts with (start_id, end_id, kind). This is incorrect and should be refactored CreateRelationship(relationship *Relationship) error @@ -363,6 +366,18 @@ type TransactionConfig struct { // TransactionOption is a function that represents a configuration setting for the underlying database transaction. type TransactionOption func(config *TransactionConfig) +func WithBatchSize(size int) BatchOption { + return func(config *BatchConfig) { + config.BatchSize = size + } +} + +type BatchOption func(config *BatchConfig) + +type BatchConfig struct { + BatchSize int +} + // Database is a high-level interface representing transactional entry-points into DAWGS driver implementations. type Database interface { // SetWriteFlushSize sets a new write flush interval on the current driver @@ -383,7 +398,7 @@ type Database interface { // given logic function. Batch operations are fundamentally different between databases supported by DAWGS, // necessitating a different interface that lacks many of the convenience features of a regular read or write // transaction. - BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error + BatchOperation(ctx context.Context, batchDelegate BatchDelegate, options ...BatchOption) error // AssertSchema will apply the given schema to the underlying database. AssertSchema(ctx context.Context, dbSchema Schema) error diff --git a/graph/graph_test.go b/graph/graph_test.go new file mode 100644 index 0000000..5408400 --- /dev/null +++ b/graph/graph_test.go @@ -0,0 +1,44 @@ +package graph_test + +import ( + "testing" + + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" +) + +func TestWithBatchSize(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + batchSize int + expectedSize int + }{ + { + name: "sets batch size to provided value", + batchSize: 500, + expectedSize: 500, + }, + { + name: "sets batch size to zero", + batchSize: 0, + expectedSize: 0, + }, + { + name: "sets batch size to large value", + batchSize: 1_000_000, + expectedSize: 1_000_000, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + config := &graph.BatchConfig{} + opt := graph.WithBatchSize(tc.batchSize) + opt(config) + assert.Equal(t, tc.expectedSize, config.BatchSize) + }) + } +} diff --git a/graph/node.go b/graph/node.go index e6db3c8..de40c8a 100644 --- a/graph/node.go +++ b/graph/node.go @@ -112,6 +112,28 @@ func (s *Node) MarshalJSON() ([]byte, error) { return json.Marshal(jsonNode) } +/* +StripAllPropertiesExcept removes all properties from the node except for the ones specified in the except list. +Deleted properties are also removed from the node, except for the ones specified in the except list. +The use case for this function is if you have fully hydrated nodes in memory, +but only want to update a few properties it is most efficient to strip all properties except for the ones you want to update. +*/ +func (s *Node) StripAllPropertiesExcept(except ...string) { + newProperties := NewProperties() + + for _, exclusion := range except { + if s.Properties.Exists(exclusion) { + newProperties.Set(exclusion, s.Properties.Get(exclusion).Any()) + } + + if _, present := s.Properties.Deleted[exclusion]; present { + newProperties.Delete(exclusion) + } + } + + s.Properties = newProperties +} + // NodeSet is a mapped index of Node instances and their ID fields. type NodeSet map[ID]*Node diff --git a/graph/node_test.go b/graph/node_test.go index 5fccad3..a617364 100644 --- a/graph/node_test.go +++ b/graph/node_test.go @@ -1,12 +1,114 @@ package graph_test import ( + "slices" "testing" "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func Test_StripAllPropertiesExcept(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func() *graph.Node + except []string + assertions func(t *testing.T, node *graph.Node) + }{ + { + name: "keeps only specified properties that exist", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("keep1", "value1") + node.Properties.Set("keep2", "value2") + node.Properties.Set("remove", "value3") + return node + }, + except: []string{"keep1", "keep2"}, + assertions: func(t *testing.T, node *graph.Node) { + assert.True(t, node.Properties.Exists("keep1")) + assert.Equal(t, "value1", node.Properties.Get("keep1").Any()) + assert.True(t, node.Properties.Exists("keep2")) + assert.Equal(t, "value2", node.Properties.Get("keep2").Any()) + assert.False(t, node.Properties.Exists("remove")) + }, + }, + { + name: "non-existent except keys are silently skipped", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("keep", "value") + return node + }, + except: []string{"keep", "nonexistent"}, + assertions: func(t *testing.T, node *graph.Node) { + assert.True(t, node.Properties.Exists("keep")) + assert.Equal(t, 1, node.Properties.Len()) + }, + }, + { + name: "no args clears all properties", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("key", "value") + return node + }, + except: nil, + assertions: func(t *testing.T, node *graph.Node) { + assert.Equal(t, 0, node.Properties.Len()) + }, + }, + { + name: "respects deleted properties", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Delete("key") + return node + }, + except: []string{"key"}, + assertions: func(t *testing.T, node *graph.Node) { + deleted := node.Properties.DeletedProperties() + assert.Equal(t, 1, len(deleted)) + assert.True(t, slices.Contains(deleted, "key")) + }, + }, + { + name: "respects deleted properties 2", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("hello", "value1") + node.Properties.Set("loremipsum", "value2") + node.Properties.Delete("key1") + node.Properties.Delete("key2") + node.Properties.Delete("key3") + return node + }, + except: []string{"hello", "world", "key1", "key2"}, + assertions: func(t *testing.T, node *graph.Node) { + deleted := node.Properties.DeletedProperties() + assert.Equal(t, 2, len(deleted)) + assert.True(t, slices.Contains(deleted, "key1")) + assert.True(t, slices.Contains(deleted, "key2")) + assert.False(t, node.Properties.Exists("loremipsum")) + assert.True(t, node.Properties.Exists("hello")) + assert.Equal(t, "value1", node.Properties.Get("hello").Any()) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + node := tc.setup() + node.StripAllPropertiesExcept(tc.except...) + tc.assertions(t, node) + }) + } +} + func Test_NodeSizeOf(t *testing.T) { node := graph.Node{ID: graph.ID(1)} oldSize := int64(node.SizeOf()) diff --git a/graph/switch.go b/graph/switch.go index 2ecfb5a..3c40272 100644 --- a/graph/switch.go +++ b/graph/switch.go @@ -141,7 +141,7 @@ func (s *DatabaseSwitch) WriteTransaction(ctx context.Context, txDelegate Transa } } -func (s *DatabaseSwitch) BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error { +func (s *DatabaseSwitch) BatchOperation(ctx context.Context, batchDelegate BatchDelegate, options ...BatchOption) error { if internalCtx, err := s.newInternalContext(ctx); err != nil { return err } else { @@ -150,7 +150,7 @@ func (s *DatabaseSwitch) BatchOperation(ctx context.Context, batchDelegate Batch s.currentDBLock.RLock() defer s.currentDBLock.RUnlock() - return s.currentDB.BatchOperation(internalCtx, batchDelegate) + return s.currentDB.BatchOperation(internalCtx, batchDelegate, options...) } } diff --git a/ops/ops.go b/ops/ops.go index 9123ecb..a71902e 100644 --- a/ops/ops.go +++ b/ops/ops.go @@ -598,3 +598,23 @@ func ParallelFetchNodes(ctx context.Context, db graph.Database, criteria graph.C return parallelFetchNodes(ctx, db, largestNodeID, criteria, numWorkers) } } + +/* +UpdateNodes batch updates nodes by node ID. +This should be the default tool for updating a collection of nodes from memory. +*/ +func UpdateNodes(ctx context.Context, graphDB graph.Database, nodes []*graph.Node, batchSize ...int) error { + options := func(_ *graph.BatchConfig) {} + + if len(batchSize) > 0 { + options = graph.WithBatchSize(batchSize[0]) + } + + if err := graphDB.BatchOperation(ctx, func(batch graph.Batch) error { + return batch.UpdateNodes(nodes) + }, options); err != nil { + return err + } + + return nil +}