Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cypher/models/pgsql/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgsql
import (
"bytes"
"encoding/json"
"strings"

"reflect"

Expand Down Expand Up @@ -55,6 +56,10 @@ func PropertiesToJSONB(properties *graph.Properties) (pgtype.JSONB, error) {
return MapStringAnyToJSONB(properties.MapOrEmpty())
}

func DeletedPropertiesToString(properties *graph.Properties) string {
return "{" + strings.Join(properties.DeletedProperties(), ",") + "}"
}

func JSONBToProperties(jsonb pgtype.JSONB) (*graph.Properties, error) {
propertiesMap := make(map[string]any)

Expand Down
80 changes: 80 additions & 0 deletions cypher/models/pgsql/type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package pgsql_test

import (
"testing"

"github.com/specterops/dawgs/cypher/models/pgsql"
"github.com/specterops/dawgs/graph"
"github.com/stretchr/testify/assert"
)

func TestDeletedPropertiesToString(t *testing.T) {
t.Parallel()

tests := []struct {
name string
setup func() *graph.Properties
check func(t *testing.T, result string)
}{
{
name: "no deleted properties returns empty braces",
setup: func() *graph.Properties {
return graph.NewProperties()
},
check: func(t *testing.T, result string) {
assert.Equal(t, "{}", result)
},
},
{
name: "single deleted property is wrapped in braces",
setup: func() *graph.Properties {
props := graph.NewProperties()
props.Set("mykey", "myvalue")
props.Delete("mykey")
return props
},
check: func(t *testing.T, result string) {
assert.Equal(t, "{mykey}", result)
},
},
{
name: "multiple deleted properties are all present in output",
setup: func() *graph.Properties {
props := graph.NewProperties()
props.Set("alpha", 1)
props.Set("beta", 2)
props.Delete("alpha")
props.Delete("beta")
return props
},
check: func(t *testing.T, result string) {
// Map iteration order is non-deterministic; accept either ordering.
assert.True(t, result == "{alpha,beta}" || result == "{beta,alpha}",
"unexpected result: %s", result)
},
},
{
name: "non-deleted properties are not included",
setup: func() *graph.Properties {
props := graph.NewProperties()
props.Set("active", "yes")
props.Set("removed", "no")
props.Delete("removed")
return props
},
check: func(t *testing.T, result string) {
assert.Equal(t, "{removed}", result)
assert.NotContains(t, result, "active")
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
props := tc.setup()
result := pgsql.DeletedPropertiesToString(props)
tc.check(t, result)
})
}
}
34 changes: 30 additions & 4 deletions drivers/neo4j/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type batchTransaction struct {
innerTx *neo4jTransaction
nodeDeletionBuffer []graph.ID
relationshipDeletionBuffer []graph.ID
nodeUpdateBuffer []*graph.Node
nodeUpdateByBuffer []graph.NodeUpdate
relationshipCreateBuffer []createRelationshipByIDs
relationshipUpdateByBuffer []graph.RelationshipUpdate
Expand Down Expand Up @@ -48,8 +49,20 @@ func (s *batchTransaction) Relationships() graph.RelationshipQuery {
}

func (s *batchTransaction) UpdateNodeBy(update graph.NodeUpdate) error {
if s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update); len(s.nodeUpdateByBuffer) >= s.batchWriteSize {
return s.flushNodeUpdates()
s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update)

if len(s.nodeUpdateByBuffer) >= s.batchWriteSize {
return s.flushNodeUpdateByBuffer()
}

return nil
}

func (s *batchTransaction) UpdateNodes(nodes []*graph.Node) error {
s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, nodes...)

if len(s.nodeUpdateBuffer) > s.batchWriteSize {
return s.flushNodeUpdateBuffer()
}

return nil
Expand All @@ -73,7 +86,13 @@ func (s *batchTransaction) DeleteRelationships(ids []graph.ID) error {

func (s *batchTransaction) Commit() error {
if len(s.nodeUpdateByBuffer) > 0 {
if err := s.flushNodeUpdates(); err != nil {
if err := s.flushNodeUpdateByBuffer(); err != nil {
return err
}
}

if len(s.nodeUpdateBuffer) > 0 {
if err := s.flushNodeUpdateBuffer(); err != nil {
return err
}
}
Expand Down Expand Up @@ -221,13 +240,20 @@ func (s *batchTransaction) flushRelationshipDeletions() error {
return s.DeleteRelationships(buffer)
}

func (s *batchTransaction) flushNodeUpdates() error {
func (s *batchTransaction) flushNodeUpdateByBuffer() error {
buffer := s.nodeUpdateByBuffer
s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0]

return s.innerTx.updateNodesBy(buffer...)
}

func (s *batchTransaction) flushNodeUpdateBuffer() error {
buffer := s.nodeUpdateBuffer
s.nodeUpdateBuffer = s.nodeUpdateBuffer[:0]

return s.innerTx.updateNodeBatch(buffer)
}

func (s *batchTransaction) flushRelationshipUpdates() error {
buffer := s.relationshipUpdateByBuffer
s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0]
Expand Down
115 changes: 114 additions & 1 deletion drivers/neo4j/cypher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"bytes"
"fmt"
"log/slog"
"maps"
"slices"
"sort"
"strings"

"github.com/cespare/xxhash/v2"
"github.com/specterops/dawgs/cypher/frontend"
"github.com/specterops/dawgs/cypher/models/cypher/format"
"github.com/specterops/dawgs/graph"
Expand Down Expand Up @@ -228,7 +231,117 @@ func (s nodeUpdateByMap) add(update graph.NodeUpdate) {
}
}

func cypherBuildNodeUpdateQueryBatch(updates []graph.NodeUpdate) ([]string, []map[string]any) {
func nodeToNodeUpdateKey(digester *xxhash.Digest, node *graph.Node) uint64 {
digester.Reset()

var (
kindSet = map[string]struct{}{}
digestKinds = func() {
sortedKinds := make([]string, 0, len(kindSet))

for nextKind := range maps.Keys(kindSet) {
sortedKinds = append(sortedKinds, nextKind)
}

slices.Sort(sortedKinds)

for _, nextKind := range sortedKinds {
digester.WriteString(nextKind)
}

clear(kindSet)
}
)

for _, addedKindStr := range node.AddedKinds.Strings() {
kindSet[addedKindStr] = struct{}{}
}

digestKinds()

for _, removedKindStr := range node.AddedKinds.Strings() {
kindSet[removedKindStr] = struct{}{}
}

digestKinds()

return digester.Sum64()
}

type nodeUpdateBatch struct {
nodeKindsToAdd graph.Kinds
nodeKindsToRemove graph.Kinds
Parameters []map[string]any
}

func cypherBuildNodeUpdateQueryBatch(updates []*graph.Node) ([]string, []map[string]any) {
var (
queries []string
queryParameters []map[string]any

output = strings.Builder{}
batchedUpdates = map[uint64]*nodeUpdateBatch{}
digester = xxhash.New()
)

for _, nodeToUpdate := range updates {
updateKey := nodeToNodeUpdateKey(digester, nodeToUpdate)

if existingBatch, hasBatch := batchedUpdates[updateKey]; hasBatch {
existingBatch.Parameters = append(existingBatch.Parameters, map[string]any{
"id": nodeToUpdate.ID,
"properties": nodeToUpdate.Properties,
})
} else {
batchedUpdates[updateKey] = &nodeUpdateBatch{
nodeKindsToAdd: nodeToUpdate.AddedKinds,
nodeKindsToRemove: nodeToUpdate.DeletedKinds,
Parameters: []map[string]any{{
"id": nodeToUpdate.ID,
"properties": nodeToUpdate.Properties,
}},
}
}
}

for _, batch := range batchedUpdates {
output.WriteString("unwind $p as p update (n) where id(n) = p.id set n += p.properties")

if len(batch.nodeKindsToAdd) > 0 {
for _, kindToAdd := range batch.nodeKindsToAdd {
output.WriteString(", n:")
output.WriteString(kindToAdd.String())
}
}

if len(batch.nodeKindsToRemove) > 0 {
output.WriteString(" remove ")

for idx, kindToRemove := range batch.nodeKindsToRemove {
if idx > 0 {
output.WriteString(",")
}

output.WriteString("n:")
output.WriteString(kindToRemove.String())
}
}

output.WriteString(";")

// Write out the query to be run
queries = append(queries, output.String())
queryParameters = append(queryParameters, map[string]any{
"p": batch.Parameters,
})

output.Reset()
}

return queries, queryParameters
}

func cypherBuildNodeUpdateQueryByBatch(updates []graph.NodeUpdate) ([]string, []map[string]any) {
var (
queries []string
queryParameters []map[string]any
Expand Down
12 changes: 10 additions & 2 deletions drivers/neo4j/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,29 @@ func (s *driver) SetWriteFlushSize(size int) {
s.writeFlushSize = size
}

func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error {
func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate, options ...graph.BatchOption) error {
// Attempt to acquire a connection slot or wait for a bit until one becomes available
if !s.limiter.Acquire(ctx) {
return graph.ErrContextTimedOut
} else {
defer s.limiter.Release()
}

config := &graph.BatchConfig{
BatchSize: s.batchWriteSize,
}

for _, opt := range options {
opt(config)
}

var (
cfg = graph.TransactionConfig{
Timeout: s.defaultTransactionTimeout,
}

session = s.driver.NewSession(writeCfg())
batch = newBatchOperation(ctx, session, cfg, s.writeFlushSize, s.batchWriteSize, s.graphQueryMemoryLimit)
batch = newBatchOperation(ctx, session, cfg, s.writeFlushSize, config.BatchSize, s.graphQueryMemoryLimit)
)

defer session.Close()
Expand Down
17 changes: 16 additions & 1 deletion drivers/neo4j/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,25 @@ func (s *neo4jTransaction) UpdateRelationshipBy(update graph.RelationshipUpdate)
return s.updateRelationshipsBy(update)
}

func (s *neo4jTransaction) updateNodeBatch(batch []*graph.Node) error {
var (
numUpdates = len(batch)
statements, queryParameterMaps = cypherBuildNodeUpdateQueryBatch(batch)
)

for parameterIdx, stmt := range statements {
if result := s.Raw(stmt, queryParameterMaps[parameterIdx]); result.Error() != nil {
return fmt.Errorf("update nodes by error on statement (%s): %s", stmt, result.Error())
}
}

return s.logWrites(numUpdates)
}

func (s *neo4jTransaction) updateNodesBy(updates ...graph.NodeUpdate) error {
var (
numUpdates = len(updates)
statements, queryParameterMaps = cypherBuildNodeUpdateQueryBatch(updates)
statements, queryParameterMaps = cypherBuildNodeUpdateQueryByBatch(updates)
)

for parameterIdx, stmt := range statements {
Expand Down
Loading
Loading