Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 283 additions & 4 deletions cmd/internal/server/handlers/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"vitess.io/vitess/go/sqltypes"
)

func TestIntegrationBasicInsertUpdateDelete(t *testing.T) {
Expand Down Expand Up @@ -354,6 +355,73 @@ func TestIntegrationIntermittentEmptySyncs(t *testing.T) {
assertIntegrationRecordCount(t, idle.recordsForTable(tableName), 0)
}

func TestIntegrationInterruptedCopyResume(t *testing.T) {
restoreCheckpointPolicy := lib.SetCursorCheckpointPolicyForIntegrationTest(1, 0)
defer restoreCheckpointPolicy()

psc := loadIntegrationSource(t)
ctx := context.Background()
tableName := integrationTableName(t)

db := openIntegrationSQL(t, psc)
t.Cleanup(func() {
if _, err := db.ExecContext(context.Background(), "drop table if exists "+quoteIdent(tableName)); err != nil {
t.Logf("drop integration table %s: %v", tableName, err)
}
_ = db.Close()
})

mustExec(t, ctx, db, fmt.Sprintf(`
create table %s (
id bigint primary key,
version int not null,
payload varchar(2048) not null,
active tinyint(1) not null
)`, quoteIdent(tableName)))

columns := []string{"id", "version", "payload", "active"}
expected := map[int64]map[string]any{}
const rowCount = int64(2500)
insertIntegrationResumeRows(t, ctx, db, tableName, 1, rowCount, 0, expected)

interruptCtx, cancelInterrupt := context.WithCancel(context.Background())
interruptSender := &interruptAfterCopyCheckpointSender{
t: t,
keyspaceName: psc.Database,
tableName: tableName,
cancel: cancelInterrupt,
}
_, err := runIntegrationSyncWithSender(t, interruptCtx, psc, tableName, columns, nil, interruptSender)
cancelInterrupt()
if err == nil {
t.Fatalf("expected interrupted copy sync to return an error")
}
if interruptSender.checkpointState == nil {
t.Fatalf("interrupted copy did not checkpoint a LastKnownPk cursor")
}
if interruptSender.checkpointRecordCount == 0 {
t.Fatalf("interrupted copy checkpointed before emitting table records")
}

checkpointedRecords := interruptSender.recordsForTable(tableName)[:interruptSender.checkpointRecordCount]
model := map[int64]map[string]any{}
applyIntegrationRecords(t, model, checkpointedRecords)

lastCopiedID := integrationLastKnownPKID(t, interruptSender.checkpointState, psc.Database, tableName)
if lastCopiedID < 1 || lastCopiedID > rowCount {
t.Fatalf("unexpected LastKnownPk id %d for row count %d", lastCopiedID, rowCount)
}
updateIntegrationResumeRow(t, ctx, db, tableName, lastCopiedID, 1, true, expected)

resumed, finalState := runIntegrationSync(t, psc, tableName, columns, interruptSender.checkpointState)
applyIntegrationRecords(t, model, resumed.recordsForTable(tableName))
assertIntegrationRowsExactly(t, model, expected)
assertIntegrationStateHasNoLastKnownPK(t, finalState, psc.Database, tableName)

idle, _ := runIntegrationSync(t, psc, tableName, columns, finalState)
assertIntegrationRecordCount(t, idle.recordsForTable(tableName), 0)
}

func TestIntegrationRepeatedSyncBursts(t *testing.T) {
psc := loadIntegrationSource(t)
ctx := context.Background()
Expand Down Expand Up @@ -684,6 +752,19 @@ func runIntegrationSync(t *testing.T, psc lib.PlanetScaleSource, tableName strin
func runIntegrationSyncWithError(t *testing.T, psc lib.PlanetScaleSource, tableName string, columns []string, state *lib.SyncState) (*integrationSender, *lib.SyncState, error) {
t.Helper()

sender := &integrationSender{}
state, err := runIntegrationSyncWithSender(t, context.Background(), psc, tableName, columns, state, sender)
return sender, state, err
}

type integrationStatefulSender interface {
Send(*fivetransdk.UpdateResponse) error
latestState(*testing.T) (*lib.SyncState, bool)
}

func runIntegrationSyncWithSender(t *testing.T, ctx context.Context, psc lib.PlanetScaleSource, tableName string, columns []string, state *lib.SyncState, sender integrationStatefulSender) (*lib.SyncState, error) {
t.Helper()

mysqlClient, err := lib.NewMySQL(&psc)
if err != nil {
t.Fatalf("create mysql client: %v", err)
Expand Down Expand Up @@ -721,16 +802,15 @@ func runIntegrationSyncWithError(t *testing.T, psc lib.PlanetScaleSource, tableN
}

selection := integrationSelection(psc.Database, tableName, columns)
sender := &integrationSender{}
logger := NewSchemaAwareSerializer(sender, "integration", psc.TreatTinyIntAsBoolean, sourceSchema.SchemaList, sourceSchema.EnumsAndSets)
syncer := &Sync{}
if err := syncer.Handle(&psc, &connectClient, logger, state, selection); err != nil {
return sender, state, err
if err := syncer.Handle(ctx, &psc, &connectClient, logger, state, selection); err != nil {
return state, err
}
if checkpoint, ok := sender.latestState(t); ok {
state = checkpoint
}
return sender, state, nil
return state, nil
}

func integrationSelection(schemaName, tableName string, columns []string) *fivetransdk.Selection_WithSchema {
Expand Down Expand Up @@ -783,6 +863,17 @@ func (s *integrationSender) latestState(t *testing.T) (*lib.SyncState, bool) {
return nil, false
}

func (s *integrationSender) recordCountForTable(tableName string) int {
count := 0
for _, response := range s.responses {
record := response.GetRecord()
if record != nil && record.TableName == tableName {
count++
}
}
return count
}

func (s *integrationSender) recordsForTable(tableName string) []*fivetransdk.Record {
records := []*fivetransdk.Record{}
for _, response := range s.responses {
Expand All @@ -795,6 +886,44 @@ func (s *integrationSender) recordsForTable(tableName string) []*fivetransdk.Rec
return records
}

type interruptAfterCopyCheckpointSender struct {
integrationSender

t *testing.T
keyspaceName string
tableName string
cancel context.CancelFunc
checkpointState *lib.SyncState
checkpointRecordCount int
}

func (s *interruptAfterCopyCheckpointSender) Send(response *fivetransdk.UpdateResponse) error {
if err := s.integrationSender.Send(response); err != nil {
return err
}
if s.checkpointState != nil {
return nil
}

checkpoint := response.GetCheckpoint()
if checkpoint == nil {
return nil
}

var state lib.SyncState
if err := json.Unmarshal([]byte(checkpoint.StateJson), &state); err != nil {
s.t.Fatalf("parse interrupted checkpoint state: %v", err)
}
if !integrationStateHasLastKnownPK(s.t, &state, s.keyspaceName, s.tableName) {
return nil
}

s.checkpointState = &state
s.checkpointRecordCount = s.integrationSender.recordCountForTable(s.tableName)
s.cancel()
return nil
}

func applyIntegrationRecords(t *testing.T, rows map[int64]map[string]any, records []*fivetransdk.Record) {
t.Helper()

Expand Down Expand Up @@ -904,6 +1033,27 @@ func assertIntegrationRows(t *testing.T, got, want map[int64]map[string]any) {
}
}

func assertIntegrationRowsExactly(t *testing.T, got, want map[int64]map[string]any) {
t.Helper()

if len(got) != len(want) {
t.Fatalf("unexpected row count: want %d, got %d", len(want), len(got))
}
for id, wantRow := range want {
if !reflect.DeepEqual(got[id], wantRow) {
gotJSON, _ := json.Marshal(got[id])
wantJSON, _ := json.Marshal(wantRow)
t.Fatalf("unexpected row %d\nwant: %s\ngot: %s", id, wantJSON, gotJSON)
}
}
for id := range got {
if _, ok := want[id]; !ok {
gotJSON, _ := json.Marshal(got[id])
t.Fatalf("unexpected extra row %d: %s", id, gotJSON)
}
}
}

func assertIntegrationStrings(t *testing.T, got, want []string) {
t.Helper()

Expand Down Expand Up @@ -983,6 +1133,16 @@ func assertIntegrationStateShards(t *testing.T, state *lib.SyncState, keyspaceNa
}
}

func assertIntegrationStateHasNoLastKnownPK(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) {
t.Helper()

for shard, cursor := range integrationStateTableCursors(t, state, keyspaceName, tableName) {
if cursor.LastKnownPk != nil {
t.Fatalf("state still has LastKnownPk for shard %s in table %s", shard, tableName)
}
}
}

func assertIntegrationRowsOnEveryShard(t *testing.T, psc lib.PlanetScaleSource, tableName string, shards []string) {
t.Helper()

Expand All @@ -1007,6 +1167,68 @@ func assertIntegrationRowsOnEveryShard(t *testing.T, psc lib.PlanetScaleSource,
}
}

func integrationStateHasLastKnownPK(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) bool {
t.Helper()

for _, cursor := range integrationStateTableCursors(t, state, keyspaceName, tableName) {
if cursor.LastKnownPk != nil {
return true
}
}
return false
}

func integrationLastKnownPKID(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) int64 {
t.Helper()

for shard, cursor := range integrationStateTableCursors(t, state, keyspaceName, tableName) {
if cursor.LastKnownPk == nil {
continue
}
result := sqltypes.Proto3ToResult(cursor.LastKnownPk)
if len(result.Rows) != 1 || len(result.Rows[0]) != 1 {
t.Fatalf("unexpected LastKnownPk shape for shard %s in table %s: %+v", shard, tableName, cursor.LastKnownPk)
}
id, err := result.Rows[0][0].ToInt64()
if err != nil {
t.Fatalf("parse LastKnownPk id for shard %s in table %s: %v", shard, tableName, err)
}
return id
}
t.Fatalf("state missing LastKnownPk for table %s", tableName)
return 0
}

func integrationStateTableCursors(t *testing.T, state *lib.SyncState, keyspaceName, tableName string) map[string]*psdbconnect.TableCursor {
t.Helper()

if state == nil {
t.Fatalf("sync state is nil")
}
keyspace, ok := state.Keyspaces[keyspaceName]
if !ok {
t.Fatalf("state missing keyspace %s: %+v", keyspaceName, state.Keyspaces)
}
streamName := keyspaceName + ":" + tableName
stream, ok := keyspace.Streams[streamName]
if !ok {
t.Fatalf("state missing stream %s: %+v", streamName, keyspace.Streams)
}

cursors := map[string]*psdbconnect.TableCursor{}
for shard, serializedCursor := range stream.Shards {
if serializedCursor == nil {
t.Fatalf("state shard %s in stream %s has nil cursor", shard, streamName)
}
cursor, err := serializedCursor.SerializedCursorToTableCursor()
if err != nil {
t.Fatalf("deserialize cursor for shard %s in stream %s: %v", shard, streamName, err)
}
cursors[shard] = cursor
}
return cursors
}

func insertIntegrationLoadRows(t *testing.T, ctx context.Context, db *sql.DB, tableName string, startID, count int64, version int, rows map[int64]map[string]any) {
t.Helper()

Expand Down Expand Up @@ -1043,6 +1265,42 @@ func insertIntegrationLoadRows(t *testing.T, ctx context.Context, db *sql.DB, ta
}
}

func insertIntegrationResumeRows(t *testing.T, ctx context.Context, db *sql.DB, tableName string, startID, count int64, version int, rows map[int64]map[string]any) {
t.Helper()

const batchSize = int64(50)
for offset := int64(0); offset < count; {
n := batchSize
if remaining := count - offset; remaining < n {
n = remaining
}

placeholders := make([]string, 0, int(n))
args := make([]any, 0, int(n)*4)
for i := int64(0); i < n; i++ {
id := startID + offset + i
payload := integrationResumePayload(version, id)
active := id%2 == 0

placeholders = append(placeholders, "(?, ?, ?, ?)")
args = append(args, id, version, payload, integrationBoolInt(active))
rows[id] = map[string]any{
"id": id,
"version": int64(version),
"payload": payload,
"active": active,
}
}

mustExec(t, ctx, db, fmt.Sprintf(
"insert into %s (id, version, payload, active) values %s",
quoteIdent(tableName),
strings.Join(placeholders, ", "),
), args...)
offset += n
}
}

func updateIntegrationLoadRow(t *testing.T, ctx context.Context, db *sql.DB, tableName string, id int64, version int, active bool, rows map[int64]map[string]any) {
t.Helper()

Expand All @@ -1060,6 +1318,23 @@ func updateIntegrationLoadRow(t *testing.T, ctx context.Context, db *sql.DB, tab
}
}

func updateIntegrationResumeRow(t *testing.T, ctx context.Context, db *sql.DB, tableName string, id int64, version int, active bool, rows map[int64]map[string]any) {
t.Helper()

payload := integrationResumePayload(version, id)
mustExec(t, ctx, db, fmt.Sprintf(
"update %s set version = ?, payload = ?, active = ? where id = ?",
quoteIdent(tableName),
), version, payload, integrationBoolInt(active), id)

rows[id] = map[string]any{
"id": id,
"version": int64(version),
"payload": payload,
"active": active,
}
}

func deleteIntegrationLoadRow(t *testing.T, ctx context.Context, db *sql.DB, tableName string, id int64, rows map[int64]map[string]any) {
t.Helper()

Expand Down Expand Up @@ -1098,6 +1373,10 @@ func integrationPayload(version int, id int64) string {
return fmt.Sprintf("payload-%03d-%06d", version, id)
}

func integrationResumePayload(version int, id int64) string {
return strings.Repeat(fmt.Sprintf("resume-payload-%03d-%06d-", version, id), 64)
}

func integrationBoolInt(value bool) int {
if value {
return 1
Expand Down
4 changes: 3 additions & 1 deletion cmd/internal/server/handlers/schema_aware_serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ func (l *schemaAwareSerializer) serializeResult(before *sqltypes.Result, after *
}
operationRecord.Record.Data = row
operationRecord.Record.Type = fivetranOpMap[opType]
l.sender.Send(l.recordResponse)
if err := l.sender.Send(l.recordResponse); err != nil {
return err
}
}

return nil
Expand Down
Loading