diff --git a/src/Gemstone.Web/APIController/ModelController.cs b/src/Gemstone.Web/APIController/ModelController.cs index b2e096a..901a618 100644 --- a/src/Gemstone.Web/APIController/ModelController.cs +++ b/src/Gemstone.Web/APIController/ModelController.cs @@ -23,9 +23,7 @@ // ReSharper disable StaticMemberInGenericType using System; -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Reflection; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Gemstone.Data; @@ -59,8 +57,8 @@ public ModelController() { } public virtual async Task Patch([FromBody] T record, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); - await tableOperations.UpdateRecordAsync(record, cancellationToken); + SecureTableOperations tableOperations = new(connection); + await tableOperations.UpdateRecordAsync(HttpContext.User, record, cancellationToken); return Ok(record); } @@ -92,8 +90,9 @@ public virtual async Task Post([FromBody]T record, CancellationTo public virtual async Task Delete([FromBody] T record, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); - await tableOperations.DeleteRecordAsync(record, cancellationToken); + SecureTableOperations tableOperations = new(connection); + object primaryKey = tableOperations.BaseOperations.GetPrimaryKeys(record).First(); + await tableOperations.DeleteRecordWhereAsync(HttpContext.User, $"{PrimaryKeyField} = {{0}}", cancellationToken, primaryKey); return Ok(1); } @@ -108,8 +107,8 @@ public virtual async Task Delete([FromBody] T record, Cancellatio public virtual async Task Delete(string id, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); - await tableOperations.DeleteRecordWhereAsync($"{PrimaryKeyField} = {{0}}", cancellationToken, id); + SecureTableOperations tableOperations = new(connection); + await tableOperations.DeleteRecordWhereAsync(HttpContext.User, $"{PrimaryKeyField} = {{0}}", cancellationToken, id); return Ok(1); } diff --git a/src/Gemstone.Web/APIController/ReadOnlyModelController.cs b/src/Gemstone.Web/APIController/ReadOnlyModelController.cs index b186af9..a5cccdf 100644 --- a/src/Gemstone.Web/APIController/ReadOnlyModelController.cs +++ b/src/Gemstone.Web/APIController/ReadOnlyModelController.cs @@ -49,7 +49,7 @@ private class ConnectionCache : IDisposable { public string Token { get; } = Guid.NewGuid().ToString(); - public TableOperations Table { get; } + public SecureTableOperations Table { get; } public IAsyncEnumerator? Records { get; set; } @@ -58,7 +58,7 @@ private class ConnectionCache : IDisposable private ConnectionCache() { m_connection = new AdoDataConnection(Settings.Default); - Table = new TableOperations(m_connection); + Table = new SecureTableOperations(m_connection); } public void Dispose() @@ -177,7 +177,7 @@ public Task Open(string? filterExpression, object?[] parameters, { ConnectionCache cache = ConnectionCache.Create(expiration ?? 1.0D); - cache.Records = cache.Table.QueryRecordsWhereAsync(filterExpression, cancellationToken, parameters).GetAsyncEnumerator(cancellationToken); + cache.Records = cache.Table.QueryRecordsWhereAsync(HttpContext.User, filterExpression, cancellationToken, parameters).GetAsyncEnumerator(cancellationToken); return Task.FromResult(Ok(cache.Token)); } @@ -237,7 +237,7 @@ public IActionResult Close(string token) public virtual async Task Get(string? parentID, int page, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); + SecureTableOperations tableOperations = new(connection); RecordFilter? filter = null; if (ParentKey != string.Empty && parentID is not null) @@ -250,7 +250,7 @@ public virtual async Task Get(string? parentID, int page, Cancell }; } - IAsyncEnumerable result = tableOperations.QueryRecordsAsync(DefaultSort, DefaultSortDirection, page, PageSize, cancellationToken, filter); + IAsyncEnumerable result = tableOperations.QueryRecordsAsync(HttpContext.User, DefaultSort, DefaultSortDirection, page, PageSize, cancellationToken, filter); return Ok(await result.ToArrayAsync(cancellationToken).ConfigureAwait(false)); } @@ -267,10 +267,10 @@ public virtual async Task Get(string? parentID, int page, Cancell public virtual async Task Get(string sort, bool ascending, int page, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); + SecureTableOperations tableOperations = new(connection); RecordFilter? filter = null; - IAsyncEnumerable result = tableOperations.QueryRecordsAsync(sort, ascending, page, PageSize, cancellationToken, filter); + IAsyncEnumerable result = tableOperations.QueryRecordsAsync(HttpContext.User, sort, ascending, page, PageSize, cancellationToken, filter); return Ok(await result.ToArrayAsync(cancellationToken).ConfigureAwait(false)); } @@ -288,7 +288,7 @@ public virtual async Task Get(string sort, bool ascending, int pa public virtual async Task Get(string parentID, string sort, bool ascending, int page, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); + SecureTableOperations tableOperations = new(connection); RecordFilter filter = new() { FieldName = ParentKey, @@ -296,7 +296,7 @@ public virtual async Task Get(string parentID, string sort, bool SearchParameter = parentID }; - IAsyncEnumerable result = tableOperations.QueryRecordsAsync(sort, ascending, page, PageSize, cancellationToken, filter); + IAsyncEnumerable result = tableOperations.QueryRecordsAsync(HttpContext.User, sort, ascending, page, PageSize, cancellationToken, filter); return Ok(await result.ToArrayAsync(cancellationToken).ConfigureAwait(false)); } @@ -311,8 +311,8 @@ public virtual async Task Get(string parentID, string sort, bool public virtual async Task GetOne(string id, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); - T? result = await tableOperations.QueryRecordAsync(new RecordRestriction($"{PrimaryKeyField} = {{0}}", id), cancellationToken).ConfigureAwait(false); + SecureTableOperations tableOperations = new(connection); + T? result = await tableOperations.QueryRecordAsync(HttpContext.User, new RecordRestriction($"{PrimaryKeyField} = {{0}}", id), cancellationToken).ConfigureAwait(false); return result is null ? NotFound() : @@ -332,7 +332,7 @@ public virtual async Task GetOne(string id, CancellationToken can public virtual async Task Search([FromBody] SearchPost postData, int page, string? parentID, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); + SecureTableOperations tableOperations = new(connection); RecordFilter[] filters = postData.Searches.ToArray(); if (ParentKey != string.Empty && parentID is not null) @@ -345,7 +345,7 @@ public virtual async Task Search([FromBody] SearchPost postDat }); } - IAsyncEnumerable result = tableOperations.QueryRecordsAsync(postData.OrderBy, postData.Ascending, page, PageSize, cancellationToken, filters); + IAsyncEnumerable result = tableOperations.QueryRecordsAsync(HttpContext.User, postData.OrderBy, postData.Ascending, page, PageSize, cancellationToken, filters); return Ok(await result.ToArrayAsync(cancellationToken).ConfigureAwait(false)); } @@ -362,7 +362,7 @@ public virtual async Task Search([FromBody] SearchPost postDat public virtual async Task GetPageInfo([FromBody] SearchPost postData, string? parentID, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); + SecureTableOperations tableOperations = new(connection); RecordFilter[] filters = postData.Searches.ToArray(); if (ParentKey != string.Empty && parentID is not null) @@ -375,7 +375,7 @@ public virtual async Task GetPageInfo([FromBody] SearchPost po }); } - int recordCount = await tableOperations.QueryRecordCountAsync(cancellationToken, filters).ConfigureAwait(false); + int recordCount = await tableOperations.QueryRecordCountAsync(HttpContext.User, cancellationToken, filters).ConfigureAwait(false); return Ok(new PageInfo() { @@ -396,7 +396,7 @@ public virtual async Task GetPageInfo([FromBody] SearchPost po public virtual async Task GetPageInfo(string? parentID, CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); + SecureTableOperations tableOperations = new(connection); RecordFilter[] filters = []; if (ParentKey != string.Empty && parentID is not null) @@ -409,7 +409,7 @@ public virtual async Task GetPageInfo(string? parentID, Cancellat }); } - int recordCount = await tableOperations.QueryRecordCountAsync(cancellationToken, filters).ConfigureAwait(false); + int recordCount = await tableOperations.QueryRecordCountAsync(HttpContext.User, cancellationToken, filters).ConfigureAwait(false); return Ok(new PageInfo() { @@ -428,9 +428,9 @@ public virtual async Task GetPageInfo(string? parentID, Cancellat public virtual async Task New(CancellationToken cancellationToken) { await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); + SecureTableOperations tableOperations = new(connection); - T? result = tableOperations.NewRecord(); + T? result = tableOperations.BaseOperations.NewRecord(); return Ok(result); } @@ -450,8 +450,8 @@ public virtual async Task GetMaxValue(string fieldName, Cancellat // Create a connection and table operations instance await using AdoDataConnection connection = CreateConnection(); - TableOperations tableOperations = new(connection); - string tableName = tableOperations.TableName; + SecureTableOperations tableOperations = new(connection); + string tableName = tableOperations.BaseOperations.TableName; string sql = $"SELECT MAX([{fieldName}]) FROM [{tableName}]"; object? maxValue = await connection.ExecuteScalarAsync(sql, cancellationToken).ConfigureAwait(false);