Skip to content

Commit 8638178

Browse files
authored
chore: add GetTotalUpdateCount method to .NET (#701)
* chore: add GetTotalUpdateCount method to .NET Add a GetTotalUpdateCount method to the .NET wrapper that can be used to return the total update count of a set of SQL statements. This will be used by the ADO.NET ExecuteNonQuery method. See https://learn.microsoft.com/en-us/dotnet/api/system.data.sqlclient.sqlcommand.executenonquery * fix: only return -1 if there are no DML statements
1 parent e8f0b7a commit 8638178

File tree

2 files changed

+119
-0
lines changed
  • spannerlib/wrappers/spannerlib-dotnet

2 files changed

+119
-0
lines changed

spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-tests/RowsTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
using System.Text;
1516
using Google.Cloud.Spanner.Admin.Database.V1;
1617
using Google.Cloud.Spanner.V1;
1718
using Google.Cloud.SpannerLib.MockServer;
@@ -429,6 +430,51 @@ public async Task TestMultipleMixedStatements([Values] LibType libType, [Values(
429430
Assert.That(numResultSets, Is.EqualTo(5));
430431
}
431432

433+
[Test]
434+
public async Task TestGetAllUpdateCounts([Values] LibType libType, [Values(0, 1, 2, 5)] int numDmlStatements, [Values(2, 10)] int numRows, [Values(0, 1, 2, 3, 5, 9, 10, 11)] int prefetchRows, [Values] bool async)
435+
{
436+
var updateCount = 3L;
437+
var dml = "update my_table set value=1 where id in (1,2,3)";
438+
Fixture.SpannerMock.AddOrUpdateStatementResult(dml, StatementResult.CreateUpdateCount(updateCount));
439+
Fixture.SpannerMock.AddOrUpdateStatementResult(dml + ";", StatementResult.CreateUpdateCount(updateCount));
440+
441+
var rowType = RandomResultSetGenerator.GenerateAllTypesRowType();
442+
var results = RandomResultSetGenerator.Generate(rowType, numRows);
443+
var query = "select * from random";
444+
Fixture.SpannerMock.AddOrUpdateStatementResult(query, StatementResult.CreateQuery(results));
445+
446+
// Create a SQL string containing a mix of DML and queries.
447+
var builder = new StringBuilder();
448+
for (var i = 0; i < numDmlStatements; i++)
449+
{
450+
while (Random.Shared.Next(2) == 0)
451+
{
452+
builder.Append(query).Append(';');
453+
}
454+
builder.Append(dml).Append(';');
455+
while (Random.Shared.Next(5) == 0)
456+
{
457+
builder.Append(query).Append(';');
458+
}
459+
}
460+
var sql = builder.ToString();
461+
if (string.IsNullOrEmpty(sql))
462+
{
463+
sql = query;
464+
}
465+
466+
await using var pool = Pool.Create(SpannerLibDictionary[libType], ConnectionString);
467+
await using var connection = pool.CreateConnection();
468+
await using var rows = async
469+
? await connection.ExecuteAsync(new ExecuteSqlRequest { Sql = sql }, prefetchRows)
470+
// ReSharper disable once MethodHasAsyncOverload
471+
: connection.Execute(new ExecuteSqlRequest { Sql = sql }, prefetchRows);
472+
473+
// ReSharper disable once MethodHasAsyncOverload
474+
var totalUpdateCount = async ? await rows.GetTotalUpdateCountAsync() : rows.GetTotalUpdateCount();
475+
Assert.That(totalUpdateCount, Is.EqualTo(numDmlStatements == 0 ? -1 : updateCount * numDmlStatements));
476+
}
477+
432478
[Test]
433479
public async Task TestMultipleMixedStatementsWithErrors(
434480
[Values] LibType libType,

spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet/Rows.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ public long UpdateCount
6262
return -1L;
6363
}
6464
}
65+
66+
private bool _hasReadAllResults;
6567

6668
public Rows(Connection connection, long id, bool initMetadata = true) : base(connection.Spanner, id)
6769
{
@@ -97,12 +99,78 @@ public Rows(Connection connection, long id, bool initMetadata = true) : base(con
9799
return await Spanner.NextAsync(this, 1, ISpannerLib.RowEncoding.Proto, cancellationToken).ConfigureAwait(false);
98100
}
99101

102+
/// <summary>
103+
/// Gets the total update count in this Rows object. This consumes all data in all the result sets.
104+
/// This method should only be called if the caller is only interested in the update count, and not in any of the
105+
/// rows in the result sets.
106+
/// </summary>
107+
/// <returns>
108+
/// The total update count of all the result sets in this Rows object.
109+
/// If the SQL string only contained non-DML statements, then the update count is -1. If the SQL string contained at
110+
/// least one DML statement, then the returned value is the sum of all update counts of the DML statements in the
111+
/// SQL string that was executed.
112+
/// </returns>
113+
public long GetTotalUpdateCount()
114+
{
115+
long result = -1;
116+
var hasUpdateCount = false;
117+
do
118+
{
119+
var updateCount = UpdateCount;
120+
if (updateCount > -1)
121+
{
122+
if (!hasUpdateCount)
123+
{
124+
hasUpdateCount = true;
125+
result = 0;
126+
}
127+
result += updateCount;
128+
}
129+
} while (NextResultSet());
130+
return result;
131+
}
132+
133+
/// <summary>
134+
/// Gets the total update count in this Rows object. This consumes all data in all the result sets.
135+
/// This method should only be called if the caller is only interested in the update count, and not in any of the
136+
/// rows in the result sets.
137+
/// </summary>
138+
/// <returns>
139+
/// The total update count of all the result sets in this Rows object.
140+
/// If the SQL string only contained non-DML statements, then the update count is -1. If the SQL string contained at
141+
/// least one DML statement, then the returned value is the sum of all update counts of the DML statements in the
142+
/// SQL string that was executed.
143+
/// </returns>
144+
public async Task<long> GetTotalUpdateCountAsync(CancellationToken cancellationToken = default)
145+
{
146+
long result = -1;
147+
var hasUpdateCount = false;
148+
do
149+
{
150+
var updateCount = UpdateCount;
151+
if (updateCount > -1)
152+
{
153+
if (!hasUpdateCount)
154+
{
155+
hasUpdateCount = true;
156+
result = 0;
157+
}
158+
result += updateCount;
159+
}
160+
} while (await NextResultSetAsync(cancellationToken).ConfigureAwait(false));
161+
return result;
162+
}
163+
100164
/// <summary>
101165
/// Moves the cursor to the next result set in this Rows object.
102166
/// </summary>
103167
/// <returns>True if there was another result set, and false otherwise</returns>
104168
public virtual bool NextResultSet()
105169
{
170+
if (_hasReadAllResults)
171+
{
172+
return false;
173+
}
106174
return NextResultSet(Spanner.NextResultSet(this));
107175
}
108176

@@ -112,13 +180,18 @@ public virtual bool NextResultSet()
112180
/// <returns>True if there was another result set, and false otherwise</returns>
113181
public virtual async Task<bool> NextResultSetAsync(CancellationToken cancellationToken = default)
114182
{
183+
if (_hasReadAllResults)
184+
{
185+
return false;
186+
}
115187
return NextResultSet(await Spanner.NextResultSetAsync(this, cancellationToken).ConfigureAwait(false));
116188
}
117189

118190
private bool NextResultSet(ResultSetMetadata? metadata)
119191
{
120192
if (metadata == null)
121193
{
194+
_hasReadAllResults = true;
122195
return false;
123196
}
124197
_metadata = metadata;

0 commit comments

Comments
 (0)