diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cb2f6d2c..97ebd9be3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added support for `Result.RowsAffected()` for YDB `database/sql` driver +* Upgraded minimal version of Go to 1.23.9 * Fixed race in `readerReconnector` ## v3.117.1 diff --git a/examples/go.mod b/examples/go.mod index 7da0a46f2..0919e48bb 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -1,8 +1,6 @@ module examples -go 1.23.0 - -toolchain go1.23.6 +go 1.23.9 require ( github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 diff --git a/go.mod b/go.mod index 31afd9bc5..8f98978d4 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ydb-platform/ydb-go-sdk/v3 -go 1.22.5 +go 1.23.9 require ( github.com/golang-jwt/jwt/v4 v4.5.2 diff --git a/internal/query/result.go b/internal/query/result.go index c21f8fc14..9fb6c5bb5 100644 --- a/internal/query/result.go +++ b/internal/query/result.go @@ -159,10 +159,6 @@ func newResult( r.lastPart = part - if part.GetExecStats() != nil && r.statsCallback != nil { - r.statsCallback(stats.FromQueryStats(part.GetExecStats())) - } - return &r, nil } } @@ -211,6 +207,10 @@ func (r *streamResult) nextPart(ctx context.Context) ( } } + if part.GetExecStats() != nil && r.statsCallback != nil { + r.statsCallback(stats.FromQueryStats(part.GetExecStats())) + } + return part, nil } } @@ -286,9 +286,6 @@ func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err err if err != nil { return nil, xerrors.WithStackTrace(err) } - if part.GetExecStats() != nil && r.statsCallback != nil { - r.statsCallback(stats.FromQueryStats(part.GetExecStats())) - } if part.GetResultSetIndex() < r.resultSetIndex { r.closer.Close(nil) @@ -326,9 +323,6 @@ func (r *streamResult) nextPartFunc( return nil, xerrors.WithStackTrace(err) } r.lastPart = part - if part.GetExecStats() != nil && r.statsCallback != nil { - r.statsCallback(stats.FromQueryStats(part.GetExecStats())) - } if part.GetResultSetIndex() > nextResultSetIndex { return nil, xerrors.WithStackTrace(fmt.Errorf( "result set (index=%d) receive part (index=%d) for next result set: %w (%w)", diff --git a/internal/xsql/xquery/conn.go b/internal/xsql/xquery/conn.go index b100bf220..76e9a980c 100644 --- a/internal/xsql/xquery/conn.go +++ b/internal/xsql/xquery/conn.go @@ -15,13 +15,6 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/common" ) -type resultNoRows struct{} - -func (resultNoRows) LastInsertId() (int64, error) { return 0, ErrUnsupported } -func (resultNoRows) RowsAffected() (int64, error) { return 0, ErrUnsupported } - -var _ driver.Result = resultNoRows{} - type Parent interface { Query() *query.Client } @@ -39,7 +32,7 @@ func (c *Conn) NodeID() uint32 { } func (c *Conn) Exec(ctx context.Context, sql string, params *params.Params) ( - result driver.Result, finalErr error, + driver.Result, error, ) { if !c.IsValid() { return nil, xerrors.WithStackTrace(xerrors.Retryable(errNotReadyConn, @@ -63,12 +56,15 @@ func (c *Conn) Exec(ctx context.Context, sql string, params *params.Params) ( opts = append(opts, options.WithTxControl(txControl)) } + r := &resultWithStats{} + opts = append(opts, options.WithStatsMode(options.StatsModeBasic, r.onQueryStats)) + err := c.session.Exec(ctx, sql, opts...) if err != nil { return nil, xerrors.WithStackTrace(err) } - return resultNoRows{}, nil + return r, nil } func (c *Conn) Query(ctx context.Context, sql string, params *params.Params) ( diff --git a/internal/xsql/xquery/results_with_stats.go b/internal/xsql/xquery/results_with_stats.go new file mode 100644 index 000000000..196ec005a --- /dev/null +++ b/internal/xsql/xquery/results_with_stats.go @@ -0,0 +1,36 @@ +package xquery + +import ( + "database/sql/driver" + + "github.com/ydb-platform/ydb-go-sdk/v3/query" +) + +type ( + resultWithStats struct { + rowsAffected *uint64 + } +) + +var _ driver.Result = &resultWithStats{} + +func (r *resultWithStats) onQueryStats(qs query.Stats) { + var rowsAffected uint64 + for queryPhase := range qs.QueryPhases() { + for tableAccess := range queryPhase.TableAccess() { + rowsAffected += tableAccess.Deletes.Rows + tableAccess.Updates.Rows + } + } + // last stats always contains the full stats of query + r.rowsAffected = &rowsAffected +} + +func (r *resultWithStats) RowsAffected() (int64, error) { + if r.rowsAffected == nil { + return 0, ErrUnsupported + } + + return int64(*r.rowsAffected), nil +} + +func (r *resultWithStats) LastInsertId() (int64, error) { return 0, ErrUnsupported } diff --git a/internal/xsql/xquery/results_with_stats_test.go b/internal/xsql/xquery/results_with_stats_test.go new file mode 100644 index 000000000..9676094c6 --- /dev/null +++ b/internal/xsql/xquery/results_with_stats_test.go @@ -0,0 +1,363 @@ +package xquery + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_TableStats" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/stats" +) + +func TestResultWithStats_RowsAffected(t *testing.T) { + t.Run("NilStats", func(t *testing.T) { + r := &resultWithStats{ + rowsAffected: nil, + } + + rows, err := r.RowsAffected() + + require.Equal(t, int64(0), rows) + require.ErrorIs(t, err, ErrUnsupported) + }) + + t.Run("EmptyStats", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{})) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(0), rows) + }) + + t.Run("SinglePhaseWithDeletes", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 5, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(5), rows) + }) + + t.Run("SinglePhaseWithUpdates", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Updates: &Ydb_TableStats.OperationStats{ + Rows: 3, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(3), rows) + }) + + t.Run("SinglePhaseWithDeletesAndUpdates", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 5, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 3, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(8), rows) + }) + + t.Run("MultipleTableAccesses", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 5, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 3, + }, + }, + { + Name: "table2", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 2, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 7, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(17), rows) // 5 + 3 + 2 + 7 + }) + + t.Run("MultiplePhases", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 5, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 3, + }, + }, + }, + }, + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table2", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 2, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 7, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(17), rows) // 5 + 3 + 2 + 7 + }) + + t.Run("OnlyReadsIgnored", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Reads: &Ydb_TableStats.OperationStats{ + Rows: 100, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(0), rows) + }) + + t.Run("MixedOperations", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Reads: &Ydb_TableStats.OperationStats{ + Rows: 100, + }, + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 10, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 5, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(15), rows) // Only deletes + updates, reads ignored + }) + + t.Run("ZeroRowsAffected", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 0, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 0, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(0), rows) + }) + + t.Run("LargeNumberOfRows", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 1000000, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 2000000, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(3000000), rows) + }) + + t.Run("ComplexMultiPhaseMultiTable", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table1", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 10, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 20, + }, + }, + { + Name: "table2", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 5, + }, + }, + }, + }, + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table3", + Updates: &Ydb_TableStats.OperationStats{ + Rows: 15, + }, + }, + }, + }, + { + TableAccess: []*Ydb_TableStats.TableAccessStats{ + { + Name: "table4", + Deletes: &Ydb_TableStats.OperationStats{ + Rows: 3, + }, + Updates: &Ydb_TableStats.OperationStats{ + Rows: 7, + }, + }, + }, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(60), rows) // 10 + 20 + 5 + 15 + 3 + 7 + }) + + t.Run("EmptyTableAccessInPhase", func(t *testing.T) { + r := &resultWithStats{} + r.onQueryStats(stats.FromQueryStats(&Ydb_TableStats.QueryStats{ + QueryPhases: []*Ydb_TableStats.QueryPhaseStats{ + { + TableAccess: []*Ydb_TableStats.TableAccessStats{}, + }, + }, + })) + + rows, err := r.RowsAffected() + + require.NoError(t, err) + require.Equal(t, int64(0), rows) + }) +} + +func TestResultWithStats_LastInsertId(t *testing.T) { + r := &resultWithStats{} + + id, err := r.LastInsertId() + + require.Equal(t, int64(0), id) + require.ErrorIs(t, err, ErrUnsupported) +} diff --git a/internal/xsql/xquery/tx.go b/internal/xsql/xquery/tx.go index ea78cdd2c..cde718546 100644 --- a/internal/xsql/xquery/tx.go +++ b/internal/xsql/xquery/tx.go @@ -30,12 +30,15 @@ func (t *transaction) Exec(ctx context.Context, sql string, params *params.Param opts = append(opts, options.WithTxControl(txControl)) } + r := &resultWithStats{} + opts = append(opts, options.WithStatsMode(options.StatsModeBasic, r.onQueryStats)) + err := t.tx.Exec(ctx, sql, opts...) if err != nil { return nil, xerrors.WithStackTrace(err) } - return resultNoRows{}, nil + return r, nil } func (t *transaction) Query(ctx context.Context, sql string, params *params.Params) (driver.RowsNextResultSet, error) { diff --git a/tests/integration/database_sql_rows_affected_test.go b/tests/integration/database_sql_rows_affected_test.go new file mode 100644 index 000000000..3d5cac0c6 --- /dev/null +++ b/tests/integration/database_sql_rows_affected_test.go @@ -0,0 +1,133 @@ +//go:build integration +// +build integration + +package integration + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" +) + +func TestDatabaseSQLRowsAffected(t *testing.T) { + tests := []struct { + sql string + rows int64 + }{ + { + sql: "INSERT INTO %s (id) values (1),(2),(3)", + rows: 3, + }, + { + sql: "UPDATE %s SET val = 'test' where id > 1", + rows: 2, + }, + { + sql: "DELETE FROM %s", + rows: 3, + }, + { + sql: "INSERT INTO %s (id) values (1),(2),(3); INSERT INTO %[1]s (id) values (4),(5); ", + rows: 5, + }, + { + sql: "UPDATE %s SET val = 'test' where id > 1; DELETE FROM %[1]s WHERE id < 4", // 4+3 + rows: 7, + }, + // Single row operations + { + sql: "INSERT INTO %s (id) VALUES (6)", + rows: 1, + }, + { + sql: "UPDATE %s SET val = 'single' WHERE id = 6", + rows: 1, + }, + { + sql: "DELETE FROM %s WHERE id = 6", + rows: 1, + }, + // Operations affecting 0 rows + { + sql: "UPDATE %s SET val = 'none' WHERE id = 999", + rows: 0, + }, + // More complex multi-statement scenarios + { + sql: "INSERT INTO %s (id) VALUES (10),(11),(12); UPDATE %[1]s SET val = 'multi' WHERE id IN (10,11,12); DELETE FROM %[1]s WHERE id = 10", + rows: 7, // 3 inserted + 3 updated + 1 deleted + }, + // Multiple statements with some affecting 0 rows + { + sql: "UPDATE %s SET val = 'zero' WHERE id = 999; INSERT INTO %[1]s (id) VALUES (13)", + rows: 1, // 0 updated + 1 inserted + }, + // UPSERT operations + { + sql: "UPSERT INTO %s (id, val) VALUES (14, 'upsert1')", + rows: 1, + }, + { + sql: "UPSERT INTO %s (id, val) VALUES (15, 'upsert2'), (16, 'upsert3')", + rows: 2, + }, + // UPSERT that updates existing rows (should still count as affecting rows) + { + sql: "UPSERT INTO %s (id, val) VALUES (1, 'updated')", + rows: 1, + }, + // Complex multi-statement with UPSERT + { + sql: "INSERT INTO %s (id, val) VALUES (17, 'insert'); UPSERT INTO %[1]s (id, val) VALUES (17, 'upserted')", + rows: 2, // 1 inserted + 1 upserted + }, + // Test case with mixed operations + { + sql: "INSERT INTO %s (id, val) VALUES (18, 'mixed1'); UPDATE %[1]s SET val = 'updated' WHERE id = 18; DELETE FROM %[1]s WHERE id = 18", + rows: 3, // 1 inserted + 1 updated + 1 deleted + }, + // Additional test cases + { + sql: "INSERT INTO %s (id, val) VALUES (19, 'test1'), (20, 'test2'); DELETE FROM %[1]s WHERE id IN (19, 20)", + rows: 4, // 2 inserted + 2 deleted + }, + { + sql: "UPDATE %s SET val = 'updated' WHERE id BETWEEN 1 AND 3", + rows: 1, + }, + { + sql: "INSERT INTO %s (id, val) VALUES (21, 'test1'); INSERT INTO %[1]s (id, val) VALUES (22, 'test2'); UPDATE %[1]s SET val = 'bulk_update' WHERE id IN (21, 22)", + rows: 4, // 2 inserted + 2 updated + }, + { + sql: "UPSERT INTO %s (id, val) VALUES (23, 'upsert1'); UPSERT INTO %[1]s (id, val) VALUES (24, 'upsert2'); DELETE FROM %[1]s WHERE id IN (23, 24)", + rows: 4, // 2 upserted + 2 deleted + }, + } + + var ( + scope = newScope(t) + db = scope.SQLDriverWithFolder(ydb.WithQueryService(true)) + ) + + defer func() { + _ = db.Close() + }() + + for _, test := range tests { + sql := fmt.Sprintf(test.sql, scope.TableName()) + t.Run(sql, func(t *testing.T) { + result, err := db.Exec(sql) + require.NoError(t, err) + + got, err := result.RowsAffected() + require.NoError(t, err) + + assert.Equal(t, test.rows, got) + }) + } +} diff --git a/tests/slo/go.mod b/tests/slo/go.mod index 099d7774f..a562b4328 100644 --- a/tests/slo/go.mod +++ b/tests/slo/go.mod @@ -1,8 +1,6 @@ module slo -go 1.23.0 - -toolchain go1.23.6 +go 1.23.9 require ( github.com/prometheus/client_golang v1.14.0