Skip to content

Commit 55359c9

Browse files
committed
feat: add mutations support
1 parent 6667833 commit 55359c9

File tree

10 files changed

+6336
-6
lines changed

10 files changed

+6336
-6
lines changed

aborted_transactions_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,51 @@ func TestCommitAborted(t *testing.T) {
6060
}
6161
}
6262

63+
func TestCommitWithMutationsAborted(t *testing.T) {
64+
t.Parallel()
65+
66+
db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=1;maxSessions=1")
67+
defer teardown()
68+
ctx := context.Background()
69+
70+
conn, err := db.Conn(ctx)
71+
if err != nil {
72+
t.Fatalf("failed to open connection: %v", err)
73+
}
74+
defer func() { _ = conn.Close() }()
75+
76+
tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
77+
if err != nil {
78+
t.Fatalf("begin failed: %v", err)
79+
}
80+
if err := conn.Raw(func(driverConn interface{}) error {
81+
spannerConn, _ := driverConn.(SpannerConn)
82+
mutation := spanner.Insert("foo", []string{}, []interface{}{})
83+
return spannerConn.BufferWrite([]*spanner.Mutation{mutation})
84+
}); err != nil {
85+
t.Fatalf("failed to buffer mutations: %v", err)
86+
}
87+
// Abort the transaction on the first commit attempt.
88+
server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{
89+
Errors: []error{status.Error(codes.Aborted, "Aborted")},
90+
})
91+
err = tx.Commit()
92+
if err != nil {
93+
t.Fatalf("commit failed: %v", err)
94+
}
95+
reqs := drainRequestsFromServer(server.TestSpanner)
96+
commitReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.CommitRequest{}))
97+
if g, w := len(commitReqs), 2; g != w {
98+
t.Fatalf("commit request count mismatch\nGot: %v\nWant: %v", g, w)
99+
}
100+
for _, req := range commitReqs {
101+
commitReq := req.(*sppb.CommitRequest)
102+
if g, w := len(commitReq.Mutations), 1; g != w {
103+
t.Fatalf("mutation count mismatch\n Got: %v\nWant: %v", g, w)
104+
}
105+
}
106+
}
107+
63108
func TestCommitAbortedWithInternalRetriesDisabled(t *testing.T) {
64109
t.Parallel()
65110

go.mod

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ go 1.24
44

55
toolchain go1.24.4
66

7+
replace cloud.google.com/go/spanner => /Users/loite/GolandProjects/google-cloud-go/spanner
8+
79
require (
810
cloud.google.com/go v0.121.2
911
cloud.google.com/go/longrunning v0.6.7
@@ -26,7 +28,7 @@ require (
2628
cloud.google.com/go/compute/metadata v0.7.0 // indirect
2729
cloud.google.com/go/iam v1.5.2 // indirect
2830
cloud.google.com/go/monitoring v1.24.2 // indirect
29-
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.2 // indirect
31+
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.3 // indirect
3032
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 // indirect
3133
github.com/cespare/xxhash/v2 v2.3.0 // indirect
3234
github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f // indirect
@@ -44,7 +46,7 @@ require (
4446
github.com/zeebo/errs v1.4.0 // indirect
4547
go.opencensus.io v0.24.0 // indirect
4648
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
47-
go.opentelemetry.io/contrib/detectors/gcp v1.35.0 // indirect
49+
go.opentelemetry.io/contrib/detectors/gcp v1.36.0 // indirect
4850
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect
4951
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
5052
go.opentelemetry.io/otel v1.36.0 // indirect
@@ -60,5 +62,5 @@ require (
6062
golang.org/x/text v0.26.0 // indirect
6163
golang.org/x/time v0.12.0 // indirect
6264
google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2 // indirect
63-
google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2 // indirect
65+
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect
6466
)

go.sum

Lines changed: 2971 additions & 0 deletions
Large diffs are not rendered by default.

spannerlib/exported/connection.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ import (
77
"fmt"
88
"sync"
99
"sync/atomic"
10+
"time"
1011

1112
"cloud.google.com/go/spanner"
1213
"cloud.google.com/go/spanner/apiv1/spannerpb"
1314
spannerdriver "github.com/googleapis/go-sql-spanner"
1415
"google.golang.org/grpc/codes"
1516
"google.golang.org/grpc/status"
1617
"google.golang.org/protobuf/proto"
18+
"google.golang.org/protobuf/types/known/timestamppb"
1719
"spannerlib/backend"
1820
)
1921

@@ -25,6 +27,18 @@ func CloseConnection(poolId, connId int64) *Message {
2527
return conn.close()
2628
}
2729

30+
func Apply(poolId, connId int64, mutationBytes []byte) *Message {
31+
mutations := spannerpb.BatchWriteRequest_MutationGroup{}
32+
if err := proto.Unmarshal(mutationBytes, &mutations); err != nil {
33+
return errMessage(err)
34+
}
35+
conn, err := findConnection(poolId, connId)
36+
if err != nil {
37+
return errMessage(err)
38+
}
39+
return conn.apply(&mutations)
40+
}
41+
2842
func BeginTransaction(poolId, connId int64, txOptsBytes []byte) *Message {
2943
txOpts := spannerpb.TransactionOptions{}
3044
if err := proto.Unmarshal(txOptsBytes, &txOpts); err != nil {
@@ -94,6 +108,34 @@ func (conn *Connection) close() *Message {
94108
return &Message{}
95109
}
96110

111+
func (conn *Connection) apply(mutation *spannerpb.BatchWriteRequest_MutationGroup) *Message {
112+
ctx := context.Background()
113+
mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations))
114+
for _, m := range mutation.Mutations {
115+
spannerMutation, err := spanner.WrapMutation(m)
116+
if err != nil {
117+
return errMessage(err)
118+
}
119+
mutations = append(mutations, spannerMutation)
120+
}
121+
var commitTimestamp time.Time
122+
if err := conn.backend.Conn.Raw(func(driverConn any) (err error) {
123+
spannerConn, _ := driverConn.(spannerdriver.SpannerConn)
124+
commitTimestamp, err = spannerConn.Apply(ctx, mutations)
125+
return err
126+
}); err != nil {
127+
return errMessage(err)
128+
}
129+
response := spannerpb.CommitResponse{
130+
CommitTimestamp: timestamppb.New(commitTimestamp),
131+
}
132+
res, err := proto.Marshal(&response)
133+
if err != nil {
134+
return errMessage(err)
135+
}
136+
return &Message{Res: res}
137+
}
138+
97139
func (conn *Connection) BeginTransaction(txOpts *spannerpb.TransactionOptions) *Message {
98140
var tx *sql.Tx
99141
var err error

spannerlib/exported/mock_server_test.go

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,243 @@ func TestDisableInternalRetries(t *testing.T) {
177177
CloseRows(pool.ObjectId, conn.ObjectId, results.ObjectId)
178178
}
179179

180+
func TestApply(t *testing.T) {
181+
t.Parallel()
182+
183+
dsn, server, teardown := setupTestDBConnection(t)
184+
defer teardown()
185+
186+
pool := CreatePool(dsn)
187+
defer ClosePool(pool.ObjectId)
188+
conn := CreateConnection(pool.ObjectId)
189+
defer CloseConnection(pool.ObjectId, conn.ObjectId)
190+
191+
mutations := sppb.BatchWriteRequest_MutationGroup{
192+
Mutations: []*sppb.Mutation{
193+
{Operation: &sppb.Mutation_Insert{Insert: &sppb.Mutation_Write{
194+
Table: "foo",
195+
Columns: []string{"id", "value"},
196+
Values: []*structpb.ListValue{
197+
{Values: []*structpb.Value{
198+
{Kind: &structpb.Value_StringValue{StringValue: "1"}},
199+
{Kind: &structpb.Value_StringValue{StringValue: "One"}},
200+
}},
201+
{Values: []*structpb.Value{
202+
{Kind: &structpb.Value_StringValue{StringValue: "2"}},
203+
{Kind: &structpb.Value_StringValue{StringValue: "Two"}},
204+
}},
205+
},
206+
}}},
207+
},
208+
}
209+
mutationsBytes, _ := proto.Marshal(&mutations)
210+
response := Apply(pool.ObjectId, conn.ObjectId, mutationsBytes)
211+
if response.Code != 0 {
212+
t.Fatalf("failed to apply mutations: %v", response.Code)
213+
}
214+
commitResponse := sppb.CommitResponse{}
215+
_ = proto.Unmarshal(response.Res, &commitResponse)
216+
if commitResponse.CommitTimestamp == nil {
217+
t.Fatal("commit timestamp missing")
218+
}
219+
220+
requests := drainRequestsFromServer(server.TestSpanner)
221+
beginRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))
222+
if g, w := len(beginRequests), 1; g != w {
223+
t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w)
224+
}
225+
req := beginRequests[0].(*sppb.BeginTransactionRequest)
226+
if req.Options == nil {
227+
t.Fatalf("missing tx opts")
228+
}
229+
if req.Options.GetReadWrite() == nil {
230+
t.Fatalf("missing tx read write")
231+
}
232+
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
233+
if g, w := len(commitRequests), 1; g != w {
234+
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
235+
}
236+
commitReq := commitRequests[0].(*sppb.CommitRequest)
237+
if g, w := len(commitReq.Mutations), 1; g != w {
238+
t.Fatalf("mutation count mismatch\n Got: %v\nWant: %v", g, w)
239+
}
240+
if g, w := len(commitReq.Mutations[0].GetInsert().Values), 2; g != w {
241+
t.Fatalf("mutation values count mismatch\n Got: %v\nWant: %v", g, w)
242+
}
243+
}
244+
245+
func TestBufferWrite(t *testing.T) {
246+
t.Parallel()
247+
248+
dsn, server, teardown := setupTestDBConnection(t)
249+
defer teardown()
250+
251+
pool := CreatePool(dsn)
252+
defer ClosePool(pool.ObjectId)
253+
conn := CreateConnection(pool.ObjectId)
254+
defer CloseConnection(pool.ObjectId, conn.ObjectId)
255+
256+
txOpts := &sppb.TransactionOptions{}
257+
txOptsBytes, _ := proto.Marshal(txOpts)
258+
tx := BeginTransaction(pool.ObjectId, conn.ObjectId, txOptsBytes)
259+
260+
mutations := sppb.BatchWriteRequest_MutationGroup{
261+
Mutations: []*sppb.Mutation{
262+
{Operation: &sppb.Mutation_Insert{Insert: &sppb.Mutation_Write{
263+
Table: "foo",
264+
Columns: []string{"id", "value"},
265+
Values: []*structpb.ListValue{
266+
{Values: []*structpb.Value{
267+
{Kind: &structpb.Value_StringValue{StringValue: "1"}},
268+
{Kind: &structpb.Value_StringValue{StringValue: "One"}},
269+
}},
270+
{Values: []*structpb.Value{
271+
{Kind: &structpb.Value_StringValue{StringValue: "2"}},
272+
{Kind: &structpb.Value_StringValue{StringValue: "Two"}},
273+
}},
274+
},
275+
}}},
276+
},
277+
}
278+
mutationsBytes, _ := proto.Marshal(&mutations)
279+
response := BufferWrite(pool.ObjectId, conn.ObjectId, tx.ObjectId, mutationsBytes)
280+
if response.Code != 0 {
281+
t.Fatalf("failed to apply mutations: %v", response.Code)
282+
}
283+
if response.Length() > 0 {
284+
t.Fatal("response length mismatch")
285+
}
286+
287+
requests := drainRequestsFromServer(server.TestSpanner)
288+
beginRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))
289+
if g, w := len(beginRequests), 1; g != w {
290+
t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w)
291+
}
292+
req := beginRequests[0].(*sppb.BeginTransactionRequest)
293+
if req.Options == nil {
294+
t.Fatalf("missing tx opts")
295+
}
296+
if req.Options.GetReadWrite() == nil {
297+
t.Fatalf("missing tx read write")
298+
}
299+
300+
// There should not be any commit requests yet.
301+
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
302+
if g, w := len(commitRequests), 0; g != w {
303+
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
304+
}
305+
306+
// Commit the transaction with the mutation.
307+
res := Commit(pool.ObjectId, conn.ObjectId, tx.ObjectId)
308+
if res.Code != 0 {
309+
t.Fatalf("failed to commit: %v", res.Code)
310+
}
311+
312+
// Verify that we have a commit request on the server.
313+
requests = drainRequestsFromServer(server.TestSpanner)
314+
commitRequests = requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
315+
if g, w := len(commitRequests), 1; g != w {
316+
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
317+
}
318+
commitReq := commitRequests[0].(*sppb.CommitRequest)
319+
if g, w := len(commitReq.Mutations), 1; g != w {
320+
t.Fatalf("mutation count mismatch\n Got: %v\nWant: %v", g, w)
321+
}
322+
if g, w := len(commitReq.Mutations[0].GetInsert().Values), 2; g != w {
323+
t.Fatalf("mutation values count mismatch\n Got: %v\nWant: %v", g, w)
324+
}
325+
326+
}
327+
328+
func TestBufferWrite_RetryAborted(t *testing.T) {
329+
t.Parallel()
330+
331+
dsn, server, teardown := setupTestDBConnection(t)
332+
defer teardown()
333+
334+
pool := CreatePool(dsn)
335+
defer ClosePool(pool.ObjectId)
336+
conn := CreateConnection(pool.ObjectId)
337+
defer CloseConnection(pool.ObjectId, conn.ObjectId)
338+
339+
txOpts := &sppb.TransactionOptions{}
340+
txOptsBytes, _ := proto.Marshal(txOpts)
341+
tx := BeginTransaction(pool.ObjectId, conn.ObjectId, txOptsBytes)
342+
343+
mutations := sppb.BatchWriteRequest_MutationGroup{
344+
Mutations: []*sppb.Mutation{
345+
{Operation: &sppb.Mutation_Insert{Insert: &sppb.Mutation_Write{
346+
Table: "foo",
347+
Columns: []string{"id", "value"},
348+
Values: []*structpb.ListValue{
349+
{Values: []*structpb.Value{
350+
{Kind: &structpb.Value_StringValue{StringValue: "1"}},
351+
{Kind: &structpb.Value_StringValue{StringValue: "One"}},
352+
}},
353+
{Values: []*structpb.Value{
354+
{Kind: &structpb.Value_StringValue{StringValue: "2"}},
355+
{Kind: &structpb.Value_StringValue{StringValue: "Two"}},
356+
}},
357+
},
358+
}}},
359+
},
360+
}
361+
mutationsBytes, _ := proto.Marshal(&mutations)
362+
response := BufferWrite(pool.ObjectId, conn.ObjectId, tx.ObjectId, mutationsBytes)
363+
if response.Code != 0 {
364+
t.Fatalf("failed to apply mutations: %v", response.Code)
365+
}
366+
if response.Length() > 0 {
367+
t.Fatal("response length mismatch")
368+
}
369+
370+
requests := drainRequestsFromServer(server.TestSpanner)
371+
beginRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))
372+
if g, w := len(beginRequests), 1; g != w {
373+
t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w)
374+
}
375+
req := beginRequests[0].(*sppb.BeginTransactionRequest)
376+
if req.Options == nil {
377+
t.Fatalf("missing tx opts")
378+
}
379+
if req.Options.GetReadWrite() == nil {
380+
t.Fatalf("missing tx read write")
381+
}
382+
383+
// There should not be any commit requests yet.
384+
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
385+
if g, w := len(commitRequests), 0; g != w {
386+
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
387+
}
388+
389+
// Instruct the mock server to abort the transaction.
390+
server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{
391+
Errors: []error{status.Error(codes.Aborted, "Aborted")},
392+
})
393+
394+
// Commit the transaction with the mutation.
395+
res := Commit(pool.ObjectId, conn.ObjectId, tx.ObjectId)
396+
if res.Code != 0 {
397+
t.Fatalf("failed to commit: %v", res.Code)
398+
}
399+
400+
// Verify that we have a commit request on the server.
401+
requests = drainRequestsFromServer(server.TestSpanner)
402+
commitRequests = requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
403+
if g, w := len(commitRequests), 2; g != w {
404+
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
405+
}
406+
for _, req := range commitRequests {
407+
commitReq := req.(*sppb.CommitRequest)
408+
if g, w := len(commitReq.Mutations), 1; g != w {
409+
t.Fatalf("mutation count mismatch\n Got: %v\nWant: %v", g, w)
410+
}
411+
if g, w := len(commitReq.Mutations[0].GetInsert().Values), 2; g != w {
412+
t.Fatalf("mutation values count mismatch\n Got: %v\nWant: %v", g, w)
413+
}
414+
}
415+
}
416+
180417
func setupTestDBConnection(t *testing.T) (dsn string, server *testutil.MockedSpannerInMemTestServer, teardown func()) {
181418
return setupTestDBConnectionWithParams(t, "")
182419
}

0 commit comments

Comments
 (0)