Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 4 additions & 22 deletions src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,6 @@ namespace Microsoft.EntityFrameworkCore;
/// </remarks>
public static class CosmosQueryableExtensions
{
internal static readonly MethodInfo WithPartitionKeyMethodInfo1
= typeof(CosmosQueryableExtensions).GetTypeInfo()
.GetDeclaredMethods(nameof(WithPartitionKey))
.Single(mi => mi.GetParameters().Length == 2);

internal static readonly MethodInfo WithPartitionKeyMethodInfo2
= typeof(CosmosQueryableExtensions).GetTypeInfo()
.GetDeclaredMethods(nameof(WithPartitionKey))
.Single(mi => mi.GetParameters().Length == 3);

internal static readonly MethodInfo WithPartitionKeyMethodInfo3
= typeof(CosmosQueryableExtensions).GetTypeInfo()
.GetDeclaredMethods(nameof(WithPartitionKey))
.Single(mi => mi.GetParameters().Length == 4);

/// <summary>
/// Specify the partition key for partition used for the query.
/// Required when using a resource token that provides permission based on a partition key for authentication,
Expand All @@ -55,7 +40,7 @@ source.Provider is EntityQueryProvider
? source.Provider.CreateQuery<TEntity>(
Expression.Call(
instance: null,
method: WithPartitionKeyMethodInfo1.MakeGenericMethod(typeof(TEntity)),
method: new Func<IQueryable<TEntity>, object, IQueryable<TEntity>>(WithPartitionKey).Method,
source.Expression,
Expression.Constant(partitionKeyValue, typeof(object))))
: source;
Expand Down Expand Up @@ -88,7 +73,7 @@ source.Provider is EntityQueryProvider
? source.Provider.CreateQuery<TEntity>(
Expression.Call(
instance: null,
method: WithPartitionKeyMethodInfo2.MakeGenericMethod(typeof(TEntity)),
method: new Func<IQueryable<TEntity>, object, object, IQueryable<TEntity>>(WithPartitionKey).Method,
source.Expression,
Expression.Constant(partitionKeyValue1, typeof(object)),
Expression.Constant(partitionKeyValue2, typeof(object))))
Expand Down Expand Up @@ -125,7 +110,7 @@ source.Provider is EntityQueryProvider
? source.Provider.CreateQuery<TEntity>(
Expression.Call(
instance: null,
method: WithPartitionKeyMethodInfo3.MakeGenericMethod(typeof(TEntity)),
method: new Func<IQueryable<TEntity>, object, object, object, IQueryable<TEntity>>(WithPartitionKey).Method,
source.Expression,
Expression.Constant(partitionKeyValue1, typeof(object)),
Expression.Constant(partitionKeyValue2, typeof(object)),
Expand Down Expand Up @@ -237,9 +222,6 @@ private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
Expression.Constant(arguments));
}

internal static readonly MethodInfo ToPageAsyncMethodInfo
= typeof(CosmosQueryableExtensions).GetMethod(nameof(ToPageAsync))!;

/// <summary>
/// Allows paginating through query results by repeatedly executing the same query, passing continuation tokens to retrieve
/// successive pages of the result set, and specifying the maximum number of results per page.
Expand Down Expand Up @@ -272,7 +254,7 @@ public static Task<CosmosPage<TSource>> ToPageAsync<TSource>(
return provider.ExecuteAsync<Task<CosmosPage<TSource>>>(
Expression.Call(
instance: null,
method: ToPageAsyncMethodInfo.MakeGenericMethod(typeof(TSource)),
method: new Func<IQueryable<TSource>, int, string?, int?, CancellationToken, Task<CosmosPage<TSource>>>(ToPageAsync).Method,
arguments:
[
source.Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,9 @@ public static class RelationalQueryableExtensions
/// <param name="source">The query source.</param>
/// <returns>The query string for debugging.</returns>
public static DbCommand CreateDbCommand(this IQueryable source)
{
if (source.Provider.Execute<IEnumerable>(source.Expression) is IRelationalQueryingEnumerable queryingEnumerable)
{
return queryingEnumerable.CreateDbCommand();
}

throw new NotSupportedException(RelationalStrings.NoDbCommand);
}
=> source.Provider.Execute<IEnumerable>(source.Expression) is IRelationalQueryingEnumerable queryingEnumerable
? queryingEnumerable.CreateDbCommand()
: throw new NotSupportedException(RelationalStrings.NoDbCommand);

#region FromSql

Expand Down Expand Up @@ -233,12 +228,11 @@ public static IQueryable<TEntity> AsSingleQuery<TEntity>(
where TEntity : class
=> source.Provider is EntityQueryProvider
? source.Provider.CreateQuery<TEntity>(
Expression.Call(AsSingleQueryMethodInfo.MakeGenericMethod(typeof(TEntity)), source.Expression))
Expression.Call(
method: new Func<IQueryable<TEntity>, IQueryable<TEntity>>(AsSingleQuery).Method,
source.Expression))
: source;

internal static readonly MethodInfo AsSingleQueryMethodInfo
= typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(AsSingleQuery))!;

/// <summary>
/// Returns a new query which is configured to load the collections in the query results through separate database queries.
/// </summary>
Expand All @@ -265,11 +259,10 @@ public static IQueryable<TEntity> AsSplitQuery<TEntity>(
where TEntity : class
=> source.Provider is EntityQueryProvider
? source.Provider.CreateQuery<TEntity>(
Expression.Call(AsSplitQueryMethodInfo.MakeGenericMethod(typeof(TEntity)), source.Expression))
Expression.Call(
new Func<IQueryable<TEntity>, IQueryable<TEntity>>(AsSplitQuery).Method,
source.Expression))
: source;

internal static readonly MethodInfo AsSplitQueryMethodInfo
= typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(AsSplitQuery))!;

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ public RelationalQueryMetadataExtractingExpressionVisitor(
/// </summary>
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == RelationalQueryableExtensions.AsSplitQueryMethodInfo)
if (methodCallExpression.Method.DeclaringType == typeof(RelationalQueryableExtensions))
{
var innerQueryable = Visit(methodCallExpression.Arguments[0]);

_relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SplitQuery;

return innerQueryable;
}

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == RelationalQueryableExtensions.AsSingleQueryMethodInfo)
{
var innerQueryable = Visit(methodCallExpression.Arguments[0]);

_relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SingleQuery;

return innerQueryable;
switch (methodCallExpression.Method.Name)
{
case nameof(RelationalQueryableExtensions.AsSplitQuery):
{
var innerQueryable = Visit(methodCallExpression.Arguments[0]);
_relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SplitQuery;
return innerQueryable;
}

case nameof(RelationalQueryableExtensions.AsSingleQuery):
{
var innerQueryable = Visit(methodCallExpression.Arguments[0]);
_relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SingleQuery;
return innerQueryable;
}
}
}

return base.VisitMethodCall(methodCallExpression);
Expand Down
Loading