diff --git a/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs index 8f4579bdce5..a3ba49f079a 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs @@ -18,21 +18,6 @@ namespace Microsoft.EntityFrameworkCore; /// public static class CosmosQueryableExtensions { - internal static readonly MethodInfo WithPartitionKeyMethodInfo1 - = typeof(CosmosQueryableExtensions).GetTypeInfo() - .GetDeclaredMethods(nameof(WithPartitionKey)) - .Single(mi => mi.GetParameters().Length == 2); - - internal static readonly MethodInfo WithPartitionKeyMethodInfo2 - = typeof(CosmosQueryableExtensions).GetTypeInfo() - .GetDeclaredMethods(nameof(WithPartitionKey)) - .Single(mi => mi.GetParameters().Length == 3); - - internal static readonly MethodInfo WithPartitionKeyMethodInfo3 - = typeof(CosmosQueryableExtensions).GetTypeInfo() - .GetDeclaredMethods(nameof(WithPartitionKey)) - .Single(mi => mi.GetParameters().Length == 4); - /// /// Specify the partition key for partition used for the query. /// Required when using a resource token that provides permission based on a partition key for authentication, @@ -55,7 +40,7 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: WithPartitionKeyMethodInfo1.MakeGenericMethod(typeof(TEntity)), + method: new Func, object, IQueryable>(WithPartitionKey).Method, source.Expression, Expression.Constant(partitionKeyValue, typeof(object)))) : source; @@ -88,7 +73,7 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: WithPartitionKeyMethodInfo2.MakeGenericMethod(typeof(TEntity)), + method: new Func, object, object, IQueryable>(WithPartitionKey).Method, source.Expression, Expression.Constant(partitionKeyValue1, typeof(object)), Expression.Constant(partitionKeyValue2, typeof(object)))) @@ -125,7 +110,7 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: WithPartitionKeyMethodInfo3.MakeGenericMethod(typeof(TEntity)), + method: new Func, object, object, object, IQueryable>(WithPartitionKey).Method, source.Expression, Expression.Constant(partitionKeyValue1, typeof(object)), Expression.Constant(partitionKeyValue2, typeof(object)), @@ -237,9 +222,6 @@ private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot( Expression.Constant(arguments)); } - internal static readonly MethodInfo ToPageAsyncMethodInfo - = typeof(CosmosQueryableExtensions).GetMethod(nameof(ToPageAsync))!; - /// /// Allows paginating through query results by repeatedly executing the same query, passing continuation tokens to retrieve /// successive pages of the result set, and specifying the maximum number of results per page. @@ -272,7 +254,7 @@ public static Task> ToPageAsync( return provider.ExecuteAsync>>( Expression.Call( instance: null, - method: ToPageAsyncMethodInfo.MakeGenericMethod(typeof(TSource)), + method: new Func, int, string?, int?, CancellationToken, Task>>(ToPageAsync).Method, arguments: [ source.Expression, diff --git a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs index 0005cae2693..03e14652357 100644 --- a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs @@ -39,14 +39,9 @@ public static class RelationalQueryableExtensions /// The query source. /// The query string for debugging. public static DbCommand CreateDbCommand(this IQueryable source) - { - if (source.Provider.Execute(source.Expression) is IRelationalQueryingEnumerable queryingEnumerable) - { - return queryingEnumerable.CreateDbCommand(); - } - - throw new NotSupportedException(RelationalStrings.NoDbCommand); - } + => source.Provider.Execute(source.Expression) is IRelationalQueryingEnumerable queryingEnumerable + ? queryingEnumerable.CreateDbCommand() + : throw new NotSupportedException(RelationalStrings.NoDbCommand); #region FromSql @@ -233,12 +228,11 @@ public static IQueryable AsSingleQuery( where TEntity : class => source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( - Expression.Call(AsSingleQueryMethodInfo.MakeGenericMethod(typeof(TEntity)), source.Expression)) + Expression.Call( + method: new Func, IQueryable>(AsSingleQuery).Method, + source.Expression)) : source; - internal static readonly MethodInfo AsSingleQueryMethodInfo - = typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(AsSingleQuery))!; - /// /// Returns a new query which is configured to load the collections in the query results through separate database queries. /// @@ -265,11 +259,10 @@ public static IQueryable AsSplitQuery( where TEntity : class => source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( - Expression.Call(AsSplitQueryMethodInfo.MakeGenericMethod(typeof(TEntity)), source.Expression)) + Expression.Call( + new Func, IQueryable>(AsSplitQuery).Method, + source.Expression)) : source; - internal static readonly MethodInfo AsSplitQueryMethodInfo - = typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(AsSplitQuery))!; - #endregion } diff --git a/src/EFCore.Relational/Query/Internal/RelationalQueryMetadataExtractingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalQueryMetadataExtractingExpressionVisitor.cs index a4d8c46e020..dd72bb3a0dc 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalQueryMetadataExtractingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalQueryMetadataExtractingExpressionVisitor.cs @@ -31,24 +31,24 @@ public RelationalQueryMetadataExtractingExpressionVisitor( /// protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == RelationalQueryableExtensions.AsSplitQueryMethodInfo) + if (methodCallExpression.Method.DeclaringType == typeof(RelationalQueryableExtensions)) { - var innerQueryable = Visit(methodCallExpression.Arguments[0]); - - _relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SplitQuery; - - return innerQueryable; - } - - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == RelationalQueryableExtensions.AsSingleQueryMethodInfo) - { - var innerQueryable = Visit(methodCallExpression.Arguments[0]); - - _relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SingleQuery; - - return innerQueryable; + switch (methodCallExpression.Method.Name) + { + case nameof(RelationalQueryableExtensions.AsSplitQuery): + { + var innerQueryable = Visit(methodCallExpression.Arguments[0]); + _relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SplitQuery; + return innerQueryable; + } + + case nameof(RelationalQueryableExtensions.AsSingleQuery): + { + var innerQueryable = Visit(methodCallExpression.Arguments[0]); + _relationalQueryCompilationContext.QuerySplittingBehavior = QuerySplittingBehavior.SingleQuery; + return innerQueryable; + } + } } return base.VisitMethodCall(methodCallExpression); diff --git a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs index 5745e5ae638..cc727db4143 100644 --- a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs +++ b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs @@ -11,18 +11,6 @@ namespace Microsoft.EntityFrameworkCore; /// /// Entity Framework LINQ related extension methods. /// -[UnconditionalSuppressMessage( - "ReflectionAnalysis", - "IL2060", - Justification = - "MakeGenericMethod is used in this class to create MethodCallExpression nodes, but only if the method in question is called " - + "from user code - so it's never trimmed. After https://github.com/dotnet/linker/issues/2482 is fixed, the suppression will no " - + "longer be necessary."), UnconditionalSuppressMessage( - "AOT", - "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", - Justification = - "MakeGenericMethod is used in this class to create MethodCallExpression nodes, but only if the method in question is called " - + "from user code - so it's never trimmed.")] public static class EntityFrameworkQueryableExtensions { /// @@ -72,7 +60,10 @@ public static string ToQueryString(this IQueryable source) public static Task AnyAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.AnyWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, bool>(Queryable.Any).Method, + source, + cancellationToken); /// /// Asynchronously determines whether any element of a sequence satisfies a condition. @@ -107,7 +98,11 @@ public static Task AnyAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.AnyWithPredicate, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, bool>(Queryable.Any).Method, + source, + predicate, + cancellationToken); } /// @@ -143,7 +138,11 @@ public static Task AllAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.All, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, bool>(Queryable.All).Method, + source, + predicate, + cancellationToken); } #endregion @@ -175,7 +174,10 @@ public static Task AllAsync( public static Task CountAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.CountWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, int>(Queryable.Count).Method, + source, + cancellationToken); /// /// Asynchronously returns the number of elements in a sequence that satisfy a condition. @@ -210,7 +212,11 @@ public static Task CountAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.CountWithPredicate, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, int>(Queryable.Count).Method, + source, + predicate, + cancellationToken); } /// @@ -238,7 +244,10 @@ public static Task CountAsync( public static Task LongCountAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.LongCountWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, long>(Queryable.LongCount).Method, + source, + cancellationToken); /// /// Asynchronously returns a that represents the number of elements in a sequence @@ -274,7 +283,11 @@ public static Task LongCountAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.LongCountWithPredicate, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, long>(Queryable.LongCount).Method, + source, + predicate, + cancellationToken); } #endregion @@ -319,7 +332,10 @@ public static Task ElementAtAsync( Check.NotNull(index); return ExecuteAsync>( - QueryableMethods.ElementAt, source, Expression.Constant(index), cancellationToken); + method: new Func, int, TSource>(Queryable.ElementAt).Method, + source, + Expression.Constant(index), + cancellationToken); } /// @@ -355,7 +371,10 @@ public static Task ElementAtAsync( Check.NotNull(index); return ExecuteAsync>( - QueryableMethods.ElementAtOrDefault, source, Expression.Constant(index), cancellationToken); + method: new Func, int, TSource?>(Queryable.ElementAtOrDefault).Method, + source, + Expression.Constant(index), + cancellationToken); } #endregion @@ -388,7 +407,10 @@ public static Task ElementAtAsync( public static Task FirstAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.FirstWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource>(Queryable.First).Method, + source, + cancellationToken); /// /// Asynchronously returns the first element of a sequence that satisfies a specified condition. @@ -434,7 +456,11 @@ public static Task FirstAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.FirstWithPredicate, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, TSource>(Queryable.First).Method, + source, + predicate, + cancellationToken); } /// @@ -463,7 +489,10 @@ public static Task FirstAsync( public static Task FirstOrDefaultAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.FirstOrDefaultWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource?>(Queryable.FirstOrDefault).Method, + source, + cancellationToken); /// /// Asynchronously returns the first element of a sequence that satisfies a specified condition @@ -501,7 +530,10 @@ public static Task FirstAsync( Check.NotNull(predicate); return ExecuteAsync>( - QueryableMethods.FirstOrDefaultWithPredicate, source, predicate, cancellationToken); + method: new Func, Expression>, TSource?>(Queryable.FirstOrDefault).Method, + source, + predicate, + cancellationToken); } #endregion @@ -534,7 +566,10 @@ public static Task FirstAsync( public static Task LastAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.LastWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource>(Queryable.Last).Method, + source, + cancellationToken); /// /// Asynchronously returns the last element of a sequence that satisfies a specified condition. @@ -580,7 +615,11 @@ public static Task LastAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.LastWithPredicate, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, TSource>(Queryable.Last).Method, + source, + predicate, + cancellationToken); } /// @@ -609,7 +648,10 @@ public static Task LastAsync( public static Task LastOrDefaultAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.LastOrDefaultWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource?>(Queryable.LastOrDefault).Method, + source, + cancellationToken); /// /// Asynchronously returns the last element of a sequence that satisfies a specified condition @@ -646,7 +688,11 @@ public static Task LastAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.LastOrDefaultWithPredicate, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, TSource?>(Queryable.LastOrDefault).Method, + source, + predicate, + cancellationToken); } #endregion @@ -690,7 +736,10 @@ public static Task LastAsync( public static Task SingleAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.SingleWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource>(Queryable.Single).Method, + source, + cancellationToken); /// /// Asynchronously returns the only element of a sequence that satisfies a specified condition, @@ -743,7 +792,11 @@ public static Task SingleAsync( { Check.NotNull(predicate); - return ExecuteAsync>(QueryableMethods.SingleWithPredicate, source, predicate, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, TSource>(Queryable.Single).Method, + source, + predicate, + cancellationToken); } /// @@ -775,7 +828,10 @@ public static Task SingleAsync( public static Task SingleOrDefaultAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.SingleOrDefaultWithoutPredicate, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource?>(Queryable.SingleOrDefault).Method, + source, + cancellationToken); /// /// Asynchronously returns the only element of a sequence that satisfies a specified condition or @@ -816,7 +872,10 @@ public static Task SingleAsync( Check.NotNull(predicate); return ExecuteAsync>( - QueryableMethods.SingleOrDefaultWithPredicate, source, predicate, cancellationToken); + method: new Func, Expression>, TSource?>(Queryable.SingleOrDefault).Method, + source, + predicate, + cancellationToken); } #endregion @@ -849,7 +908,10 @@ public static Task SingleAsync( public static Task MinAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.MinWithoutSelector, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource?>(Queryable.Min).Method, + source, + cancellationToken); /// /// Asynchronously invokes a projection function on each element of a sequence and returns the minimum resulting value. @@ -887,7 +949,11 @@ public static Task MinAsync( { Check.NotNull(selector); - return ExecuteAsync>(QueryableMethods.MinWithSelector, source, selector, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, TResult?>(Queryable.Min).Method, + source, + selector, + cancellationToken); } #endregion @@ -920,7 +986,10 @@ public static Task MinAsync( public static Task MaxAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.MaxWithoutSelector, source, cancellationToken); + => ExecuteAsync>( + method: new Func, TSource?>(Queryable.Max).Method, + source, + cancellationToken); /// /// Asynchronously invokes a projection function on each element of a sequence and returns the maximum resulting value. @@ -958,7 +1027,11 @@ public static Task MaxAsync( { Check.NotNull(selector); - return ExecuteAsync>(QueryableMethods.MaxWithSelector, source, selector, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, TResult?>(Queryable.Max).Method, + source, + selector, + cancellationToken); } #endregion @@ -989,7 +1062,10 @@ public static Task MaxAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(decimal)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, decimal>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of a sequence of values. @@ -1016,7 +1092,9 @@ public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) => ExecuteAsync>( - QueryableMethods.GetSumWithoutSelector(typeof(decimal?)), source, cancellationToken); + method: new Func, decimal?>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on @@ -1051,7 +1129,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(decimal)), source, selector, cancellationToken); + method: new Func, Expression>, decimal>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1087,7 +1168,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(decimal?)), source, selector, cancellationToken); + method: new Func, Expression>, decimal?>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1114,7 +1198,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(int)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, int>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of a sequence of values. @@ -1140,7 +1227,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(int?)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, int?>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on @@ -1174,7 +1264,11 @@ public static Task SumAsync( { Check.NotNull(selector); - return ExecuteAsync>(QueryableMethods.GetSumWithSelector(typeof(int)), source, selector, cancellationToken); + return ExecuteAsync>( + method: new Func, Expression>, int>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1210,7 +1304,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(int?)), source, selector, cancellationToken); + method: new Func, Expression>, int?>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1237,7 +1334,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(long)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, long>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of a sequence of values. @@ -1263,7 +1363,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(long?)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, long?>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on @@ -1298,7 +1401,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(long)), source, selector, cancellationToken); + method: new Func, Expression>, long>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1334,7 +1440,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(long?)), source, selector, cancellationToken); + method: new Func, Expression>, long?>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1361,7 +1470,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(double)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, double>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of a sequence of values. @@ -1387,7 +1499,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(double?)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, double?>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on @@ -1422,7 +1537,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(double)), source, selector, cancellationToken); + method: new Func, Expression>, double>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1458,7 +1576,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(double?)), source, selector, cancellationToken); + method: new Func, Expression>, double?>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1485,7 +1606,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(float)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, float>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of a sequence of values. @@ -1511,7 +1635,10 @@ public static Task SumAsync( public static Task SumAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetSumWithoutSelector(typeof(float?)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, float?>(Queryable.Sum).Method, + source, + cancellationToken); /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on @@ -1546,7 +1673,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(float)), source, selector, cancellationToken); + method: new Func, Expression>, float>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } /// @@ -1582,7 +1712,10 @@ public static Task SumAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetSumWithSelector(typeof(float?)), source, selector, cancellationToken); + method: new Func, Expression>, float?>(Queryable.Sum).Method, + source, + selector, + cancellationToken); } #endregion @@ -1615,7 +1748,9 @@ public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) => ExecuteAsync>( - QueryableMethods.GetAverageWithoutSelector(typeof(decimal)), source, cancellationToken); + method: new Func, decimal>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values. @@ -1642,7 +1777,9 @@ public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) => ExecuteAsync>( - QueryableMethods.GetAverageWithoutSelector(typeof(decimal?)), source, cancellationToken); + method: new Func, decimal?>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values that is obtained @@ -1679,7 +1816,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(decimal)), source, selector, cancellationToken); + method: new Func, Expression>, decimal>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -1716,7 +1856,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(decimal?)), source, selector, cancellationToken); + method: new Func, Expression>, decimal?>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -1744,7 +1887,10 @@ public static Task AverageAsync( public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetAverageWithoutSelector(typeof(int)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, double>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values. @@ -1770,7 +1916,10 @@ public static Task AverageAsync( public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetAverageWithoutSelector(typeof(int?)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, double?>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values that is obtained @@ -1807,7 +1956,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(int)), source, selector, cancellationToken); + method: new Func, Expression>, double>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -1844,7 +1996,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(int?)), source, selector, cancellationToken); + method: new Func, Expression>, double?>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -1872,7 +2027,10 @@ public static Task AverageAsync( public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetAverageWithoutSelector(typeof(long)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, double>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values. @@ -1898,7 +2056,10 @@ public static Task AverageAsync( public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetAverageWithoutSelector(typeof(long?)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, double?>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values that is obtained @@ -1935,7 +2096,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(long)), source, selector, cancellationToken); + method: new Func, Expression>, double>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -1972,7 +2136,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(long?)), source, selector, cancellationToken); + method: new Func, Expression>, double?>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -2001,7 +2168,9 @@ public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) => ExecuteAsync>( - QueryableMethods.GetAverageWithoutSelector(typeof(double)), source, cancellationToken); + method: new Func, double>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values. @@ -2028,7 +2197,9 @@ public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) => ExecuteAsync>( - QueryableMethods.GetAverageWithoutSelector(typeof(double?)), source, cancellationToken); + method: new Func, double?>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values that is obtained @@ -2065,7 +2236,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(double)), source, selector, cancellationToken); + method: new Func, Expression>, double>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -2102,7 +2276,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(double?)), source, selector, cancellationToken); + method: new Func, Expression>, double?>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -2130,7 +2307,10 @@ public static Task AverageAsync( public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync>(QueryableMethods.GetAverageWithoutSelector(typeof(float)), source, cancellationToken); + => ExecuteAsync>( + method: new Func, float>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values. @@ -2157,7 +2337,9 @@ public static Task AverageAsync( this IQueryable source, CancellationToken cancellationToken = default) => ExecuteAsync>( - QueryableMethods.GetAverageWithoutSelector(typeof(float?)), source, cancellationToken); + method: new Func, float?>(Queryable.Average).Method, + source, + cancellationToken); /// /// Asynchronously computes the average of a sequence of values that is obtained @@ -2194,7 +2376,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(float)), source, selector, cancellationToken); + method: new Func, Expression>, float>(Queryable.Average).Method, + source, + selector, + cancellationToken); } /// @@ -2231,7 +2416,10 @@ public static Task AverageAsync( Check.NotNull(selector); return ExecuteAsync>( - QueryableMethods.GetAverageWithSelector(typeof(float?)), source, selector, cancellationToken); + method: new Func, Expression>, float?>(Queryable.Average).Method, + source, + selector, + cancellationToken); } #endregion @@ -2267,7 +2455,7 @@ public static Task ContainsAsync( TSource item, CancellationToken cancellationToken = default) => ExecuteAsync>( - QueryableMethods.Contains, + method: new Func, TSource, bool>(Queryable.Contains).Method, source, Expression.Constant(item, typeof(TSource)), cancellationToken); @@ -2424,20 +2612,6 @@ public static async Task> ToHashSetAsync( #region Include - internal static readonly MethodInfo IncludeMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(Include)) - .Single(mi => - mi.GetGenericArguments().Length == 2 - && mi.GetParameters().Any(pi => pi.Name == "navigationPropertyPath" && pi.ParameterType != typeof(string))); - - internal static readonly MethodInfo NotQuiteIncludeMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(NotQuiteInclude)) - .Single(mi => - mi.GetGenericArguments().Length == 2 - && mi.GetParameters().Any(pi => pi.Name == "navigationPropertyPath" && pi.ParameterType != typeof(string))); - /// /// Specifies related entities to include in the query results. The navigation property to be included is specified starting with the /// type of entity being queried (). If you wish to include additional types based on the navigation @@ -2472,7 +2646,7 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: IncludeMethodInfo.MakeGenericMethod(typeof(TEntity), typeof(TProperty)), + method: new Func, Expression>, IQueryable>(Include).Method, arguments: [source.Expression, Expression.Quote(navigationPropertyPath)])) : source); } @@ -2487,27 +2661,10 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: NotQuiteIncludeMethodInfo.MakeGenericMethod(typeof(TEntity), typeof(TProperty)), + method: new Func, Expression>, IQueryable>(NotQuiteInclude).Method, arguments: [source.Expression, Expression.Quote(navigationPropertyPath)])) : source); - internal static readonly MethodInfo ThenIncludeAfterEnumerableMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(ThenInclude)) - .Where(mi => mi.GetGenericArguments().Length == 3) - .Single(mi => - { - var typeInfo = mi.GetParameters()[0].ParameterType.GenericTypeArguments[1]; - return typeInfo.IsGenericType - && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>); - }); - - internal static readonly MethodInfo ThenIncludeAfterReferenceMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(ThenInclude)) - .Single(mi => mi.GetGenericArguments().Length == 3 - && mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter); - /// /// Specifies additional related data to be further included based on a related type that was just included. /// @@ -2532,8 +2689,10 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: ThenIncludeAfterEnumerableMethodInfo.MakeGenericMethod( - typeof(TEntity), typeof(TPreviousProperty), typeof(TProperty)), + method: new Func< + IIncludableQueryable>, // source + Expression>, // navigationPropertyPath + IIncludableQueryable>(ThenInclude).Method, arguments: [source.Expression, Expression.Quote(navigationPropertyPath)])) : source); @@ -2561,8 +2720,10 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: ThenIncludeAfterReferenceMethodInfo.MakeGenericMethod( - typeof(TEntity), typeof(TPreviousProperty), typeof(TProperty)), + method: new Func< + IIncludableQueryable, // source + Expression>, // navigationPropertyPath + IIncludableQueryable>(ThenInclude).Method, arguments: [source.Expression, Expression.Quote(navigationPropertyPath)])) : source); @@ -2588,11 +2749,6 @@ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } - internal static readonly MethodInfo StringIncludeMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(Include)) - .Single(mi => mi.GetParameters().Any(pi => pi.Name == "navigationPropertyPath" && pi.ParameterType == typeof(string))); - /// /// Specifies related entities to include in the query results. The navigation property to be included is /// specified starting with the type of entity being queried (). Further @@ -2622,7 +2778,7 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: StringIncludeMethodInfo.MakeGenericMethod(typeof(TEntity)), + method: new Func, string, IQueryable>(Include).Method, arg0: source.Expression, arg1: Expression.Constant(navigationPropertyPath))) : source; @@ -2632,9 +2788,6 @@ source.Provider is EntityQueryProvider #region Auto included navigations - internal static readonly MethodInfo IgnoreAutoIncludesMethodInfo - = typeof(EntityFrameworkQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(IgnoreAutoIncludes))!; - /// /// Specifies that the current Entity Framework LINQ query should not have any model-level eager loaded navigations applied. /// @@ -2651,7 +2804,7 @@ public static IQueryable IgnoreAutoIncludes( ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: IgnoreAutoIncludesMethodInfo.MakeGenericMethod(typeof(TEntity)), + method: new Func, IQueryable>(IgnoreAutoIncludes).Method, arguments: source.Expression)) : source; @@ -2659,16 +2812,6 @@ public static IQueryable IgnoreAutoIncludes( #region Query Filters - internal static readonly MethodInfo IgnoreQueryFiltersMethodInfo - = typeof(EntityFrameworkQueryableExtensions).GetTypeInfo().GetDeclaredMethods(nameof(IgnoreQueryFilters)) - .Where(info => info.GetParameters().Length == 1) - .First(); - - internal static readonly MethodInfo IgnoreNamedQueryFiltersMethodInfo - = typeof(EntityFrameworkQueryableExtensions).GetTypeInfo().GetDeclaredMethods(nameof(IgnoreQueryFilters)) - .Where(info => info.GetParameters().Length == 2) - .First(); - /// /// Specifies that the current Entity Framework LINQ query should not have any model-level entity query filters applied. /// @@ -2686,7 +2829,7 @@ public static IQueryable IgnoreQueryFilters( ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: IgnoreQueryFiltersMethodInfo.MakeGenericMethod(typeof(TEntity)), + method: new Func, IQueryable>(IgnoreQueryFilters).Method, arguments: source.Expression)) : source; @@ -2709,7 +2852,7 @@ public static IQueryable IgnoreQueryFilters( ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: IgnoreNamedQueryFiltersMethodInfo.MakeGenericMethod(typeof(TEntity)), + method: new Func, IReadOnlyCollection, IQueryable>(IgnoreQueryFilters).Method, // converting the collection to an array if it isn't already one to ensure consistent caching. Fixes #37112. // #37212 may be a possible future solution providing broader capabilities around parameterizing collections. arguments: [source.Expression, Expression.Constant(filterKeys is string[] ? filterKeys : filterKeys.ToArray())])) @@ -2719,9 +2862,6 @@ public static IQueryable IgnoreQueryFilters( #region Tracking - internal static readonly MethodInfo AsNoTrackingMethodInfo - = typeof(EntityFrameworkQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(AsNoTracking))!; - /// /// The change tracker will not track any of the entities that are returned from a LINQ query. If the /// entity instances are modified, this will not be detected by the change tracker and @@ -2756,13 +2896,10 @@ public static IQueryable AsNoTracking( ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: AsNoTrackingMethodInfo.MakeGenericMethod(typeof(TEntity)), + method: new Func, IQueryable>(AsNoTracking).Method, arguments: source.Expression)) : source; - internal static readonly MethodInfo AsNoTrackingWithIdentityResolutionMethodInfo - = typeof(EntityFrameworkQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(AsNoTrackingWithIdentityResolution))!; - /// /// The change tracker will not track any of the entities that are returned from a LINQ query. If the /// entity instances are modified, this will not be detected by the change tracker and @@ -2797,16 +2934,10 @@ public static IQueryable AsNoTrackingWithIdentityResolution( ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: AsNoTrackingWithIdentityResolutionMethodInfo.MakeGenericMethod(typeof(TEntity)), + method: new Func, IQueryable>(AsNoTrackingWithIdentityResolution).Method, arguments: source.Expression)) : source; - internal static readonly MethodInfo AsTrackingMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo() - .GetDeclaredMethods(nameof(AsTracking)) - .Single(m => m.GetParameters().Length == 1); - /// /// Returns a new query where the change tracker will keep track of changes for all entities that are returned. /// Any modification to the entity instances will be detected and persisted to the database during @@ -2831,7 +2962,7 @@ public static IQueryable AsTracking( ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: AsTrackingMethodInfo.MakeGenericMethod(typeof(TEntity)), + method: new Func, IQueryable>(AsTracking).Method, arguments: source.Expression)) : source; @@ -2878,16 +3009,6 @@ public static IQueryable AsTracking( #region Tagging - internal static readonly MethodInfo TagWithMethodInfo - = typeof(EntityFrameworkQueryableExtensions).GetMethod( - nameof(TagWith), [typeof(IQueryable<>).MakeGenericType(Type.MakeGenericMethodParameter(0)), typeof(string)])!; - - internal static readonly MethodInfo TagWithCallSiteMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetMethod( - nameof(TagWithCallSite), - [typeof(IQueryable<>).MakeGenericType(Type.MakeGenericMethodParameter(0)), typeof(string), typeof(int)])!; - /// /// Adds a tag to the collection of tags associated with an EF LINQ query. Tags are query annotations /// that can provide contextual tracing information at different points in the query pipeline. @@ -2914,7 +3035,7 @@ source.Provider is EntityQueryProvider ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: TagWithMethodInfo.MakeGenericMethod(typeof(T)), + method: new Func, string, IQueryable>(TagWith).Method, arg0: source.Expression, arg1: Expression.Constant(tag))) : source; @@ -2943,7 +3064,7 @@ public static IQueryable TagWithCallSite( ? source.Provider.CreateQuery( Expression.Call( instance: null, - method: TagWithCallSiteMethodInfo.MakeGenericMethod(typeof(T)), + method: new Func, string, int, IQueryable>(TagWithCallSite).Method, arg0: source.Expression, arg1: Expression.Constant(filePath), arg2: Expression.Constant(lineNumber))) @@ -3247,48 +3368,33 @@ public static IAsyncEnumerable AsAsyncEnumerable( #region Impl. private static TResult ExecuteAsync( - MethodInfo operatorMethodInfo, + MethodInfo method, IQueryable source, - Expression? expression, + Expression? additionalArgument, CancellationToken cancellationToken = default) - { - if (source.Provider is IAsyncQueryProvider provider) - { - if (operatorMethodInfo.IsGenericMethod) - { - operatorMethodInfo - = operatorMethodInfo.GetGenericArguments().Length == 2 - ? operatorMethodInfo.MakeGenericMethod(typeof(TSource), typeof(TResult).GetGenericArguments().Single()) - : operatorMethodInfo.MakeGenericMethod(typeof(TSource)); - } - - return provider.ExecuteAsync( - Expression.Call( - instance: null, - method: operatorMethodInfo, - arguments: expression == null - ? [source.Expression] - : [source.Expression, expression]), - cancellationToken); - } - - throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync); - } + => source.Provider is IAsyncQueryProvider provider + ? provider.ExecuteAsync( + Expression.Call( + instance: null, + method: method, + arguments: additionalArgument is null + ? [source.Expression] + : [source.Expression, additionalArgument]), + cancellationToken) + : throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync); private static TResult ExecuteAsync( - MethodInfo operatorMethodInfo, + MethodInfo method, IQueryable source, - LambdaExpression expression, + LambdaExpression additionalArgument, CancellationToken cancellationToken = default) - => ExecuteAsync( - operatorMethodInfo, source, Expression.Quote(expression), cancellationToken); + => ExecuteAsync(method, source, Expression.Quote(additionalArgument), cancellationToken); private static TResult ExecuteAsync( - MethodInfo operatorMethodInfo, + MethodInfo method, IQueryable source, CancellationToken cancellationToken = default) - => ExecuteAsync( - operatorMethodInfo, source, (Expression?)null, cancellationToken); + => ExecuteAsync(method, source, (Expression?)null, cancellationToken); #endregion diff --git a/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs index e15dbc5b400..b121b85687a 100644 --- a/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs @@ -178,41 +178,36 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } if (method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) - && method.IsGenericMethod - && method.GetGenericMethodDefinition() is var genericMethod - && (genericMethod == EntityFrameworkQueryableExtensions.IncludeMethodInfo - || genericMethod == EntityFrameworkQueryableExtensions.ThenIncludeAfterEnumerableMethodInfo - || genericMethod == EntityFrameworkQueryableExtensions.ThenIncludeAfterReferenceMethodInfo - || genericMethod == EntityFrameworkQueryableExtensions.NotQuiteIncludeMethodInfo)) - { - var includeLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - if (includeLambda.ReturnType.IsGenericType - && includeLambda.ReturnType.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) + && method.Name is nameof(EntityFrameworkQueryableExtensions.Include) + or nameof(EntityFrameworkQueryableExtensions.ThenInclude) + or nameof(EntityFrameworkQueryableExtensions.NotQuiteInclude) + && methodCallExpression.Arguments[1].TryGetLambdaExpression(out var includeLambda) + && includeLambda.ReturnType.IsGenericType + && includeLambda.ReturnType.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) + { + var source = Visit(methodCallExpression.Arguments[0]); + var body = Visit(includeLambda.Body); + + // we have to rewrite the lambda to accommodate for IOrderedEnumerable<> into IOrderedQueryable<> conversion + var lambda = (Expression)Expression.Lambda(body, includeLambda.Parameters); + if (methodCallExpression.Arguments[1].NodeType == ExpressionType.Quote) { - var source = Visit(methodCallExpression.Arguments[0]); - var body = Visit(includeLambda.Body); - - // we have to rewrite the lambda to accommodate for IOrderedEnumerable<> into IOrderedQueryable<> conversion - var lambda = (Expression)Expression.Lambda(body, includeLambda.Parameters); - if (methodCallExpression.Arguments[1].NodeType == ExpressionType.Quote) - { - lambda = Expression.Quote(lambda); - } + lambda = Expression.Quote(lambda); + } - var genericArguments = methodCallExpression.Method.GetGenericArguments(); + var genericArguments = methodCallExpression.Method.GetGenericArguments(); - if (body.Type.IsGenericType - && body.Type.GetGenericTypeDefinition() == typeof(IOrderedQueryable<>)) - { - genericArguments[^1] = body.Type; - var newIncludeMethod = methodCallExpression.Method.GetGenericMethodDefinition() - .MakeGenericMethod(genericArguments); - - return Expression.Call(newIncludeMethod, source, lambda); - } + if (body.Type.IsGenericType + && body.Type.GetGenericTypeDefinition() == typeof(IOrderedQueryable<>)) + { + genericArguments[^1] = body.Type; + var newIncludeMethod = methodCallExpression.Method.GetGenericMethodDefinition() + .MakeGenericMethod(genericArguments); - return methodCallExpression.Update(null, [source, lambda]); + return Expression.Call(newIncludeMethod, source, lambda); } + + return methodCallExpression.Update(null, [source, lambda]); } if (visitedExpression == null) @@ -289,79 +284,81 @@ private static void VerifyReturnType(Expression expression, ParameterExpression private Expression? ExtractQueryMetadata(MethodCallExpression methodCallExpression) { // We visit innerQueryable first so that we can get information in the same order operators are applied. - var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition(); - - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.AsTrackingMethodInfo) - { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.QueryTrackingBehavior = QueryTrackingBehavior.TrackAll; - - return visitedExpression; - } - - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.AsNoTrackingMethodInfo) - { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking; - - return visitedExpression; - } - - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.AsNoTrackingWithIdentityResolutionMethodInfo) - { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.QueryTrackingBehavior = QueryTrackingBehavior.NoTrackingWithIdentityResolution; - - return visitedExpression; - } - - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.TagWithMethodInfo) + var method = methodCallExpression.Method; + if (method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)) { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.AddTag(methodCallExpression.Arguments[1].GetConstantValue()); + switch (method.Name) + { + case nameof(EntityFrameworkQueryableExtensions.AsTracking): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.QueryTrackingBehavior = QueryTrackingBehavior.TrackAll; + return visitedExpression; + } - return visitedExpression; - } + case nameof(EntityFrameworkQueryableExtensions.AsNoTracking): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking; + return visitedExpression; + } - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.TagWithCallSiteMethodInfo) - { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); + case nameof(EntityFrameworkQueryableExtensions.AsNoTrackingWithIdentityResolution): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.QueryTrackingBehavior = QueryTrackingBehavior.NoTrackingWithIdentityResolution; + return visitedExpression; + } - var filePath = methodCallExpression.Arguments[1].GetConstantValue(); - var lineNumber = methodCallExpression.Arguments[2].GetConstantValue(); - _queryCompilationContext.AddTag($"File: {filePath}:{lineNumber}"); + case nameof(EntityFrameworkQueryableExtensions.TagWith): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.AddTag(methodCallExpression.Arguments[1].GetConstantValue()); + return visitedExpression; + } - return visitedExpression; - } + case nameof(EntityFrameworkQueryableExtensions.TagWithCallSite): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + var filePath = methodCallExpression.Arguments[1].GetConstantValue(); + var lineNumber = methodCallExpression.Arguments[2].GetConstantValue(); + _queryCompilationContext.AddTag($"File: {filePath}:{lineNumber}"); + return visitedExpression; + } - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.IgnoreQueryFiltersMethodInfo) - { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.IgnoredQueryFilters = null; - _queryCompilationContext.IgnoreQueryFilters = true; + // For named query filters + case nameof(EntityFrameworkQueryableExtensions.IgnoreQueryFilters) + when methodCallExpression.Arguments is + [ + var source, + ConstantExpression { Value: IReadOnlyCollection filterKeys } + ]: + { + var visitedExpression = Visit(source); + if (filterKeys?.Count > 0) + { + _queryCompilationContext.IgnoredQueryFilters ??= []; + _queryCompilationContext.IgnoredQueryFilters.UnionWith(filterKeys); + } + return visitedExpression; + } - return visitedExpression; - } + // For unnamed query filters + case nameof(EntityFrameworkQueryableExtensions.IgnoreQueryFilters): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.IgnoredQueryFilters = null; + _queryCompilationContext.IgnoreQueryFilters = true; + return visitedExpression; + } - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.IgnoreNamedQueryFiltersMethodInfo) - { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); - var filterKeys = methodCallExpression.Arguments[1].GetConstantValue>(); - if (filterKeys?.Count > 0) - { - _queryCompilationContext.IgnoredQueryFilters ??= []; - _queryCompilationContext.IgnoredQueryFilters.UnionWith(filterKeys); + case nameof(EntityFrameworkQueryableExtensions.IgnoreAutoIncludes): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.IgnoreAutoIncludes = true; + return visitedExpression; + } } - - return visitedExpression; - } - - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.IgnoreAutoIncludesMethodInfo) - { - var visitedExpression = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.IgnoreAutoIncludes = true; - - return visitedExpression; } return null; diff --git a/src/Shared/ExpressionExtensions.cs b/src/Shared/ExpressionExtensions.cs index eeb6f0f9ad8..84b2788e275 100644 --- a/src/Shared/ExpressionExtensions.cs +++ b/src/Shared/ExpressionExtensions.cs @@ -15,6 +15,18 @@ internal static class ExpressionExtensions public static bool IsNullConstantExpression(this Expression expression) => RemoveConvert(expression) is ConstantExpression { Value: null }; + public static bool TryGetLambdaExpression(this Expression expression, [NotNullWhen(true)] out LambdaExpression? lambdaExpression) + { + lambdaExpression = expression switch + { + UnaryExpression { NodeType: ExpressionType.Quote, Operand: LambdaExpression lambda } => lambda, + LambdaExpression lambda => lambda, + _ => null + }; + + return lambdaExpression is not null; + } + public static LambdaExpression UnwrapLambdaFromQuote(this Expression expression) => (LambdaExpression)(expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote ? unary.Operand