From 483a59e743f5684ee7cbb531eddd6d32f1f400c3 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Mon, 2 Mar 2026 17:33:26 -0600 Subject: [PATCH 01/14] initial implementation --- cypher/models/pgsql/type.go | 5 ++ drivers/neo4j/batch.go | 5 ++ drivers/pg/batch.go | 127 ++++++++++++++++++++++++++++++++++++ drivers/pg/query/format.go | 39 +++++++++++ graph/graph.go | 3 + graph/node.go | 18 +++++ ops/ops.go | 14 ++++ 7 files changed, 211 insertions(+) diff --git a/cypher/models/pgsql/type.go b/cypher/models/pgsql/type.go index da00462..4fc96e4 100644 --- a/cypher/models/pgsql/type.go +++ b/cypher/models/pgsql/type.go @@ -3,6 +3,7 @@ package pgsql import ( "bytes" "encoding/json" + "strings" "reflect" @@ -55,6 +56,10 @@ func PropertiesToJSONB(properties *graph.Properties) (pgtype.JSONB, error) { return MapStringAnyToJSONB(properties.MapOrEmpty()) } +func DeletedPropertiesToString(properties *graph.Properties) string { + return "{" + strings.Join(properties.DeletedProperties(), ",") + "}" +} + func JSONBToProperties(jsonb pgtype.JSONB) (*graph.Properties, error) { propertiesMap := make(map[string]any) diff --git a/drivers/neo4j/batch.go b/drivers/neo4j/batch.go index 3e8586c..276eecf 100644 --- a/drivers/neo4j/batch.go +++ b/drivers/neo4j/batch.go @@ -55,6 +55,11 @@ func (s *batchTransaction) UpdateNodeBy(update graph.NodeUpdate) error { return nil } +func (s *batchTransaction) UpdateNodes(nodes []*graph.Node) error { + panic("unimplemented") + return nil +} + func (s *batchTransaction) UpdateRelationshipBy(update graph.RelationshipUpdate) error { if s.relationshipUpdateByBuffer = append(s.relationshipUpdateByBuffer, update); len(s.relationshipUpdateByBuffer) >= s.batchWriteSize { return s.flushRelationshipUpdates() diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index a46a520..ef509a4 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -16,6 +16,10 @@ import ( "github.com/specterops/dawgs/graph" ) +const ( + LargeUpdateThreshold = 1_000_000 +) + type Int2ArrayEncoder struct { buffer *bytes.Buffer } @@ -43,6 +47,7 @@ type batch struct { nodeDeletionBuffer []graph.ID relationshipDeletionBuffer []graph.ID nodeCreateBuffer []*graph.Node + nodeUpdateBuffer []*graph.Node nodeUpdateByBuffer []graph.NodeUpdate relationshipCreateBuffer []*graph.Relationship relationshipUpdateByBuffer []graph.RelationshipUpdate @@ -89,6 +94,27 @@ func (s *batch) UpdateNodeBy(update graph.NodeUpdate) error { return s.tryFlush(s.batchWriteSize) } +// TODO: test COPY ... FROM ... with MERGE INTO ... +func (s *batch) largeUpdate(_ []*graph.Node) error { + return nil +} + +func (s *batch) UpdateNodes(nodes []*graph.Node) error { + if len(nodes) > LargeUpdateThreshold { + return s.largeUpdate(nodes) + } else { + for _, node := range nodes { + + s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, node) + if err := s.tryFlush(s.batchWriteSize); err != nil { + return err + } + } + } + + return nil +} + func (s *batch) flushNodeDeleteBuffer() error { if _, err := s.innerTransaction.conn.Exec(s.ctx, deleteNodeWithIDStatement, s.nodeDeletionBuffer); err != nil { return err @@ -246,6 +272,101 @@ func (s *batch) tryFlushNodeUpdateByBuffer() error { return nil } +func (s *batch) flushNodeUpdateBatch(nodes []*graph.Node) error { + parameters := NewNodeUpdateParameters(len(nodes)) + + if err := parameters.AppendAll(s.ctx, nodes, s.schemaManager, s.kindIDEncoder); err != nil { + return err + } + + if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { + return err + } else { + query := sql.FormatNodesUpdate(graphTarget) + + if rows, err := s.innerTransaction.conn.Query(s.ctx, query, parameters.Format()...); err != nil { + return err + } else { + rows.Close() + + return rows.Err() + } + } +} + +func (s *batch) tryFlushNodeUpdateBuffer() error { + if err := s.flushNodeUpdateBatch(s.nodeUpdateBuffer); err != nil { + return err + } + + s.nodeUpdateBuffer = s.nodeUpdateBuffer[:0] + return nil +} + +type NodeUpdateParameters struct { + NodeIDs []graph.ID + KindSlices []string + DeletedKindSlices []string + Properties []pgtype.JSONB + DeletedProperties []string +} + +func NewNodeUpdateParameters(size int) *NodeUpdateParameters { + return &NodeUpdateParameters{ + NodeIDs: make([]graph.ID, 0, size), + KindSlices: make([]string, 0, size), + DeletedKindSlices: make([]string, 0, size), + Properties: make([]pgtype.JSONB, 0, size), + DeletedProperties: make([]string, 0, size), + } +} + +func (s *NodeUpdateParameters) Format() []any { + return []any{ + s.NodeIDs, + s.KindSlices, + s.DeletedKindSlices, + s.Properties, + s.DeletedProperties, + } +} + +func (s *NodeUpdateParameters) Append(ctx context.Context, node *graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + s.NodeIDs = append(s.NodeIDs, node.ID) + + if mappedKindIDs, err := schemaManager.AssertKinds(ctx, node.Kinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) + } else { + s.KindSlices = append(s.KindSlices, kindIDEncoder.Encode(mappedKindIDs)) + } + + if mappedKindIDs, err := schemaManager.AssertKinds(ctx, node.DeletedKinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) + } else { + s.DeletedKindSlices = append(s.DeletedKindSlices, kindIDEncoder.Encode(mappedKindIDs)) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(node.Properties); err != nil { + return err + } else { + s.Properties = append(s.Properties, propertiesJSONB) + } + + s.DeletedProperties = append(s.DeletedProperties, pgsql.DeletedPropertiesToString(node.Properties)) + + return nil +} + +func (s *NodeUpdateParameters) AppendAll(ctx context.Context, nodes []*graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + for _, node := range nodes { + if err := s.Append(ctx, node, schemaManager, kindIDEncoder); err != nil { + return err + } + } + + return nil +} + type NodeUpsertParameters struct { IDFutures []*sql.Future[graph.ID] KindIDSlices []string @@ -501,6 +622,12 @@ func (s *batch) tryFlush(batchWriteSize int) error { } } + if len(s.nodeUpdateBuffer) > batchWriteSize { + if err := s.tryFlushNodeUpdateBuffer(); err != nil { + return err + } + } + if len(s.relationshipUpdateByBuffer) > batchWriteSize { if err := s.tryFlushRelationshipUpdateByBuffer(); err != nil { return err diff --git a/drivers/pg/query/format.go b/drivers/pg/query/format.go index db1ed24..4fdc2c8 100644 --- a/drivers/pg/query/format.go +++ b/drivers/pg/query/format.go @@ -101,6 +101,45 @@ func formatConflictMatcher(propertyNames []string, defaultOnConflict string) str return builder.String() } +func FormatNodesUpdate(graphTarget model.Graph) string { + + /* + TODO: clean up + + update node_1 as n + set + kind_ids = sort (uniq (kind_ids - u.deleted_kinds || u.added_kinds)), + properties = n.properties - u.deleted_properties || u.properties + from + ( + select + unnest(:IDS::text[])::int8 as id, + unnest(:KINDS::text[])::int2[] as added_kinds, + unnest(:DELETED_KINDS::text[])::int2[] as deleted_kinds, + unnest(:PROPERTIES::jsonb[]) as properties, + unnest(:DELETED_PROPERTIES::text[])::text[] as deleted_properties + ) as u + where + n.id = u.id; + + */ + + return join( + "update ", graphTarget.Partitions.Node.Name, " as n ", + "set ", + " kind_ids = sort (uniq (kind_ids - u.deleted_kinds || u.added_kinds)), ", + " properties = n.properties - u.deleted_properties || u.properties ", + "from ", + " (select ", + " unnest($1::text[])::int8 as id, ", + " unnest($2::text[])::int2[] as added_kinds, ", + " unnest($3::text[])::int2[] as deleted_kinds, ", + " unnest($4::jsonb[]) as properties, ", + " unnest($5::text[])::text[] as deleted_properties) as u ", + "where n.id = u.id ", + ) +} + func FormatNodeUpsert(graphTarget model.Graph, identityProperties []string) string { return join( "insert into ", graphTarget.Partitions.Node.Name, " as n ", diff --git a/graph/graph.go b/graph/graph.go index 08b9850..6ce460b 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -283,6 +283,9 @@ type Batch interface { // exist, created. UpdateNodeBy(update NodeUpdate) error + // Updates nodes by ID + UpdateNodes(nodes []*Node) error + // TODO: Existing batch logic expects this to perform an upsert on conficts with (start_id, end_id, kind). This is incorrect and should be refactored CreateRelationship(relationship *Relationship) error diff --git a/graph/node.go b/graph/node.go index e6db3c8..61f8ca9 100644 --- a/graph/node.go +++ b/graph/node.go @@ -112,6 +112,24 @@ func (s *Node) MarshalJSON() ([]byte, error) { return json.Marshal(jsonNode) } +func (s *Node) StripAllPropertiesExcept(except ...string) { + tmp := make([]any, 0, len(except)) + found := make([]string, 0, len(except)) + + for _, exclusion := range except { + if s.Properties.Exists(exclusion) { + found = append(found, exclusion) + tmp = append(tmp, s.Properties.Get(exclusion).Any()) + } + } + + s.Properties = NewProperties() + + for i, key := range found { + s.Properties.Set(key, tmp[i]) + } +} + // NodeSet is a mapped index of Node instances and their ID fields. type NodeSet map[ID]*Node diff --git a/ops/ops.go b/ops/ops.go index 9123ecb..9ef08c2 100644 --- a/ops/ops.go +++ b/ops/ops.go @@ -598,3 +598,17 @@ func ParallelFetchNodes(ctx context.Context, db graph.Database, criteria graph.C return parallelFetchNodes(ctx, db, largestNodeID, criteria, numWorkers) } } + +/* +UpdateNodes batch updates nodes by node ID. +This should be the default tool for updating a collection of nodes from memory. +*/ +func UpdateNodes(ctx context.Context, graphDB graph.Database, nodes []*graph.Node) error { + if err := graphDB.BatchOperation(ctx, func(batch graph.Batch) error { + return batch.UpdateNodes(nodes) + }); err != nil { + return err + } + + return nil +} From 488df1a6f5964b699cee6d9ef0d55d383728bbeb Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Mon, 2 Mar 2026 18:33:51 -0600 Subject: [PATCH 02/14] fix sql --- drivers/pg/query/format.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/drivers/pg/query/format.go b/drivers/pg/query/format.go index 4fdc2c8..a722871 100644 --- a/drivers/pg/query/format.go +++ b/drivers/pg/query/format.go @@ -127,7 +127,7 @@ func FormatNodesUpdate(graphTarget model.Graph) string { return join( "update ", graphTarget.Partitions.Node.Name, " as n ", "set ", - " kind_ids = sort (uniq (kind_ids - u.deleted_kinds || u.added_kinds)), ", + " kind_ids = uniq(sort(kind_ids - u.deleted_kinds || u.added_kinds)), ", " properties = n.properties - u.deleted_properties || u.properties ", "from ", " (select ", From 54ea6a9cca1bbe0a100f2ae99de292ad3f451d93 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Mon, 2 Mar 2026 20:23:38 -0600 Subject: [PATCH 03/14] add batch size option to batchoperation --- drivers/neo4j/driver.go | 12 ++++++++++-- drivers/pg/batch.go | 1 + drivers/pg/driver.go | 12 ++++++++++-- graph/graph.go | 14 +++++++++++++- graph/switch.go | 4 ++-- ops/ops.go | 10 ++++++++-- 6 files changed, 44 insertions(+), 9 deletions(-) diff --git a/drivers/neo4j/driver.go b/drivers/neo4j/driver.go index 7498d9d..731b647 100644 --- a/drivers/neo4j/driver.go +++ b/drivers/neo4j/driver.go @@ -43,7 +43,7 @@ func (s *driver) SetWriteFlushSize(size int) { s.writeFlushSize = size } -func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { +func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate, options ...graph.BatchOption) error { // Attempt to acquire a connection slot or wait for a bit until one becomes available if !s.limiter.Acquire(ctx) { return graph.ErrContextTimedOut @@ -51,13 +51,21 @@ func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDe defer s.limiter.Release() } + config := &graph.BatchConfig{ + BatchSize: s.batchWriteSize, + } + + for _, opt := range options { + opt(config) + } + var ( cfg = graph.TransactionConfig{ Timeout: s.defaultTransactionTimeout, } session = s.driver.NewSession(writeCfg()) - batch = newBatchOperation(ctx, session, cfg, s.writeFlushSize, s.batchWriteSize, s.graphQueryMemoryLimit) + batch = newBatchOperation(ctx, session, cfg, s.writeFlushSize, config.BatchSize, s.graphQueryMemoryLimit) ) defer session.Close() diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index ef509a4..c750c74 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -244,6 +244,7 @@ func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { if rows, err := s.innerTransaction.conn.Query(s.ctx, query, parameters.Format(graphTarget)...); err != nil { return err } else { + // TODO: rows.Err() is never called, silently swallowing errors defer rows.Close() idFutureIndex := 0 diff --git a/drivers/pg/driver.go b/drivers/pg/driver.go index fcc05f2..bbee76b 100644 --- a/drivers/pg/driver.go +++ b/drivers/pg/driver.go @@ -64,8 +64,16 @@ func (s *Driver) SetWriteFlushSize(size int) { // THis is a no-op function since PostgreSQL does not require transaction rotation like Neo4j does } -func (s *Driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { - if cfg, err := renderConfig(batchWriteSize, readWriteTxOptions, nil); err != nil { +func (s *Driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate, options ...graph.BatchOption) error { + batchConfig := &graph.BatchConfig{ + BatchSize: batchWriteSize, + } + + for _, opt := range options { + opt(batchConfig) + } + + if cfg, err := renderConfig(batchConfig.BatchSize, readWriteTxOptions, nil); err != nil { return err } else if conn, err := s.pool.Acquire(ctx); err != nil { return err diff --git a/graph/graph.go b/graph/graph.go index 6ce460b..5855409 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -366,6 +366,18 @@ type TransactionConfig struct { // TransactionOption is a function that represents a configuration setting for the underlying database transaction. type TransactionOption func(config *TransactionConfig) +func WithBatchSize(size int) BatchOption { + return func(config *BatchConfig) { + config.BatchSize = size + } +} + +type BatchOption func(config *BatchConfig) + +type BatchConfig struct { + BatchSize int +} + // Database is a high-level interface representing transactional entry-points into DAWGS driver implementations. type Database interface { // SetWriteFlushSize sets a new write flush interval on the current driver @@ -386,7 +398,7 @@ type Database interface { // given logic function. Batch operations are fundamentally different between databases supported by DAWGS, // necessitating a different interface that lacks many of the convenience features of a regular read or write // transaction. - BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error + BatchOperation(ctx context.Context, batchDelegate BatchDelegate, options ...BatchOption) error // AssertSchema will apply the given schema to the underlying database. AssertSchema(ctx context.Context, dbSchema Schema) error diff --git a/graph/switch.go b/graph/switch.go index 2ecfb5a..3c40272 100644 --- a/graph/switch.go +++ b/graph/switch.go @@ -141,7 +141,7 @@ func (s *DatabaseSwitch) WriteTransaction(ctx context.Context, txDelegate Transa } } -func (s *DatabaseSwitch) BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error { +func (s *DatabaseSwitch) BatchOperation(ctx context.Context, batchDelegate BatchDelegate, options ...BatchOption) error { if internalCtx, err := s.newInternalContext(ctx); err != nil { return err } else { @@ -150,7 +150,7 @@ func (s *DatabaseSwitch) BatchOperation(ctx context.Context, batchDelegate Batch s.currentDBLock.RLock() defer s.currentDBLock.RUnlock() - return s.currentDB.BatchOperation(internalCtx, batchDelegate) + return s.currentDB.BatchOperation(internalCtx, batchDelegate, options...) } } diff --git a/ops/ops.go b/ops/ops.go index 9ef08c2..1831e16 100644 --- a/ops/ops.go +++ b/ops/ops.go @@ -603,10 +603,16 @@ func ParallelFetchNodes(ctx context.Context, db graph.Database, criteria graph.C UpdateNodes batch updates nodes by node ID. This should be the default tool for updating a collection of nodes from memory. */ -func UpdateNodes(ctx context.Context, graphDB graph.Database, nodes []*graph.Node) error { +func UpdateNodes(ctx context.Context, graphDB graph.Database, nodes []*graph.Node, batchSize ...int) error { + options := func(_ *graph.BatchConfig) {} + + if batchSize != nil { + options = graph.WithBatchSize(batchSize[0]) + } + if err := graphDB.BatchOperation(ctx, func(batch graph.Batch) error { return batch.UpdateNodes(nodes) - }); err != nil { + }, options); err != nil { return err } From 5c177f735cf166e695d910818cda13863c53d6f7 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Mon, 2 Mar 2026 21:26:08 -0600 Subject: [PATCH 04/14] update --- drivers/pg/batch.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index c750c74..54e8394 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -100,15 +100,11 @@ func (s *batch) largeUpdate(_ []*graph.Node) error { } func (s *batch) UpdateNodes(nodes []*graph.Node) error { - if len(nodes) > LargeUpdateThreshold { - return s.largeUpdate(nodes) - } else { - for _, node := range nodes { + for _, node := range nodes { - s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, node) - if err := s.tryFlush(s.batchWriteSize); err != nil { - return err - } + s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, node) + if err := s.tryFlush(s.batchWriteSize); err != nil { + return err } } From dde02f3773dbb86bb51e0bbce7b2ef7cfaa69758 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Wed, 4 Mar 2026 11:04:47 -0600 Subject: [PATCH 05/14] update --- drivers/pg/batch.go | 105 +++++++++++++++++++++++++++++++++++-- drivers/pg/query/format.go | 52 ++++++++++-------- 2 files changed, 131 insertions(+), 26 deletions(-) diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index 54e8394..37e2848 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/drivers/pg/model" @@ -17,7 +18,7 @@ import ( ) const ( - LargeUpdateThreshold = 1_000_000 + LargeNodeUpdateThreshold = 1_000_000 ) type Int2ArrayEncoder struct { @@ -94,12 +95,110 @@ func (s *batch) UpdateNodeBy(update graph.NodeUpdate) error { return s.tryFlush(s.batchWriteSize) } -// TODO: test COPY ... FROM ... with MERGE INTO ... -func (s *batch) largeUpdate(_ []*graph.Node) error { +// largeUpdate performs a bulk node update using PostgreSQL's COPY FROM to stream +// nodes into a temporary staging table and then MERGE INTO the live node partition. +// This path is more efficient than a parameterised UPDATE for very large batches +// (see LargeUpdateThreshold). +func (s *batch) largeUpdate(nodes []*graph.Node) error { + tx, err := s.innerTransaction.conn.Begin(s.ctx) + if err != nil { + return err + } + + if _, err := tx.Exec(s.ctx, sql.FormatCreateNodeUpdateStagingTable(sql.NodeUpdateStagingTable)); err != nil { + return fmt.Errorf("creating node update staging table: %w", err) + } + + nodeRows := NewLargeNodeUpdateRows(len(nodes)) + if err := nodeRows.AppendAll(s.ctx, nodes, s.schemaManager, s.kindIDEncoder); err != nil { + return err + } + + // Stream the rows into the staging table via COPY FROM. + if _, err := tx.Conn().CopyFrom( + s.ctx, + pgx.Identifier{sql.NodeUpdateStagingTable}, + sql.NodeUpdateStagingColumns, + pgx.CopyFromRows(nodeRows.Rows()), + ); err != nil { + return fmt.Errorf("copying nodes into staging table: %w", err) + } + + graphTarget, err := s.innerTransaction.getTargetGraph() + if err != nil { + return err + } + + if _, err := tx.Exec(s.ctx, sql.FormatMergeNodeLargeUpdate(graphTarget, sql.NodeUpdateStagingTable)); err != nil { + return fmt.Errorf("merging node updates from staging table: %w", err) + } + + if err := tx.Commit(s.ctx); err != nil { + return err + } + + return nil +} + +// LargeNodeUpdateRows accumulates encoded node rows for bulk loading via COPY FROM. +// The column order matches sql.NodeUpdateStagingColumns. +type LargeNodeUpdateRows struct { + rows [][]any +} + +func NewLargeNodeUpdateRows(size int) *LargeNodeUpdateRows { + return &LargeNodeUpdateRows{ + rows: make([][]any, 0, size), + } +} + +func (s *LargeNodeUpdateRows) Rows() [][]any { + return s.rows +} + +func (s *LargeNodeUpdateRows) Append(ctx context.Context, node *graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + addedKindIDs, err := schemaManager.AssertKinds(ctx, node.Kinds) + if err != nil { + return fmt.Errorf("mapping added kinds for node %d: %w", node.ID, err) + } + + deletedKindIDs, err := schemaManager.AssertKinds(ctx, node.DeletedKinds) + if err != nil { + return fmt.Errorf("mapping deleted kinds for node %d: %w", node.ID, err) + } + + propertiesJSONB, err := pgsql.PropertiesToJSONB(node.Properties) + if err != nil { + return fmt.Errorf("encoding properties for node %d: %w", node.ID, err) + } + + s.rows = append(s.rows, []any{ + node.ID.Int64(), + kindIDEncoder.Encode(addedKindIDs), + kindIDEncoder.Encode(deletedKindIDs), + string(propertiesJSONB.Bytes), + pgsql.DeletedPropertiesToString(node.Properties), + }) + + return nil +} + +// AppendAll encodes every node in the slice and appends its row to the accumulator. +func (s *LargeNodeUpdateRows) AppendAll(ctx context.Context, nodes []*graph.Node, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + for _, node := range nodes { + if err := s.Append(ctx, node, schemaManager, kindIDEncoder); err != nil { + return err + } + } + return nil } func (s *batch) UpdateNodes(nodes []*graph.Node) error { + if len(nodes) > LargeNodeUpdateThreshold { + return s.largeUpdate(nodes) + } + for _, node := range nodes { s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, node) diff --git a/drivers/pg/query/format.go b/drivers/pg/query/format.go index a722871..bc181e9 100644 --- a/drivers/pg/query/format.go +++ b/drivers/pg/query/format.go @@ -102,28 +102,6 @@ func formatConflictMatcher(propertyNames []string, defaultOnConflict string) str } func FormatNodesUpdate(graphTarget model.Graph) string { - - /* - TODO: clean up - - update node_1 as n - set - kind_ids = sort (uniq (kind_ids - u.deleted_kinds || u.added_kinds)), - properties = n.properties - u.deleted_properties || u.properties - from - ( - select - unnest(:IDS::text[])::int8 as id, - unnest(:KINDS::text[])::int2[] as added_kinds, - unnest(:DELETED_KINDS::text[])::int2[] as deleted_kinds, - unnest(:PROPERTIES::jsonb[]) as properties, - unnest(:DELETED_PROPERTIES::text[])::text[] as deleted_properties - ) as u - where - n.id = u.id; - - */ - return join( "update ", graphTarget.Partitions.Node.Name, " as n ", "set ", @@ -136,7 +114,35 @@ func FormatNodesUpdate(graphTarget model.Graph) string { " unnest($3::text[])::int2[] as deleted_kinds, ", " unnest($4::jsonb[]) as properties, ", " unnest($5::text[])::text[] as deleted_properties) as u ", - "where n.id = u.id ", + "where n.id = u.id; ", + ) +} + +// NodeUpdateStagingTable is the name of the temporary staging table used by largeUpdate. +const NodeUpdateStagingTable = "node_update_staging" + +// NodeUpdateStagingColumns lists the columns (in order) written by a COPY FROM during largeUpdate. +var NodeUpdateStagingColumns = []string{"id", "added_kinds", "deleted_kinds", "properties", "deleted_props"} + +func FormatCreateNodeUpdateStagingTable(stagingTable string) string { + return join( + "create temp table if not exists ", stagingTable, " (", + "id bigint, ", + "added_kinds text, ", + "deleted_kinds text, ", + "properties text, ", + "deleted_props text", + ") on commit drop;", + ) +} + +func FormatMergeNodeLargeUpdate(graphTarget model.Graph, stagingTable string) string { + return join( + "merge into ", graphTarget.Partitions.Node.Name, " as n ", + "using ", stagingTable, " as u on n.id = u.id ", + "when matched then update set ", + "kind_ids = uniq(sort(n.kind_ids - u.deleted_kinds::int2[] || u.added_kinds::int2[])), ", + "properties = n.properties - u.deleted_props::text[] || u.properties::jsonb;", ) } From 444edaf779e3546d7e831fda64500c1145b4c33d Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Wed, 4 Mar 2026 11:49:35 -0600 Subject: [PATCH 06/14] update --- ops/ops.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ops/ops.go b/ops/ops.go index 1831e16..a71902e 100644 --- a/ops/ops.go +++ b/ops/ops.go @@ -606,7 +606,7 @@ This should be the default tool for updating a collection of nodes from memory. func UpdateNodes(ctx context.Context, graphDB graph.Database, nodes []*graph.Node, batchSize ...int) error { options := func(_ *graph.BatchConfig) {} - if batchSize != nil { + if len(batchSize) > 0 { options = graph.WithBatchSize(batchSize[0]) } From 130d20aa012c7bdca208a48cb989c512d6d48f31 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Wed, 4 Mar 2026 16:52:13 -0600 Subject: [PATCH 07/14] added unit tests --- cypher/models/pgsql/type_test.go | 80 +++++++++++++++++++++++++++ drivers/pg/query/format_test.go | 93 ++++++++++++++++++++++++++++++++ graph/graph_test.go | 43 +++++++++++++++ graph/node_test.go | 65 ++++++++++++++++++++++ 4 files changed, 281 insertions(+) create mode 100644 cypher/models/pgsql/type_test.go create mode 100644 drivers/pg/query/format_test.go create mode 100644 graph/graph_test.go diff --git a/cypher/models/pgsql/type_test.go b/cypher/models/pgsql/type_test.go new file mode 100644 index 0000000..0b82bde --- /dev/null +++ b/cypher/models/pgsql/type_test.go @@ -0,0 +1,80 @@ +package pgsql_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/pgsql" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" +) + +func TestDeletedPropertiesToString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func() *graph.Properties + check func(t *testing.T, result string) + }{ + { + name: "no deleted properties returns empty braces", + setup: func() *graph.Properties { + return graph.NewProperties() + }, + check: func(t *testing.T, result string) { + assert.Equal(t, "{}", result) + }, + }, + { + name: "single deleted property is wrapped in braces", + setup: func() *graph.Properties { + props := graph.NewProperties() + props.Set("mykey", "myvalue") + props.Delete("mykey") + return props + }, + check: func(t *testing.T, result string) { + assert.Equal(t, "{mykey}", result) + }, + }, + { + name: "multiple deleted properties are all present in output", + setup: func() *graph.Properties { + props := graph.NewProperties() + props.Set("alpha", 1) + props.Set("beta", 2) + props.Delete("alpha") + props.Delete("beta") + return props + }, + check: func(t *testing.T, result string) { + // Map iteration order is non-deterministic; accept either ordering. + assert.True(t, result == "{alpha,beta}" || result == "{beta,alpha}", + "unexpected result: %s", result) + }, + }, + { + name: "non-deleted properties are not included", + setup: func() *graph.Properties { + props := graph.NewProperties() + props.Set("active", "yes") + props.Set("removed", "no") + props.Delete("removed") + return props + }, + check: func(t *testing.T, result string) { + assert.Equal(t, "{removed}", result) + assert.NotContains(t, result, "active") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + props := tc.setup() + result := pgsql.DeletedPropertiesToString(props) + tc.check(t, result) + }) + } +} diff --git a/drivers/pg/query/format_test.go b/drivers/pg/query/format_test.go new file mode 100644 index 0000000..2092182 --- /dev/null +++ b/drivers/pg/query/format_test.go @@ -0,0 +1,93 @@ +package query_test + +import ( + "strings" + "testing" + + "github.com/specterops/dawgs/drivers/pg/model" + query "github.com/specterops/dawgs/drivers/pg/query" + "github.com/stretchr/testify/assert" +) + +func generateTestGraphTarget(nodePartitionName string) model.Graph { + return model.Graph{ + Partitions: model.GraphPartitions{ + Node: model.NewGraphPartition(nodePartitionName), + }, + } +} + +func TestFormatNodesUpdate(t *testing.T) { + t.Parallel() + + var ( + partitionName = "node_1" + expected = strings.Join([]string{ + "update node_1 as n ", + "set ", + " kind_ids = uniq(sort(kind_ids - u.deleted_kinds || u.added_kinds)), ", + " properties = n.properties - u.deleted_properties || u.properties ", + "from ", + " (select ", + " unnest($1::text[])::int8 as id, ", + " unnest($2::text[])::int2[] as added_kinds, ", + " unnest($3::text[])::int2[] as deleted_kinds, ", + " unnest($4::jsonb[]) as properties, ", + " unnest($5::text[])::text[] as deleted_properties) as u ", + "where n.id = u.id; ", + }, "") + result = query.FormatNodesUpdate(generateTestGraphTarget(partitionName)) + ) + + assert.Equal(t, expected, result) +} + +func TestFormatCreateNodeUpdateStagingTable(t *testing.T) { + t.Parallel() + + var ( + tableName = "my_staging_table" + expected = strings.Join([]string{ + "create temp table if not exists my_staging_table (", + "id bigint, ", + "added_kinds text, ", + "deleted_kinds text, ", + "properties text, ", + "deleted_props text", + ") on commit drop;", + }, "") + result = query.FormatCreateNodeUpdateStagingTable(tableName) + ) + + assert.Equal(t, expected, result) +} + +func TestFormatMergeNodeLargeUpdate(t *testing.T) { + t.Parallel() + + var ( + partitionName = "node_part_1" + stagingTable = "my_staging" + expected = strings.Join([]string{ + "merge into node_part_1 as n ", + "using my_staging as u on n.id = u.id ", + "when matched then update set ", + "kind_ids = uniq(sort(n.kind_ids - u.deleted_kinds::int2[] || u.added_kinds::int2[])), ", + "properties = n.properties - u.deleted_props::text[] || u.properties::jsonb;", + }, "") + result = query.FormatMergeNodeLargeUpdate(generateTestGraphTarget(partitionName), stagingTable) + ) + + assert.Equal(t, expected, result) +} + +func TestNodeUpdateStagingColumns(t *testing.T) { + t.Parallel() + + var ( + // Note: order is important + expected = []string{"id", "added_kinds", "deleted_kinds", "properties", "deleted_props"} + ) + + assert.Equal(t, expected, query.NodeUpdateStagingColumns) +} diff --git a/graph/graph_test.go b/graph/graph_test.go new file mode 100644 index 0000000..b9691f0 --- /dev/null +++ b/graph/graph_test.go @@ -0,0 +1,43 @@ +package graph + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithBatchSize(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + batchSize int + expectedSize int + }{ + { + name: "sets batch size to provided value", + batchSize: 500, + expectedSize: 500, + }, + { + name: "sets batch size to zero", + batchSize: 0, + expectedSize: 0, + }, + { + name: "sets batch size to large value", + batchSize: 1_000_000, + expectedSize: 1_000_000, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + config := &BatchConfig{} + opt := WithBatchSize(tc.batchSize) + opt(config) + assert.Equal(t, tc.expectedSize, config.BatchSize) + }) + } +} diff --git a/graph/node_test.go b/graph/node_test.go index 5fccad3..73c25f5 100644 --- a/graph/node_test.go +++ b/graph/node_test.go @@ -4,9 +4,74 @@ import ( "testing" "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func Test_StripAllPropertiesExcept(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func() *graph.Node + except []string + assertions func(t *testing.T, node *graph.Node) + }{ + { + name: "keeps only specified properties that exist", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("keep1", "value1") + node.Properties.Set("keep2", "value2") + node.Properties.Set("remove", "value3") + return node + }, + except: []string{"keep1", "keep2"}, + assertions: func(t *testing.T, node *graph.Node) { + assert.True(t, node.Properties.Exists("keep1")) + assert.Equal(t, "value1", node.Properties.Get("keep1").Any()) + assert.True(t, node.Properties.Exists("keep2")) + assert.Equal(t, "value2", node.Properties.Get("keep2").Any()) + assert.False(t, node.Properties.Exists("remove")) + }, + }, + { + name: "non-existent except keys are silently skipped", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("keep", "value") + return node + }, + except: []string{"keep", "nonexistent"}, + assertions: func(t *testing.T, node *graph.Node) { + assert.True(t, node.Properties.Exists("keep")) + assert.Equal(t, 1, node.Properties.Len()) + }, + }, + { + name: "no args clears all properties", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("key", "value") + return node + }, + except: nil, + assertions: func(t *testing.T, node *graph.Node) { + assert.Equal(t, 0, node.Properties.Len()) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + node := tc.setup() + node.StripAllPropertiesExcept(tc.except...) + tc.assertions(t, node) + }) + } +} + func Test_NodeSizeOf(t *testing.T) { node := graph.Node{ID: graph.ID(1)} oldSize := int64(node.SizeOf()) From 482e1c45c71bcbc175e6ecf17707520ae1d1845b Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Wed, 4 Mar 2026 17:17:57 -0600 Subject: [PATCH 08/14] change StripAllPropertiesExcept to respect deleted properties; simplify implementation --- graph/node.go | 16 +++++++--------- graph/node_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/graph/node.go b/graph/node.go index 61f8ca9..ae44ded 100644 --- a/graph/node.go +++ b/graph/node.go @@ -113,21 +113,19 @@ func (s *Node) MarshalJSON() ([]byte, error) { } func (s *Node) StripAllPropertiesExcept(except ...string) { - tmp := make([]any, 0, len(except)) - found := make([]string, 0, len(except)) + newProperties := NewProperties() for _, exclusion := range except { if s.Properties.Exists(exclusion) { - found = append(found, exclusion) - tmp = append(tmp, s.Properties.Get(exclusion).Any()) + newProperties.Set(exclusion, s.Properties.Get(exclusion).Any()) } - } - - s.Properties = NewProperties() - for i, key := range found { - s.Properties.Set(key, tmp[i]) + if _, present := s.Properties.Deleted[exclusion]; present { + newProperties.Delete(exclusion) + } } + + s.Properties = newProperties } // NodeSet is a mapped index of Node instances and their ID fields. diff --git a/graph/node_test.go b/graph/node_test.go index 73c25f5..a617364 100644 --- a/graph/node_test.go +++ b/graph/node_test.go @@ -1,6 +1,7 @@ package graph_test import ( + "slices" "testing" "github.com/specterops/dawgs/graph" @@ -60,6 +61,42 @@ func Test_StripAllPropertiesExcept(t *testing.T) { assert.Equal(t, 0, node.Properties.Len()) }, }, + { + name: "respects deleted properties", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Delete("key") + return node + }, + except: []string{"key"}, + assertions: func(t *testing.T, node *graph.Node) { + deleted := node.Properties.DeletedProperties() + assert.Equal(t, 1, len(deleted)) + assert.True(t, slices.Contains(deleted, "key")) + }, + }, + { + name: "respects deleted properties 2", + setup: func() *graph.Node { + node := graph.NewNode(graph.ID(1), graph.NewProperties()) + node.Properties.Set("hello", "value1") + node.Properties.Set("loremipsum", "value2") + node.Properties.Delete("key1") + node.Properties.Delete("key2") + node.Properties.Delete("key3") + return node + }, + except: []string{"hello", "world", "key1", "key2"}, + assertions: func(t *testing.T, node *graph.Node) { + deleted := node.Properties.DeletedProperties() + assert.Equal(t, 2, len(deleted)) + assert.True(t, slices.Contains(deleted, "key1")) + assert.True(t, slices.Contains(deleted, "key2")) + assert.False(t, node.Properties.Exists("loremipsum")) + assert.True(t, node.Properties.Exists("hello")) + assert.Equal(t, "value1", node.Properties.Get("hello").Any()) + }, + }, } for _, tc := range tests { From 2da2db990320710c551ee42a5a022d99548a59ce Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Wed, 4 Mar 2026 17:19:05 -0600 Subject: [PATCH 09/14] update comment --- graph/node.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graph/node.go b/graph/node.go index ae44ded..1f9f6e3 100644 --- a/graph/node.go +++ b/graph/node.go @@ -112,6 +112,10 @@ func (s *Node) MarshalJSON() ([]byte, error) { return json.Marshal(jsonNode) } +/* +StripAllPropertiesExcept removes all properties from the node except for the ones specified in the except list. +Deleted properties are also removed from the node, except for the ones specified in the except list. +*/ func (s *Node) StripAllPropertiesExcept(except ...string) { newProperties := NewProperties() From ea70c59c5460035ea7f41b7048da85c7def1bd13 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Wed, 4 Mar 2026 17:26:44 -0600 Subject: [PATCH 10/14] update comment --- graph/node.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graph/node.go b/graph/node.go index 1f9f6e3..de40c8a 100644 --- a/graph/node.go +++ b/graph/node.go @@ -115,6 +115,8 @@ func (s *Node) MarshalJSON() ([]byte, error) { /* StripAllPropertiesExcept removes all properties from the node except for the ones specified in the except list. Deleted properties are also removed from the node, except for the ones specified in the except list. +The use case for this function is if you have fully hydrated nodes in memory, +but only want to update a few properties it is most efficient to strip all properties except for the ones you want to update. */ func (s *Node) StripAllPropertiesExcept(except ...string) { newProperties := NewProperties() From 14010b8a42a6a36079973003bcdab51f23817024 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 12 Mar 2026 11:36:10 -0700 Subject: [PATCH 11/14] feat: neo4j batched node updates --- drivers/neo4j/batch.go | 31 ++++++++-- drivers/neo4j/cypher.go | 115 ++++++++++++++++++++++++++++++++++- drivers/neo4j/transaction.go | 17 +++++- 3 files changed, 156 insertions(+), 7 deletions(-) diff --git a/drivers/neo4j/batch.go b/drivers/neo4j/batch.go index 276eecf..420cb91 100644 --- a/drivers/neo4j/batch.go +++ b/drivers/neo4j/batch.go @@ -20,6 +20,7 @@ type batchTransaction struct { innerTx *neo4jTransaction nodeDeletionBuffer []graph.ID relationshipDeletionBuffer []graph.ID + nodeUpdateBuffer []*graph.Node nodeUpdateByBuffer []graph.NodeUpdate relationshipCreateBuffer []createRelationshipByIDs relationshipUpdateByBuffer []graph.RelationshipUpdate @@ -48,15 +49,22 @@ func (s *batchTransaction) Relationships() graph.RelationshipQuery { } func (s *batchTransaction) UpdateNodeBy(update graph.NodeUpdate) error { - if s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update); len(s.nodeUpdateByBuffer) >= s.batchWriteSize { - return s.flushNodeUpdates() + s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update) + + if len(s.nodeUpdateByBuffer) >= s.batchWriteSize { + return s.flushNodeUpdateByBuffer() } return nil } func (s *batchTransaction) UpdateNodes(nodes []*graph.Node) error { - panic("unimplemented") + s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, nodes...) + + if len(s.nodeUpdateBuffer) > s.batchWriteSize { + return s.flushNodeUpdateBuffer() + } + return nil } @@ -78,7 +86,13 @@ func (s *batchTransaction) DeleteRelationships(ids []graph.ID) error { func (s *batchTransaction) Commit() error { if len(s.nodeUpdateByBuffer) > 0 { - if err := s.flushNodeUpdates(); err != nil { + if err := s.flushNodeUpdateByBuffer(); err != nil { + return err + } + } + + if len(s.nodeUpdateBuffer) > 0 { + if err := s.flushNodeUpdateBuffer(); err != nil { return err } } @@ -226,13 +240,20 @@ func (s *batchTransaction) flushRelationshipDeletions() error { return s.DeleteRelationships(buffer) } -func (s *batchTransaction) flushNodeUpdates() error { +func (s *batchTransaction) flushNodeUpdateByBuffer() error { buffer := s.nodeUpdateByBuffer s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0] return s.innerTx.updateNodesBy(buffer...) } +func (s *batchTransaction) flushNodeUpdateBuffer() error { + buffer := s.nodeUpdateBuffer + s.nodeUpdateBuffer = s.nodeUpdateBuffer[:0] + + return s.innerTx.updateNodeBatch(buffer) +} + func (s *batchTransaction) flushRelationshipUpdates() error { buffer := s.relationshipUpdateByBuffer s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0] diff --git a/drivers/neo4j/cypher.go b/drivers/neo4j/cypher.go index 378fc24..158ad24 100644 --- a/drivers/neo4j/cypher.go +++ b/drivers/neo4j/cypher.go @@ -4,9 +4,12 @@ import ( "bytes" "fmt" "log/slog" + "maps" + "slices" "sort" "strings" + "github.com/cespare/xxhash/v2" "github.com/specterops/dawgs/cypher/frontend" "github.com/specterops/dawgs/cypher/models/cypher/format" "github.com/specterops/dawgs/graph" @@ -228,7 +231,117 @@ func (s nodeUpdateByMap) add(update graph.NodeUpdate) { } } -func cypherBuildNodeUpdateQueryBatch(updates []graph.NodeUpdate) ([]string, []map[string]any) { +func nodeToNodeUpdateKey(digester *xxhash.Digest, node *graph.Node) uint64 { + digester.Reset() + + var ( + kindSet = map[string]struct{}{} + digestKinds = func() { + sortedKinds := make([]string, 0, len(kindSet)) + + for nextKind := range maps.Keys(kindSet) { + sortedKinds = append(sortedKinds, nextKind) + } + + slices.Sort(sortedKinds) + + for _, nextKind := range sortedKinds { + digester.WriteString(nextKind) + } + + clear(kindSet) + } + ) + + for _, addedKindStr := range node.AddedKinds.Strings() { + kindSet[addedKindStr] = struct{}{} + } + + digestKinds() + + for _, removedKindStr := range node.AddedKinds.Strings() { + kindSet[removedKindStr] = struct{}{} + } + + digestKinds() + + return digester.Sum64() +} + +type nodeUpdateBatch struct { + nodeKindsToAdd graph.Kinds + nodeKindsToRemove graph.Kinds + Parameters []map[string]any +} + +func cypherBuildNodeUpdateQueryBatch(updates []*graph.Node) ([]string, []map[string]any) { + var ( + queries []string + queryParameters []map[string]any + + output = strings.Builder{} + batchedUpdates = map[uint64]*nodeUpdateBatch{} + digester = xxhash.New() + ) + + for _, nodeToUpdate := range updates { + updateKey := nodeToNodeUpdateKey(digester, nodeToUpdate) + + if existingBatch, hasBatch := batchedUpdates[updateKey]; hasBatch { + existingBatch.Parameters = append(existingBatch.Parameters, map[string]any{ + "id": nodeToUpdate.ID, + "properties": nodeToUpdate.Properties, + }) + } else { + batchedUpdates[updateKey] = &nodeUpdateBatch{ + nodeKindsToAdd: nodeToUpdate.AddedKinds, + nodeKindsToRemove: nodeToUpdate.DeletedKinds, + Parameters: []map[string]any{{ + "id": nodeToUpdate.ID, + "properties": nodeToUpdate.Properties, + }}, + } + } + } + + for _, batch := range batchedUpdates { + output.WriteString("unwind $p as p update (n) where id(n) = p.id set n += p.properties") + + if len(batch.nodeKindsToAdd) > 0 { + for _, kindToAdd := range batch.nodeKindsToAdd { + output.WriteString(", n:") + output.WriteString(kindToAdd.String()) + } + } + + if len(batch.nodeKindsToRemove) > 0 { + output.WriteString(" remove ") + + for idx, kindToRemove := range batch.nodeKindsToRemove { + if idx > 0 { + output.WriteString(",") + } + + output.WriteString("n:") + output.WriteString(kindToRemove.String()) + } + } + + output.WriteString(";") + + // Write out the query to be run + queries = append(queries, output.String()) + queryParameters = append(queryParameters, map[string]any{ + "p": batch.Parameters, + }) + + output.Reset() + } + + return queries, queryParameters +} + +func cypherBuildNodeUpdateQueryByBatch(updates []graph.NodeUpdate) ([]string, []map[string]any) { var ( queries []string queryParameters []map[string]any diff --git a/drivers/neo4j/transaction.go b/drivers/neo4j/transaction.go index dd2f1a9..c5307c3 100644 --- a/drivers/neo4j/transaction.go +++ b/drivers/neo4j/transaction.go @@ -93,10 +93,25 @@ func (s *neo4jTransaction) UpdateRelationshipBy(update graph.RelationshipUpdate) return s.updateRelationshipsBy(update) } +func (s *neo4jTransaction) updateNodeBatch(batch []*graph.Node) error { + var ( + numUpdates = len(batch) + statements, queryParameterMaps = cypherBuildNodeUpdateQueryBatch(batch) + ) + + for parameterIdx, stmt := range statements { + if result := s.Raw(stmt, queryParameterMaps[parameterIdx]); result.Error() != nil { + return fmt.Errorf("update nodes by error on statement (%s): %s", stmt, result.Error()) + } + } + + return s.logWrites(numUpdates) +} + func (s *neo4jTransaction) updateNodesBy(updates ...graph.NodeUpdate) error { var ( numUpdates = len(updates) - statements, queryParameterMaps = cypherBuildNodeUpdateQueryBatch(updates) + statements, queryParameterMaps = cypherBuildNodeUpdateQueryByBatch(updates) ) for parameterIdx, stmt := range statements { From 4840ed8817866c38dc4867819623ca7be4694fc3 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Thu, 12 Mar 2026 15:35:21 -0500 Subject: [PATCH 12/14] check rows.Err() --- drivers/pg/batch.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index 37e2848..4eaeed2 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -339,7 +339,6 @@ func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { if rows, err := s.innerTransaction.conn.Query(s.ctx, query, parameters.Format(graphTarget)...); err != nil { return err } else { - // TODO: rows.Err() is never called, silently swallowing errors defer rows.Close() idFutureIndex := 0 @@ -351,6 +350,10 @@ func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { idFutureIndex++ } + + if rows.Err() != nil { + return rows.Err() + } } } From 60f7db708ab0cf89c6e25ffae524c10c8bf9d172 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Thu, 12 Mar 2026 15:37:25 -0500 Subject: [PATCH 13/14] update graph test package --- graph/graph_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/graph/graph_test.go b/graph/graph_test.go index b9691f0..5408400 100644 --- a/graph/graph_test.go +++ b/graph/graph_test.go @@ -1,8 +1,9 @@ -package graph +package graph_test import ( "testing" + "github.com/specterops/dawgs/graph" "github.com/stretchr/testify/assert" ) @@ -34,8 +35,8 @@ func TestWithBatchSize(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - config := &BatchConfig{} - opt := WithBatchSize(tc.batchSize) + config := &graph.BatchConfig{} + opt := graph.WithBatchSize(tc.batchSize) opt(config) assert.Equal(t, tc.expectedSize, config.BatchSize) }) From 9cb32252e2bebd3435f39914fd0ebaf0346409c0 Mon Sep 17 00:00:00 2001 From: Benjamin Sheth Date: Thu, 12 Mar 2026 15:44:51 -0500 Subject: [PATCH 14/14] update --- drivers/pg/batch.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index 4eaeed2..c37b110 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -351,8 +351,8 @@ func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { idFutureIndex++ } - if rows.Err() != nil { - return rows.Err() + if err := rows.Err(); err != nil { + return err } } }