From 0b09196d3a869ca684b7a7f8f4555a24408e38dc Mon Sep 17 00:00:00 2001 From: Nick Van Wiggeren Date: Tue, 16 Jun 2026 20:00:02 +0000 Subject: [PATCH 1/7] Improve large sync resumability --- .../server/handlers/integration_test.go | 2 +- cmd/internal/server/handlers/sync.go | 17 +- cmd/internal/server/handlers/sync_test.go | 161 ++++++++++- cmd/internal/server/handlers/test_types.go | 2 + cmd/internal/server/server.go | 11 +- cmd/internal/server/types.go | 2 +- lib/connect_client.go | 62 +++- lib/connect_client_test.go | 270 ++++++++++++++++++ lib/test_types.go | 4 + 9 files changed, 497 insertions(+), 34 deletions(-) diff --git a/cmd/internal/server/handlers/integration_test.go b/cmd/internal/server/handlers/integration_test.go index fea7ea9..905da74 100644 --- a/cmd/internal/server/handlers/integration_test.go +++ b/cmd/internal/server/handlers/integration_test.go @@ -724,7 +724,7 @@ func runIntegrationSyncWithError(t *testing.T, psc lib.PlanetScaleSource, tableN sender := &integrationSender{} logger := NewSchemaAwareSerializer(sender, "integration", psc.TreatTinyIntAsBoolean, sourceSchema.SchemaList, sourceSchema.EnumsAndSets) syncer := &Sync{} - if err := syncer.Handle(&psc, &connectClient, logger, state, selection); err != nil { + if err := syncer.Handle(context.Background(), &psc, &connectClient, logger, state, selection); err != nil { return sender, state, err } if checkpoint, ok := sender.latestState(t); ok { diff --git a/cmd/internal/server/handlers/sync.go b/cmd/internal/server/handlers/sync.go index 73a7b86..f96c707 100644 --- a/cmd/internal/server/handlers/sync.go +++ b/cmd/internal/server/handlers/sync.go @@ -15,7 +15,7 @@ import ( type Sync struct{} -func (s *Sync) Handle(psc *lib.PlanetScaleSource, db *lib.ConnectClient, logger Serializer, state *lib.SyncState, schema *fivetransdk.Selection_WithSchema) error { +func (s *Sync) Handle(ctx context.Context, psc *lib.PlanetScaleSource, db *lib.ConnectClient, logger Serializer, state *lib.SyncState, schema *fivetransdk.Selection_WithSchema) error { if state == nil { return status.Error(codes.Internal, "syncState cannot be nil") } @@ -23,7 +23,6 @@ func (s *Sync) Handle(psc *lib.PlanetScaleSource, db *lib.ConnectClient, logger if db == nil { return status.Error(codes.Internal, "database accessor has not been initialized") } - ctx := context.Background() for _, ks := range includedKeyspaces(schema) { for _, table := range includedTables(ks) { stateKey := ks.SchemaName + ":" + table.TableName @@ -39,8 +38,8 @@ func (s *Sync) Handle(psc *lib.PlanetScaleSource, db *lib.ConnectClient, logger return logger.Update(upd, ks, table) } - // First pass: check if all shards have empty positions - allShardsHaveEmptyPosition := true + // First pass: check if all shards are at the initial cursor. + allShardsHaveNoProgress := true hasShards := false for _, shardState := range streamState.Shards { @@ -50,15 +49,15 @@ func (s *Sync) Handle(psc *lib.PlanetScaleSource, db *lib.ConnectClient, logger return status.Error(codes.Internal, fmt.Sprintf("invalid cursor for stream %v, failed with [%v]", stateKey, err)) } - // Check if this shard has data (non-empty position) - if tc.Position != "" { - allShardsHaveEmptyPosition = false + // LastKnownPk is durable historical-copy progress even when Position is empty. + if tc.Position != "" || tc.LastKnownPk != nil { + allShardsHaveNoProgress = false break // No need to check remaining shards } } - // Call truncate before any Read operations if all shards are empty - if hasShards && allShardsHaveEmptyPosition { + // Call truncate before any Read operations only when no shard has durable progress. + if hasShards && allShardsHaveNoProgress { logger.Truncate(ks, table) } diff --git a/cmd/internal/server/handlers/sync_test.go b/cmd/internal/server/handlers/sync_test.go index 9678dc2..b1ae1e5 100644 --- a/cmd/internal/server/handlers/sync_test.go +++ b/cmd/internal/server/handlers/sync_test.go @@ -11,9 +11,13 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" ) func TestCallsReadWithSelectedSchema(t *testing.T) { + type contextKey struct{} + ctx := context.WithValue(context.Background(), contextKey{}, "sync-request") psc := &lib.PlanetScaleSource{} tl := &testLogger{} schema := fivetransdk.SchemaSelection{ @@ -42,6 +46,7 @@ func TestCallsReadWithSelectedSchema(t *testing.T) { readFn := func(ctx context.Context, logger lib.DatabaseLogger, ps lib.PlanetScaleSource, tableName string, columns []string, tc *psdbconnect.TableCursor, onResult lib.OnResult, onCursor lib.OnCursor, onUpdate lib.OnUpdate, ) (*lib.SerializedCursor, error) { + assert.Equal(t, "sync-request", ctx.Value(contextKey{})) assert.Equal(t, "customers", tableName) return nil, nil } @@ -49,7 +54,7 @@ func TestCallsReadWithSelectedSchema(t *testing.T) { state, err := psc.GetInitialState("SalesDB", []string{"-"}) assert.NoError(t, err) db := lib.NewTestConnectClient(readFn) - err = sync.Handle(psc, &db, tl, &lib.SyncState{ + err = sync.Handle(ctx, psc, &db, tl, &lib.SyncState{ Keyspaces: map[string]lib.KeyspaceState{ "SalesDB": { Streams: map[string]lib.ShardStates{ @@ -94,7 +99,7 @@ func TestCallsTruncateOnInitialSync(t *testing.T) { assert.NoError(t, err) db := lib.NewTestConnectClient(readFn) - err = sync.Handle(psc, &db, tl, &lib.SyncState{ + err = sync.Handle(context.Background(), psc, &db, tl, &lib.SyncState{ Keyspaces: map[string]lib.KeyspaceState{ "SalesDB": { Streams: map[string]lib.ShardStates{ @@ -163,7 +168,7 @@ func TestCallsReadWithStartingGtids(t *testing.T) { assert.Equal(t, state.Shards["80-"], cursor) db := lib.NewTestConnectClient(readFn) - err = sync.Handle(psc, &db, tl, &lib.SyncState{ + err = sync.Handle(context.Background(), psc, &db, tl, &lib.SyncState{ Keyspaces: map[string]lib.KeyspaceState{ "SalesDB": { Streams: map[string]lib.ShardStates{ @@ -207,7 +212,7 @@ func TestVStreamSchemaIncompatibilityReturnsFailedPrecondition(t *testing.T) { state, err := psc.GetInitialState("SalesDB", []string{"-"}) assert.NoError(t, err) db := lib.NewTestConnectClient(readFn) - err = (&Sync{}).Handle(psc, &db, tl, &lib.SyncState{ + err = (&Sync{}).Handle(context.Background(), psc, &db, tl, &lib.SyncState{ Keyspaces: map[string]lib.KeyspaceState{ "SalesDB": { Streams: map[string]lib.ShardStates{ @@ -219,3 +224,151 @@ func TestVStreamSchemaIncompatibilityReturnsFailedPrecondition(t *testing.T) { assert.Error(t, err) assert.Equal(t, codes.FailedPrecondition, status.Code(err)) } + +func TestCheckpointsHistoricalCopyCursorFromRead(t *testing.T) { + psc := &lib.PlanetScaleSource{} + tl := &testLogger{} + schemaSelection := &fivetransdk.Selection_WithSchema{ + WithSchema: &fivetransdk.TablesWithSchema{ + Schemas: []*fivetransdk.SchemaSelection{ + { + SchemaName: "SalesDB", + Included: true, + Tables: []*fivetransdk.TableSelection{ + { + Included: true, + TableName: "customers", + }, + }, + }, + }, + }, + } + + copyCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "SalesDB", + LastKnownPk: testLastKnownPK("42"), + } + readFn := func(ctx context.Context, logger lib.DatabaseLogger, ps lib.PlanetScaleSource, tableName string, columns []string, + tc *psdbconnect.TableCursor, onResult lib.OnResult, onCursor lib.OnCursor, onUpdate lib.OnUpdate, + ) (*lib.SerializedCursor, error) { + return nil, onCursor(copyCursor) + } + + state, err := psc.GetInitialState("SalesDB", []string{"-"}) + assert.NoError(t, err) + db := lib.NewTestConnectClient(readFn) + err = (&Sync{}).Handle(context.Background(), psc, &db, tl, &lib.SyncState{ + Keyspaces: map[string]lib.KeyspaceState{ + "SalesDB": { + Streams: map[string]lib.ShardStates{ + "SalesDB:customers": state, + }, + }, + }, + }, schemaSelection) + assert.NoError(t, err) + + assert.True(t, stateLogContainsLastKnownPK(t, tl.states, "SalesDB", "SalesDB:customers", "-")) +} + +func TestDoesNotTruncateWhenStateHasHistoricalCopyProgress(t *testing.T) { + psc := &lib.PlanetScaleSource{} + tl := &testLogger{} + schemaSelection := &fivetransdk.Selection_WithSchema{ + WithSchema: &fivetransdk.TablesWithSchema{ + Schemas: []*fivetransdk.SchemaSelection{ + { + SchemaName: "SalesDB", + Included: true, + Tables: []*fivetransdk.TableSelection{ + { + Included: true, + TableName: "customers", + }, + }, + }, + }, + }, + } + + cursor, err := lib.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "SalesDB", + LastKnownPk: testLastKnownPK("42"), + }) + assert.NoError(t, err) + + readCalled := false + readFn := func(ctx context.Context, logger lib.DatabaseLogger, ps lib.PlanetScaleSource, tableName string, columns []string, + tc *psdbconnect.TableCursor, onResult lib.OnResult, onCursor lib.OnCursor, onUpdate lib.OnUpdate, + ) (*lib.SerializedCursor, error) { + readCalled = true + assert.Empty(t, tc.Position) + assert.NotNil(t, tc.LastKnownPk) + return nil, nil + } + + db := lib.NewTestConnectClient(readFn) + err = (&Sync{}).Handle(context.Background(), psc, &db, tl, &lib.SyncState{ + Keyspaces: map[string]lib.KeyspaceState{ + "SalesDB": { + Streams: map[string]lib.ShardStates{ + "SalesDB:customers": { + Shards: map[string]*lib.SerializedCursor{ + "-": cursor, + }, + }, + }, + }, + }, + }, schemaSelection) + assert.NoError(t, err) + assert.True(t, readCalled) + assert.False(t, tl.truncateCalled) +} + +func testLastKnownPK(value string) *query.QueryResult { + return &query.QueryResult{ + Fields: []*query.Field{ + { + Type: sqltypes.Int64, + Name: "id", + }, + }, + Rows: []*query.Row{ + { + Lengths: []int64{int64(len(value))}, + Values: []byte(value), + }, + }, + } +} + +func stateLogContainsLastKnownPK(t *testing.T, states []lib.SyncState, keyspace string, stream string, shard string) bool { + t.Helper() + + for _, state := range states { + keyspaceState, ok := state.Keyspaces[keyspace] + if !ok { + continue + } + streamState, ok := keyspaceState.Streams[stream] + if !ok { + continue + } + serializedCursor, ok := streamState.Shards[shard] + if !ok { + continue + } + cursor, err := serializedCursor.SerializedCursorToTableCursor() + if err != nil { + t.Fatalf("deserialize cursor: %v", err) + } + if cursor.LastKnownPk != nil { + return true + } + } + return false +} diff --git a/cmd/internal/server/handlers/test_types.go b/cmd/internal/server/handlers/test_types.go index 3245984..b2c0977 100644 --- a/cmd/internal/server/handlers/test_types.go +++ b/cmd/internal/server/handlers/test_types.go @@ -20,6 +20,7 @@ func (l *testLogSender) Send(response *fivetransdk.UpdateResponse) error { type testLogger struct { truncateCalled bool + states []lib.SyncState } func (testLogger) Info(s string) error { @@ -62,6 +63,7 @@ func (tl *testLogger) Record(result *sqltypes.Result, selection *fivetransdk.Sch } func (tl *testLogger) State(state lib.SyncState) error { + tl.states = append(tl.states, state) return nil } diff --git a/cmd/internal/server/server.go b/cmd/internal/server/server.go index 683d7a4..d0ad712 100644 --- a/cmd/internal/server/server.go +++ b/cmd/internal/server/server.go @@ -111,7 +111,7 @@ func (c *connectorServer) Schema(ctx context.Context, request *fivetran_sdk_v2.S } // check if credentials are still valid. - checkConn, err := c.checkConnection.Handle(context.Background(), db, handlers.CheckConnectionTestName, psc) + checkConn, err := c.checkConnection.Handle(ctx, db, handlers.CheckConnectionTestName, psc) if err != nil { return nil, status.Error(codes.InvalidArgument, "unable to open connection to PlanetScale database") } @@ -124,6 +124,7 @@ func (c *connectorServer) Schema(ctx context.Context, request *fivetran_sdk_v2.S } func (c *connectorServer) Update(request *fivetran_sdk_v2.UpdateRequest, server fivetran_sdk_v2.SourceConnector_UpdateServer) error { + ctx := server.Context() requestId := newRequestID() rLogger := newRequestLogger(requestId) @@ -164,7 +165,7 @@ func (c *connectorServer) Update(request *fivetran_sdk_v2.UpdateRequest, server } schemaBuilder := handlers.NewSchemaBuilder(psc.TreatTinyIntAsBoolean) - if err := mysqlClient.BuildSchema(context.Background(), *psc, schemaBuilder); err != nil { + if err := mysqlClient.BuildSchema(ctx, *psc, schemaBuilder); err != nil { status.Error(codes.InvalidArgument, "unable to build schema from PlanetScale database") return nil } @@ -176,7 +177,7 @@ func (c *connectorServer) Update(request *fivetran_sdk_v2.UpdateRequest, server logger := handlers.NewSchemaAwareSerializer(server, requestId, psc.TreatTinyIntAsBoolean, sourceSchema.SchemaList, sourceSchema.EnumsAndSets) - shards, err := db.ListShards(context.Background(), *psc) + shards, err := db.ListShards(ctx, *psc) if err != nil { return status.Errorf(codes.InvalidArgument, "unable to list shards for this database : %q", err) } @@ -187,7 +188,7 @@ func (c *connectorServer) Update(request *fivetran_sdk_v2.UpdateRequest, server } // check if credentials are still valid. - checkConn, err := c.checkConnection.Handle(context.Background(), db, handlers.CheckConnectionTestName, psc) + checkConn, err := c.checkConnection.Handle(ctx, db, handlers.CheckConnectionTestName, psc) if err != nil { msg := fmt.Sprintf("unable to connect to PlanetScale database, failed with : %q", err) logger.Severe(msg) @@ -200,7 +201,7 @@ func (c *connectorServer) Update(request *fivetran_sdk_v2.UpdateRequest, server return status.Error(codes.NotFound, msg) } - return c.sync.Handle(psc, &db, logger, state, schemaSelection) + return c.sync.Handle(ctx, psc, &db, logger, state, schemaSelection) } func newRequestLogger(requestID string) *log.Logger { diff --git a/cmd/internal/server/types.go b/cmd/internal/server/types.go index 7411d60..087dde8 100644 --- a/cmd/internal/server/types.go +++ b/cmd/internal/server/types.go @@ -25,7 +25,7 @@ type SchemaHandler interface { } type SyncHandler interface { - Handle(*lib.PlanetScaleSource, *lib.ConnectClient, handlers.Serializer, *lib.SyncState, *fivetransdk.Selection_WithSchema) error + Handle(context.Context, *lib.PlanetScaleSource, *lib.ConnectClient, handlers.Serializer, *lib.SyncState, *fivetransdk.Selection_WithSchema) error } func NewConfigurationFormHandler() ConfigurationFormHandler { diff --git a/lib/connect_client.go b/lib/connect_client.go index 3f1e9b2..85726cf 100644 --- a/lib/connect_client.go +++ b/lib/connect_client.go @@ -37,6 +37,12 @@ type DatabaseLogger interface { Warning(string) error } +var ( + cursorCheckpointRows = 1000 + cursorCheckpointInterval = 30 * time.Second + maxConsecutiveSyncTimeouts = 5 +) + // ConnectClient is a general purpose interface // that defines all the data access methods needed for the PlanetScale Fivetran source to function. type ConnectClient interface { @@ -115,7 +121,6 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane // Timeout tracking variables consecutiveTimeouts := 0 - const maxConsecutiveTimeouts = 5 const maxTimeout = 1 * time.Hour const timeoutMultiplier = 2.8 backoffDuration := 10 * time.Second @@ -144,7 +149,7 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane logger.Info(fmt.Sprintf(preamble+"syncing rows with cursor [%v]", currentPosition)) currentPosition, err = p.sync(ctx, logger, tableName, existingColumns, currentPosition, latestCursorPosition, ps, tabletType, readDuration, onResult, onCursor, onUpdate) - if currentPosition.Position != "" { + if tableCursorHasProgress(currentPosition) { currentSerializedCursor, sErr = TableCursorToSerializedCursor(currentPosition) if sErr != nil { // if we failed to serialize here, we should bail. @@ -164,13 +169,11 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane currentPosition.Position = "" currentPosition.LastKnownPk = nil - // Mark the error as binlog expiration in the state - if currentSerializedCursor != nil { - currentSerializedCursor.SetBinlogExpirationError(fmt.Sprintf("Binlogs have expired. Cursor reset to trigger historical sync. Original error: %v", err.Error())) - } else { - currentSerializedCursor, _ = TableCursorToSerializedCursor(currentPosition) - currentSerializedCursor.SetBinlogExpirationError(fmt.Sprintf("Binlogs have expired. Cursor reset to trigger historical sync. Original error: %v", err.Error())) + currentSerializedCursor, sErr = TableCursorToSerializedCursor(currentPosition) + if sErr != nil { + return currentSerializedCursor, errors.Wrap(sErr, "unable to serialize reset cursor after binlog expiration") } + currentSerializedCursor.SetBinlogExpirationError(fmt.Sprintf("Binlogs have expired. Cursor reset to trigger historical sync. Original error: %v", err.Error())) // Continue with historical sync instead of returning error continue @@ -183,10 +186,10 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane return currentSerializedCursor, err } else { consecutiveTimeouts++ - logger.Info(fmt.Sprintf("%sTimeout occurred (%d/%d consecutive timeouts)", preamble, consecutiveTimeouts, maxConsecutiveTimeouts)) + logger.Info(fmt.Sprintf("%sTimeout occurred (%d/%d consecutive timeouts)", preamble, consecutiveTimeouts, maxConsecutiveSyncTimeouts)) - if consecutiveTimeouts >= maxConsecutiveTimeouts { - logger.Info(fmt.Sprintf("%sReached maximum consecutive timeouts (%d), stopping sync", preamble, maxConsecutiveTimeouts)) + if consecutiveTimeouts >= maxConsecutiveSyncTimeouts { + logger.Info(fmt.Sprintf("%sReached maximum consecutive timeouts (%d), stopping sync", preamble, maxConsecutiveSyncTimeouts)) warningMessage := fmt.Sprintf("Timeout occurred while reading table %s after %d consecutive attempts. Stopping sync.", tableName, consecutiveTimeouts) @@ -218,7 +221,7 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane }() logger.Info(fmt.Sprintf("%sIncreased read timeout to: %v", preamble, readDuration)) - logger.Info(fmt.Sprintf("%sContinuing with cursor after server timeout (attempt %d/%d)", preamble, consecutiveTimeouts, maxConsecutiveTimeouts)) + logger.Info(fmt.Sprintf("%sContinuing with cursor after server timeout (attempt %d/%d)", preamble, consecutiveTimeouts, maxConsecutiveSyncTimeouts)) } } else if errors.Is(err, io.EOF) { logger.Info(fmt.Sprintf("%vFinished reading all rows for table [%v]", preamble, tableName)) @@ -298,6 +301,8 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam // stop when we've reached the well known stop position for this sync session. watchForVgGtidChange := false + recordsSinceCheckpoint := 0 + lastCheckpoint := time.Now() for { res, err := c.Recv() @@ -316,6 +321,7 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam if err := onResult(sqlResult, OpType_Insert); err != nil { return syncStartCursor, status.Error(codes.Internal, "unable to serialize row") } + recordsSinceCheckpoint++ } } @@ -329,6 +335,7 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam if err := onResult(sqlResult, OpType_Delete); err != nil { return syncStartCursor, status.Error(codes.Internal, "unable to serialize row") } + recordsSinceCheckpoint++ } } } @@ -342,11 +349,19 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam if err := onUpdate(updatedRow); err != nil { return syncStartCursor, status.Error(codes.Internal, "unable to serialize update") } + recordsSinceCheckpoint++ } } if res.Cursor != nil { tc = res.Cursor + if shouldCheckpointCursor(tc, recordsSinceCheckpoint, lastCheckpoint) && onCursor != nil { + if err := onCursor(tc); err != nil { + return tc, status.Error(codes.Internal, "unable to serialize cursor") + } + recordsSinceCheckpoint = 0 + lastCheckpoint = time.Now() + } } // A single VGTID can appear in multiple ordered responses. Once we reach @@ -355,14 +370,33 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam watchForVgGtidChange = watchForVgGtidChange || tc.Position == stopPosition if watchForVgGtidChange && tc.Position != stopPosition { - if err := onCursor(tc); err != nil { - return tc, status.Error(codes.Internal, "unable to serialize cursor") + if onCursor != nil { + if err := onCursor(tc); err != nil { + return tc, status.Error(codes.Internal, "unable to serialize cursor") + } } return tc, io.EOF } } } +func tableCursorHasProgress(tc *psdbconnect.TableCursor) bool { + return tc != nil && (tc.Position != "" || tc.LastKnownPk != nil) +} + +func shouldCheckpointCursor(tc *psdbconnect.TableCursor, recordsSinceCheckpoint int, lastCheckpoint time.Time) bool { + if !tableCursorHasProgress(tc) || recordsSinceCheckpoint == 0 { + return false + } + if cursorCheckpointRows > 0 && recordsSinceCheckpoint >= cursorCheckpointRows { + return true + } + if cursorCheckpointInterval > 0 && time.Since(lastCheckpoint) >= cursorCheckpointInterval { + return true + } + return false +} + func cloneTableCursor(tc *psdbconnect.TableCursor) *psdbconnect.TableCursor { if tc == nil { return nil diff --git a/lib/connect_client_test.go b/lib/connect_client_test.go index ffdc12b..7dcb96d 100644 --- a/lib/connect_client_test.go +++ b/lib/connect_client_test.go @@ -3,7 +3,9 @@ package lib import ( "context" "fmt" + "io" "testing" + "time" "vitess.io/vitess/go/vt/proto/query" @@ -14,6 +16,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "vitess.io/vitess/go/sqltypes" ) @@ -755,6 +758,256 @@ func TestRead_FiltersNonExistentColumns(t *testing.T) { } } +func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { + originalMaxTimeouts := maxConsecutiveSyncTimeouts + maxConsecutiveSyncTimeouts = 1 + t.Cleanup(func() { + maxConsecutiveSyncTimeouts = originalMaxTimeouts + }) + + dbl := &dbLogger{} + ped := connectClient{} + getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) { + return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}}, nil + } + mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc) + ped.Mysql = &mysqlClient + + initialCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + } + stopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "STOP_GTID", + } + copyCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + LastKnownPk: testLastKnownPK("42"), + } + testFields := sqltypes.MakeTestFields("id|name", "int64|varbinary") + copyResponseSent := false + + getCurrentVGtidClient := &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{{Cursor: stopCursor}}, + } + syncClient := &connectSyncClientMock{ + recvFn: func() (*psdbconnect.SyncResponse, error) { + if !copyResponseSent { + copyResponseSent = true + return &psdbconnect.SyncResponse{ + Result: []*query.QueryResult{ + sqltypes.ResultToProto3(sqltypes.MakeTestResult(testFields, "42|copied")), + }, + Cursor: copyCursor, + }, nil + } + return nil, status.Error(codes.DeadlineExceeded, "server deadline") + }, + } + + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + if in.Cursor.Position == "current" { + return getCurrentVGtidClient, nil + } + assert.Empty(t, in.Cursor.Position) + return syncClient, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + sc, err := ped.Read(context.Background(), dbl, PlanetScaleSource{Database: "connect-test"}, "customers", []string{"id"}, initialCursor, nil, nil, nil) + assert.NoError(t, err) + if assert.NotNil(t, sc) { + cursor, err := sc.SerializedCursorToTableCursor() + assert.NoError(t, err) + assert.Empty(t, cursor.Position) + assert.NotNil(t, cursor.LastKnownPk) + } + assert.Equal(t, 2, cc.syncFnInvokedCount) +} + +func TestRead_BinlogExpirationReturnsResetCursor(t *testing.T) { + dbl := &dbLogger{} + ped := connectClient{} + getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) { + return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}}, nil + } + mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc) + ped.Mysql = &mysqlClient + + initialCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "OLD_GTID", + } + stopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "STOP_GTID", + } + copyCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + LastKnownPk: testLastKnownPK("42"), + } + copyResponseSent := false + currentCursorRequests := 0 + + syncClient := &connectSyncClientMock{ + recvFn: func() (*psdbconnect.SyncResponse, error) { + if !copyResponseSent { + copyResponseSent = true + return &psdbconnect.SyncResponse{Cursor: copyCursor}, nil + } + return nil, status.Error(codes.Unknown, "Cannot replicate because the source purged required binary logs") + }, + } + + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + if in.Cursor.Position == "current" { + currentCursorRequests++ + if currentCursorRequests == 1 { + return &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{{Cursor: stopCursor}}, + }, nil + } + return nil, errors.New("peek failed after reset") + } + return syncClient, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + sc, err := ped.Read(context.Background(), dbl, PlanetScaleSource{Database: "connect-test"}, "customers", []string{"id"}, initialCursor, nil, nil, nil) + assert.ErrorContains(t, err, "peek failed after reset") + if assert.NotNil(t, sc) { + cursor, err := sc.SerializedCursorToTableCursor() + assert.NoError(t, err) + assert.Empty(t, cursor.Position) + assert.Nil(t, cursor.LastKnownPk) + if assert.NotNil(t, sc.ErrorCode) { + assert.Equal(t, "BINLOG_EXPIRATION_ERROR", *sc.ErrorCode) + } + if assert.NotNil(t, sc.ErrorMessage) { + assert.Contains(t, *sc.ErrorMessage, "Binlogs have expired") + } + } + assert.Equal(t, 3, cc.syncFnInvokedCount) +} + +func TestSync_CheckpointsHistoricalCopyProgress(t *testing.T) { + originalCheckpointRows := cursorCheckpointRows + originalCheckpointInterval := cursorCheckpointInterval + cursorCheckpointRows = 1 + cursorCheckpointInterval = time.Hour + t.Cleanup(func() { + cursorCheckpointRows = originalCheckpointRows + cursorCheckpointInterval = originalCheckpointInterval + }) + + dbl := &dbLogger{} + ped := connectClient{} + testFields := sqltypes.MakeTestFields("id|name", "int64|varbinary") + startCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + } + copyCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + LastKnownPk: testLastKnownPK("2"), + } + stopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "STOP_GTID", + } + afterStopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "AFTER_STOP_GTID", + } + + syncClient := &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{ + { + Result: []*query.QueryResult{ + sqltypes.ResultToProto3(sqltypes.MakeTestResult(testFields, "1|first", "2|second")), + }, + Cursor: copyCursor, + }, + {Cursor: stopCursor}, + {Cursor: afterStopCursor}, + }, + } + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + assert.Empty(t, in.Cursor.Position) + return syncClient, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + rows := 0 + checkpoints := []*psdbconnect.TableCursor{} + onResult := func(*sqltypes.Result, Operation) error { + rows++ + return nil + } + onCursor := func(cursor *psdbconnect.TableCursor) error { + checkpoints = append(checkpoints, cloneTableCursor(cursor)) + return nil + } + + returnedCursor, err := ped.sync(context.Background(), dbl, "customers", []string{"id", "name"}, startCursor, stopCursor.Position, PlanetScaleSource{Database: "connect-test"}, psdbconnect.TabletType_primary, time.Second, onResult, onCursor, nil) + assert.True(t, errors.Is(err, io.EOF)) + assert.True(t, proto.Equal(afterStopCursor, returnedCursor)) + assert.Equal(t, 2, rows) + if assert.Len(t, checkpoints, 2) { + assert.True(t, proto.Equal(copyCursor, checkpoints[0])) + assert.True(t, proto.Equal(afterStopCursor, checkpoints[1])) + } +} + +func TestSync_ResumesHistoricalCopyWithLastKnownPK(t *testing.T) { + dbl := &dbLogger{} + ped := connectClient{} + resumeCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "SHOULD_BE_CLEARED", + LastKnownPk: testLastKnownPK("42"), + } + + requestChecked := false + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + requestChecked = true + assert.Empty(t, in.Cursor.Position) + assert.NotNil(t, in.Cursor.LastKnownPk) + return &connectSyncClientMock{}, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + _, err := ped.sync(context.Background(), dbl, "customers", []string{"id"}, resumeCursor, "STOP_GTID", PlanetScaleSource{Database: "connect-test"}, psdbconnect.TabletType_primary, time.Second, nil, nil, nil) + assert.True(t, errors.Is(err, io.EOF)) + assert.True(t, requestChecked) +} + func TestIsVStreamSchemaIncompatibilityError(t *testing.T) { tests := []struct { name string @@ -792,6 +1045,23 @@ func TestIsVStreamSchemaIncompatibilityError(t *testing.T) { } } +func testLastKnownPK(value string) *query.QueryResult { + return &query.QueryResult{ + Fields: []*query.Field{ + { + Type: sqltypes.Int64, + Name: "id", + }, + }, + Rows: []*query.Row{ + { + Lengths: []int64{int64(len(value))}, + Values: []byte(value), + }, + }, + } +} + const vstreamColumnNotFoundErrorMessage = "error starting stream from shard GTID keyspace:\"fivetran\" shard:\"-\": persistent error in vstream: " + "stream (at source tablet) error @ (including the GTID we failed to process): Code: FAILED_PRECONDITION\n" + "column after_col not found in table customers\n\n" + diff --git a/lib/test_types.go b/lib/test_types.go index ace3114..6cefc70 100644 --- a/lib/test_types.go +++ b/lib/test_types.go @@ -42,10 +42,14 @@ type clientConnectionMock struct { type connectSyncClientMock struct { lastResponseSent int syncResponses []*psdbconnect.SyncResponse + recvFn func() (*psdbconnect.SyncResponse, error) grpc.ClientStream } func (x *connectSyncClientMock) Recv() (*psdbconnect.SyncResponse, error) { + if x.recvFn != nil { + return x.recvFn() + } if x.lastResponseSent >= len(x.syncResponses) { return nil, io.EOF } From 868304e69da0a184af0d9b75afad8ceab7bd66a9 Mon Sep 17 00:00:00 2001 From: Nick Van Wiggeren Date: Tue, 16 Jun 2026 20:19:58 +0000 Subject: [PATCH 2/7] Restrict periodic checkpoints to copy cursors --- lib/connect_client.go | 4 +-- lib/connect_client_test.go | 71 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/lib/connect_client.go b/lib/connect_client.go index 85726cf..de054be 100644 --- a/lib/connect_client.go +++ b/lib/connect_client.go @@ -39,7 +39,7 @@ type DatabaseLogger interface { var ( cursorCheckpointRows = 1000 - cursorCheckpointInterval = 30 * time.Second + cursorCheckpointInterval = 10 * time.Minute maxConsecutiveSyncTimeouts = 5 ) @@ -385,7 +385,7 @@ func tableCursorHasProgress(tc *psdbconnect.TableCursor) bool { } func shouldCheckpointCursor(tc *psdbconnect.TableCursor, recordsSinceCheckpoint int, lastCheckpoint time.Time) bool { - if !tableCursorHasProgress(tc) || recordsSinceCheckpoint == 0 { + if tc == nil || tc.LastKnownPk == nil || recordsSinceCheckpoint == 0 { return false } if cursorCheckpointRows > 0 && recordsSinceCheckpoint >= cursorCheckpointRows { diff --git a/lib/connect_client_test.go b/lib/connect_client_test.go index 7dcb96d..b317c78 100644 --- a/lib/connect_client_test.go +++ b/lib/connect_client_test.go @@ -980,6 +980,77 @@ func TestSync_CheckpointsHistoricalCopyProgress(t *testing.T) { } } +func TestSync_DoesNotPeriodicallyCheckpointVGTIDProgress(t *testing.T) { + originalCheckpointRows := cursorCheckpointRows + originalCheckpointInterval := cursorCheckpointInterval + cursorCheckpointRows = 1 + cursorCheckpointInterval = time.Hour + t.Cleanup(func() { + cursorCheckpointRows = originalCheckpointRows + cursorCheckpointInterval = originalCheckpointInterval + }) + + dbl := &dbLogger{} + ped := connectClient{} + testFields := sqltypes.MakeTestFields("id|name", "int64|varbinary") + startCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "START_GTID", + } + vgtidCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "MID_GTID", + } + stopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "STOP_GTID", + } + afterStopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "AFTER_STOP_GTID", + } + + syncClient := &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{ + { + Result: []*query.QueryResult{ + sqltypes.ResultToProto3(sqltypes.MakeTestResult(testFields, "1|first")), + }, + Cursor: vgtidCursor, + }, + {Cursor: stopCursor}, + {Cursor: afterStopCursor}, + }, + } + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + return syncClient, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + checkpoints := []*psdbconnect.TableCursor{} + onCursor := func(cursor *psdbconnect.TableCursor) error { + checkpoints = append(checkpoints, cloneTableCursor(cursor)) + return nil + } + + returnedCursor, err := ped.sync(context.Background(), dbl, "customers", []string{"id", "name"}, startCursor, stopCursor.Position, PlanetScaleSource{Database: "connect-test"}, psdbconnect.TabletType_primary, time.Second, func(*sqltypes.Result, Operation) error { + return nil + }, onCursor, nil) + assert.True(t, errors.Is(err, io.EOF)) + assert.True(t, proto.Equal(afterStopCursor, returnedCursor)) + if assert.Len(t, checkpoints, 1) { + assert.True(t, proto.Equal(afterStopCursor, checkpoints[0])) + } +} + func TestSync_ResumesHistoricalCopyWithLastKnownPK(t *testing.T) { dbl := &dbLogger{} ped := connectClient{} From 6df3ac0a617032e1a361af50d691fffbdba9ddfd Mon Sep 17 00:00:00 2001 From: Nick Van Wiggeren Date: Tue, 16 Jun 2026 20:44:57 +0000 Subject: [PATCH 3/7] Fix copy cursor resume and fail fast on send errors --- .../handlers/schema_aware_serializer.go | 4 +- .../handlers/schema_aware_serializer_test.go | 23 +++ cmd/internal/server/handlers/sync.go | 4 +- cmd/internal/server/handlers/sync_test.go | 49 ++++++ cmd/internal/server/handlers/test_types.go | 3 +- cmd/internal/server/server.go | 3 +- cmd/internal/server/server_test.go | 45 ++++++ lib/connect_client.go | 19 ++- lib/connect_client_test.go | 147 +++++++++++++++++- 9 files changed, 284 insertions(+), 13 deletions(-) diff --git a/cmd/internal/server/handlers/schema_aware_serializer.go b/cmd/internal/server/handlers/schema_aware_serializer.go index 7726771..d6cbc57 100644 --- a/cmd/internal/server/handlers/schema_aware_serializer.go +++ b/cmd/internal/server/handlers/schema_aware_serializer.go @@ -275,7 +275,9 @@ func (l *schemaAwareSerializer) serializeResult(before *sqltypes.Result, after * } operationRecord.Record.Data = row operationRecord.Record.Type = fivetranOpMap[opType] - l.sender.Send(l.recordResponse) + if err := l.sender.Send(l.recordResponse); err != nil { + return err + } } return nil diff --git a/cmd/internal/server/handlers/schema_aware_serializer_test.go b/cmd/internal/server/handlers/schema_aware_serializer_test.go index bf33fed..9e7dc29 100644 --- a/cmd/internal/server/handlers/schema_aware_serializer_test.go +++ b/cmd/internal/server/handlers/schema_aware_serializer_test.go @@ -76,6 +76,29 @@ func TestCanSerializeInsert(t *testing.T) { assert.Equal(t, "string:\"enum_value\"", data["enum_value"].String()) } +func TestRecordReturnsSenderError(t *testing.T) { + row, s, err := generateTestRecord("PhaniRaj") + require.NoError(t, err) + tl := &testLogSender{sendError: assert.AnError} + l := NewSchemaAwareSerializer(tl, "", true, &fivetransdk.SchemaList{Schemas: []*fivetransdk.Schema{s}}, map[string]map[string]map[string]ValueMap{}) + + schema := &fivetransdk.SchemaSelection{ + Included: true, + SchemaName: s.Name, + } + table := &fivetransdk.TableSelection{ + TableName: "Customers", + Included: true, + Columns: map[string]bool{}, + } + for _, f := range row.Fields { + table.Columns[f.Name] = true + } + + err = l.Record(row, schema, table, lib.OpType_Insert) + assert.ErrorIs(t, err, assert.AnError) +} + func TestConvertRowToMapRequiresMatchingColumnCount(t *testing.T) { tests := []struct { name string diff --git a/cmd/internal/server/handlers/sync.go b/cmd/internal/server/handlers/sync.go index f96c707..737a903 100644 --- a/cmd/internal/server/handlers/sync.go +++ b/cmd/internal/server/handlers/sync.go @@ -58,7 +58,9 @@ func (s *Sync) Handle(ctx context.Context, psc *lib.PlanetScaleSource, db *lib.C // Call truncate before any Read operations only when no shard has durable progress. if hasShards && allShardsHaveNoProgress { - logger.Truncate(ks, table) + if err := logger.Truncate(ks, table); err != nil { + return status.Error(codes.Internal, fmt.Sprintf("failed to truncate table %s, error: %s", table.TableName, err.Error())) + } } // Second pass: perform the actual sync for each shard diff --git a/cmd/internal/server/handlers/sync_test.go b/cmd/internal/server/handlers/sync_test.go index b1ae1e5..6d49cd6 100644 --- a/cmd/internal/server/handlers/sync_test.go +++ b/cmd/internal/server/handlers/sync_test.go @@ -112,6 +112,55 @@ func TestCallsTruncateOnInitialSync(t *testing.T) { assert.True(t, tl.truncateCalled) } +func TestInitialSyncReturnsErrorWhenTruncateFails(t *testing.T) { + psc := &lib.PlanetScaleSource{} + tl := &testLogger{truncateErr: assert.AnError} + schema := fivetransdk.SchemaSelection{ + SchemaName: "SalesDB", + Included: true, + Tables: []*fivetransdk.TableSelection{ + { + Included: true, + TableName: "customers", + }, + }, + } + schemaSelection := &fivetransdk.Selection_WithSchema{ + WithSchema: &fivetransdk.TablesWithSchema{ + Schemas: []*fivetransdk.SchemaSelection{ + &schema, + }, + }, + } + + readCalled := false + readFn := func(ctx context.Context, logger lib.DatabaseLogger, ps lib.PlanetScaleSource, tableName string, columns []string, + tc *psdbconnect.TableCursor, onResult lib.OnResult, onCursor lib.OnCursor, onUpdate lib.OnUpdate, + ) (*lib.SerializedCursor, error) { + readCalled = true + return nil, nil + } + + state, err := psc.GetInitialState("SalesDB", []string{"-"}) + assert.NoError(t, err) + + db := lib.NewTestConnectClient(readFn) + err = (&Sync{}).Handle(context.Background(), psc, &db, tl, &lib.SyncState{ + Keyspaces: map[string]lib.KeyspaceState{ + "SalesDB": { + Streams: map[string]lib.ShardStates{ + "SalesDB:customers": state, + }, + }, + }, + }, schemaSelection) + assert.Error(t, err) + assert.Equal(t, codes.Internal, status.Code(err)) + assert.Contains(t, err.Error(), "failed to truncate table customers") + assert.True(t, tl.truncateCalled) + assert.False(t, readCalled) +} + func TestCallsReadWithStartingGtids(t *testing.T) { psc := &lib.PlanetScaleSource{ StartingGtids: "{\"SalesDB\":{\"-80\":\"MySQL56/MYGTID:1-3\",\"80-\":\"MySQL56/MYOTHERGTID:1-3\"}}", diff --git a/cmd/internal/server/handlers/test_types.go b/cmd/internal/server/handlers/test_types.go index b2c0977..e30e2d3 100644 --- a/cmd/internal/server/handlers/test_types.go +++ b/cmd/internal/server/handlers/test_types.go @@ -20,6 +20,7 @@ func (l *testLogSender) Send(response *fivetransdk.UpdateResponse) error { type testLogger struct { truncateCalled bool + truncateErr error states []lib.SyncState } @@ -54,7 +55,7 @@ func (tl *testLogger) Update(*lib.UpdatedRow, *fivetransdk.SchemaSelection, *fiv func (tl *testLogger) Truncate(*fivetransdk.SchemaSelection, *fivetransdk.TableSelection) error { tl.truncateCalled = true - return nil + return tl.truncateErr } func (tl *testLogger) Record(result *sqltypes.Result, selection *fivetransdk.SchemaSelection, selection2 *fivetransdk.TableSelection, operation lib.Operation) error { diff --git a/cmd/internal/server/server.go b/cmd/internal/server/server.go index d0ad712..0d4d19c 100644 --- a/cmd/internal/server/server.go +++ b/cmd/internal/server/server.go @@ -166,8 +166,7 @@ func (c *connectorServer) Update(request *fivetran_sdk_v2.UpdateRequest, server schemaBuilder := handlers.NewSchemaBuilder(psc.TreatTinyIntAsBoolean) if err := mysqlClient.BuildSchema(ctx, *psc, schemaBuilder); err != nil { - status.Error(codes.InvalidArgument, "unable to build schema from PlanetScale database") - return nil + return status.Error(codes.InvalidArgument, "unable to build schema from PlanetScale database") } sourceSchema, err := schemaBuilder.(*handlers.FiveTranSchemaBuilder).BuildUpdateResponse() diff --git a/cmd/internal/server/server_test.go b/cmd/internal/server/server_test.go index 584702d..51b7745 100644 --- a/cmd/internal/server/server_test.go +++ b/cmd/internal/server/server_test.go @@ -160,6 +160,51 @@ func TestUpdateValidatesState(t *testing.T) { assert.ErrorContains(t, err, "request did not contain a valid stateJson") } +func TestUpdateReturnsSchemaBuildError(t *testing.T) { + ctx := context.Background() + mysqlClientConstructor := func() lib.MysqlClient { + return &lib.TestMysqlClient{ + BuildSchemaFn: func(ctx context.Context, psc lib.PlanetScaleSource, schemaBuilder lib.SchemaBuilder) error { + return assert.AnError + }, + } + } + client, closer := server(ctx, nil, mysqlClientConstructor) + defer closer() + + out, err := client.Update(ctx, &fivetransdk.UpdateRequest{ + Configuration: map[string]string{ + "host": "earth.psdb", + "username": "phanatic", + "password": "password", + "database": "employees", + }, + Selection: &fivetransdk.Selection{ + Selection: &fivetransdk.Selection_WithSchema{ + WithSchema: &fivetransdk.TablesWithSchema{ + Schemas: []*fivetransdk.SchemaSelection{ + { + SchemaName: "SalesDB", + Included: true, + Tables: []*fivetransdk.TableSelection{ + { + Included: true, + TableName: "customers", + }, + }, + }, + }, + }, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, out) + + _, err = out.Recv() + assert.ErrorContains(t, err, "unable to build schema from PlanetScale database") +} + func TestCanSerializeGeometryTypes(t *testing.T) { tests := []struct { Type string diff --git a/lib/connect_client.go b/lib/connect_client.go index de054be..d0c8c96 100644 --- a/lib/connect_client.go +++ b/lib/connect_client.go @@ -140,8 +140,10 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane return currentSerializedCursor, errors.Wrap(lcErr, "Unable to get latest cursor position") } - // the current vgtid is the same as the last synced vgtid, no new rows. - if latestCursorPosition == currentPosition.Position { + // The current vgtid is the same as the last synced vgtid, no new rows. + // A LastKnownPk cursor is still in the historical copy phase and must + // resume even when the binlog position has not advanced. + if currentPosition.LastKnownPk == nil && latestCursorPosition == currentPosition.Position { logger.Info(preamble + "no new rows found, exiting") return TableCursorToSerializedCursor(currentPosition) } @@ -209,7 +211,15 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane // Apply exponential backoff logger.Info(fmt.Sprintf("%sApplying backoff delay: %v", preamble, backoffDuration)) - time.Sleep(backoffDuration) + backoffTimer := time.NewTimer(backoffDuration) + select { + case <-ctx.Done(): + if !backoffTimer.Stop() { + <-backoffTimer.C + } + return currentSerializedCursor, ctx.Err() + case <-backoffTimer.C: + } backoffDuration = time.Duration(math.Min(float64(backoffDuration)*2, float64(5*time.Minute))) // Cap at 5 minutes newReadDuration := time.Duration(float64(readDuration) * timeoutMultiplier) @@ -276,9 +286,6 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam } } - if tc.LastKnownPk != nil { - tc.Position = "" - } syncStartCursor := cloneTableCursor(tc) logger.Info(fmt.Sprintf("%sSyncing with cursor position : [%v], using last known PK : %v, stop cursor is : [%v]", preamble, tc.Position, tc.LastKnownPk != nil, stopPosition)) diff --git a/lib/connect_client_test.go b/lib/connect_client_test.go index b317c78..ba89c31 100644 --- a/lib/connect_client_test.go +++ b/lib/connect_client_test.go @@ -832,6 +832,122 @@ func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { assert.Equal(t, 2, cc.syncFnInvokedCount) } +func TestRead_CancelDuringTimeoutBackoffReturnsImmediately(t *testing.T) { + originalMaxTimeouts := maxConsecutiveSyncTimeouts + maxConsecutiveSyncTimeouts = 2 + t.Cleanup(func() { + maxConsecutiveSyncTimeouts = originalMaxTimeouts + }) + + dbl := &dbLogger{} + ped := connectClient{} + getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) { + return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}}, nil + } + mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc) + ped.Mysql = &mysqlClient + + startCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "START_GTID", + } + stopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "STOP_GTID", + } + + ctx, cancel := context.WithCancel(context.Background()) + getCurrentVGtidClient := &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{{Cursor: stopCursor}}, + } + syncClient := &connectSyncClientMock{ + recvFn: func() (*psdbconnect.SyncResponse, error) { + cancel() + return nil, status.Error(codes.DeadlineExceeded, "server deadline") + }, + } + + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + if in.Cursor.Position == "current" { + return getCurrentVGtidClient, nil + } + return syncClient, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + start := time.Now() + _, err := ped.Read(ctx, dbl, PlanetScaleSource{Database: "connect-test"}, "customers", []string{"id"}, startCursor, nil, nil, nil) + assert.ErrorIs(t, err, context.Canceled) + assert.Less(t, time.Since(start), time.Second) +} + +func TestRead_ResumesHistoricalCopyEvenWhenPeekMatchesCopyCursorPosition(t *testing.T) { + dbl := &dbLogger{} + ped := connectClient{} + getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) { + return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}}, nil + } + mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc) + ped.Mysql = &mysqlClient + + copyCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "COPY_SNAPSHOT_GTID", + LastKnownPk: testLastKnownPK("42"), + } + stopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "COPY_SNAPSHOT_GTID", + } + doneCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "AFTER_COPY_GTID", + } + + getCurrentVGtidClient := &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{{Cursor: stopCursor}}, + } + syncClient := &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{ + {Cursor: stopCursor}, + {Cursor: doneCursor}, + }, + } + + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + if in.Cursor.Position == "current" { + return getCurrentVGtidClient, nil + } + assert.Equal(t, "COPY_SNAPSHOT_GTID", in.Cursor.Position) + assert.NotNil(t, in.Cursor.LastKnownPk) + return syncClient, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + sc, err := ped.Read(context.Background(), dbl, PlanetScaleSource{Database: "connect-test"}, "customers", []string{"id"}, copyCursor, nil, nil, nil) + assert.NoError(t, err) + if assert.NotNil(t, sc) { + cursor, err := sc.SerializedCursorToTableCursor() + assert.NoError(t, err) + assert.Equal(t, "AFTER_COPY_GTID", cursor.Position) + assert.Nil(t, cursor.LastKnownPk) + } + assert.Equal(t, 2, cc.syncFnInvokedCount) +} + func TestRead_BinlogExpirationReturnsResetCursor(t *testing.T) { dbl := &dbLogger{} ped := connectClient{} @@ -1051,13 +1167,12 @@ func TestSync_DoesNotPeriodicallyCheckpointVGTIDProgress(t *testing.T) { } } -func TestSync_ResumesHistoricalCopyWithLastKnownPK(t *testing.T) { +func TestSync_ResumesHistoricalCopyWithLastKnownPKOnly(t *testing.T) { dbl := &dbLogger{} ped := connectClient{} resumeCursor := &psdbconnect.TableCursor{ Shard: "-", Keyspace: "connect-test", - Position: "SHOULD_BE_CLEARED", LastKnownPk: testLastKnownPK("42"), } @@ -1079,6 +1194,34 @@ func TestSync_ResumesHistoricalCopyWithLastKnownPK(t *testing.T) { assert.True(t, requestChecked) } +func TestSync_ResumesHistoricalCopyWithPositionAndLastKnownPK(t *testing.T) { + dbl := &dbLogger{} + ped := connectClient{} + resumeCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "COPY_SNAPSHOT_GTID", + LastKnownPk: testLastKnownPK("42"), + } + + requestChecked := false + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + requestChecked = true + assert.Equal(t, "COPY_SNAPSHOT_GTID", in.Cursor.Position) + assert.NotNil(t, in.Cursor.LastKnownPk) + return &connectSyncClientMock{}, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + _, err := ped.sync(context.Background(), dbl, "customers", []string{"id"}, resumeCursor, "STOP_GTID", PlanetScaleSource{Database: "connect-test"}, psdbconnect.TabletType_primary, time.Second, nil, nil, nil) + assert.True(t, errors.Is(err, io.EOF)) + assert.True(t, requestChecked) +} + func TestIsVStreamSchemaIncompatibilityError(t *testing.T) { tests := []struct { name string From a10f7ec5833563488395ade6ac294fc1d5cd9844 Mon Sep 17 00:00:00 2001 From: Nick Van Wiggeren Date: Tue, 16 Jun 2026 22:11:59 +0000 Subject: [PATCH 4/7] Reset timeout cap after copy progress --- lib/connect_client.go | 111 ++++++++++++++++++++--------------- lib/connect_client_test.go | 117 +++++++++++++++++++++++++++++++------ 2 files changed, 163 insertions(+), 65 deletions(-) diff --git a/lib/connect_client.go b/lib/connect_client.go index d0c8c96..270e0fc 100644 --- a/lib/connect_client.go +++ b/lib/connect_client.go @@ -150,7 +150,9 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane logger.Info(fmt.Sprintf(preamble+"new rows found, syncing rows for %v", readDuration)) logger.Info(fmt.Sprintf(preamble+"syncing rows with cursor [%v]", currentPosition)) + previousPosition := cloneTableCursor(currentPosition) currentPosition, err = p.sync(ctx, logger, tableName, existingColumns, currentPosition, latestCursorPosition, ps, tabletType, readDuration, onResult, onCursor, onUpdate) + madeProgress := tableCursorMadeProgress(previousPosition, currentPosition) if tableCursorHasProgress(currentPosition) { currentSerializedCursor, sErr = TableCursorToSerializedCursor(currentPosition) if sErr != nil { @@ -186,53 +188,65 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane } return currentSerializedCursor, err - } else { - consecutiveTimeouts++ - logger.Info(fmt.Sprintf("%sTimeout occurred (%d/%d consecutive timeouts)", preamble, consecutiveTimeouts, maxConsecutiveSyncTimeouts)) - - if consecutiveTimeouts >= maxConsecutiveSyncTimeouts { - logger.Info(fmt.Sprintf("%sReached maximum consecutive timeouts (%d), stopping sync", preamble, maxConsecutiveSyncTimeouts)) - - warningMessage := fmt.Sprintf("Timeout occurred while reading table %s after %d consecutive attempts. Stopping sync.", tableName, consecutiveTimeouts) - - if serializer, ok := logger.(interface { - SendWarningAlert(string) error - }); ok { - if err := serializer.SendWarningAlert(warningMessage); err != nil { - logger.Warning(fmt.Sprintf("Failed to send warning message: %v", err)) - } - } else { - // Fallback to regular warning log if serializer doesn't support SendWarningAlert - logger.Warning(warningMessage) - } + } - return currentSerializedCursor, nil + newReadDuration := time.Duration(float64(readDuration) * timeoutMultiplier) + readDuration = func() time.Duration { + if newReadDuration > maxTimeout { + return maxTimeout + } + return newReadDuration + }() + + if madeProgress { + if consecutiveTimeouts > 0 { + logger.Info(fmt.Sprintf("%sTimeout occurred after cursor progress; resetting no-progress timeout counter (was %d)", preamble, consecutiveTimeouts)) + } else { + logger.Info(fmt.Sprintf("%sTimeout occurred after cursor progress; continuing without incrementing no-progress timeout counter", preamble)) } + consecutiveTimeouts = 0 + backoffDuration = 10 * time.Second + logger.Info(fmt.Sprintf("%sIncreased read timeout to: %v", preamble, readDuration)) + continue + } - // Apply exponential backoff - logger.Info(fmt.Sprintf("%sApplying backoff delay: %v", preamble, backoffDuration)) - backoffTimer := time.NewTimer(backoffDuration) - select { - case <-ctx.Done(): - if !backoffTimer.Stop() { - <-backoffTimer.C + consecutiveTimeouts++ + logger.Info(fmt.Sprintf("%sTimeout occurred without cursor progress (%d/%d consecutive no-progress timeouts)", preamble, consecutiveTimeouts, maxConsecutiveSyncTimeouts)) + + if consecutiveTimeouts >= maxConsecutiveSyncTimeouts { + logger.Info(fmt.Sprintf("%sReached maximum consecutive no-progress timeouts (%d), stopping sync", preamble, maxConsecutiveSyncTimeouts)) + + warningMessage := fmt.Sprintf("Timeout occurred while reading table %s after %d consecutive attempts without cursor progress. Stopping sync.", tableName, consecutiveTimeouts) + + if serializer, ok := logger.(interface { + SendWarningAlert(string) error + }); ok { + if err := serializer.SendWarningAlert(warningMessage); err != nil { + logger.Warning(fmt.Sprintf("Failed to send warning message: %v", err)) } - return currentSerializedCursor, ctx.Err() - case <-backoffTimer.C: + } else { + // Fallback to regular warning log if serializer doesn't support SendWarningAlert + logger.Warning(warningMessage) } - backoffDuration = time.Duration(math.Min(float64(backoffDuration)*2, float64(5*time.Minute))) // Cap at 5 minutes - newReadDuration := time.Duration(float64(readDuration) * timeoutMultiplier) - readDuration = func() time.Duration { - if newReadDuration > maxTimeout { - return maxTimeout - } - return newReadDuration - }() + return currentSerializedCursor, nil + } - logger.Info(fmt.Sprintf("%sIncreased read timeout to: %v", preamble, readDuration)) - logger.Info(fmt.Sprintf("%sContinuing with cursor after server timeout (attempt %d/%d)", preamble, consecutiveTimeouts, maxConsecutiveSyncTimeouts)) + // Apply exponential backoff only for timeouts that did not move the cursor. + logger.Info(fmt.Sprintf("%sApplying backoff delay: %v", preamble, backoffDuration)) + backoffTimer := time.NewTimer(backoffDuration) + select { + case <-ctx.Done(): + if !backoffTimer.Stop() { + <-backoffTimer.C + } + return currentSerializedCursor, ctx.Err() + case <-backoffTimer.C: } + backoffDuration = time.Duration(math.Min(float64(backoffDuration)*2, float64(5*time.Minute))) // Cap at 5 minutes + + logger.Info(fmt.Sprintf("%sIncreased read timeout to: %v", preamble, readDuration)) + logger.Info(fmt.Sprintf("%sContinuing with cursor after server timeout without progress (attempt %d/%d)", preamble, consecutiveTimeouts, maxConsecutiveSyncTimeouts)) } else if errors.Is(err, io.EOF) { logger.Info(fmt.Sprintf("%vFinished reading all rows for table [%v]", preamble, tableName)) return currentSerializedCursor, nil @@ -240,15 +254,6 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane logger.Warning(fmt.Sprintf(preamble+"non-grpc error [%v]]", err)) return currentSerializedCursor, err } - } else { - // Reset timeout counter on successful sync - if consecutiveTimeouts > 0 { - logger.Info(fmt.Sprintf("%sSync successful, resetting timeout counter (was %d)", preamble, consecutiveTimeouts)) - consecutiveTimeouts = 0 - backoffDuration = 10 * time.Second // Reset backoff - readDuration = 1 * time.Minute // Reset read duration to original value - logger.Info(fmt.Sprintf("%sReset read timeout to: %v", preamble, readDuration)) - } } } } @@ -391,6 +396,16 @@ func tableCursorHasProgress(tc *psdbconnect.TableCursor) bool { return tc != nil && (tc.Position != "" || tc.LastKnownPk != nil) } +func tableCursorMadeProgress(before, after *psdbconnect.TableCursor) bool { + if after == nil { + return false + } + if before == nil { + return tableCursorHasProgress(after) + } + return before.Position != after.Position || !proto.Equal(before.LastKnownPk, after.LastKnownPk) +} + func shouldCheckpointCursor(tc *psdbconnect.TableCursor, recordsSinceCheckpoint int, lastCheckpoint time.Time) bool { if tc == nil || tc.LastKnownPk == nil || recordsSinceCheckpoint == 0 { return false diff --git a/lib/connect_client_test.go b/lib/connect_client_test.go index ba89c31..218fd76 100644 --- a/lib/connect_client_test.go +++ b/lib/connect_client_test.go @@ -758,7 +758,7 @@ func TestRead_FiltersNonExistentColumns(t *testing.T) { } } -func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { +func TestRead_ReturnsLastKnownPKCursorAfterMaxNoProgressTimeout(t *testing.T) { originalMaxTimeouts := maxConsecutiveSyncTimeouts maxConsecutiveSyncTimeouts = 1 t.Cleanup(func() { @@ -773,10 +773,6 @@ func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc) ped.Mysql = &mysqlClient - initialCursor := &psdbconnect.TableCursor{ - Shard: "-", - Keyspace: "connect-test", - } stopCursor := &psdbconnect.TableCursor{ Shard: "-", Keyspace: "connect-test", @@ -787,23 +783,12 @@ func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { Keyspace: "connect-test", LastKnownPk: testLastKnownPK("42"), } - testFields := sqltypes.MakeTestFields("id|name", "int64|varbinary") - copyResponseSent := false getCurrentVGtidClient := &connectSyncClientMock{ syncResponses: []*psdbconnect.SyncResponse{{Cursor: stopCursor}}, } syncClient := &connectSyncClientMock{ recvFn: func() (*psdbconnect.SyncResponse, error) { - if !copyResponseSent { - copyResponseSent = true - return &psdbconnect.SyncResponse{ - Result: []*query.QueryResult{ - sqltypes.ResultToProto3(sqltypes.MakeTestResult(testFields, "42|copied")), - }, - Cursor: copyCursor, - }, nil - } return nil, status.Error(codes.DeadlineExceeded, "server deadline") }, } @@ -814,6 +799,7 @@ func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { return getCurrentVGtidClient, nil } assert.Empty(t, in.Cursor.Position) + assert.NotNil(t, in.Cursor.LastKnownPk) return syncClient, nil }, } @@ -821,7 +807,7 @@ func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { return &cc, nil } - sc, err := ped.Read(context.Background(), dbl, PlanetScaleSource{Database: "connect-test"}, "customers", []string{"id"}, initialCursor, nil, nil, nil) + sc, err := ped.Read(context.Background(), dbl, PlanetScaleSource{Database: "connect-test"}, "customers", []string{"id"}, copyCursor, nil, nil, nil) assert.NoError(t, err) if assert.NotNil(t, sc) { cursor, err := sc.SerializedCursorToTableCursor() @@ -832,6 +818,103 @@ func TestRead_ReturnsLastKnownPKCursorAfterMaxTimeout(t *testing.T) { assert.Equal(t, 2, cc.syncFnInvokedCount) } +func TestRead_DoesNotStopAfterProgressingTimeoutWindows(t *testing.T) { + originalMaxTimeouts := maxConsecutiveSyncTimeouts + maxConsecutiveSyncTimeouts = 3 + t.Cleanup(func() { + maxConsecutiveSyncTimeouts = originalMaxTimeouts + }) + + dbl := &dbLogger{} + ped := connectClient{} + getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) { + return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}}, nil + } + mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc) + ped.Mysql = &mysqlClient + + initialCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + } + stopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "STOP_GTID", + } + afterStopCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "AFTER_STOP_GTID", + } + testFields := sqltypes.MakeTestFields("id|name", "int64|varbinary") + progressingWindows := maxConsecutiveSyncTimeouts + 2 + syncAttempts := 0 + + cc := clientConnectionMock{ + syncFn: func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) { + if in.Cursor.Position == "current" { + return &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{{Cursor: stopCursor}}, + }, nil + } + + syncAttempts++ + if syncAttempts <= progressingWindows { + lastPK := fmt.Sprintf("%d", syncAttempts) + sentResponse := false + return &connectSyncClientMock{ + recvFn: func() (*psdbconnect.SyncResponse, error) { + if !sentResponse { + sentResponse = true + return &psdbconnect.SyncResponse{ + Result: []*query.QueryResult{ + sqltypes.ResultToProto3(sqltypes.MakeTestResult(testFields, fmt.Sprintf("%s|copied", lastPK))), + }, + Cursor: &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + LastKnownPk: testLastKnownPK(lastPK), + }, + }, nil + } + return nil, status.Error(codes.DeadlineExceeded, "server deadline") + }, + }, nil + } + + return &connectSyncClientMock{ + syncResponses: []*psdbconnect.SyncResponse{ + {Cursor: stopCursor}, + {Cursor: afterStopCursor}, + }, + }, nil + }, + } + ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) { + return &cc, nil + } + + rows := 0 + sc, err := ped.Read(context.Background(), dbl, PlanetScaleSource{Database: "connect-test"}, "customers", []string{"id"}, initialCursor, func(*sqltypes.Result, Operation) error { + rows++ + return nil + }, nil, nil) + assert.NoError(t, err) + if assert.NotNil(t, sc) { + cursor, err := sc.SerializedCursorToTableCursor() + assert.NoError(t, err) + assert.Equal(t, afterStopCursor.Position, cursor.Position) + assert.Nil(t, cursor.LastKnownPk) + } + assert.Equal(t, progressingWindows, rows) + assert.Equal(t, progressingWindows+1, syncAttempts) + for _, msg := range dbl.messages { + assert.NotContains(t, msg.message, "Reached maximum consecutive no-progress timeouts") + assert.NotContains(t, msg.message, "Stopping sync.") + } +} + func TestRead_CancelDuringTimeoutBackoffReturnsImmediately(t *testing.T) { originalMaxTimeouts := maxConsecutiveSyncTimeouts maxConsecutiveSyncTimeouts = 2 From 19c0718eeebb3aedf1e8b8af22c265cbc34872a4 Mon Sep 17 00:00:00 2001 From: Nick Van Wiggeren Date: Tue, 16 Jun 2026 22:30:47 +0000 Subject: [PATCH 5/7] Add interrupted copy resume integration test --- .../server/handlers/integration_test.go | 287 +++++++++++++++++- lib/integration_test_hooks.go | 18 ++ 2 files changed, 301 insertions(+), 4 deletions(-) create mode 100644 lib/integration_test_hooks.go diff --git a/cmd/internal/server/handlers/integration_test.go b/cmd/internal/server/handlers/integration_test.go index 905da74..8a6b596 100644 --- a/cmd/internal/server/handlers/integration_test.go +++ b/cmd/internal/server/handlers/integration_test.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "vitess.io/vitess/go/sqltypes" ) func TestIntegrationBasicInsertUpdateDelete(t *testing.T) { @@ -354,6 +355,73 @@ func TestIntegrationIntermittentEmptySyncs(t *testing.T) { assertIntegrationRecordCount(t, idle.recordsForTable(tableName), 0) } +func TestIntegrationInterruptedCopyResume(t *testing.T) { + restoreCheckpointPolicy := lib.SetCursorCheckpointPolicyForIntegrationTest(1, 0) + defer restoreCheckpointPolicy() + + psc := loadIntegrationSource(t) + ctx := context.Background() + tableName := integrationTableName(t) + + db := openIntegrationSQL(t, psc) + t.Cleanup(func() { + if _, err := db.ExecContext(context.Background(), "drop table if exists "+quoteIdent(tableName)); err != nil { + t.Logf("drop integration table %s: %v", tableName, err) + } + _ = db.Close() + }) + + mustExec(t, ctx, db, fmt.Sprintf(` + create table %s ( + id bigint primary key, + version int not null, + payload varchar(2048) not null, + active tinyint(1) not null + )`, quoteIdent(tableName))) + + columns := []string{"id", "version", "payload", "active"} + expected := map[int64]map[string]any{} + const rowCount = int64(2500) + insertIntegrationResumeRows(t, ctx, db, tableName, 1, rowCount, 0, expected) + + interruptCtx, cancelInterrupt := context.WithCancel(context.Background()) + interruptSender := &interruptAfterCopyCheckpointSender{ + t: t, + keyspaceName: psc.Database, + tableName: tableName, + cancel: cancelInterrupt, + } + _, err := runIntegrationSyncWithSender(t, interruptCtx, psc, tableName, columns, nil, interruptSender) + cancelInterrupt() + if err == nil { + t.Fatalf("expected interrupted copy sync to return an error") + } + if interruptSender.checkpointState == nil { + t.Fatalf("interrupted copy did not checkpoint a LastKnownPk cursor") + } + if interruptSender.checkpointRecordCount == 0 { + t.Fatalf("interrupted copy checkpointed before emitting table records") + } + + checkpointedRecords := interruptSender.recordsForTable(tableName)[:interruptSender.checkpointRecordCount] + model := map[int64]map[string]any{} + applyIntegrationRecords(t, model, checkpointedRecords) + + lastCopiedID := integrationLastKnownPKID(t, interruptSender.checkpointState, psc.Database, tableName) + if lastCopiedID < 1 || lastCopiedID > rowCount { + t.Fatalf("unexpected LastKnownPk id %d for row count %d", lastCopiedID, rowCount) + } + updateIntegrationResumeRow(t, ctx, db, tableName, lastCopiedID, 1, true, expected) + + resumed, finalState := runIntegrationSync(t, psc, tableName, columns, interruptSender.checkpointState) + applyIntegrationRecords(t, model, resumed.recordsForTable(tableName)) + assertIntegrationRowsExactly(t, model, expected) + assertIntegrationStateHasNoLastKnownPK(t, finalState, psc.Database, tableName) + + idle, _ := runIntegrationSync(t, psc, tableName, columns, finalState) + assertIntegrationRecordCount(t, idle.recordsForTable(tableName), 0) +} + func TestIntegrationRepeatedSyncBursts(t *testing.T) { psc := loadIntegrationSource(t) ctx := context.Background() @@ -684,6 +752,19 @@ func runIntegrationSync(t *testing.T, psc lib.PlanetScaleSource, tableName strin func runIntegrationSyncWithError(t *testing.T, psc lib.PlanetScaleSource, tableName string, columns []string, state *lib.SyncState) (*integrationSender, *lib.SyncState, error) { t.Helper() + sender := &integrationSender{} + state, err := runIntegrationSyncWithSender(t, context.Background(), psc, tableName, columns, state, sender) + return sender, state, err +} + +type integrationStatefulSender interface { + Send(*fivetransdk.UpdateResponse) error + latestState(*testing.T) (*lib.SyncState, bool) +} + +func runIntegrationSyncWithSender(t *testing.T, ctx context.Context, psc lib.PlanetScaleSource, tableName string, columns []string, state *lib.SyncState, sender integrationStatefulSender) (*lib.SyncState, error) { + t.Helper() + mysqlClient, err := lib.NewMySQL(&psc) if err != nil { t.Fatalf("create mysql client: %v", err) @@ -721,16 +802,15 @@ func runIntegrationSyncWithError(t *testing.T, psc lib.PlanetScaleSource, tableN } selection := integrationSelection(psc.Database, tableName, columns) - sender := &integrationSender{} logger := NewSchemaAwareSerializer(sender, "integration", psc.TreatTinyIntAsBoolean, sourceSchema.SchemaList, sourceSchema.EnumsAndSets) syncer := &Sync{} - if err := syncer.Handle(context.Background(), &psc, &connectClient, logger, state, selection); err != nil { - return sender, state, err + if err := syncer.Handle(ctx, &psc, &connectClient, logger, state, selection); err != nil { + return state, err } if checkpoint, ok := sender.latestState(t); ok { state = checkpoint } - return sender, state, nil + return state, nil } func integrationSelection(schemaName, tableName string, columns []string) *fivetransdk.Selection_WithSchema { @@ -783,6 +863,17 @@ func (s *integrationSender) latestState(t *testing.T) (*lib.SyncState, bool) { return nil, false } +func (s *integrationSender) recordCountForTable(tableName string) int { + count := 0 + for _, response := range s.responses { + record := response.GetRecord() + if record != nil && record.TableName == tableName { + count++ + } + } + return count +} + func (s *integrationSender) recordsForTable(tableName string) []*fivetransdk.Record { records := []*fivetransdk.Record{} for _, response := range s.responses { @@ -795,6 +886,44 @@ func (s *integrationSender) recordsForTable(tableName string) []*fivetransdk.Rec return records } +type interruptAfterCopyCheckpointSender struct { + integrationSender + + t *testing.T + keyspaceName string + tableName string + cancel context.CancelFunc + checkpointState *lib.SyncState + checkpointRecordCount int +} + +func (s *interruptAfterCopyCheckpointSender) Send(response *fivetransdk.UpdateResponse) error { + if err := s.integrationSender.Send(response); err != nil { + return err + } + if s.checkpointState != nil { + return nil + } + + checkpoint := response.GetCheckpoint() + if checkpoint == nil { + return nil + } + + var state lib.SyncState + if err := json.Unmarshal([]byte(checkpoint.StateJson), &state); err != nil { + s.t.Fatalf("parse interrupted checkpoint state: %v", err) + } + if !integrationStateHasLastKnownPK(s.t, &state, s.keyspaceName, s.tableName) { + return nil + } + + s.checkpointState = &state + s.checkpointRecordCount = s.integrationSender.recordCountForTable(s.tableName) + s.cancel() + return nil +} + func applyIntegrationRecords(t *testing.T, rows map[int64]map[string]any, records []*fivetransdk.Record) { t.Helper() @@ -904,6 +1033,27 @@ func assertIntegrationRows(t *testing.T, got, want map[int64]map[string]any) { } } +func assertIntegrationRowsExactly(t *testing.T, got, want map[int64]map[string]any) { + t.Helper() + + if len(got) != len(want) { + t.Fatalf("unexpected row count: want %d, got %d", len(want), len(got)) + } + for id, wantRow := range want { + if !reflect.DeepEqual(got[id], wantRow) { + gotJSON, _ := json.Marshal(got[id]) + wantJSON, _ := json.Marshal(wantRow) + t.Fatalf("unexpected row %d\nwant: %s\ngot: %s", id, wantJSON, gotJSON) + } + } + for id := range got { + if _, ok := want[id]; !ok { + gotJSON, _ := json.Marshal(got[id]) + t.Fatalf("unexpected extra row %d: %s", id, gotJSON) + } + } +} + func assertIntegrationStrings(t *testing.T, got, want []string) { t.Helper() @@ -983,6 +1133,16 @@ func assertIntegrationStateShards(t *testing.T, state *lib.SyncState, keyspaceNa } } +func assertIntegrationStateHasNoLastKnownPK(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) { + t.Helper() + + for shard, cursor := range integrationStateTableCursors(t, state, keyspaceName, tableName) { + if cursor.LastKnownPk != nil { + t.Fatalf("state still has LastKnownPk for shard %s in table %s", shard, tableName) + } + } +} + func assertIntegrationRowsOnEveryShard(t *testing.T, psc lib.PlanetScaleSource, tableName string, shards []string) { t.Helper() @@ -1007,6 +1167,68 @@ func assertIntegrationRowsOnEveryShard(t *testing.T, psc lib.PlanetScaleSource, } } +func integrationStateHasLastKnownPK(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) bool { + t.Helper() + + for _, cursor := range integrationStateTableCursors(t, state, keyspaceName, tableName) { + if cursor.LastKnownPk != nil { + return true + } + } + return false +} + +func integrationLastKnownPKID(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) int64 { + t.Helper() + + for shard, cursor := range integrationStateTableCursors(t, state, keyspaceName, tableName) { + if cursor.LastKnownPk == nil { + continue + } + result := sqltypes.Proto3ToResult(cursor.LastKnownPk) + if len(result.Rows) != 1 || len(result.Rows[0]) != 1 { + t.Fatalf("unexpected LastKnownPk shape for shard %s in table %s: %+v", shard, tableName, cursor.LastKnownPk) + } + id, err := result.Rows[0][0].ToInt64() + if err != nil { + t.Fatalf("parse LastKnownPk id for shard %s in table %s: %v", shard, tableName, err) + } + return id + } + t.Fatalf("state missing LastKnownPk for table %s", tableName) + return 0 +} + +func integrationStateTableCursors(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) map[string]*psdbconnect.TableCursor { + t.Helper() + + if state == nil { + t.Fatalf("sync state is nil") + } + keyspace, ok := state.Keyspaces[keyspaceName] + if !ok { + t.Fatalf("state missing keyspace %s: %+v", keyspaceName, state.Keyspaces) + } + streamName := keyspaceName + ":" + tableName + stream, ok := keyspace.Streams[streamName] + if !ok { + t.Fatalf("state missing stream %s: %+v", streamName, keyspace.Streams) + } + + cursors := map[string]*psdbconnect.TableCursor{} + for shard, serializedCursor := range stream.Shards { + if serializedCursor == nil { + t.Fatalf("state shard %s in stream %s has nil cursor", shard, streamName) + } + cursor, err := serializedCursor.SerializedCursorToTableCursor() + if err != nil { + t.Fatalf("deserialize cursor for shard %s in stream %s: %v", shard, streamName, err) + } + cursors[shard] = cursor + } + return cursors +} + func insertIntegrationLoadRows(t *testing.T, ctx context.Context, db *sql.DB, tableName string, startID, count int64, version int, rows map[int64]map[string]any) { t.Helper() @@ -1043,6 +1265,42 @@ func insertIntegrationLoadRows(t *testing.T, ctx context.Context, db *sql.DB, ta } } +func insertIntegrationResumeRows(t *testing.T, ctx context.Context, db *sql.DB, tableName string, startID, count int64, version int, rows map[int64]map[string]any) { + t.Helper() + + const batchSize = int64(50) + for offset := int64(0); offset < count; { + n := batchSize + if remaining := count - offset; remaining < n { + n = remaining + } + + placeholders := make([]string, 0, int(n)) + args := make([]any, 0, int(n)*4) + for i := int64(0); i < n; i++ { + id := startID + offset + i + payload := integrationResumePayload(version, id) + active := id%2 == 0 + + placeholders = append(placeholders, "(?, ?, ?, ?)") + args = append(args, id, version, payload, integrationBoolInt(active)) + rows[id] = map[string]any{ + "id": id, + "version": int64(version), + "payload": payload, + "active": active, + } + } + + mustExec(t, ctx, db, fmt.Sprintf( + "insert into %s (id, version, payload, active) values %s", + quoteIdent(tableName), + strings.Join(placeholders, ", "), + ), args...) + offset += n + } +} + func updateIntegrationLoadRow(t *testing.T, ctx context.Context, db *sql.DB, tableName string, id int64, version int, active bool, rows map[int64]map[string]any) { t.Helper() @@ -1060,6 +1318,23 @@ func updateIntegrationLoadRow(t *testing.T, ctx context.Context, db *sql.DB, tab } } +func updateIntegrationResumeRow(t *testing.T, ctx context.Context, db *sql.DB, tableName string, id int64, version int, active bool, rows map[int64]map[string]any) { + t.Helper() + + payload := integrationResumePayload(version, id) + mustExec(t, ctx, db, fmt.Sprintf( + "update %s set version = ?, payload = ?, active = ? where id = ?", + quoteIdent(tableName), + ), version, payload, integrationBoolInt(active), id) + + rows[id] = map[string]any{ + "id": id, + "version": int64(version), + "payload": payload, + "active": active, + } +} + func deleteIntegrationLoadRow(t *testing.T, ctx context.Context, db *sql.DB, tableName string, id int64, rows map[int64]map[string]any) { t.Helper() @@ -1098,6 +1373,10 @@ func integrationPayload(version int, id int64) string { return fmt.Sprintf("payload-%03d-%06d", version, id) } +func integrationResumePayload(version int, id int64) string { + return strings.Repeat(fmt.Sprintf("resume-payload-%03d-%06d-", version, id), 64) +} + func integrationBoolInt(value bool) int { if value { return 1 diff --git a/lib/integration_test_hooks.go b/lib/integration_test_hooks.go new file mode 100644 index 0000000..5c04dbe --- /dev/null +++ b/lib/integration_test_hooks.go @@ -0,0 +1,18 @@ +//go:build integration + +package lib + +import "time" + +func SetCursorCheckpointPolicyForIntegrationTest(rows int, interval time.Duration) func() { + previousRows := cursorCheckpointRows + previousInterval := cursorCheckpointInterval + + cursorCheckpointRows = rows + cursorCheckpointInterval = interval + + return func() { + cursorCheckpointRows = previousRows + cursorCheckpointInterval = previousInterval + } +} From 2748e7d706f2b317e1e8e85ed0671660c63488c3 Mon Sep 17 00:00:00 2001 From: Nick Van Wiggeren Date: Tue, 16 Jun 2026 23:07:09 +0000 Subject: [PATCH 6/7] Use direct Vitess VStream --- lib/connect_client.go | 684 +++++++++++++++++++++++++++++-------- lib/connect_client_test.go | 106 ++++++ lib/test_types.go | 32 ++ 3 files changed, 688 insertions(+), 134 deletions(-) diff --git a/lib/connect_client.go b/lib/connect_client.go index 270e0fc..973cb16 100644 --- a/lib/connect_client.go +++ b/lib/connect_client.go @@ -9,13 +9,18 @@ import ( "strings" "time" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + vtgateservicepb "vitess.io/vitess/go/vt/proto/vtgateservice" "github.com/pkg/errors" psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" "github.com/planetscale/psdb/auth" grpcclient "github.com/planetscale/psdb/core/pool" clientoptions "github.com/planetscale/psdb/core/pool/options" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -61,8 +66,13 @@ func NewConnectClient(mysqlAccess *MysqlClient) ConnectClient { // It uses the mysql interface provided by PlanetScale for all schema/shard/tablet discovery and // the grpc API for incrementally syncing rows from PlanetScale. type connectClient struct { - clientFn func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) - Mysql *MysqlClient + clientFn func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) + vstreamClientFn func(ctx context.Context, ps PlanetScaleSource) (vstreamClient, error) + Mysql *MysqlClient +} + +type vstreamClient interface { + VStream(ctx context.Context, in *vtgatepb.VStreamRequest, opts ...grpc.CallOption) (vtgateservicepb.Vitess_VStreamClient, error) } func (p connectClient) ListShards(ctx context.Context, ps PlanetScaleSource) ([]string, error) { @@ -264,49 +274,23 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam var ( err error - client psdbconnect.ConnectClient + client vstreamClient + close func() ) preamble := fmt.Sprintf("[%v:%v shard:%v tabletType:%s] ", ps.Database, tableName, tc.Shard, tabletType) - if p.clientFn == nil { - conn, err := grpcclient.Dial( - ctx, ps.Host, - clientoptions.WithDefaultTLSConfig(), - clientoptions.WithCompression(true), - clientoptions.WithConnectionPool(1), - clientoptions.WithExtraCallOption( - auth.NewBasicAuth(ps.Username, ps.Password).CallOption(), - ), - ) - if err != nil { - return tc, err - } - defer conn.Close() - client = psdbconnect.NewConnectClient(conn) - } else { - client, err = p.clientFn(ctx, ps) - if err != nil { - return tc, err - } + client, close, err = p.newVStreamClient(ctx, ps) + if err != nil { + return tc, err } + defer close() syncStartCursor := cloneTableCursor(tc) logger.Info(fmt.Sprintf("%sSyncing with cursor position : [%v], using last known PK : %v, stop cursor is : [%v]", preamble, tc.Position, tc.LastKnownPk != nil, stopPosition)) - sReq := &psdbconnect.SyncRequest{ - TableName: tableName, - Cursor: tc, - TabletType: tabletType, - Columns: columns, - IncludeUpdates: true, - IncludeInserts: true, - IncludeDeletes: true, - Cells: []string{"planetscale_operator_default"}, - } - - c, err := client.Sync(ctx, sReq) + c, err := client.VStream(ctx, buildVStreamRequest(tableName, columns, tc, tabletType)) if err != nil { return tc, err } @@ -315,81 +299,532 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam watchForVgGtidChange := false recordsSinceCheckpoint := 0 lastCheckpoint := time.Now() + fieldsByTable := map[string][]*query.Field{} for { - res, err := c.Recv() if err != nil { return tc, err } - if onResult != nil { - for _, insertedRow := range res.Result { - qr := sqltypes.Proto3ToResult(insertedRow) - for _, row := range qr.Rows { - sqlResult := &sqltypes.Result{ - Fields: insertedRow.Fields, - } - sqlResult.Rows = append(sqlResult.Rows, row) - if err := onResult(sqlResult, OpType_Insert); err != nil { - return syncStartCursor, status.Error(codes.Internal, "unable to serialize row") - } - recordsSinceCheckpoint++ - } - } - - for _, deletedRow := range res.Deletes { - qr := sqltypes.Proto3ToResult(deletedRow.Result) - for _, row := range qr.Rows { - sqlResult := &sqltypes.Result{ - Fields: deletedRow.Result.Fields, - } - sqlResult.Rows = append(sqlResult.Rows, row) - if err := onResult(sqlResult, OpType_Delete); err != nil { - return syncStartCursor, status.Error(codes.Internal, "unable to serialize row") - } - recordsSinceCheckpoint++ - } + copyCompleted := false + for _, event := range res.Events { + var ( + eventRecords int + eventCopyCompleted bool + ) + tc, eventRecords, eventCopyCompleted, err = handleVStreamEvent(tableName, tc, event, fieldsByTable, onResult, onUpdate) + if err != nil { + return syncStartCursor, err } + recordsSinceCheckpoint += eventRecords + copyCompleted = copyCompleted || eventCopyCompleted } - if onUpdate != nil { - for _, update := range res.Updates { - updatedRow := &UpdatedRow{ - Before: serializeQueryResult(update.Before), - After: serializeQueryResult(update.After), - } - if err := onUpdate(updatedRow); err != nil { - return syncStartCursor, status.Error(codes.Internal, "unable to serialize update") - } - recordsSinceCheckpoint++ + // A single VGTID can appear in multiple ordered responses. Once we reach + // the desired stop VGTID, keep reading while the cursor stays there so all + // rows for that position are processed, then stop when a newer VGTID arrives. + watchForVgGtidChange = watchForVgGtidChange || tc.Position == stopPosition + + if shouldCheckpointCursor(tc, recordsSinceCheckpoint, lastCheckpoint) && onCursor != nil { + if err := onCursor(tc); err != nil { + return tc, status.Error(codes.Internal, "unable to serialize cursor") } + recordsSinceCheckpoint = 0 + lastCheckpoint = time.Now() } - if res.Cursor != nil { - tc = res.Cursor - if shouldCheckpointCursor(tc, recordsSinceCheckpoint, lastCheckpoint) && onCursor != nil { + if copyCompleted || (watchForVgGtidChange && tc.LastKnownPk == nil && tc.Position != stopPosition) { + if onCursor != nil { if err := onCursor(tc); err != nil { return tc, status.Error(codes.Internal, "unable to serialize cursor") } - recordsSinceCheckpoint = 0 - lastCheckpoint = time.Now() } + return tc, io.EOF } + } +} - // A single VGTID can appear in multiple ordered responses. Once we reach - // the desired stop VGTID, keep reading while the cursor stays there so all - // rows for that position are processed, then stop when a newer VGTID arrives. - watchForVgGtidChange = watchForVgGtidChange || tc.Position == stopPosition +func (p connectClient) newVStreamClient(ctx context.Context, ps PlanetScaleSource) (vstreamClient, func(), error) { + if p.vstreamClientFn != nil { + client, err := p.vstreamClientFn(ctx, ps) + return client, func() {}, err + } + if p.clientFn != nil { + client, err := p.clientFn(ctx, ps) + if err != nil { + return nil, func() {}, err + } + return connectSyncCompatVStreamClient{client: client}, func() {}, nil + } - if watchForVgGtidChange && tc.Position != stopPosition { - if onCursor != nil { - if err := onCursor(tc); err != nil { - return tc, status.Error(codes.Internal, "unable to serialize cursor") + conn, err := grpcclient.Dial( + ctx, ps.Host, + clientoptions.WithDefaultTLSConfig(), + clientoptions.WithCompression(true), + clientoptions.WithConnectionPool(1), + clientoptions.WithExtraCallOption( + auth.NewBasicAuth(ps.Username, ps.Password).CallOption(), + ), + ) + if err != nil { + return nil, func() {}, err + } + return vtgateservicepb.NewVitessClient(conn), func() { _ = conn.Close() }, nil +} + +type connectSyncCompatVStreamClient struct { + client psdbconnect.ConnectClient +} + +func (c connectSyncCompatVStreamClient) VStream(ctx context.Context, in *vtgatepb.VStreamRequest, opts ...grpc.CallOption) (vtgateservicepb.Vitess_VStreamClient, error) { + syncReq := syncRequestFromVStreamRequest(in) + stream, err := c.client.Sync(ctx, syncReq, opts...) + if err != nil { + return nil, err + } + return &connectSyncCompatVStream{ + tableName: syncReq.TableName, + stream: stream, + }, nil +} + +type connectSyncCompatVStream struct { + tableName string + stream psdbconnect.Connect_SyncClient + grpc.ClientStream +} + +func (s *connectSyncCompatVStream) Recv() (*vtgatepb.VStreamResponse, error) { + res, err := s.stream.Recv() + if err != nil { + return nil, err + } + return vstreamResponseFromSyncResponse(s.tableName, res), nil +} + +func syncRequestFromVStreamRequest(req *vtgatepb.VStreamRequest) *psdbconnect.SyncRequest { + tableName := tableNameFromVStreamRequest(req) + return &psdbconnect.SyncRequest{ + TableName: tableName, + Cursor: tableCursorFromVStreamRequest(req, tableName), + TabletType: fromVStreamTabletType(req.GetTabletType()), + IncludeUpdates: true, + IncludeInserts: true, + IncludeDeletes: true, + Columns: columnsFromVStreamRequest(req), + Cells: cellsFromVStreamRequest(req), + } +} + +func tableNameFromVStreamRequest(req *vtgatepb.VStreamRequest) string { + for _, rule := range req.GetFilter().GetRules() { + if rule.GetMatch() != "" { + return rule.GetMatch() + } + } + return "" +} + +func tableCursorFromVStreamRequest(req *vtgatepb.VStreamRequest, tableName string) *psdbconnect.TableCursor { + cursor := &psdbconnect.TableCursor{} + shardGtids := req.GetVgtid().GetShardGtids() + if len(shardGtids) == 0 { + return cursor + } + shardGtid := shardGtids[0] + cursor.Shard = shardGtid.Shard + cursor.Keyspace = shardGtid.Keyspace + cursor.Position = shardGtid.Gtid + cursor.LastKnownPk = lastKnownPKForTable(shardGtid.TablePKs, tableName) + return cursor +} + +func columnsFromVStreamRequest(req *vtgatepb.VStreamRequest) []string { + for _, rule := range req.GetFilter().GetRules() { + columns := columnsFromVStreamFilter(rule.GetFilter()) + if columns != nil { + return columns + } + } + return nil +} + +func columnsFromVStreamFilter(filter string) []string { + upper := strings.ToUpper(filter) + fromIdx := strings.LastIndex(upper, " FROM ") + if !strings.HasPrefix(upper, "SELECT ") || fromIdx < len("SELECT ") { + return nil + } + columnList := strings.TrimSpace(filter[len("SELECT "):fromIdx]) + if columnList == "" || columnList == "*" { + return nil + } + parts := strings.Split(columnList, ",") + columns := make([]string, 0, len(parts)) + for _, part := range parts { + column := unquoteVStreamIdentifier(strings.TrimSpace(part)) + if column != "" { + columns = append(columns, column) + } + } + if len(columns) == 0 { + return nil + } + return columns +} + +func unquoteVStreamIdentifier(identifier string) string { + if len(identifier) >= 2 && identifier[0] == '`' && identifier[len(identifier)-1] == '`' { + return strings.ReplaceAll(identifier[1:len(identifier)-1], "``", "`") + } + return identifier +} + +func cellsFromVStreamRequest(req *vtgatepb.VStreamRequest) []string { + cells := req.GetFlags().GetCells() + if cells == "" { + return nil + } + return strings.Split(cells, ",") +} + +func vstreamResponseFromSyncResponse(tableName string, res *psdbconnect.SyncResponse) *vtgatepb.VStreamResponse { + events := []*binlogdatapb.VEvent{} + for _, result := range res.GetResult() { + events = append(events, vstreamRowEventsFromQueryResult(tableName, result, func(row *query.Row) *binlogdatapb.RowChange { + return &binlogdatapb.RowChange{After: row} + })...) + } + for _, deleted := range res.GetDeletes() { + events = append(events, vstreamRowEventsFromQueryResult(tableName, deleted.GetResult(), func(row *query.Row) *binlogdatapb.RowChange { + return &binlogdatapb.RowChange{Before: row} + })...) + } + for _, updated := range res.GetUpdates() { + events = append(events, vstreamRowEventsFromUpdate(tableName, updated)...) + } + if res.GetCursor() != nil { + events = append(events, vstreamVGtidEventFromCursor(tableName, res.GetCursor())) + } + return &vtgatepb.VStreamResponse{Events: events} +} + +func vstreamRowEventsFromQueryResult(tableName string, result *query.QueryResult, rowChange func(*query.Row) *binlogdatapb.RowChange) []*binlogdatapb.VEvent { + if result == nil { + return nil + } + events := []*binlogdatapb.VEvent{} + if len(result.Fields) > 0 { + events = append(events, &binlogdatapb.VEvent{ + Type: binlogdatapb.VEventType_FIELD, + FieldEvent: &binlogdatapb.FieldEvent{ + TableName: tableName, + Fields: result.Fields, + }, + }) + } + changes := make([]*binlogdatapb.RowChange, 0, len(result.Rows)) + for _, row := range result.Rows { + changes = append(changes, rowChange(row)) + } + if len(changes) > 0 { + events = append(events, &binlogdatapb.VEvent{ + Type: binlogdatapb.VEventType_ROW, + RowEvent: &binlogdatapb.RowEvent{ + TableName: tableName, + RowChanges: changes, + }, + }) + } + return events +} + +func vstreamRowEventsFromUpdate(tableName string, updated *psdbconnect.UpdatedRow) []*binlogdatapb.VEvent { + if updated == nil || (updated.Before == nil && updated.After == nil) { + return nil + } + fields := updated.GetAfter().GetFields() + if len(fields) == 0 { + fields = updated.GetBefore().GetFields() + } + events := []*binlogdatapb.VEvent{} + if len(fields) > 0 { + events = append(events, &binlogdatapb.VEvent{ + Type: binlogdatapb.VEventType_FIELD, + FieldEvent: &binlogdatapb.FieldEvent{ + TableName: tableName, + Fields: fields, + }, + }) + } + + beforeRows := updated.GetBefore().GetRows() + afterRows := updated.GetAfter().GetRows() + rowCount := len(beforeRows) + if len(afterRows) > rowCount { + rowCount = len(afterRows) + } + changes := make([]*binlogdatapb.RowChange, 0, rowCount) + for i := 0; i < rowCount; i++ { + change := &binlogdatapb.RowChange{} + if i < len(beforeRows) { + change.Before = beforeRows[i] + } + if i < len(afterRows) { + change.After = afterRows[i] + } + changes = append(changes, change) + } + if len(changes) > 0 { + events = append(events, &binlogdatapb.VEvent{ + Type: binlogdatapb.VEventType_ROW, + RowEvent: &binlogdatapb.RowEvent{ + TableName: tableName, + RowChanges: changes, + }, + }) + } + return events +} + +func vstreamVGtidEventFromCursor(tableName string, cursor *psdbconnect.TableCursor) *binlogdatapb.VEvent { + shardGtid := &binlogdatapb.ShardGtid{ + Keyspace: cursor.Keyspace, + Shard: cursor.Shard, + Gtid: cursor.Position, + } + if cursor.LastKnownPk != nil { + shardGtid.TablePKs = []*binlogdatapb.TableLastPK{{ + TableName: tableName, + Lastpk: cursor.LastKnownPk, + }} + } + return &binlogdatapb.VEvent{ + Type: binlogdatapb.VEventType_VGTID, + Vgtid: &binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{shardGtid}, + }, + } +} + +func buildVStreamRequest(tableName string, columns []string, tc *psdbconnect.TableCursor, tabletType psdbconnect.TabletType) *vtgatepb.VStreamRequest { + shardGtid := &binlogdatapb.ShardGtid{ + Keyspace: tc.Keyspace, + Shard: tc.Shard, + Gtid: tc.Position, + } + if tc.LastKnownPk != nil { + shardGtid.TablePKs = []*binlogdatapb.TableLastPK{{ + TableName: tableName, + Lastpk: tc.LastKnownPk, + }} + } + + return &vtgatepb.VStreamRequest{ + TabletType: toVStreamTabletType(tabletType), + Vgtid: &binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{shardGtid}, + }, + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{{ + Match: tableName, + Filter: "SELECT " + joinVStreamColumns(columns) + " FROM " + quoteVStreamIdentifier(tableName), + }}, + }, + Flags: &vtgatepb.VStreamFlags{ + MinimizeSkew: true, + Cells: "planetscale_operator_default", + }, + } +} + +func toVStreamTabletType(tabletType psdbconnect.TabletType) topodatapb.TabletType { + switch tabletType { + case psdbconnect.TabletType_replica: + return topodatapb.TabletType_REPLICA + case psdbconnect.TabletType_batch: + return topodatapb.TabletType_RDONLY + default: + return topodatapb.TabletType_PRIMARY + } +} + +func fromVStreamTabletType(tabletType topodatapb.TabletType) psdbconnect.TabletType { + switch tabletType { + case topodatapb.TabletType_REPLICA: + return psdbconnect.TabletType_replica + case topodatapb.TabletType_RDONLY: + return psdbconnect.TabletType_batch + default: + return psdbconnect.TabletType_primary + } +} + +func joinVStreamColumns(columns []string) string { + if len(columns) == 0 { + return "*" + } + quoted := make([]string, 0, len(columns)) + for _, column := range columns { + column = strings.TrimSpace(column) + if column == "" { + continue + } + quoted = append(quoted, quoteVStreamIdentifier(column)) + } + if len(quoted) == 0 { + return "*" + } + return strings.Join(quoted, ",") +} + +func quoteVStreamIdentifier(identifier string) string { + return "`" + strings.ReplaceAll(identifier, "`", "``") + "`" +} + +func handleVStreamEvent(tableName string, cursor *psdbconnect.TableCursor, event *binlogdatapb.VEvent, fieldsByTable map[string][]*query.Field, onResult OnResult, onUpdate OnUpdate) (*psdbconnect.TableCursor, int, bool, error) { + switch event.Type { + case binlogdatapb.VEventType_FIELD: + if event.FieldEvent != nil { + cacheVStreamFields(fieldsByTable, event.FieldEvent.TableName, event.FieldEvent.Fields) + } + case binlogdatapb.VEventType_ROW: + count, err := handleVStreamRows(tableName, event.RowEvent, fieldsByTable, onResult, onUpdate) + if err != nil { + return cursor, 0, false, err + } + return cursor, count, false, nil + case binlogdatapb.VEventType_VGTID: + return tableCursorFromVGtid(cursor, event.Vgtid, tableName, fieldsByTable), 0, false, nil + case binlogdatapb.VEventType_LASTPK: + return tableCursorFromLastPK(cursor, event.LastPKEvent, tableName, fieldsByTable), 0, false, nil + case binlogdatapb.VEventType_COPY_COMPLETED: + return cursor, 0, true, nil + } + return cursor, 0, false, nil +} + +func handleVStreamRows(tableName string, rowEvent *binlogdatapb.RowEvent, fieldsByTable map[string][]*query.Field, onResult OnResult, onUpdate OnUpdate) (int, error) { + if rowEvent == nil || normalizeVStreamTableName(rowEvent.TableName) != tableName { + return 0, nil + } + fields := vstreamFieldsForTable(fieldsByTable, rowEvent.TableName) + if len(fields) == 0 { + return 0, status.Error(codes.Internal, fmt.Sprintf("missing VStream fields for table %s", rowEvent.TableName)) + } + + records := 0 + for _, change := range rowEvent.RowChanges { + switch { + case change.After != nil && change.Before == nil: + if onResult != nil { + if err := onResult(vstreamRowResult(fields, change.After), OpType_Insert); err != nil { + return records, status.Error(codes.Internal, "unable to serialize row") } } - return tc, io.EOF + records++ + case change.After == nil && change.Before != nil: + if onResult != nil { + if err := onResult(vstreamRowResult(fields, change.Before), OpType_Delete); err != nil { + return records, status.Error(codes.Internal, "unable to serialize row") + } + } + records++ + case change.After != nil && change.Before != nil: + if onUpdate != nil { + if err := onUpdate(&UpdatedRow{ + Before: vstreamRowResult(fields, change.Before), + After: vstreamRowResult(fields, change.After), + }); err != nil { + return records, status.Error(codes.Internal, "unable to serialize update") + } + } + records++ + } + } + return records, nil +} + +func vstreamRowResult(fields []*query.Field, row *query.Row) *sqltypes.Result { + return sqltypes.Proto3ToResult(&query.QueryResult{ + Fields: fields, + Rows: []*query.Row{row}, + }) +} + +func cacheVStreamFields(fieldsByTable map[string][]*query.Field, tableName string, fields []*query.Field) { + fieldsByTable[tableName] = fields + fieldsByTable[normalizeVStreamTableName(tableName)] = fields +} + +func vstreamFieldsForTable(fieldsByTable map[string][]*query.Field, tableName string) []*query.Field { + if fields := fieldsByTable[tableName]; len(fields) > 0 { + return fields + } + return fieldsByTable[normalizeVStreamTableName(tableName)] +} + +func normalizeVStreamTableName(tableName string) string { + if idx := strings.LastIndex(tableName, "."); idx >= 0 { + return tableName[idx+1:] + } + return tableName +} + +func tableCursorFromVGtid(cursor *psdbconnect.TableCursor, vgtid *binlogdatapb.VGtid, tableName string, fieldsByTable map[string][]*query.Field) *psdbconnect.TableCursor { + if vgtid == nil { + return cursor + } + for _, shardGtid := range vgtid.ShardGtids { + if shardGtid.Keyspace != cursor.Keyspace || shardGtid.Shard != cursor.Shard { + continue + } + next := cloneTableCursor(cursor) + if shardGtid.Gtid != "" { + next.Position = shardGtid.Gtid + } + next.LastKnownPk = completeLastKnownPKFields( + lastKnownPKForTable(shardGtid.TablePKs, tableName), + vstreamFieldsForTable(fieldsByTable, tableName), + ) + return next + } + return cursor +} + +func tableCursorFromLastPK(cursor *psdbconnect.TableCursor, event *binlogdatapb.LastPKEvent, tableName string, fieldsByTable map[string][]*query.Field) *psdbconnect.TableCursor { + if event == nil || event.TableLastPK == nil || normalizeVStreamTableName(event.TableLastPK.TableName) != tableName { + return cursor + } + next := cloneTableCursor(cursor) + if event.Completed { + next.LastKnownPk = nil + return next + } + next.LastKnownPk = completeLastKnownPKFields(event.TableLastPK.Lastpk, vstreamFieldsForTable(fieldsByTable, event.TableLastPK.TableName)) + return next +} + +func lastKnownPKForTable(tablePKs []*binlogdatapb.TableLastPK, tableName string) *query.QueryResult { + for _, tablePK := range tablePKs { + if tablePK != nil && normalizeVStreamTableName(tablePK.TableName) == tableName { + return tablePK.Lastpk } } + return nil +} + +func completeLastKnownPKFields(lastPK *query.QueryResult, fields []*query.Field) *query.QueryResult { + if lastPK == nil || len(lastPK.Fields) > 0 || len(fields) == 0 || len(lastPK.Rows) == 0 { + return lastPK + } + fieldCount := len(lastPK.Rows[0].Lengths) + if fieldCount > len(fields) { + return lastPK + } + next := proto.Clone(lastPK).(*query.QueryResult) + next.Fields = fields[:fieldCount] + return next } func tableCursorHasProgress(tc *psdbconnect.TableCursor) bool { @@ -447,61 +882,27 @@ func (p connectClient) filterExistingColumns(ctx context.Context, ps PlanetScale return existingColumns, err } -func serializeQueryResult(result *query.QueryResult) *sqltypes.Result { - qr := sqltypes.Proto3ToResult(result) - var sqlResult *sqltypes.Result - for _, row := range qr.Rows { - sqlResult = &sqltypes.Result{ - Fields: result.Fields, - } - sqlResult.Rows = append(sqlResult.Rows, row) - } - return sqlResult -} - func (p connectClient) getLatestCursorPosition(ctx context.Context, shard, keyspace string, tableName string, ps PlanetScaleSource, tabletType psdbconnect.TabletType) (string, error) { timeout := 45 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var ( err error - client psdbconnect.ConnectClient + client vstreamClient + close func() ) - if p.clientFn == nil { - conn, err := grpcclient.Dial( - ctx, ps.Host, - clientoptions.WithDefaultTLSConfig(), - clientoptions.WithCompression(true), - clientoptions.WithConnectionPool(1), - clientoptions.WithExtraCallOption( - auth.NewBasicAuth(ps.Username, ps.Password).CallOption(), - ), - ) - if err != nil { - return "", err - } - defer conn.Close() - client = psdbconnect.NewConnectClient(conn) - } else { - client, err = p.clientFn(ctx, ps) - if err != nil { - return "", err - } - } - - sReq := &psdbconnect.SyncRequest{ - TableName: tableName, - Cursor: &psdbconnect.TableCursor{ - Shard: shard, - Keyspace: keyspace, - Position: "current", - }, - TabletType: tabletType, - Cells: []string{"planetscale_operator_default"}, + client, close, err = p.newVStreamClient(ctx, ps) + if err != nil { + return "", err } + defer close() - c, err := client.Sync(ctx, sReq) + c, err := client.VStream(ctx, buildVStreamRequest(tableName, nil, &psdbconnect.TableCursor{ + Shard: shard, + Keyspace: keyspace, + Position: "current", + }, tabletType)) if err != nil { return "", err } @@ -512,10 +913,25 @@ func (p connectClient) getLatestCursorPosition(ctx context.Context, shard, keysp return "", err } - if res.Cursor != nil { - return res.Cursor.Position, nil + position := vgtidPositionForShard(res.GetEvents(), keyspace, shard) + if position != "" { + return position, nil + } + } +} + +func vgtidPositionForShard(events []*binlogdatapb.VEvent, keyspace, shard string) string { + for _, event := range events { + if event.GetType() != binlogdatapb.VEventType_VGTID { + continue + } + for _, shardGtid := range event.GetVgtid().GetShardGtids() { + if shardGtid.Keyspace == keyspace && shardGtid.Shard == shard && shardGtid.Gtid != "" { + return shardGtid.Gtid + } } } + return "" } func IsBinlogsExpirationError(err error) bool { diff --git a/lib/connect_client_test.go b/lib/connect_client_test.go index 218fd76..71595f7 100644 --- a/lib/connect_client_test.go +++ b/lib/connect_client_test.go @@ -7,7 +7,11 @@ import ( "testing" "time" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + vtgateservicepb "vitess.io/vitess/go/vt/proto/vtgateservice" "github.com/pkg/errors" psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" @@ -1103,6 +1107,108 @@ func TestRead_BinlogExpirationReturnsResetCursor(t *testing.T) { assert.Equal(t, 3, cc.syncFnInvokedCount) } +func TestSync_DirectVStreamHandlesRawEventsAndCopyCompleted(t *testing.T) { + originalCheckpointRows := cursorCheckpointRows + originalCheckpointInterval := cursorCheckpointInterval + cursorCheckpointRows = 1 + cursorCheckpointInterval = time.Hour + t.Cleanup(func() { + cursorCheckpointRows = originalCheckpointRows + cursorCheckpointInterval = originalCheckpointInterval + }) + + dbl := &dbLogger{} + ped := connectClient{} + testFields := sqltypes.MakeTestFields("id|name", "int64|varbinary") + rows := sqltypes.ResultToProto3(sqltypes.MakeTestResult(testFields, "1|first", "2|second")).Rows + lastPK := testLastKnownPK("2") + lastPK.Fields = nil + startCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + } + copyCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "COPY_GTID", + LastKnownPk: lastPK, + } + doneCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "AFTER_COPY_GTID", + } + + rawStream := &vstreamClientMock{ + responses: []*vtgatepb.VStreamResponse{ + { + Events: []*binlogdatapb.VEvent{ + { + Type: binlogdatapb.VEventType_FIELD, + FieldEvent: &binlogdatapb.FieldEvent{ + TableName: "connect-test.customers", + Fields: testFields, + }, + }, + { + Type: binlogdatapb.VEventType_ROW, + RowEvent: &binlogdatapb.RowEvent{ + TableName: "connect-test.customers", + RowChanges: []*binlogdatapb.RowChange{ + {After: rows[0]}, + {After: rows[1]}, + }, + }, + }, + vstreamVGtidEventFromCursor("connect-test.customers", copyCursor), + }, + }, + { + Events: []*binlogdatapb.VEvent{ + {Type: binlogdatapb.VEventType_COPY_COMPLETED}, + vstreamVGtidEventFromCursor("connect-test.customers", doneCursor), + }, + }, + }, + } + vc := &vstreamConnectionMock{ + vstreamFn: func(ctx context.Context, in *vtgatepb.VStreamRequest, opts ...grpc.CallOption) (vtgateservicepb.Vitess_VStreamClient, error) { + assert.Equal(t, topodatapb.TabletType_PRIMARY, in.TabletType) + assert.Equal(t, "connect-test", in.Vgtid.ShardGtids[0].Keyspace) + assert.Equal(t, "-", in.Vgtid.ShardGtids[0].Shard) + assert.Empty(t, in.Vgtid.ShardGtids[0].Gtid) + assert.Equal(t, "customers", in.Filter.Rules[0].Match) + assert.Equal(t, "SELECT `id`,`name` FROM `customers`", in.Filter.Rules[0].Filter) + assert.Equal(t, "planetscale_operator_default", in.Flags.Cells) + assert.True(t, in.Flags.MinimizeSkew) + return rawStream, nil + }, + } + ped.vstreamClientFn = func(ctx context.Context, ps PlanetScaleSource) (vstreamClient, error) { + return vc, nil + } + + records := 0 + checkpoints := []*psdbconnect.TableCursor{} + returnedCursor, err := ped.sync(context.Background(), dbl, "customers", []string{"id", "name"}, startCursor, "STOP_GTID", PlanetScaleSource{Database: "connect-test"}, psdbconnect.TabletType_primary, time.Second, func(*sqltypes.Result, Operation) error { + records++ + return nil + }, func(cursor *psdbconnect.TableCursor) error { + checkpoints = append(checkpoints, cloneTableCursor(cursor)) + return nil + }, nil) + + assert.True(t, errors.Is(err, io.EOF)) + assert.True(t, proto.Equal(doneCursor, returnedCursor)) + assert.Equal(t, 2, records) + assert.Equal(t, 1, vc.vstreamFnInvokedCount) + if assert.Len(t, checkpoints, 2) { + assert.True(t, proto.Equal(doneCursor, checkpoints[1])) + assert.NotNil(t, checkpoints[0].LastKnownPk) + assert.Equal(t, testFields[:1], checkpoints[0].LastKnownPk.Fields) + } +} + func TestSync_CheckpointsHistoricalCopyProgress(t *testing.T) { originalCheckpointRows := cursorCheckpointRows originalCheckpointInterval := cursorCheckpointInterval diff --git a/lib/test_types.go b/lib/test_types.go index 6cefc70..06b5355 100644 --- a/lib/test_types.go +++ b/lib/test_types.go @@ -8,6 +8,8 @@ import ( psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" "google.golang.org/grpc" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + vtgateservicepb "vitess.io/vitess/go/vt/proto/vtgateservice" ) type dbLogMessage struct { @@ -63,6 +65,36 @@ func (c *clientConnectionMock) Sync(ctx context.Context, in *psdbconnect.SyncReq return c.syncFn(ctx, in, opts...) } +type vstreamConnectionMock struct { + vstreamFn func(ctx context.Context, in *vtgatepb.VStreamRequest, opts ...grpc.CallOption) (vtgateservicepb.Vitess_VStreamClient, error) + vstreamFnInvoked bool + vstreamFnInvokedCount int +} + +func (c *vstreamConnectionMock) VStream(ctx context.Context, in *vtgatepb.VStreamRequest, opts ...grpc.CallOption) (vtgateservicepb.Vitess_VStreamClient, error) { + c.vstreamFnInvoked = true + c.vstreamFnInvokedCount += 1 + return c.vstreamFn(ctx, in, opts...) +} + +type vstreamClientMock struct { + lastResponseSent int + responses []*vtgatepb.VStreamResponse + recvFn func() (*vtgatepb.VStreamResponse, error) + grpc.ClientStream +} + +func (x *vstreamClientMock) Recv() (*vtgatepb.VStreamResponse, error) { + if x.recvFn != nil { + return x.recvFn() + } + if x.lastResponseSent >= len(x.responses) { + return nil, io.EOF + } + x.lastResponseSent += 1 + return x.responses[x.lastResponseSent-1], nil +} + type ( BuildSchemaFunc func(ctx context.Context, psc PlanetScaleSource, schemaBuilder SchemaBuilder) error PingContextFunc func(context.Context, PlanetScaleSource) error From 4556bb42de57d850e6832f9512d0ee7ea162709c Mon Sep 17 00:00:00 2001 From: Nick Van Wiggeren Date: Tue, 16 Jun 2026 23:34:02 +0000 Subject: [PATCH 7/7] Reject fieldless VStream copy cursors --- lib/connect_client.go | 49 +++++++++++++++++++------------------- lib/connect_client_test.go | 47 +++++++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/lib/connect_client.go b/lib/connect_client.go index 973cb16..6190e12 100644 --- a/lib/connect_client.go +++ b/lib/connect_client.go @@ -695,9 +695,11 @@ func handleVStreamEvent(tableName string, cursor *psdbconnect.TableCursor, event } return cursor, count, false, nil case binlogdatapb.VEventType_VGTID: - return tableCursorFromVGtid(cursor, event.Vgtid, tableName, fieldsByTable), 0, false, nil + next, err := tableCursorFromVGtid(cursor, event.Vgtid, tableName) + return next, 0, false, err case binlogdatapb.VEventType_LASTPK: - return tableCursorFromLastPK(cursor, event.LastPKEvent, tableName, fieldsByTable), 0, false, nil + next, err := tableCursorFromLastPK(cursor, event.LastPKEvent, tableName) + return next, 0, false, err case binlogdatapb.VEventType_COPY_COMPLETED: return cursor, 0, true, nil } @@ -771,9 +773,9 @@ func normalizeVStreamTableName(tableName string) string { return tableName } -func tableCursorFromVGtid(cursor *psdbconnect.TableCursor, vgtid *binlogdatapb.VGtid, tableName string, fieldsByTable map[string][]*query.Field) *psdbconnect.TableCursor { +func tableCursorFromVGtid(cursor *psdbconnect.TableCursor, vgtid *binlogdatapb.VGtid, tableName string) (*psdbconnect.TableCursor, error) { if vgtid == nil { - return cursor + return cursor, nil } for _, shardGtid := range vgtid.ShardGtids { if shardGtid.Keyspace != cursor.Keyspace || shardGtid.Shard != cursor.Shard { @@ -783,26 +785,29 @@ func tableCursorFromVGtid(cursor *psdbconnect.TableCursor, vgtid *binlogdatapb.V if shardGtid.Gtid != "" { next.Position = shardGtid.Gtid } - next.LastKnownPk = completeLastKnownPKFields( - lastKnownPKForTable(shardGtid.TablePKs, tableName), - vstreamFieldsForTable(fieldsByTable, tableName), - ) - return next + next.LastKnownPk = lastKnownPKForTable(shardGtid.TablePKs, tableName) + if err := validateLastKnownPKFields(next.LastKnownPk, tableName); err != nil { + return cursor, err + } + return next, nil } - return cursor + return cursor, nil } -func tableCursorFromLastPK(cursor *psdbconnect.TableCursor, event *binlogdatapb.LastPKEvent, tableName string, fieldsByTable map[string][]*query.Field) *psdbconnect.TableCursor { +func tableCursorFromLastPK(cursor *psdbconnect.TableCursor, event *binlogdatapb.LastPKEvent, tableName string) (*psdbconnect.TableCursor, error) { if event == nil || event.TableLastPK == nil || normalizeVStreamTableName(event.TableLastPK.TableName) != tableName { - return cursor + return cursor, nil } next := cloneTableCursor(cursor) if event.Completed { next.LastKnownPk = nil - return next + return next, nil } - next.LastKnownPk = completeLastKnownPKFields(event.TableLastPK.Lastpk, vstreamFieldsForTable(fieldsByTable, event.TableLastPK.TableName)) - return next + next.LastKnownPk = event.TableLastPK.Lastpk + if err := validateLastKnownPKFields(next.LastKnownPk, tableName); err != nil { + return cursor, err + } + return next, nil } func lastKnownPKForTable(tablePKs []*binlogdatapb.TableLastPK, tableName string) *query.QueryResult { @@ -814,17 +819,11 @@ func lastKnownPKForTable(tablePKs []*binlogdatapb.TableLastPK, tableName string) return nil } -func completeLastKnownPKFields(lastPK *query.QueryResult, fields []*query.Field) *query.QueryResult { - if lastPK == nil || len(lastPK.Fields) > 0 || len(fields) == 0 || len(lastPK.Rows) == 0 { - return lastPK - } - fieldCount := len(lastPK.Rows[0].Lengths) - if fieldCount > len(fields) { - return lastPK +func validateLastKnownPKFields(lastPK *query.QueryResult, tableName string) error { + if lastPK == nil || len(lastPK.Fields) > 0 { + return nil } - next := proto.Clone(lastPK).(*query.QueryResult) - next.Fields = fields[:fieldCount] - return next + return status.Error(codes.Internal, fmt.Sprintf("VStream copy cursor for table %s is missing LastKnownPk field metadata", tableName)) } func tableCursorHasProgress(tc *psdbconnect.TableCursor) bool { diff --git a/lib/connect_client_test.go b/lib/connect_client_test.go index 71595f7..e6a5252 100644 --- a/lib/connect_client_test.go +++ b/lib/connect_client_test.go @@ -1122,7 +1122,6 @@ func TestSync_DirectVStreamHandlesRawEventsAndCopyCompleted(t *testing.T) { testFields := sqltypes.MakeTestFields("id|name", "int64|varbinary") rows := sqltypes.ResultToProto3(sqltypes.MakeTestResult(testFields, "1|first", "2|second")).Rows lastPK := testLastKnownPK("2") - lastPK.Fields = nil startCursor := &psdbconnect.TableCursor{ Shard: "-", Keyspace: "connect-test", @@ -1203,12 +1202,54 @@ func TestSync_DirectVStreamHandlesRawEventsAndCopyCompleted(t *testing.T) { assert.Equal(t, 2, records) assert.Equal(t, 1, vc.vstreamFnInvokedCount) if assert.Len(t, checkpoints, 2) { + assert.True(t, proto.Equal(copyCursor, checkpoints[0])) assert.True(t, proto.Equal(doneCursor, checkpoints[1])) - assert.NotNil(t, checkpoints[0].LastKnownPk) - assert.Equal(t, testFields[:1], checkpoints[0].LastKnownPk.Fields) } } +func TestSync_DirectVStreamRejectsFieldlessLastKnownPK(t *testing.T) { + dbl := &dbLogger{} + ped := connectClient{} + lastPK := testLastKnownPK("2") + lastPK.Fields = nil + startCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + } + copyCursor := &psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "connect-test", + Position: "COPY_GTID", + LastKnownPk: lastPK, + } + + rawStream := &vstreamClientMock{ + responses: []*vtgatepb.VStreamResponse{ + { + Events: []*binlogdatapb.VEvent{ + vstreamVGtidEventFromCursor("customers", copyCursor), + }, + }, + }, + } + vc := &vstreamConnectionMock{ + vstreamFn: func(ctx context.Context, in *vtgatepb.VStreamRequest, opts ...grpc.CallOption) (vtgateservicepb.Vitess_VStreamClient, error) { + return rawStream, nil + }, + } + ped.vstreamClientFn = func(ctx context.Context, ps PlanetScaleSource) (vstreamClient, error) { + return vc, nil + } + + returnedCursor, err := ped.sync(context.Background(), dbl, "customers", []string{"id", "name"}, startCursor, "STOP_GTID", PlanetScaleSource{Database: "connect-test"}, psdbconnect.TabletType_primary, time.Second, nil, nil, nil) + + assert.Error(t, err) + assert.Equal(t, codes.Internal, status.Code(err)) + assert.ErrorContains(t, err, "missing LastKnownPk field metadata") + assert.True(t, proto.Equal(startCursor, returnedCursor)) + assert.Equal(t, 1, vc.vstreamFnInvokedCount) +} + func TestSync_CheckpointsHistoricalCopyProgress(t *testing.T) { originalCheckpointRows := cursorCheckpointRows originalCheckpointInterval := cursorCheckpointInterval