diff --git a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs index fd5998c93b3..9d7097aaf94 100644 --- a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs +++ b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs @@ -50,6 +50,48 @@ public static TValue Deserialize(this IBsonSerializer serializer return serializer.Deserialize(context, args); } + /// + /// Gets the serializer for a base type starting from a serializer for a derived type. + /// + /// The serializer for the derived type. + /// The base type. + /// The serializer for the base type. + public static IBsonSerializer GetBaseTypeSerializer(this IBsonSerializer derivedTypeSerializer, Type baseType) + { + if (derivedTypeSerializer.ValueType == baseType) + { + return derivedTypeSerializer; + } + + if (!baseType.IsAssignableFrom(derivedTypeSerializer.ValueType)) + { + throw new ArgumentException($"{baseType} is not assignable from {derivedTypeSerializer.ValueType}."); + } + + return BsonSerializer.LookupSerializer(baseType); // TODO: should be able to navigate from serializer + } + + /// + /// Gets the serializer for a derived type starting from a serializer for a base type. + /// + /// The serializer for the base type. + /// The derived type. + /// The serializer for the derived type. + public static IBsonSerializer GetDerivedTypeSerializer(this IBsonSerializer baseTypeSerializer, Type derivedType) + { + if (baseTypeSerializer.ValueType == derivedType) + { + return baseTypeSerializer; + } + + if (!baseTypeSerializer.ValueType.IsAssignableFrom(derivedType)) + { + throw new ArgumentException($"{baseTypeSerializer.ValueType} is not assignable from {derivedType}."); + } + + return BsonSerializer.LookupSerializer(derivedType); // TODO: should be able to navigate from serializer + } + /// /// Gets the discriminator convention for a serializer. /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs index f10cb541d16..e90210fbc14 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs @@ -13,10 +13,29 @@ * limitations under the License. */ +using System; using System.Collections.Generic; namespace MongoDB.Bson.Serialization.Serializers { + /// + /// A static factory class for ArraySerializers. + /// + public static class ArraySerializer + { + /// + /// Creates an ArraySerializer. + /// + /// The item serializer. + /// An ArraySerializer. + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var itemType = itemSerializer.ValueType; + var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); + return (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + } + } + /// /// Represents a serializer for one-dimensional arrays. /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs index 8740bdd3a9b..423b9500bed 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs @@ -33,6 +33,73 @@ public interface INullableSerializer /// public static class NullableSerializer { + private readonly static IBsonSerializer __nullableBooleanInstance = new NullableSerializer(BooleanSerializer.Instance); + private readonly static IBsonSerializer __nullableDecimalInstance = new NullableSerializer(DecimalSerializer.Instance); + private readonly static IBsonSerializer __nullableDecimal128Instance = new NullableSerializer(Decimal128Serializer.Instance); + private readonly static IBsonSerializer __nullableDoubleInstance = new NullableSerializer(DoubleSerializer.Instance); + private readonly static IBsonSerializer __nullableInt32Instance = new NullableSerializer(Int32Serializer.Instance); + private readonly static IBsonSerializer __nullableInt64Instance = new NullableSerializer(Int64Serializer.Instance); + private readonly static IBsonSerializer __nullableLocalDateTimeInstance = new NullableSerializer(DateTimeSerializer.LocalInstance); + private readonly static IBsonSerializer __nullableObjectIdInstance = new NullableSerializer(ObjectIdSerializer.Instance); + private readonly static IBsonSerializer __nullableSingleInstance = new NullableSerializer(SingleSerializer.Instance); + private readonly static IBsonSerializer __nullableStandardGuidInstance = new NullableSerializer(GuidSerializer.StandardInstance); + private readonly static IBsonSerializer __nullableUtcDateTimeInstance = new NullableSerializer(DateTimeSerializer.UtcInstance); + + /// + /// Gets a serializer for nullable bools. + /// + public static IBsonSerializer NullableBooleanInstance => __nullableBooleanInstance; + + /// + /// Gets a serializer for nullable decimals. + /// + public static IBsonSerializer NullableDecimalInstance => __nullableDecimalInstance; + + /// + /// Gets a serializer for nullable Decimal128s. + /// + public static IBsonSerializer NullableDecimal128Instance => __nullableDecimal128Instance; + + /// + /// Gets a serializer for nullable doubles. + /// + public static IBsonSerializer NullableDoubleInstance => __nullableDoubleInstance; + + /// + /// Gets a serializer for nullable ints. + /// + public static IBsonSerializer NullableInt32Instance => __nullableInt32Instance; + + /// + /// Gets a serializer for nullable longs. + /// + public static IBsonSerializer NullableInt64Instance => __nullableInt64Instance; + + /// + /// Gets a serializer for local DateTime. + /// + public static IBsonSerializer NullableLocalDateTimeInstance => __nullableLocalDateTimeInstance; + + /// + /// Gets a serializer for nullable floats. + /// + public static IBsonSerializer NullableSingleInstance => __nullableSingleInstance; + + /// + /// Gets a serializer for nullable ObjectIds. + /// + public static IBsonSerializer NullableObjectIdInstance => __nullableObjectIdInstance; + + /// + /// Gets a serializer for nullable Guids with standard representation. + /// + public static IBsonSerializer NullableStandardGuidInstance => __nullableStandardGuidInstance; + + /// + /// Gets a serializer for UTC DateTime. + /// + public static IBsonSerializer NullableUtcDateTimeInstance => __nullableUtcDateTimeInstance; + /// /// Creates a NullableSerializer. /// diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs index db7618ce677..e80842589d8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs @@ -60,5 +60,17 @@ public static TValue GetConstantValue(this Expression expression, Expres var message = $"Expression must be a constant: {expression} in {containingExpression}."; throw new ExpressionNotSupportedException(message); } + + public static bool IsConvert(this Expression expression, out Expression operand) + { + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression) + { + operand = unaryExpression.Operand; + return true; + } + + operand = null; + return false; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 4d62eaea95c..38e93428c37 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -64,7 +64,8 @@ private AstStage RenderProjectStage( out IBsonSerializer outputSerializer) { var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(_output); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedOutput.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedOutput, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); outputSerializer = (IBsonSerializer)projectSerializer; @@ -106,7 +107,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -150,7 +152,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer, TInput>> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -188,7 +191,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar); var groupBySerializer = (IBsonSerializer)groupByTranslation.Serializer; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinder.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinder.cs new file mode 100644 index 00000000000..fd8f0fa76f9 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinder.cs @@ -0,0 +1,81 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal static class KnownSerializerFinder +{ + public static KnownSerializerMap FindKnownSerializers( + Expression expression, + ExpressionTranslationOptions translationOptions) + { + var knownSerializers = new KnownSerializerMap(); + return FindKnownSerializers(expression, translationOptions, knownSerializers); + } + + public static KnownSerializerMap FindKnownSerializers( + Expression expression, + ExpressionTranslationOptions translationOptions, + Expression initialNode, + IBsonSerializer knownSerializer) + { + var knownSerializers = new KnownSerializerMap(); + knownSerializers.AddSerializer(initialNode, knownSerializer); + return FindKnownSerializers(expression, translationOptions, knownSerializers); + } + + public static KnownSerializerMap FindKnownSerializers( + Expression expression, + ExpressionTranslationOptions translationOptions, + (Expression Node, IBsonSerializer KnownSerializer)[] initialNodes) + { + var knownSerializers = new KnownSerializerMap(); + foreach (var (initialNode, knownSerializer) in initialNodes) + { + knownSerializers.AddSerializer(initialNode, knownSerializer); + + } + return FindKnownSerializers(expression, translationOptions, knownSerializers); + } + + public static KnownSerializerMap FindKnownSerializers( + Expression expression, + ExpressionTranslationOptions translationOptions, + KnownSerializerMap knownSerializers) + { + var visitor = new KnownSerializerFinderVisitor(translationOptions, knownSerializers); + + do + { + visitor.StartPass(); + visitor.Visit(expression); + visitor.EndPass(); + } + while (visitor.IsMakingProgress); + + //#if DEBUG + var expressionWithUnknownSerializer = UnknownSerializerFinder.FindExpressionWithUnknownSerializer(expression, knownSerializers); + if (expressionWithUnknownSerializer != null) + { + throw new ExpressionNotSupportedException(expressionWithUnknownSerializer, because: "we were unable to determine which serializer to use for the result"); + } + //#endif + + return knownSerializers; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderHelperMethods.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderHelperMethods.cs new file mode 100644 index 00000000000..99c316f8563 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderHelperMethods.cs @@ -0,0 +1,243 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using IOrderedEnumerableSerializer=MongoDB.Driver.Linq.Linq3Implementation.Serializers.IOrderedEnumerableSerializer; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + private void AddKnownSerializer(Expression node, IBsonSerializer serializer) => _knownSerializers.AddSerializer(node, serializer); + + private bool AllAreKnown(IEnumerable nodes, out IReadOnlyList knownSerializers) + { + var knownSerializersList = new List(); + foreach (var node in nodes) + { + if (IsKnown(node, out var nodeSerializer)) + { + knownSerializersList.Add(nodeSerializer); + } + else + { + knownSerializers = null; + return false; + } + } + + knownSerializers = knownSerializersList; + return true; + } + + private bool AnyIsKnown(IEnumerable nodes, out IBsonSerializer knownSerializer) + { + foreach (var node in nodes) + { + if (IsKnown(node, out var nodeSerializer)) + { + knownSerializer = nodeSerializer; + return true; + } + } + + knownSerializer = null; + return false; + } + + private bool AnyIsNotKnown(IEnumerable nodes) + { + return nodes.Any(IsNotKnown); + } + + private bool CanDeduceSerializer(Expression node1, Expression node2, out Expression unknownNode, out IBsonSerializer knownSerializer) + { + if (IsNotKnown(node1) && IsKnown(node2, out var node2Serializer)) + { + unknownNode = node1; + knownSerializer = node2Serializer; + return true; + } + + if (IsNotKnown(node2) && IsKnown(node1, out var node1Serializer)) + { + unknownNode = node2; + knownSerializer = node1Serializer; + return true; + } + + unknownNode = null; + knownSerializer = null; + return false; + } + + IBsonSerializer CreateCollectionSerializerFromCollectionSerializer(Type collectionType, IBsonSerializer collectionSerializer) + { + if (collectionSerializer.ValueType == collectionType) + { + return collectionSerializer; + } + + if (collectionSerializer is IUnknowableSerializer) + { + return UnknowableSerializer.Create(collectionType); + } + + var itemSerializer = collectionSerializer.GetItemSerializer(); + return CreateCollectionSerializerFromItemSerializer(collectionType, itemSerializer); + } + + IBsonSerializer CreateCollectionSerializerFromItemSerializer(Type collectionType, IBsonSerializer itemSerializer) + { + if (itemSerializer is IUnknowableSerializer) + { + return UnknowableSerializer.Create(collectionType); + } + + return collectionType switch + { + _ when collectionType.IsArray => ArraySerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IEnumerable<>) => IEnumerableSerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>) => IOrderedEnumerableSerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IQueryable<>) => IQueryableSerializer.Create(itemSerializer), + _ => (BsonSerializer.LookupSerializer(collectionType) as IChildSerializerConfigurable)?.WithChildSerializer(itemSerializer) + }; + } + + private void DeduceBaseTypeAndDerivedTypeSerializers(Expression baseTypeExpression, Expression derivedTypeExpression) + { + if (IsNotKnown(baseTypeExpression) && IsKnown(derivedTypeExpression, out var knownDerivedTypeSerializer)) + { + var baseTypeSerializer = knownDerivedTypeSerializer.GetBaseTypeSerializer(baseTypeExpression.Type); + AddKnownSerializer(baseTypeExpression, baseTypeSerializer); + } + + if (IsNotKnown(derivedTypeExpression) && IsKnown(baseTypeExpression, out var knownBaseTypeSerializer)) + { + var derivedTypeSerializer = knownBaseTypeSerializer.GetDerivedTypeSerializer(baseTypeExpression.Type); + AddKnownSerializer(derivedTypeExpression, derivedTypeSerializer); + } + } + + private void DeduceBooleanSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, BooleanSerializer.Instance); + } + } + + private void DeduceCollectionAndCollectionSerializers(Expression collectionExpression1, Expression collectionExpression2) + { + if (IsNotKnown(collectionExpression1) && IsKnown(collectionExpression2, out var knownCollectionSerializer2)) + { + var collectionSerializer1 = CreateCollectionSerializerFromCollectionSerializer(collectionExpression1.Type, knownCollectionSerializer2); + AddKnownSerializer(collectionExpression1, collectionSerializer1); + } + + if (IsNotKnown(collectionExpression2) && IsKnown(collectionExpression1, out var knownCollectionSerializer1)) + { + var collectionSerializer2 = CreateCollectionSerializerFromCollectionSerializer(collectionExpression2.Type, knownCollectionSerializer1); + AddKnownSerializer(collectionExpression2, collectionSerializer2); + } + } + + private void DeduceCollectionAndItemSerializers(Expression collectionExpression, Expression itemExpression) + { + DeduceItemAndCollectionSerializers(itemExpression, collectionExpression); + } + + private void DeduceItemAndCollectionSerializers(Expression itemExpression, Expression collectionExpression) + { + if (IsNotKnown(itemExpression) && IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + AddKnownSerializer(itemExpression, itemSerializer); + } + + if (IsNotKnown(collectionExpression) && IsKnown(itemExpression, out itemSerializer)) + { + var collectionSerializer = CreateCollectionSerializerFromItemSerializer(collectionExpression.Type, itemSerializer); + if (collectionSerializer != null) + { + AddKnownSerializer(collectionExpression, collectionSerializer); + } + } + } + + private void DeduceSerializer(Expression node, IBsonSerializer serializer) + { + if (IsNotKnown(node) && serializer != null) + { + AddKnownSerializer(node, serializer); + } + } + + private void DeduceSerializers(Expression expression1, Expression expression2) + { + if (IsNotKnown(expression1) && IsKnown(expression2, out var expression2Serializer) && expression2Serializer.ValueType == expression1.Type) + { + AddKnownSerializer(expression1, expression2Serializer); + } + + if (IsNotKnown(expression2) && IsKnown(expression1, out var expression1Serializer)&& expression1Serializer.ValueType == expression2.Type) + { + AddKnownSerializer(expression2, expression1Serializer); + } + } + + private void DeduceStringSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, StringSerializer.Instance); + } + } + + private void DeduceUnknowableSerializer(Expression node) + { + if (IsNotKnown(node)) + { + var unknowableSerializer = UnknowableSerializer.Create(node.Type); + AddKnownSerializer(node, unknowableSerializer); + } + } + + private bool IsItemSerializerKnown(Expression node, out IBsonSerializer itemSerializer) + { + if (IsKnown(node, out var nodeSerializer) && + nodeSerializer is IBsonArraySerializer arraySerializer && + arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + itemSerializer = itemSerializationInfo.Serializer; + return true; + } + + itemSerializer = null; + return false; + } + + private bool IsKnown(Expression node) => _knownSerializers.IsKnown(node); + + private bool IsKnown(Expression node, out IBsonSerializer knownSerializer) => _knownSerializers.IsKnown(node, out knownSerializer); + + private bool IsNotKnown(Expression node) => _knownSerializers.IsNotKnown(node); +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderNewExpressionSerializerCreator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderNewExpressionSerializerCreator.cs new file mode 100644 index 00000000000..607389a4898 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderNewExpressionSerializerCreator.cs @@ -0,0 +1,201 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + public IBsonSerializer CreateNewExpressionSerializer( + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings) + { + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); + + if (constructorInfo != null && creatorMap != null) + { + var constructorParameters = constructorInfo.GetParameters(); + var creatorMapParameters = creatorMap.Arguments?.ToArray(); + if (constructorParameters.Length > 0) + { + if (creatorMapParameters == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); + } + + if (creatorMapParameters.Length != constructorParameters.Length) + { + throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); + } + + for (var i = 0; i < creatorMapParameters.Length; i++) + { + var creatorMapParameter = creatorMapParameters[i]; + var constructorArgumentExpression = constructorArguments[i]; + if (!IsKnown(constructorArgumentExpression, out var constructorArgumentSerializer)) + { + return null; + } + var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); + EnsureDefaultValue(memberMap); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); + memberMap.SetSerializer(memberSerializer); + } + } + } + + if (bindings != null) + { + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + var memberMap = FindMemberMap(expression, classMap, member.Name); + var valueExpression = memberAssignment.Expression; + if (!IsKnown(valueExpression, out var valueSerializer)) + { + return null; + } + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueSerializer); + memberMap.SetSerializer(memberSerializer); + } + } + + classMap.Freeze(); + + var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); + return (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); + } + + private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) + { + BsonClassMap baseClassMap = null; + if (classType.BaseType != null) + { + baseClassMap = CreateClassMap(classType.BaseType, null, out _); + } + + var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); + var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); + var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); + if (constructorInfo != null) + { + creatorMap = classMap.MapConstructor(constructorInfo); + } + else + { + creatorMap = null; + } + + classMap.AutoMap(); + classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here + + return classMap; + } + + private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) + { + var memberType = memberMap.MemberType; + var memberSerializer = memberMap.GetSerializer(); + var sourceType = sourceSerializer.ValueType; + + if (memberType != sourceType && + memberType.ImplementsIEnumerable(out var memberItemType) && + sourceType.ImplementsIEnumerable(out var sourceItemType) && + sourceItemType == memberItemType && + sourceSerializer is IBsonArraySerializer sourceArraySerializer && + sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && + memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) + { + var sourceItemSerializer = sourceItemSerializationInfo.Serializer; + return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); + } + + return sourceSerializer; + } + + private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) + { + var declaringClassMap = classMap; + while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) + { + declaringClassMap = declaringClassMap.BaseClassMap; + + if (declaringClassMap == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + } + } + + foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + { + if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + { + return memberMap; + } + } + + return declaringClassMap.MapMember(creatorMapParameter); + + static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) + { + var memberInfo = memberMap.MemberInfo; + return + memberInfo.MemberType == creatorMapParameter.MemberType && + memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); + } + } + + private static void EnsureDefaultValue(BsonMemberMap memberMap) + { + if (memberMap.IsDefaultValueSpecified) + { + return; + } + + var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; + memberMap.SetDefaultValue(defaultValue); + } + + private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) + { + foreach (var memberMap in classMap.DeclaredMemberMaps) + { + if (memberMap.MemberName == memberName) + { + return memberMap; + } + } + + if (classMap.BaseClassMap != null) + { + return FindMemberMap(expression, classMap.BaseClassMap, memberName); + } + + throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitBinary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitBinary.cs new file mode 100644 index 00000000000..cfda8538b0b --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitBinary.cs @@ -0,0 +1,182 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitBinary(BinaryExpression node) + { + base.VisitBinary(node); + + var @operator = node.NodeType; + var leftExpression = node.Left; + var rightExpression = node.Right; + + if (node.NodeType == ExpressionType.Add && node.Type == typeof(string)) + { + DeduceStringSerializer(node); + return node; + } + + if (IsSymmetricalBinaryOperator(@operator) && + CanDeduceSerializer(leftExpression, rightExpression, out var unknownNode, out var knownSerializer)) + { + // expr1 op expr2 => expr1: expr2Serializer or expr2: expr1Serializer + if (knownSerializer.ValueType == unknownNode.Type) + { + AddKnownSerializer(unknownNode, knownSerializer); + } + } + + if (@operator == ExpressionType.ArrayIndex) + { + if (IsNotKnown(node) && + IsKnown(leftExpression, out var leftSerializer)) + { + IBsonSerializer itemSerializer; + if (leftSerializer is IFixedSizeArraySerializer fixedSizeArraySerializer) + { + var index = rightExpression.GetConstantValue(node); + itemSerializer = fixedSizeArraySerializer.GetItemSerializer(index); + } + else + { + itemSerializer = leftSerializer.GetItemSerializer(); + } + + // expr[index] => node: itemSerializer + AddKnownSerializer(node, itemSerializer); + } + } + + if (@operator == ExpressionType.Coalesce) + { + if (IsNotKnown(node) && + IsKnown(leftExpression, out var leftSerializer)) + { + if (leftSerializer.ValueType == node.Type) + { + AddKnownSerializer(node, leftSerializer); + } + else if ( + leftSerializer is INullableSerializer nullableSerializer && + nullableSerializer.ValueSerializer is var nullableSerializerValueSerializer && + nullableSerializerValueSerializer.ValueType == node.Type) + { + AddKnownSerializer(node, nullableSerializerValueSerializer); + } + else + { + DeduceUnknowableSerializer(node); + } + } + } + + if (leftExpression.IsConvert(out var leftConvertOperand) && + rightExpression.IsConvert(out var rightConvertOperand) && + leftConvertOperand.Type == rightConvertOperand.Type) + { + // TODO: verify left and right operands are same type + if (CanDeduceSerializer(leftConvertOperand, rightConvertOperand, out unknownNode, out knownSerializer)) + { + // Convert(expr1, T) op Convert(expr2, T) => expr1: expr2Serializer or expr2: expr1Serializer + AddKnownSerializer(unknownNode, knownSerializer); + } + } + + if (IsNotKnown(node)) + { + var resultSerializer = GetResultSerializer(node, @operator); + if (resultSerializer != null) + { + AddKnownSerializer(node, resultSerializer); + } + } + + return node; + + static IBsonSerializer GetResultSerializer(Expression node, ExpressionType @operator) + { + switch (@operator) + { + case ExpressionType.And: + case ExpressionType.ExclusiveOr: + case ExpressionType.Or: + switch (node.Type) + { + case Type t when t == typeof(bool): return BooleanSerializer.Instance; + case Type t when t == typeof(int): return Int32Serializer.Instance; + } + goto default; + + case ExpressionType.AndAlso: + case ExpressionType.Equal: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.NotEqual: + case ExpressionType.OrElse: + case ExpressionType.TypeEqual: + return BooleanSerializer.Instance; + + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + if (StandardSerializers.TryGetSerializer(node.Type, out var resultSerializer)) + { + return resultSerializer; + } + goto default; + + default: + return null; + } + } + + static bool IsSymmetricalBinaryOperator(ExpressionType @operator) + => @operator is + ExpressionType.Add or + ExpressionType.AddChecked or + ExpressionType.And or + ExpressionType.AndAlso or + ExpressionType.Coalesce or + ExpressionType.Divide or + ExpressionType.Equal or + ExpressionType.GreaterThan or + ExpressionType.GreaterThanOrEqual or + ExpressionType.Modulo or + ExpressionType.Multiply or + ExpressionType.MultiplyChecked or + ExpressionType.Or or + ExpressionType.OrElse or + ExpressionType.Subtract or + ExpressionType.SubtractChecked; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitConditional.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitConditional.cs new file mode 100644 index 00000000000..86eaff201e2 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitConditional.cs @@ -0,0 +1,40 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitConditional(ConditionalExpression node) + { + var ifTrueExpression = node.IfTrue; + var ifFalseExpression = node.IfFalse; + + DeduceConditionaSerializers(); + base.VisitConditional(node); + DeduceConditionaSerializers(); + + return node; + + void DeduceConditionaSerializers() + { + DeduceBaseTypeAndDerivedTypeSerializers(node, ifTrueExpression); + DeduceBaseTypeAndDerivedTypeSerializers(node, ifFalseExpression); + DeduceBaseTypeAndDerivedTypeSerializers(node, ifTrueExpression); // call a second time in case ifFalse is the only known serializer + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitConstant.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitConstant.cs new file mode 100644 index 00000000000..425aa27bfac --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitConstant.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitConstant(ConstantExpression node) + { + if (IsNotKnown(node) && _useDefaultSerializerForConstants) + { + if (StandardSerializers.TryGetSerializer(node.Type, out var standardSerializer)) + { + AddKnownSerializer(node, standardSerializer); + } + else + { + var registeredSerializer = BsonSerializer.LookupSerializer(node.Type); // TODO: don't use static registry + AddKnownSerializer(node, registeredSerializer); + } + } + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitIndex.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitIndex.cs new file mode 100644 index 00000000000..464848d4fd5 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitIndex.cs @@ -0,0 +1,86 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitIndex(IndexExpression node) + { + base.VisitIndex(node); + + var collectionExpression = node.Object; + var indexer = node.Indexer; + var arguments = node.Arguments; + + if (IsBsonValueIndexer()) + { + DeduceSerializer(node, BsonValueSerializer.Instance); + } + else if (IsDictionaryIndexer()) + { + if (IsKnown(collectionExpression, out var collectionSerializer) && + collectionSerializer is IBsonDictionarySerializer dictionarySerializer) + { + var valueSerializer = dictionarySerializer.ValueSerializer; + DeduceSerializer(node, valueSerializer); + } + } + // check array indexer AFTER dictionary indexer + else if (IsCollectionIndexer()) + { + if (IsKnown(collectionExpression, out var collectionSerializer) && + collectionSerializer is IBsonArraySerializer arraySerializer) + { + var itemSerializer = arraySerializer.GetItemSerializer(); + DeduceSerializer(node, itemSerializer); + } + } + // handle generic cases? + + return node; + + bool IsCollectionIndexer() + { + return + arguments.Count == 1 && + arguments[0] is var index && + index.Type == typeof(int); + } + + bool IsBsonValueIndexer() + { + var declaringType = indexer.DeclaringType; + return + (declaringType == typeof(BsonValue) || declaringType.IsSubclassOf(typeof(BsonValue))) && + arguments.Count == 1 && + arguments[0] is var index && + (index.Type == typeof(int) || index.Type == typeof(string)); + } + + bool IsDictionaryIndexer() + { + return + indexer.DeclaringType.Name.Contains("Dictionary") && + arguments.Count == 1; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitLambda.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitLambda.cs new file mode 100644 index 00000000000..db565dd897a --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitLambda.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitLambda(Expression node) + { + if (IsNotKnown(node)) + { + var ignoreNodeSerializer = IgnoreNodeSerializer.Create(node.Type); + AddKnownSerializer(node, ignoreNodeSerializer); + } + + return base.VisitLambda(node); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitListInit.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitListInit.cs new file mode 100644 index 00000000000..e8000018e59 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitListInit.cs @@ -0,0 +1,38 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitListInit(ListInitExpression node) + { + var newExpression = node.NewExpression; + var initializers = node.Initializers; + + DeduceSerialiers(); + base.VisitListInit(node); + DeduceSerialiers(); + + return node; + + void DeduceSerialiers() + { + DeduceSerializers(node, newExpression); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMember.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMember.cs new file mode 100644 index 00000000000..c8483b59969 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMember.cs @@ -0,0 +1,206 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Support; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitMember(MemberExpression node) + { + IBsonSerializer containingSerializer; + var member = node.Member; + var declaringType = member.DeclaringType; + var memberName = member.Name; + + base.VisitMember(node); + + if (IsNotKnown(node)) + { + var containingExpression = node.Expression; + if (IsKnown(containingExpression, out containingSerializer)) + { + // TODO: handle special cases like DateTime.Year etc. + + var resultSerializer = node.Member switch + { + _ when declaringType == typeof(BsonValue) => GetBsonValuePropertySerializer(), + _ when IsCollectionCountOrLengthProperty() => GetCollectionCountOrLengthPropertySerializer(), + _ when declaringType == typeof(DateTime) => GetDateTimePropertySerializer(), + _ when declaringType.IsNullable() => GetNullablePropertySerializer(), + _ when IsTupleOrValueTuple(declaringType) => GetTupleOrValueTuplePropertySerializer(), + _ => GetPropertySerializer() + }; + + AddKnownSerializer(node, resultSerializer); + } + } + + return node; + + IBsonSerializer GetBsonValuePropertySerializer() + { + return memberName switch + { + "AsBoolean" => BooleanSerializer.Instance, + "AsBsonArray" => BsonArraySerializer.Instance, + "AsBsonBinaryData" => BsonBinaryDataSerializer.Instance, + "AsBsonDateTime" => BsonDateTimeSerializer.Instance, + "AsBsonDocument" => BsonDocumentSerializer.Instance, + "AsBsonJavaScript" => BsonJavaScriptSerializer.Instance, + "AsBsonJavaScriptWithScope" => BsonJavaScriptWithScopeSerializer.Instance, + "AsBsonMaxKey" => BsonMaxKeySerializer.Instance, + "AsBsonMinKey" => BsonMinKeySerializer.Instance, + "AsBsonNull" => BsonNullSerializer.Instance, + "AsBsonRegularExpression" => BsonRegularExpressionSerializer.Instance, + "AsBsonSymbol" => BsonSymbolSerializer.Instance, + "AsBsonTimestamp" => BsonTimestampSerializer.Instance, + "AsBsonUndefined" => BsonUndefinedSerializer.Instance, + "AsBsonValue" => BsonValueSerializer.Instance, + "AsByteArray" => ByteArraySerializer.Instance, + "AsDecimal128" => Decimal128Serializer.Instance, + "AsDecimal" => DecimalSerializer.Instance, + "AsDouble" => DoubleSerializer.Instance, + "AsGuid" => GuidSerializer.StandardInstance, + "AsInt32" => Int32Serializer.Instance, + "AsInt64" => Int64Serializer.Instance, + "AsLocalTime" => DateTimeSerializer.LocalInstance, + "AsNullableBoolean" => NullableSerializer.NullableBooleanInstance, + "AsNullableDecimal128" => NullableSerializer.NullableDecimal128Instance, + "AsNullableDecimal" => NullableSerializer.NullableDecimalInstance, + "AsNullableDouble" => NullableSerializer.NullableDoubleInstance, + "AsNullableGuid" => NullableSerializer.NullableStandardGuidInstance, + "AsNullableInt32" => NullableSerializer.NullableInt32Instance, + "AsNullableInt64" => NullableSerializer.NullableInt64Instance, + "AsNullableLocalTime" => NullableSerializer.NullableLocalDateTimeInstance, + "AsNullableObjectId" => NullableSerializer.NullableObjectIdInstance, + "AsNullableUniversalTime" => NullableSerializer.NullableUtcDateTimeInstance, + "AsObjectId" => ObjectIdSerializer.Instance, + "AsRegex" => RegexSerializer.RegularExpressionInstance, + "AsString" => StringSerializer.Instance, + "AsUniversalTime" => DateTimeSerializer.UtcInstance, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetCollectionCountOrLengthPropertySerializer() + { + return Int32Serializer.Instance; + } + + IBsonSerializer GetDateTimePropertySerializer() + { + return memberName switch + { + "Date" => DateTimeSerializer.Instance, + "Day" => Int32Serializer.Instance, + "DayOfWeek" => new EnumSerializer(BsonType.Int32), + "DayOfYear" => Int32Serializer.Instance, + "Hour" => Int32Serializer.Instance, + "Millisecond" => Int32Serializer.Instance, + "Minute" => Int32Serializer.Instance, + "Month" => Int32Serializer.Instance, + "Now" => DateTimeSerializer.Instance, + "Second" => Int32Serializer.Instance, + "Ticks" => Int64Serializer.Instance, + "TimeOfDay" => new TimeSpanSerializer(BsonType.Int64, TimeSpanUnits.Milliseconds), + "Today" => DateTimeSerializer.Instance, + "UtcNow" => DateTimeSerializer.Instance, + "Year" => Int32Serializer.Instance, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetNullablePropertySerializer() + { + return memberName switch + { + "HasValue" => BooleanSerializer.Instance, + "Value" => (containingSerializer as INullableSerializer)?.ValueSerializer, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetPropertySerializer() + { + if (containingSerializer is not IBsonDocumentSerializer documentSerializer) + { + // TODO: return UnknowableSerializer??? + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not implement the {nameof(IBsonDocumentSerializer)} interface"); + } + + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + // TODO: return UnknowableSerializer??? + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not support a member named: {memberName}"); + } + + return memberSerializationInfo.Serializer; + } + + IBsonSerializer GetTupleOrValueTuplePropertySerializer() + { + if (containingSerializer is not IBsonTupleSerializer tupleSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not implement the {nameof(IBsonTupleSerializer)} interface"); + } + + return memberName switch + { + "Item1" => tupleSerializer.GetItemSerializer(1), + "Item2" => tupleSerializer.GetItemSerializer(2), + "Item3" => tupleSerializer.GetItemSerializer(3), + "Item4" => tupleSerializer.GetItemSerializer(4), + "Item5" => tupleSerializer.GetItemSerializer(5), + "Item6" => tupleSerializer.GetItemSerializer(6), + "Item7" => tupleSerializer.GetItemSerializer(7), + "Rest" => tupleSerializer.GetItemSerializer(8), + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + bool IsCollectionCountOrLengthProperty() + { + return + (declaringType.ImplementsInterface(typeof(IEnumerable)) || declaringType == typeof(BitArray)) && + node.Type == typeof(int) && + (member.Name == "Count" || member.Name == "Length"); + } + + bool IsTupleOrValueTuple(Type type) + { + return + type.Namespace == "System" && + (type.Name.StartsWith("Tuple") || type.Name.StartsWith("ValueTuple")) && + type.IsPublic && + type.IsConstructedGenericType && + type.GetGenericArguments().Length is >= 1 and <= 8; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMemberInit.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMemberInit.cs new file mode 100644 index 00000000000..8d6bd0d404b --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMemberInit.cs @@ -0,0 +1,97 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitMemberInit(MemberInitExpression node) + { + if (IsKnown(node, out var knownSerializer)) + { + var newExpression = node.NewExpression; + if (newExpression != null) + { + if (IsNotKnown(newExpression)) + { + AddKnownSerializer(newExpression, knownSerializer); + } + } + + if (node.Bindings.Count > 0) + { + if (knownSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {knownSerializer.GetType()} does not implement IBsonDocumentSerializer interface"); + } + + foreach (var binding in node.Bindings) + { + if (binding is MemberAssignment memberAssignment) + { + if (IsNotKnown(memberAssignment.Expression)) + { + var member = memberAssignment.Member; + var memberName = member.Name; + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(node, because: $"type {member.DeclaringType} does not have a member named: {memberName}"); + } + var expressionSerializer = memberSerializationInfo.Serializer; + + if (expressionSerializer.ValueType != memberAssignment.Expression.Type && + expressionSerializer.ValueType.IsAssignableFrom(memberAssignment.Expression.Type)) + { + expressionSerializer = expressionSerializer.GetDerivedTypeSerializer(memberAssignment.Expression.Type); + } + + // member = expression => expression: memberSerializer (or derivedTypeSerializer) + AddKnownSerializer(memberAssignment.Expression, expressionSerializer); + } + } + } + } + } + + base.VisitMemberInit(node); + + if (IsNotKnown(node)) + { + var resultSerializer = GetResultSerializer(); + if (resultSerializer != null) + { + AddKnownSerializer(node, resultSerializer); + } + } + + return node; + + IBsonSerializer GetResultSerializer() + { + if (node.Type == typeof(BsonDocument)) + { + return BsonDocumentSerializer.Instance; + } + var newExpression = node.NewExpression; + var bindings = node.Bindings; + return CreateNewExpressionSerializer(node, newExpression, bindings); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMethodCall.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMethodCall.cs new file mode 100644 index 00000000000..ce7d4f691a2 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitMethodCall.cs @@ -0,0 +1,3420 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.CodeDom; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + private static readonly HashSet __absMethods = + [ + MathMethod.AbsDecimal, + MathMethod.AbsDouble, + MathMethod.AbsInt16, + MathMethod.AbsInt32, + MathMethod.AbsInt64, + MathMethod.AbsSByte, + MathMethod.AbsSingle + ]; + + private static readonly HashSet __aggregateMethods = + [ + EnumerableMethod.AggregateWithFunc, + EnumerableMethod.AggregateWithSeedAndFunc, + EnumerableMethod.AggregateWithSeedFuncAndResultSelector, + QueryableMethod.AggregateWithFunc, + QueryableMethod.AggregateWithSeedAndFunc, + QueryableMethod.AggregateWithSeedFuncAndResultSelector + ]; + + private static readonly HashSet __aggregateWithFuncMethods = + [ + EnumerableMethod.AggregateWithFunc, + QueryableMethod.AggregateWithFunc + ]; + + private static readonly HashSet __aggregateWithSeedAndFuncMethods = + [ + EnumerableMethod.AggregateWithSeedAndFunc, + QueryableMethod.AggregateWithSeedAndFunc + ]; + + private static readonly HashSet __aggregateWithSeedFuncAdResultSelectorMethods = + [ + EnumerableMethod.AggregateWithSeedFuncAndResultSelector, + QueryableMethod.AggregateWithSeedFuncAndResultSelector + ]; + + private static readonly HashSet __anyMethods = + [ + EnumerableMethod.Any, + EnumerableMethod.AnyWithPredicate, + QueryableMethod.Any, + QueryableMethod.AnyWithPredicate + ]; + + private static readonly HashSet __anyWithPredicateMethods = + [ + EnumerableMethod.AnyWithPredicate, + QueryableMethod.AnyWithPredicate + ]; + + private static readonly HashSet __appendOrPrependMethods = + [ + EnumerableMethod.Append, + EnumerableMethod.Prepend, + QueryableMethod.Append, + QueryableMethod.Prepend + ]; + + private static readonly HashSet __averageOrMedianOrPercentileMethods = + [ + EnumerableMethod.AverageDecimal, + EnumerableMethod.AverageDecimalWithSelector, + EnumerableMethod.AverageDouble, + EnumerableMethod.AverageDoubleWithSelector, + EnumerableMethod.AverageInt32, + EnumerableMethod.AverageInt32WithSelector, + EnumerableMethod.AverageInt64, + EnumerableMethod.AverageInt64WithSelector, + EnumerableMethod.AverageNullableDecimal, + EnumerableMethod.AverageNullableDecimalWithSelector, + EnumerableMethod.AverageNullableDouble, + EnumerableMethod.AverageNullableDoubleWithSelector, + EnumerableMethod.AverageNullableInt32, + EnumerableMethod.AverageNullableInt32WithSelector, + EnumerableMethod.AverageNullableInt64, + EnumerableMethod.AverageNullableInt64WithSelector, + EnumerableMethod.AverageNullableSingle, + EnumerableMethod.AverageNullableSingleWithSelector, + EnumerableMethod.AverageSingle, + EnumerableMethod.AverageSingleWithSelector, + QueryableMethod.AverageDecimal, + QueryableMethod.AverageDecimalWithSelector, + QueryableMethod.AverageDouble, + QueryableMethod.AverageDoubleWithSelector, + QueryableMethod.AverageInt32, + QueryableMethod.AverageInt32WithSelector, + QueryableMethod.AverageInt64, + QueryableMethod.AverageInt64WithSelector, + QueryableMethod.AverageNullableDecimal, + QueryableMethod.AverageNullableDecimalWithSelector, + QueryableMethod.AverageNullableDouble, + QueryableMethod.AverageNullableDoubleWithSelector, + QueryableMethod.AverageNullableInt32, + QueryableMethod.AverageNullableInt32WithSelector, + QueryableMethod.AverageNullableInt64, + QueryableMethod.AverageNullableInt64WithSelector, + QueryableMethod.AverageNullableSingle, + QueryableMethod.AverageNullableSingleWithSelector, + QueryableMethod.AverageSingle, + QueryableMethod.AverageSingleWithSelector, + MongoEnumerableMethod.MedianDecimal, + MongoEnumerableMethod.MedianDecimalWithSelector, + MongoEnumerableMethod.MedianDouble, + MongoEnumerableMethod.MedianDoubleWithSelector, + MongoEnumerableMethod.MedianInt32, + MongoEnumerableMethod.MedianInt32WithSelector, + MongoEnumerableMethod.MedianInt64, + MongoEnumerableMethod.MedianInt64WithSelector, + MongoEnumerableMethod.MedianNullableDecimal, + MongoEnumerableMethod.MedianNullableDecimalWithSelector, + MongoEnumerableMethod.MedianNullableDouble, + MongoEnumerableMethod.MedianNullableDoubleWithSelector, + MongoEnumerableMethod.MedianNullableInt32, + MongoEnumerableMethod.MedianNullableInt32WithSelector, + MongoEnumerableMethod.MedianNullableInt64, + MongoEnumerableMethod.MedianNullableInt64WithSelector, + MongoEnumerableMethod.MedianNullableSingle, + MongoEnumerableMethod.MedianNullableSingleWithSelector, + MongoEnumerableMethod.MedianSingle, + MongoEnumerableMethod.MedianSingleWithSelector, + MongoEnumerableMethod.PercentileDecimal, + MongoEnumerableMethod.PercentileDecimalWithSelector, + MongoEnumerableMethod.PercentileDouble, + MongoEnumerableMethod.PercentileDoubleWithSelector, + MongoEnumerableMethod.PercentileInt32, + MongoEnumerableMethod.PercentileInt32WithSelector, + MongoEnumerableMethod.PercentileInt64, + MongoEnumerableMethod.PercentileInt64WithSelector, + MongoEnumerableMethod.PercentileNullableDecimal, + MongoEnumerableMethod.PercentileNullableDecimalWithSelector, + MongoEnumerableMethod.PercentileNullableDouble, + MongoEnumerableMethod.PercentileNullableDoubleWithSelector, + MongoEnumerableMethod.PercentileNullableInt32, + MongoEnumerableMethod.PercentileNullableInt32WithSelector, + MongoEnumerableMethod.PercentileNullableInt64, + MongoEnumerableMethod.PercentileNullableInt64WithSelector, + MongoEnumerableMethod.PercentileNullableSingle, + MongoEnumerableMethod.PercentileNullableSingleWithSelector, + MongoEnumerableMethod.PercentileSingle, + MongoEnumerableMethod.PercentileSingleWithSelector, + WindowMethod.PercentileWithDecimal, + WindowMethod.PercentileWithDouble, + WindowMethod.PercentileWithInt32, + WindowMethod.PercentileWithInt64, + WindowMethod.PercentileWithNullableDecimal, + WindowMethod.PercentileWithNullableDouble, + WindowMethod.PercentileWithNullableInt32, + WindowMethod.PercentileWithNullableInt64, + WindowMethod.PercentileWithNullableSingle, + WindowMethod.PercentileWithSingle + ]; + + private static readonly HashSet __averageOrMedianOrPercentileWithSelectorMethods = + [ + EnumerableMethod.AverageDecimalWithSelector, + EnumerableMethod.AverageDoubleWithSelector, + EnumerableMethod.AverageInt32WithSelector, + EnumerableMethod.AverageInt64WithSelector, + EnumerableMethod.AverageNullableDecimalWithSelector, + EnumerableMethod.AverageNullableDoubleWithSelector, + EnumerableMethod.AverageNullableInt32WithSelector, + EnumerableMethod.AverageNullableInt64WithSelector, + EnumerableMethod.AverageNullableSingleWithSelector, + EnumerableMethod.AverageSingleWithSelector, + QueryableMethod.AverageDecimalWithSelector, + QueryableMethod.AverageDoubleWithSelector, + QueryableMethod.AverageInt32WithSelector, + QueryableMethod.AverageInt64WithSelector, + QueryableMethod.AverageNullableDecimalWithSelector, + QueryableMethod.AverageNullableDoubleWithSelector, + QueryableMethod.AverageNullableInt32WithSelector, + QueryableMethod.AverageNullableInt64WithSelector, + QueryableMethod.AverageNullableSingleWithSelector, + QueryableMethod.AverageSingleWithSelector, + MongoEnumerableMethod.MedianDecimalWithSelector, + MongoEnumerableMethod.MedianDoubleWithSelector, + MongoEnumerableMethod.MedianInt32WithSelector, + MongoEnumerableMethod.MedianInt64WithSelector, + MongoEnumerableMethod.MedianNullableDecimalWithSelector, + MongoEnumerableMethod.MedianNullableDoubleWithSelector, + MongoEnumerableMethod.MedianNullableInt32WithSelector, + MongoEnumerableMethod.MedianNullableInt64WithSelector, + MongoEnumerableMethod.MedianNullableSingleWithSelector, + MongoEnumerableMethod.MedianSingleWithSelector, + MongoEnumerableMethod.PercentileDecimalWithSelector, + MongoEnumerableMethod.PercentileDoubleWithSelector, + MongoEnumerableMethod.PercentileInt32WithSelector, + MongoEnumerableMethod.PercentileInt64WithSelector, + MongoEnumerableMethod.PercentileNullableDecimalWithSelector, + MongoEnumerableMethod.PercentileNullableDoubleWithSelector, + MongoEnumerableMethod.PercentileNullableInt32WithSelector, + MongoEnumerableMethod.PercentileNullableInt64WithSelector, + MongoEnumerableMethod.PercentileNullableSingleWithSelector, + MongoEnumerableMethod.PercentileSingleWithSelector, + WindowMethod.PercentileWithDecimal, + WindowMethod.PercentileWithDouble, + WindowMethod.PercentileWithInt32, + WindowMethod.PercentileWithInt64, + WindowMethod.PercentileWithNullableDecimal, + WindowMethod.PercentileWithNullableDouble, + WindowMethod.PercentileWithNullableInt32, + WindowMethod.PercentileWithNullableInt64, + WindowMethod.PercentileWithNullableSingle, + WindowMethod.PercentileWithSingle + ]; + + private static readonly HashSet __countMethods = + [ + EnumerableMethod.Count, + EnumerableMethod.CountWithPredicate, + EnumerableMethod.LongCount, + EnumerableMethod.LongCountWithPredicate, + QueryableMethod.Count, + QueryableMethod.CountWithPredicate, + QueryableMethod.LongCount, + QueryableMethod.LongCountWithPredicate + ]; + + private static readonly HashSet __countWithPredicateMethods = + [ + EnumerableMethod.CountWithPredicate, + EnumerableMethod.LongCountWithPredicate, + QueryableMethod.CountWithPredicate, + QueryableMethod.LongCountWithPredicate + ]; + + private static readonly HashSet __groupByMethods = + [ + EnumerableMethod.GroupByWithKeySelector, + EnumerableMethod.GroupByWithKeySelectorAndElementSelector, + EnumerableMethod.GroupByWithKeySelectorAndResultSelector, + EnumerableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelector, + QueryableMethod.GroupByWithKeySelectorAndElementSelector, + QueryableMethod.GroupByWithKeySelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector + ]; + + private static readonly HashSet __indexOfMethods = + [ + StringMethod.IndexOfAny, + StringMethod.IndexOfAnyWithStartIndex, + StringMethod.IndexOfAnyWithStartIndexAndCount, + StringMethod.IndexOfBytesWithValue, + StringMethod.IndexOfBytesWithValueAndStartIndex, + StringMethod.IndexOfBytesWithValueAndStartIndexAndCount, + StringMethod.IndexOfWithChar, + StringMethod.IndexOfWithCharAndStartIndex, + StringMethod.IndexOfWithCharAndStartIndexAndCount, + StringMethod.IndexOfWithString, + StringMethod.IndexOfWithStringAndComparisonType, + StringMethod.IndexOfWithStringAndStartIndex, + StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, + StringMethod.IndexOfWithStringAndStartIndexAndCount, + StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType, + ]; + + private static readonly HashSet __tupleOrValueTupleCreateMethods = + [ + TupleMethod.Create1, + TupleMethod.Create2, + TupleMethod.Create3, + TupleMethod.Create4, + TupleMethod.Create5, + TupleMethod.Create6, + TupleMethod.Create7, + TupleMethod.Create8, + ValueTupleMethod.Create1, + ValueTupleMethod.Create2, + ValueTupleMethod.Create3, + ValueTupleMethod.Create4, + ValueTupleMethod.Create5, + ValueTupleMethod.Create6, + ValueTupleMethod.Create7, + ValueTupleMethod.Create8 + ]; + + private static readonly HashSet __firstOrLastMethods = + [ + EnumerableMethod.First, + EnumerableMethod.FirstOrDefault, + EnumerableMethod.FirstOrDefaultWithPredicate, + EnumerableMethod.FirstWithPredicate, + EnumerableMethod.Last, + EnumerableMethod.LastOrDefault, + EnumerableMethod.LastOrDefaultWithPredicate, + EnumerableMethod.LastWithPredicate, + EnumerableMethod.Single, + EnumerableMethod.SingleOrDefault, + EnumerableMethod.SingleOrDefaultWithPredicate, + EnumerableMethod.SingleWithPredicate, + QueryableMethod.First, + QueryableMethod.FirstOrDefault, + QueryableMethod.FirstOrDefaultWithPredicate, + QueryableMethod.FirstWithPredicate, + QueryableMethod.Last, + QueryableMethod.LastOrDefault, + QueryableMethod.LastOrDefaultWithPredicate, + QueryableMethod.LastWithPredicate, + QueryableMethod.Single, + QueryableMethod.SingleOrDefault, + QueryableMethod.SingleOrDefaultWithPredicate, + QueryableMethod.SingleWithPredicate + ]; + + private static readonly HashSet __firstOrLastWithPredicateMethods = + [ + EnumerableMethod.FirstOrDefaultWithPredicate, + EnumerableMethod.FirstWithPredicate, + EnumerableMethod.LastOrDefaultWithPredicate, + EnumerableMethod.LastWithPredicate, + EnumerableMethod.SingleOrDefaultWithPredicate, + EnumerableMethod.SingleWithPredicate, + QueryableMethod.LastOrDefaultWithPredicate, + QueryableMethod.LastWithPredicate, + QueryableMethod.LastOrDefaultWithPredicate, + QueryableMethod.LastWithPredicate, + QueryableMethod.SingleOrDefaultWithPredicate, + QueryableMethod.SingleWithPredicate + ]; + + private static readonly HashSet __logMethods = + [ + MathMethod.Log, + MathMethod.Log10, + MathMethod.LogWithNewBase + ]; + + private static readonly HashSet __lookupMethods = + [ + MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignField, + MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, + MongoQueryableMethod.LookupWithDocumentsAndPipeline, + MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignField, + MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline, + MongoQueryableMethod.LookupWithFromAndPipeline + ]; + + private static readonly HashSet __maxOrMinMethods = + [ + EnumerableMethod.Max, + EnumerableMethod.MaxDecimal, + EnumerableMethod.MaxDecimalWithSelector, + EnumerableMethod.MaxDouble, + EnumerableMethod.MaxDoubleWithSelector, + EnumerableMethod.MaxInt32, + EnumerableMethod.MaxInt32WithSelector, + EnumerableMethod.MaxInt64, + EnumerableMethod.MaxInt64WithSelector, + EnumerableMethod.MaxNullableDecimal, + EnumerableMethod.MaxNullableDecimalWithSelector, + EnumerableMethod.MaxNullableDouble, + EnumerableMethod.MaxNullableDoubleWithSelector, + EnumerableMethod.MaxNullableInt32, + EnumerableMethod.MaxNullableInt32WithSelector, + EnumerableMethod.MaxNullableInt64, + EnumerableMethod.MaxNullableInt64WithSelector, + EnumerableMethod.MaxNullableSingle, + EnumerableMethod.MaxNullableSingleWithSelector, + EnumerableMethod.MaxSingle, + EnumerableMethod.MaxSingleWithSelector, + EnumerableMethod.MaxWithSelector, + EnumerableMethod.Min, + EnumerableMethod.MinDecimal, + EnumerableMethod.MinDecimalWithSelector, + EnumerableMethod.MinDouble, + EnumerableMethod.MinDoubleWithSelector, + EnumerableMethod.MinInt32, + EnumerableMethod.MinInt32WithSelector, + EnumerableMethod.MinInt64, + EnumerableMethod.MinInt64WithSelector, + EnumerableMethod.MinNullableDecimal, + EnumerableMethod.MinNullableDecimalWithSelector, + EnumerableMethod.MinNullableDouble, + EnumerableMethod.MinNullableDoubleWithSelector, + EnumerableMethod.MinNullableInt32, + EnumerableMethod.MinNullableInt32WithSelector, + EnumerableMethod.MinNullableInt64, + EnumerableMethod.MinNullableInt64WithSelector, + EnumerableMethod.MinNullableSingle, + EnumerableMethod.MinNullableSingleWithSelector, + EnumerableMethod.MinSingle, + EnumerableMethod.MinSingleWithSelector, + EnumerableMethod.MinWithSelector, + QueryableMethod.Max, + QueryableMethod.MaxWithSelector, + QueryableMethod.Min, + QueryableMethod.MinWithSelector, + ]; + + private static readonly HashSet __maxOrMinWithSelectorMethods = + [ + EnumerableMethod.MaxDecimalWithSelector, + EnumerableMethod.MaxDoubleWithSelector, + EnumerableMethod.MaxInt32WithSelector, + EnumerableMethod.MaxInt64WithSelector, + EnumerableMethod.MaxNullableDecimalWithSelector, + EnumerableMethod.MaxNullableDoubleWithSelector, + EnumerableMethod.MaxNullableInt32WithSelector, + EnumerableMethod.MaxNullableInt64WithSelector, + EnumerableMethod.MaxNullableSingleWithSelector, + EnumerableMethod.MaxSingleWithSelector, + EnumerableMethod.MaxWithSelector, + EnumerableMethod.MinDecimalWithSelector, + EnumerableMethod.MinDoubleWithSelector, + EnumerableMethod.MinInt32WithSelector, + EnumerableMethod.MinInt64WithSelector, + EnumerableMethod.MinNullableDecimalWithSelector, + EnumerableMethod.MinNullableDoubleWithSelector, + EnumerableMethod.MinNullableInt32WithSelector, + EnumerableMethod.MinNullableInt64WithSelector, + EnumerableMethod.MinNullableSingleWithSelector, + EnumerableMethod.MinSingleWithSelector, + EnumerableMethod.MinWithSelector, + QueryableMethod.MaxWithSelector, + QueryableMethod.MinWithSelector, + ]; + + private static readonly MethodInfo[] __pickMethods = new[] + { + EnumerableMethod.Bottom, + EnumerableMethod.BottomN, + EnumerableMethod.BottomNWithComputedN, + EnumerableMethod.FirstN, + EnumerableMethod.FirstNWithComputedN, + EnumerableMethod.LastN, + EnumerableMethod.LastNWithComputedN, + EnumerableMethod.MaxN, + EnumerableMethod.MaxNWithComputedN, + EnumerableMethod.MinN, + EnumerableMethod.MinNWithComputedN, + EnumerableMethod.Top, + EnumerableMethod.TopN, + EnumerableMethod.TopNWithComputedN + }; + + private static readonly MethodInfo[] __pickWithComputedNMethods = new[] + { + EnumerableMethod.BottomNWithComputedN, + EnumerableMethod.FirstNWithComputedN, + EnumerableMethod.LastNWithComputedN, + EnumerableMethod.MaxNWithComputedN, + EnumerableMethod.MinNWithComputedN, + EnumerableMethod.TopNWithComputedN + }; + + private static readonly MethodInfo[] __pickWithSortDefinitionMethods = new[] + { + EnumerableMethod.Bottom, + EnumerableMethod.BottomN, + EnumerableMethod.BottomNWithComputedN, + EnumerableMethod.Top, + EnumerableMethod.TopN, + EnumerableMethod.TopNWithComputedN + }; + + private static readonly HashSet __selectManyMethods = + [ + EnumerableMethod.SelectManyWithSelector, + EnumerableMethod.SelectManyWithCollectionSelectorAndResultSelector, + QueryableMethod.SelectManyWithSelector, + QueryableMethod.SelectManyWithCollectionSelectorAndResultSelector + ]; + + private static readonly HashSet __selectManyWithCollectionSelectorAndResultSelectorMethods = + [ + EnumerableMethod.SelectManyWithCollectionSelectorAndResultSelector, + QueryableMethod.SelectManyWithCollectionSelectorAndResultSelector + ]; + + private static readonly HashSet __selectManyWithResultSelectorMethods = + [ + EnumerableMethod.SelectManyWithSelector, + QueryableMethod.SelectManyWithSelector + ]; + + private static readonly HashSet __skipOrTakeMethods = + [ + EnumerableMethod.Skip, + EnumerableMethod.SkipWhile, + EnumerableMethod.Take, + EnumerableMethod.TakeWhile, + QueryableMethod.Skip, + QueryableMethod.SkipWhile, + QueryableMethod.Take, + QueryableMethod.TakeWhile, + MongoQueryableMethod.SkipWithLong, + MongoQueryableMethod.TakeWithLong + ]; + + private static readonly HashSet __skipOrTakeWhileMethods = + [ + EnumerableMethod.SkipWhile, + EnumerableMethod.TakeWhile, + QueryableMethod.SkipWhile, + QueryableMethod.TakeWhile + ]; + + private static readonly HashSet __splitMethods = + [ + StringMethod.SplitWithChars, + StringMethod.SplitWithCharsAndCount, + StringMethod.SplitWithCharsAndCountAndOptions, + StringMethod.SplitWithCharsAndOptions, + StringMethod.SplitWithStringsAndCountAndOptions, + StringMethod.SplitWithStringsAndOptions + ]; + + private static readonly HashSet __standardDeviationMethods = + [ + MongoEnumerableMethod.StandardDeviationPopulationDecimal, + MongoEnumerableMethod.StandardDeviationPopulationDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationDouble, + MongoEnumerableMethod.StandardDeviationPopulationDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationInt32, + MongoEnumerableMethod.StandardDeviationPopulationInt32WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationInt64, + MongoEnumerableMethod.StandardDeviationPopulationInt64WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableDecimal, + MongoEnumerableMethod.StandardDeviationPopulationNullableDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableDouble, + MongoEnumerableMethod.StandardDeviationPopulationNullableDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableInt32, + MongoEnumerableMethod.StandardDeviationPopulationNullableInt32WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableInt64, + MongoEnumerableMethod.StandardDeviationPopulationNullableInt64WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableSingle, + MongoEnumerableMethod.StandardDeviationPopulationNullableSingleWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationSingle, + MongoEnumerableMethod.StandardDeviationPopulationSingleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleDecimal, + MongoEnumerableMethod.StandardDeviationSampleDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationSampleDouble, + MongoEnumerableMethod.StandardDeviationSampleDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleInt32, + MongoEnumerableMethod.StandardDeviationSampleInt32WithSelector, + MongoEnumerableMethod.StandardDeviationSampleInt64, + MongoEnumerableMethod.StandardDeviationSampleInt64WithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableDecimal, + MongoEnumerableMethod.StandardDeviationSampleNullableDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableDouble, + MongoEnumerableMethod.StandardDeviationSampleNullableDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableInt32, + MongoEnumerableMethod.StandardDeviationSampleNullableInt32WithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableInt64, + MongoEnumerableMethod.StandardDeviationSampleNullableInt64WithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableSingle, + MongoEnumerableMethod.StandardDeviationSampleNullableSingleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleSingle, + MongoEnumerableMethod.StandardDeviationSampleSingleWithSelector, + ]; + + private static readonly HashSet __standardDeviationWithSelectorMethods = + [ + MongoEnumerableMethod.StandardDeviationPopulationDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationInt32WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationInt64WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableInt32WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableInt64WithSelector, + MongoEnumerableMethod.StandardDeviationPopulationNullableSingleWithSelector, + MongoEnumerableMethod.StandardDeviationPopulationSingleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationSampleDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleInt32WithSelector, + MongoEnumerableMethod.StandardDeviationSampleInt64WithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableDecimalWithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableDoubleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableInt32WithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableInt64WithSelector, + MongoEnumerableMethod.StandardDeviationSampleNullableSingleWithSelector, + MongoEnumerableMethod.StandardDeviationSampleSingleWithSelector, + ]; + + private static readonly HashSet __stringConcatMethods = + [ + StringMethod.ConcatWith1Object, + StringMethod.ConcatWith2Objects, + StringMethod.ConcatWith2Strings, + StringMethod.ConcatWith3Objects, + StringMethod.ConcatWith3Strings, + StringMethod.ConcatWith4Strings, + StringMethod.ConcatWithObjectArray, + StringMethod.ConcatWithStringArray + ]; + + private static readonly HashSet __stringContainsMethods = + [ + StringMethod.ContainsWithChar, + StringMethod.ContainsWithCharAndComparisonType, + StringMethod.ContainsWithString, + StringMethod.ContainsWithStringAndComparisonType + ]; + + private static readonly HashSet __stringEndsWithOrStartsWithMethods = + [ + StringMethod.EndsWithWithChar, + StringMethod.EndsWithWithString, + StringMethod.EndsWithWithStringAndComparisonType, + StringMethod.EndsWithWithStringAndIgnoreCaseAndCulture, + StringMethod.StartsWithWithChar, + StringMethod.StartsWithWithString, + StringMethod.StartsWithWithStringAndComparisonType, + StringMethod.StartsWithWithStringAndIgnoreCaseAndCulture + ]; + + private static readonly HashSet __subtractReturningDateTimeMethods = + [ + DateTimeMethod.SubtractWithTimeSpan, + DateTimeMethod.SubtractWithTimeSpanAndTimezone, + DateTimeMethod.SubtractWithUnit, + DateTimeMethod.SubtractWithUnitAndTimezone + ]; + + private static readonly HashSet __subtractReturningInt64Methods = + [ + DateTimeMethod.SubtractWithDateTimeAndUnit, + DateTimeMethod.SubtractWithDateTimeAndUnitAndTimezone + ]; + + private static readonly HashSet __subtractReturningTimeSpanWithMillisecondsUnitsMethods = + [ + DateTimeMethod.SubtractWithDateTime, + DateTimeMethod.SubtractWithDateTimeAndTimezone + ]; + + private static readonly HashSet __sumMethods = + [ + EnumerableMethod.SumDecimal, + EnumerableMethod.SumDecimalWithSelector, + EnumerableMethod.SumDouble, + EnumerableMethod.SumDoubleWithSelector, + EnumerableMethod.SumInt32, + EnumerableMethod.SumInt32WithSelector, + EnumerableMethod.SumInt64, + EnumerableMethod.SumInt64WithSelector, + EnumerableMethod.SumNullableDecimal, + EnumerableMethod.SumNullableDecimalWithSelector, + EnumerableMethod.SumNullableDouble, + EnumerableMethod.SumNullableDoubleWithSelector, + EnumerableMethod.SumNullableInt32, + EnumerableMethod.SumNullableInt32WithSelector, + EnumerableMethod.SumNullableInt64, + EnumerableMethod.SumNullableInt64WithSelector, + EnumerableMethod.SumNullableSingle, + EnumerableMethod.SumNullableSingleWithSelector, + EnumerableMethod.SumSingle, + EnumerableMethod.SumSingleWithSelector, + QueryableMethod.SumDecimal, + QueryableMethod.SumDecimalWithSelector, + QueryableMethod.SumDouble, + QueryableMethod.SumDoubleWithSelector, + QueryableMethod.SumInt32, + QueryableMethod.SumInt32WithSelector, + QueryableMethod.SumInt64, + QueryableMethod.SumInt64WithSelector, + QueryableMethod.SumNullableDecimal, + QueryableMethod.SumNullableDecimalWithSelector, + QueryableMethod.SumNullableDouble, + QueryableMethod.SumNullableDoubleWithSelector, + QueryableMethod.SumNullableInt32, + QueryableMethod.SumNullableInt32WithSelector, + QueryableMethod.SumNullableInt64, + QueryableMethod.SumNullableInt64WithSelector, + QueryableMethod.SumNullableSingle, + QueryableMethod.SumNullableSingleWithSelector, + QueryableMethod.SumSingle, + QueryableMethod.SumSingleWithSelector + ]; + + private static readonly HashSet __sumWithSelectorMethods = + [ + EnumerableMethod.SumDecimalWithSelector, + EnumerableMethod.SumDoubleWithSelector, + EnumerableMethod.SumInt32WithSelector, + EnumerableMethod.SumInt64WithSelector, + EnumerableMethod.SumNullableDecimalWithSelector, + EnumerableMethod.SumNullableDoubleWithSelector, + EnumerableMethod.SumNullableInt32WithSelector, + EnumerableMethod.SumNullableInt64WithSelector, + EnumerableMethod.SumNullableSingleWithSelector, + EnumerableMethod.SumSingleWithSelector, + QueryableMethod.SumDecimalWithSelector, + QueryableMethod.SumDoubleWithSelector, + QueryableMethod.SumInt32WithSelector, + QueryableMethod.SumInt64WithSelector, + QueryableMethod.SumNullableDecimalWithSelector, + QueryableMethod.SumNullableDoubleWithSelector, + QueryableMethod.SumNullableInt32WithSelector, + QueryableMethod.SumNullableInt64WithSelector, + QueryableMethod.SumNullableSingleWithSelector, + QueryableMethod.SumSingleWithSelector, + ]; + + private static readonly HashSet __toLowerOrToUpperMethods = + [ + StringMethod.ToLower, + StringMethod.ToLowerInvariant, + StringMethod.ToLowerWithCulture, + StringMethod.ToUpper, + StringMethod.ToUpperInvariant, + StringMethod.ToUpperWithCulture, + ]; + + private static readonly HashSet __whereMethods = + [ + EnumerableMethod.Where, + MongoEnumerableMethod.WhereWithLimit, + QueryableMethod.Where, + ]; + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + var method = node.Method; + var arguments = node.Arguments; + + DeduceMethodCallSerializers(); + // if (IsKnown(node, out var knownSerializer) && knownSerializer is IUnknowableSerializer) + // { + // return node; // don't visit node any further + // } + base.VisitMethodCall(node); + DeduceMethodCallSerializers(); + + return node; + + void DeduceMethodCallSerializers() + { + switch (node.Method.Name) + { + case "Abs": DeduceAbsMethodSerializers(); break; + case "Acos": DeduceAcosMethodSerializers(); break; + case "Acosh": DeduceAcoshMethodSerializers(); break; + case "Add": DeduceAddMethodSerializers(); break; + case "AddDays": DeduceAddDaysMethodSerializers(); break; + case "AddHours": DeduceAddHoursMethodSerializers(); break; + case "AddMilliseconds": DeduceAddMillisecondsMethodSerializers(); break; + case "AddMinutes": DeduceAddMinutesMethodSerializers(); break; + case "AddMonths": DeduceAddMonthsMethodSerializers(); break; + case "AddQuarters": DeduceAddQuartersMethodSerializers(); break; + case "AddSeconds": DeduceAddSecondsMethodSerializers(); break; + case "AddTicks": DeduceAddTicksMethodSerializers(); break; + case "AddWeeks": DeduceAddWeeksMethodSerializers(); break; + case "AddYears": DeduceAddYearsMethodSerializers(); break; + case "Aggregate": DeduceAggregateMethodSerializers(); break; + case "All": DeduceAllMethodSerializers(); break; + case "Any": DeduceAnyMethodSerializers(); break; + case "AppendStage": DeduceAppendStageMethodSerializers(); break; + case "As": DeduceAsMethodSerializers(); break; + case "Asin": DeduceAsinMethodSerializers(); break; + case "Asinh": DeduceAsinhMethodSerializers(); break; + case "AsQueryable": DeduceAsQueryableMethodSerializers(); break; + case "Atan": DeduceAtanMethodSerializers(); break; + case "Atanh": DeduceAtanhMethodSerializers(); break; + case "Atan2": DeduceAtan2MethodSerializers(); break; + case "CompareTo": DeduceCompareToMethodSerializers(); break; + case "Concat": DeduceConcatMethodSerializers(); break; + case "Constant": DeduceConstantMethodSerializers(); break; + case "Contains": DeduceContainsMethodSerializers(); break; + case "ContainsKey": DeduceContainsKeyMethodSerializers(); break; + case "ContainsValue": DeduceContainsValueMethodSerializers(); break; + case "Convert": DeduceConvertMethodSerializers(); break; + case "Cos": DeduceCosMethodSerializers(); break; + case "Cosh": DeduceCoshMethodSerializers(); break; + case "Create": DeduceCreateMethodSerializers(); break; + case "DefaultIfEmpty": DeduceDefaultIfEmptyMethodSerializers(); break; + case "DegreesToRadians": DeduceDegreesToRadiansMethodSerializers(); break; + case "Distinct": DeduceDistinctMethodSerializers(); break; + case "Documents": DeduceDocumentsMethodSerializers(); break; + case "Equals": DeduceEqualsMethodSerializers(); break; + case "Except": DeduceExceptMethodSerializers(); break; + case "Exists": DeduceExistsMethodSerializers(); break; + case "Exp": DeduceExpMethodSerializers(); break; + case "Field": DeduceFieldMethodSerializers(); break; + case "get_Item": DeduceGetItemMethodSerializers(); break; + case "GroupBy": DeduceGroupByMethodSerializers(); break; + case "GroupJoin": DeduceGroupJoinMethodSerializers(); break; + case "Inject": DeduceInjectMethodSerializers(); break; + case "Intersect": DeduceIntersectMethodSerializers(); break; + case "IsMatch": DeduceIsMatchMethodSerializers(); break; + case "IsSubsetOf": DeduceIsSubsetOfMethodSerializers(); break; + case "Join": DeduceJoinMethodSerializers(); break; + case "Lookup": DeduceLookupMethodSerializers(); break; + case "OfType": DeduceOfTypeMethodSerializers(); break; + case "Parse": DeduceParseMethodSerializers(); break; + case "Pow": DeducePowMethodSerializers(); break; + case "RadiansToDegrees": DeduceRadiansToDegreesMethodSerializers(); break; + case "Range": DeduceRangeMethodSerializers(); break; + case "Repeat": DeduceRepeatMethodSerializers(); break; + case "Reverse": DeduceReverseMethodSerializers(); break; + case "Round": DeduceRoundMethodSerializers(); break; + case "Select": DeduceSelectMethodSerializers(); break; + case "SelectMany": DeduceSelectManySerializers(); break; + case "SequenceEqual": DeduceSequenceEqualMethodSerializers(); break; + case "SetEquals": DeduceSetEqualsMethodSerializers(); break; + case "SetWindowFields": DeduceSetWindowFieldsMethodSerializers(); break; + case "Shift": DeduceShiftMethodSerializers(); break; + case "Sin": DeduceSinMethodSerializers(); break; + case "Sinh": DeduceSinhMethodSerializers(); break; + case "Split": DeduceSplitMethodSerializers(); break; + case "Sqrt": DeduceSqrtMethodSerializers(); break; + case "StringIn": DeduceStringInMethodSerializers(); break; + case "StrLenBytes": DeduceStrLenBytesMethodSerializers(); break; + case "Subtract": DeduceSubtractMethodSerializers(); break; + case "Sum": DeduceSumMethodSerializers(); break; + case "Tan": DeduceTanMethodSerializers(); break; + case "Tanh": DeduceTanhMethodSerializers(); break; + case "ToArray": DeduceToArrayMethodSerializers(); break; + case "ToList": DeduceToListSerializers(); break; + case "ToString": DeduceToStringSerializers(); break; + case "Truncate": DeduceTruncateSerializers(); break; + case "Union": DeduceUnionSerializers(); break; + case "Week": DeduceWeekSerializers(); break; + case "Where": DeduceWhereSerializers(); break; + case "Zip": DeduceZipSerializers(); break; + + case "AllElements": + case "AllMatchingElements": + case "FirstMatchingElement": + DeduceMatchingElementsMethodSerializers(); + break; + + case "Append": + case "Prepend": + DeduceAppendOrPrependMethodSerializers(); + break; + + case "Average": + case "Median": + case "Percentile": + DeduceAverageOrMedianOrPercentileMethodSerializers(); + break; + + case "Bottom": + case "BottomN": + case "FirstN": + case "LastN": + case "MaxN": + case "MinN": + case "Top": + case "TopN": + DeducePickMethodSerializers(); + break; + + case "Ceiling": + case "Floor": + DeduceCeilingOrFloorMethodSerializers(); + break; + + case "Count": + case "LongCount": + DeduceCountMethodSerializers(); + break; + + case "ElementAt": + case "ElementAtOrDefault": + DeduceElementAtMethodSerializers(); + break; + + case "EndsWith": + case "StartsWith": + DeduceEndsWithOrStartsWithMethodSerializers(); + break; + + case "First": + case "FirstOrDefault": + case "Last": + case "LastOrDefault": + case "Single": + case "SingleOrDefault": + DeduceFirstOrLastMethodsSerializers(); + break; + + case "IndexOf": + case "IndexOfBytes": + DeduceIndexOfMethodSerializers(); + break; + + case "IsMissing": + case "IsNullOrMissing": + DeduceIsMissingOrIsNullOrMissingMethodSerializers(); + break; + + case "IsNullOrEmpty": + case "IsNullOrWhiteSpace": + DeduceIsNullOrEmptyOrIsNullOrWhiteSpaceMethodSerializers(); + break; + + case "Ln": + case "Log": + case "Log10": + DeduceLogMethodSerializers(); + break; + + case "Max": + case "Min": + DeduceMaxOrMinMethodSerializers(); + break; + + case "OrderBy": + case "OrderByDescending": + case "ThenBy": + case "ThenByDescending": + DeduceOrderByMethodSerializers(); + break; + + case "Skip": + case "SkipWhile": + case "Take": + case "TakeWhile": + DeduceSkipOrTakeMethodSerializers(); + break; + + case "StandardDeviationPopulation": + case "StandardDeviationSample": + DeduceStandardDeviationMethodSerializers(); + break; + + case "Substring": + case "SubstrBytes": + DeduceSubstringMethodSerializers(); + break; + + case "ToLower": + case "ToLowerInvariant": + case "ToUpper": + case "ToUpperInvariant": + DeduceToLowerOrToUpperSerializers(); + break; + + default: + DeduceUnknownMethodSerializer(); + break; + } + } + + void DeduceAbsMethodSerializers() + { + if (method.IsOneOf(__absMethods)) + { + var valueExpression = arguments[0]; + DeduceSerializers(node, valueExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAcosMethodSerializers() + { + if (method.Is(MathMethod.Acos)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAcoshMethodSerializers() + { + if (method.Is(MathMethod.Acosh)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.Add, DateTimeMethod.AddWithTimezone, DateTimeMethod.AddWithUnit, DateTimeMethod.AddWithUnitAndTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddDaysMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddDays, DateTimeMethod.AddDaysWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddHoursMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddHours, DateTimeMethod.AddHoursWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMillisecondsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMilliseconds, DateTimeMethod.AddMillisecondsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMinutesMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMinutes, DateTimeMethod.AddMinutesWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMonthsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMonths, DateTimeMethod.AddMonthsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddQuartersMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddQuarters, DateTimeMethod.AddQuartersWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddSecondsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddSeconds, DateTimeMethod.AddSecondsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddTicksMethodSerializers() + { + if (method.Is(DateTimeMethod.AddTicks)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddWeeksMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddWeeks, DateTimeMethod.AddWeeksWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddYearsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddYears, DateTimeMethod.AddYearsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAggregateMethodSerializers() + { + if (method.IsOneOf(__aggregateMethods)) + { + var sourceExpression = arguments[0]; + _ = IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer); + + if (method.IsOneOf(__aggregateWithFuncMethods)) + { + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(funcAccumulatorParameter, sourceExpression); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(funcLambda.Body, sourceExpression); + DeduceSerializers(node, funcLambda.Body); + } + + if (method.IsOneOf(__aggregateWithSeedAndFuncMethods)) + { + var seedExpression = arguments[1]; + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + + DeduceSerializers(seedExpression, funcLambda.Body); + DeduceSerializers(funcAccumulatorParameter, funcLambda.Body); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceSerializers(node, funcLambda.Body); + } + + if (method.IsOneOf(__aggregateWithSeedFuncAdResultSelectorMethods)) + { + var seedExpression = arguments[1]; + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var resultSelectorAccumulatorParameter = resultSelectorLambda.Parameters[0]; + + DeduceSerializers(seedExpression, funcLambda.Body); + DeduceSerializers(funcAccumulatorParameter, funcLambda.Body); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceSerializers(resultSelectorAccumulatorParameter, funcLambda.Body); + DeduceSerializers(node, resultSelectorLambda.Body); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAllMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.AllWithPredicate, QueryableMethod.AllWithPredicate)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAnyMethodSerializers() + { + if (method.IsOneOf(__anyMethods)) + { + if (method.IsOneOf(__anyWithPredicateMethods)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAppendOrPrependMethodSerializers() + { + if (method.IsOneOf(__appendOrPrependMethods)) + { + var sourceExpression = arguments[0]; + var elementExpression = arguments[1]; + + DeduceItemAndCollectionSerializers(elementExpression, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsMethodSerializers() + { + if (method.Is(MongoQueryableMethod.As)) + { + if (IsNotKnown(node)) + { + var resultSerializerExpression = arguments[1]; + if (resultSerializerExpression is not ConstantExpression resultSerializerConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "resultSerializer argument must be a constant"); + } + + var resultItemSerializer = (IBsonSerializer)resultSerializerConstantExpression.Value; + if (resultItemSerializer == null) + { + var resultItemType = method.GetGenericArguments()[1]; + resultItemSerializer = BsonSerializer.LookupSerializer(resultItemType); + } + + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAppendStageMethodSerializers() + { + if (method.Is(MongoQueryableMethod.AppendStage)) + { + if (IsNotKnown(node)) + { + var sourceExpression = arguments[0]; + var stageExpression = arguments[1]; + var resultSerializerExpression = arguments[2]; + + if (stageExpression is not ConstantExpression stageConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "stage argument must be a constant"); + } + var stageDefinition = (IPipelineStageDefinition)stageConstantExpression.Value; + + if (resultSerializerExpression is not ConstantExpression resultSerializerConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "resultSerializer argument must be a constant"); + } + var resultItemSerializer = (IBsonSerializer)resultSerializerConstantExpression.Value; + + if (resultItemSerializer == null && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var serializerRegistry = BsonSerializer.SerializerRegistry; // TODO: get correct registry + var translationOptions = new ExpressionTranslationOptions(); // TODO: get correct translation options + var renderedStage = stageDefinition.Render(sourceItemSerializer, serializerRegistry, translationOptions); + resultItemSerializer = renderedStage.OutputSerializer; + } + + if (resultItemSerializer != null) + { + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddKnownSerializer(node, resultSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsinMethodSerializers() + { + if (method.Is(MathMethod.Asin)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsQueryableMethodSerializers() + { + if (method.Is(QueryableMethod.AsQueryable)) + { + var sourceExpression = arguments[0]; + + if (IsNotKnown(node) && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var resultSerializer = NestedAsQueryableSerializer.Create(sourceItemSerializer); + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsinhMethodSerializers() + { + if (method.Is(MathMethod.Asinh)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAtanMethodSerializers() + { + if (method.Is(MathMethod.Atan)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAtanhMethodSerializers() + { + if (method.Is(MathMethod.Atanh)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAtan2MethodSerializers() + { + if (method.Is(MathMethod.Atan2)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAverageOrMedianOrPercentileMethodSerializers() + { + if (method.IsOneOf(__averageOrMedianOrPercentileMethods)) + { + if (method.IsOneOf(__averageOrMedianOrPercentileWithSelectorMethods)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceItemParameter = selectorLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(selectorSourceItemParameter, sourceExpression); + } + + if (IsNotKnown(node)) + { + var nodeSerializer = StandardSerializers.GetSerializer(node.Type); + AddKnownSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCeilingOrFloorMethodSerializers() + { + if (method.IsOneOf(MathMethod.CeilingWithDecimal, MathMethod.CeilingWithDouble, MathMethod.FloorWithDecimal, MathMethod.FloorWithDouble)) + { + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCompareToMethodSerializers() + { + if (IsCompareToMethod()) + { + var objectExpression = node.Object; + var valueExpression = arguments[0]; + + DeduceSerializers(objectExpression, valueExpression); + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsCompareToMethod() + { + return + method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(int) && + method.Name == "CompareTo" && + arguments.Count == 1 && + arguments[0].Type == node.Object.Type; + } + } + + void DeduceConcatMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Concat, QueryableMethod.Concat)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + + DeduceCollectionAndCollectionSerializers(firstExpression, secondExpression); + DeduceCollectionAndCollectionSerializers(node, firstExpression); + } + else if (method.IsOneOf(__stringConcatMethods)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceConstantMethodSerializers() + { + if (method.IsOneOf(MqlMethod.ConstantWithRepresentation, MqlMethod.ConstantWithSerializer)) + { + var valueExpression = arguments[0]; + IBsonSerializer serializer = null; + + if (IsNotKnown(node) || IsNotKnown(valueExpression)) + { + if (method.Is(MqlMethod.ConstantWithRepresentation)) + { + var representationExpression = arguments[1]; + + var representation = representationExpression.GetConstantValue(node); + var defaultSerializer = BsonSerializer.LookupSerializer(valueExpression.Type); // TODO: don't use BsonSerializer + if (defaultSerializer is IRepresentationConfigurable representationConfigurableSerializer) + { + serializer = representationConfigurableSerializer.WithRepresentation(representation); + } + } + else if (method.Is(MqlMethod.ConstantWithSerializer)) + { + var serializerExpression = arguments[1]; + serializer = serializerExpression.GetConstantValue(node); + } + } + + DeduceSerializer(valueExpression, serializer); + DeduceSerializer(node, serializer); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsKeyMethodSerializers() + { + if (IsDictionaryContainsKeyMethod(out var keyExpression)) + { + var dictionaryExpression = node.Object; + if (IsNotKnown(keyExpression) && IsKnown(dictionaryExpression, out var dictionarySerializer)) + { + var keySerializer = (dictionarySerializer as IBsonDictionarySerializer)?.KeySerializer; + AddKnownSerializer(keyExpression, keySerializer); + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsMethodSerializers() + { + if (method.IsOneOf(__stringContainsMethods)) + { + DeduceReturnsBooleanSerializer(); + } + else if (IsCollectionContainsMethod(out var collectionExpression, out var itemExpression)) + { + DeduceCollectionAndItemSerializers(collectionExpression, itemExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsCollectionContainsMethod(out Expression collectionExpression, out Expression itemExpression) + { + if (method.IsPublic && + method.ReturnType == typeof(bool) && + method.Name == "Contains" && + method.GetParameters().Length == (method.IsStatic ? 2 : 1)) + { + collectionExpression = method.IsStatic ? arguments[0] : node.Object; + itemExpression = method.IsStatic ? arguments[1] : arguments[0]; + return true; + } + + collectionExpression = null; + itemExpression = null; + return false; + } + } + + void DeduceContainsValueMethodSerializers() + { + if (IsContainsValueInstanceMethod(out var collectionExpression, out var valueExpression)) + { + if (IsNotKnown(valueExpression) && + IsKnown(collectionExpression, out var collectionSerializer)) + { + if (collectionSerializer is IBsonDictionarySerializer dictionarySerializer) + { + var valueSerializer = dictionarySerializer.ValueSerializer; + AddKnownSerializer(valueExpression, valueSerializer); + } + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsContainsValueInstanceMethod(out Expression collectionExpression, out Expression valueExpression) + { + if (method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "ContainsValue" && + method.GetParameters() is var parameters && + parameters.Length == 1) + { + collectionExpression = node.Object; + valueExpression = arguments[0]; + return true; + } + + collectionExpression = null; + valueExpression = null; + return false; + } + } + + void DeduceConvertMethodSerializers() + { + if (method.Is(MqlMethod.Convert)) + { + if (IsNotKnown(node)) + { + var toType = method.GetGenericArguments()[1]; + var resultSerializer = GetResultSerializer(node, toType); + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + + static IBsonSerializer GetResultSerializer(Expression expression, Type toType) + { + var isNullable = toType.IsNullable(); + var valueType = isNullable ? Nullable.GetUnderlyingType(toType) : toType; + + var valueSerializer = (IBsonSerializer)(Type.GetTypeCode(valueType) switch + { + TypeCode.Boolean => BooleanSerializer.Instance, + TypeCode.Byte => ByteSerializer.Instance, + TypeCode.Char => StringSerializer.Instance, + TypeCode.DateTime => DateTimeSerializer.Instance, + TypeCode.Decimal => DecimalSerializer.Instance, + TypeCode.Double => DoubleSerializer.Instance, + TypeCode.Int16 => Int16Serializer.Instance, + TypeCode.Int32 => Int32Serializer.Instance, + TypeCode.Int64 => Int64Serializer.Instance, + TypeCode.SByte => SByteSerializer.Instance, + TypeCode.Single => SingleSerializer.Instance, + TypeCode.String => StringSerializer.Instance, + TypeCode.UInt16 => UInt16Serializer.Instance, + TypeCode.UInt32 => Int32Serializer.Instance, + TypeCode.UInt64 => UInt64Serializer.Instance, + + _ when valueType == typeof(byte[]) => ByteArraySerializer.Instance, + _ when valueType == typeof(BsonBinaryData) => BsonBinaryDataSerializer.Instance, + _ when valueType == typeof(Decimal128) => Decimal128Serializer.Instance, + _ when valueType == typeof(Guid) => GuidSerializer.StandardInstance, + _ when valueType == typeof(ObjectId) => ObjectIdSerializer.Instance, + + _ => throw new ExpressionNotSupportedException(expression, because: $"{toType} is not a valid TTo for Convert") + }); + + return isNullable ? NullableSerializer.Create(valueSerializer) : valueSerializer; + } + } + + void DeduceCosMethodSerializers() + { + if (method.Is(MathMethod.Cos)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCoshMethodSerializers() + { + if (method.Is(MathMethod.Cosh)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCreateMethodSerializers() + { +#if NET6_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER + if (method.Is(KeyValuePairMethod.Create)) + { + if (AnyIsNotKnown(arguments) && IsKnown(node, out var nodeSerializer)) + { + var keyExpression = arguments[0]; + var valueExpression = arguments[1]; + + if (nodeSerializer.IsKeyValuePairSerializer(out _, out _, out var keySerializer, out var valueSerializer)) + { + DeduceSerializer(keyExpression, keySerializer); + DeduceSerializer(valueExpression, valueSerializer); + } + } + + if (IsNotKnown(node) && AllAreKnown(arguments, out var argumentSerializers)) + { + var keySerializer = argumentSerializers[0]; + var valueSerializer = argumentSerializers[1]; + var keyValuePairSerializer = KeyValuePairSerializer.Create(BsonType.Document, keySerializer, valueSerializer); + AddKnownSerializer(node, keyValuePairSerializer); + } + } + else + #endif + if (method.IsOneOf(__tupleOrValueTupleCreateMethods)) + { + if (AnyIsNotKnown(arguments) && IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is IBsonTupleSerializer tupleSerializer) + { + for (var i = 1; i <= arguments.Count; i++) + { + var argumentExpression = arguments[i]; + if (IsNotKnown(argumentExpression)) + { + var itemSerializer = tupleSerializer.GetItemSerializer(i); + if (i == 8) + { + itemSerializer = (itemSerializer as IBsonTupleSerializer)?.GetItemSerializer(1); + } + AddKnownSerializer(argumentExpression, itemSerializer); + } + } + } + } + + if (IsNotKnown(node) && AllAreKnown(arguments, out var argumentSerializers)) + { + if (arguments.Count == 8) + { + var tempList = new List(argumentSerializers); + tempList[7] = method.ReturnType.Name.StartsWith("ValueTuple") ? + ValueTupleSerializer.Create([argumentSerializers[7]]) : + TupleSerializer.Create([argumentSerializers[7]]); + argumentSerializers = tempList; + } + + var resultSerializer = method.ReturnType.Name.StartsWith("ValueTuple") ? + ValueTupleSerializer.Create(argumentSerializers) : + TupleSerializer.Create(argumentSerializers); + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCountMethodSerializers() + { + if (method.IsOneOf(__countMethods)) + { + if (method.IsOneOf(__countWithPredicateMethods)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDefaultIfEmptyMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.DefaultIfEmpty, EnumerableMethod.DefaultIfEmptyWithDefaultValue, QueryableMethod.DefaultIfEmpty, QueryableMethod.DefaultIfEmptyWithDefaultValue)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableMethod.DefaultIfEmptyWithDefaultValue, QueryableMethod.DefaultIfEmptyWithDefaultValue)) + { + var defaultValueExpression = arguments[1]; + DeduceItemAndCollectionSerializers(defaultValueExpression, sourceExpression); + } + + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDegreesToRadiansMethodSerializers() + { + if (method.Is(MongoDBMathMethod.DegreesToRadians)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDistinctMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Distinct, QueryableMethod.Distinct)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDocumentsMethodSerializers() + { + if (method.IsOneOf(MongoQueryableMethod.Documents, MongoQueryableMethod.DocumentsWithSerializer)) + { + if (IsNotKnown(node)) + { + IBsonSerializer documentSerializer; + if (method.Is(MongoQueryableMethod.DocumentsWithSerializer)) + { + var documentSerializerExpression = arguments[2]; + documentSerializer = documentSerializerExpression.GetConstantValue(node); + } + else + { + var documentsParameter = method.GetParameters()[1]; + var documentType = documentsParameter.ParameterType.GetElementType(); + documentSerializer = BsonSerializer.LookupSerializer(documentType); // TODO: don't use static registry + } + + var nodeSerializer = IQueryableSerializer.Create(documentSerializer); + AddKnownSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceElementAtMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.ElementAt, EnumerableMethod.ElementAtOrDefault, QueryableMethod.ElementAt, QueryableMethod.ElementAtOrDefault, QueryableMethod.ElementAtOrDefault)) + { + var sourceExpression = arguments[0]; + DeduceItemAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceEqualsMethodSerializers() + { + if (IsEqualsReturningBooleanMethod(out var expression1, out var expression2)) + { + DeduceSerializers(expression1, expression2); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsEqualsReturningBooleanMethod(out Expression expression1, out Expression expression2) + { + if (method.Name == "Equals" && + method.ReturnType == typeof(bool) && + method.IsPublic) + { + if (method.IsStatic && + arguments.Count == 2) + { + expression1 = arguments[0]; + expression2 = arguments[1]; + return true; + } + + if (!method.IsStatic && + arguments.Count == 1) + { + expression1 = node.Object; + expression2 = arguments[0]; + return true; + } + + if (method.Is(StringMethod.EqualsWithComparisonType)) + { + expression1 = node.Object; + expression2 = arguments[0]; + return true; + } + + if (method.Is(StringMethod.StaticEqualsWithComparisonType)) + { + expression1 = arguments[0]; + expression2 = arguments[1]; + return true; + } + } + + expression1 = null; + expression2 = null; + return false; + } + } + + void DeduceExceptMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Except, QueryableMethod.Except)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + DeduceCollectionAndCollectionSerializers(secondExpression, firstExpression); + DeduceCollectionAndCollectionSerializers(node, firstExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceExistsMethodSerializers() + { + if (method.Is(ArrayMethod.Exists) || ListMethod.IsExistsMethod(method)) + { + var collectionExpression = method.IsStatic ? arguments[0] : node.Object; + var predicateExpression = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, method.IsStatic ? arguments[1] : arguments[0]); + var predicateParameter = predicateExpression.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, collectionExpression); + DeduceReturnsBooleanSerializer(); + } + else if (method.Is(MqlMethod.Exists)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceExpMethodSerializers() + { + if (method.IsOneOf(MathMethod.Exp)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceFieldMethodSerializers() + { + if (method.Is(MqlMethod.Field)) + { + if (IsNotKnown(node)) + { + var fieldSerializerExpression = arguments[2]; + var fieldSerializer = fieldSerializerExpression.GetConstantValue(node); + if (fieldSerializer == null) + { + throw new ExpressionNotSupportedException(node, because: "fieldSerializer is null"); + } + + AddKnownSerializer(node, fieldSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceFirstOrLastMethodsSerializers() + { + if (method.IsOneOf(__firstOrLastMethods)) + { + if (method.IsOneOf(__firstOrLastWithPredicateMethods)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsOneSourceItemSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceGetItemMethodSerializers() + { + if (IsNotKnown(node)) + { + if (BsonValueMethod.IsGetItemWithIntMethod(method) || BsonValueMethod.IsGetItemWithStringMethod(method)) + { + AddKnownSerializer(node, BsonValueSerializer.Instance); + } + else if (IsInstanceGetItemMethod(out var collectionExpression, out var indexExpression)) + { + if (IsKnown(collectionExpression, out var collectionSerializer)) + { + if (collectionSerializer is IBsonArraySerializer arraySerializer && + indexExpression.Type == typeof(int) && + arraySerializer.GetItemSerializer() is var itemSerializer && + itemSerializer.ValueType == method.ReturnType) + { + AddKnownSerializer(node, itemSerializer); + } + else if ( + collectionSerializer is IBsonDictionarySerializer dictionarySerializer && + dictionarySerializer.KeySerializer is var keySerializer && + dictionarySerializer.ValueSerializer is var valueSerializer && + keySerializer.ValueType == indexExpression.Type && + valueSerializer.ValueType == method.ReturnType) + { + AddKnownSerializer(node, valueSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + bool IsInstanceGetItemMethod(out Expression collectionExpression, out Expression indexExpression) + { + if (method.IsStatic == false && + method.Name == "get_Item") + { + collectionExpression = node.Object; + indexExpression = arguments[0]; + return true; + } + + collectionExpression = null; + indexExpression = null; + return false; + } + } + + void DeduceGroupByMethodSerializers() + { + if (method.IsOneOf(__groupByMethods)) + { + var sourceExpression = arguments[0]; + var keySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var keySelectorParameter = keySelectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + + if (method.IsOneOf(EnumerableMethod.GroupByWithKeySelector, QueryableMethod.GroupByWithKeySelector)) + { + if (IsNotKnown(node) && IsKnown(keySelectorLambda.Body, out var keySerializer) && IsItemSerializerKnown(sourceExpression, out var elementSerializer)) + { + var groupingSerializer = IGroupingSerializer.Create(keySerializer, elementSerializer); + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, groupingSerializer); + AddKnownSerializer(node, nodeSerializer); + } + } + else if (method.IsOneOf(EnumerableMethod.GroupByWithKeySelectorAndElementSelector, QueryableMethod.GroupByWithKeySelectorAndElementSelector)) + { + var elementSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var elementSelectorParameter = elementSelectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(elementSelectorParameter, sourceExpression); + if (IsNotKnown(node) && IsKnown(keySelectorLambda.Body, out var keySerializer) && IsKnown(elementSelectorLambda.Body, out var elementSerializer)) + { + var groupingSerializer = IGroupingSerializer.Create(keySerializer, elementSerializer); + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, groupingSerializer); + AddKnownSerializer(node, nodeSerializer); + } + } + else if (method.IsOneOf(EnumerableMethod.GroupByWithKeySelectorAndResultSelector, QueryableMethod.GroupByWithKeySelectorAndResultSelector)) + { + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var resultSelectorKeyParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorElementsParameter = resultSelectorLambda.Parameters[1]; + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceSerializers(resultSelectorKeyParameter, keySelectorLambda.Body); + DeduceCollectionAndCollectionSerializers(resultSelectorElementsParameter, sourceExpression); + DeduceResultSerializer(resultSelectorLambda.Body); + } + else if (method.IsOneOf(EnumerableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector, QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector)) + { + var elementSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var elementSelectorParameter = elementSelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var resultSelectorKeyParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorElementsParameter = resultSelectorLambda.Parameters[1]; + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceItemAndCollectionSerializers(elementSelectorParameter, sourceExpression); + DeduceSerializers(resultSelectorKeyParameter, keySelectorLambda.Body); + DeduceCollectionAndItemSerializers(resultSelectorElementsParameter, elementSelectorLambda.Body); + DeduceResultSerializer(resultSelectorLambda.Body); + } + + void DeduceResultSerializer(Expression resultExpression) + { + if (IsNotKnown(node) && IsKnown(resultExpression, out var resultSerializer)) + { + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultSerializer); + AddKnownSerializer(node, nodeSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceGroupJoinMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.GroupJoin, QueryableMethod.GroupJoin)) + { + var outerExpression = arguments[0]; + var innerExpression = arguments[1]; + var outerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var outerKeySelectorItemParameter = outerKeySelectorLambda.Parameters.Single(); + var innerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var innerKeySelectorItemParameter = innerKeySelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var resultSelectorOuterItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorInnerItemsParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(outerKeySelectorItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(innerKeySelectorItemParameter, innerExpression); + DeduceItemAndCollectionSerializers(resultSelectorOuterItemParameter, outerExpression); + DeduceCollectionAndCollectionSerializers(resultSelectorInnerItemsParameter, innerExpression); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIndexOfMethodSerializers() + { + if (method.IsOneOf(__indexOfMethods)) + { + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceInjectMethodSerializers() + { + if (method.Is(LinqExtensionsMethod.Inject)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIntersectMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Intersect, QueryableMethod.Intersect)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsMatchMethodSerializers() + { + if (method.Is(RegexMethod.StaticIsMatch)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsMissingOrIsNullOrMissingMethodSerializers() + { + if (method.IsOneOf(MqlMethod.IsMissing, MqlMethod.IsNullOrMissing)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsSubsetOfMethodSerializers() + { + if (IsSubsetOfMethod(method)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + static bool IsSubsetOfMethod(MethodInfo method) + { + var declaringType = method.DeclaringType; + var parameters = method.GetParameters(); + return + method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "IsSubsetOf" && + parameters.Length == 1 && + parameters[0] is var otherParameter && + declaringType.ImplementsIEnumerable(out var declaringTypeItemType) && + otherParameter.ParameterType.ImplementsIEnumerable(out var otherTypeItemType) && + otherTypeItemType == declaringTypeItemType; + } + } + + void DeduceJoinMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Join, QueryableMethod.Join)) + { + var outerExpression = arguments[0]; + var innerExpression = arguments[1]; + var outerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var outerKeySelectorItemParameter = outerKeySelectorLambda.Parameters.Single(); + var innerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var innerKeySelectorItemParameter = innerKeySelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var resultSelectorOuterItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorInnerItemsParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(outerKeySelectorItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(innerKeySelectorItemParameter, innerExpression); + DeduceItemAndCollectionSerializers(resultSelectorOuterItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(resultSelectorInnerItemsParameter, innerExpression); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsNullOrEmptyOrIsNullOrWhiteSpaceMethodSerializers() + { + if (method.IsOneOf(StringMethod.IsNullOrEmpty, StringMethod.IsNullOrWhiteSpace)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceLogMethodSerializers() + { + if (method.IsOneOf(__logMethods)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceLookupMethodSerializers() + { + if (method.IsOneOf(__lookupMethods)) + { + var sourceExpression = arguments[0]; + + if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignField)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(foreignFieldLambdaParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(documentsLambda.Body, out var documentSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, documentSerializer); + AddKnownSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLambdaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(foreignFieldLambdaParameter, documentsLambda.Body); + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(pipelineLambdaForeignQueryableParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineDocumentSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineDocumentSerializer); + AddKnownSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndPipeline)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var pipelineLambdaSourceParameter = pipelineLambda.Parameters[0]; + var pipelineLambdaQueryableDocumentParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(pipelineLambdaSourceParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(pipelineLambdaQueryableDocumentParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddKnownSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + + if (method.Is(MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignField)) + { + var fromExpression = arguments[1]; + var fromCollection = fromExpression.GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceSerializer(foreignFieldLambdaParameter, foreignDocumentSerializer); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, foreignDocumentSerializer); + AddKnownSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline)) + { + var fromExpression = arguments[1]; + var fromCollection = fromExpression.GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLamdbaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceSerializer(foreignFieldLambdaParameter, foreignDocumentSerializer); + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + + if (IsNotKnown(pipelineLamdbaForeignQueryableParameter)) + { + var foreignQueryableSerializer = IQueryableSerializer.Create(foreignDocumentSerializer); + AddKnownSerializer(pipelineLamdbaForeignQueryableParameter, foreignQueryableSerializer); + } + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultsSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddKnownSerializer(node, IQueryableSerializer.Create(lookupResultsSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithFromAndPipeline)) + { + var fromCollection = arguments[1].GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLamdbaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + + if (IsNotKnown(pipelineLamdbaForeignQueryableParameter)) + { + var foreignQueryableSerializer = IQueryableSerializer.Create(foreignDocumentSerializer); + AddKnownSerializer(pipelineLamdbaForeignQueryableParameter, foreignQueryableSerializer); + } + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddKnownSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceMatchingElementsMethodSerializers() + { + if (method.IsOneOf(MongoEnumerableMethod.AllElements, MongoEnumerableMethod.AllMatchingElements, MongoEnumerableMethod.FirstMatchingElement)) + { + DeduceReturnsOneSourceItemSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceMaxOrMinMethodSerializers() + { + if (method.IsOneOf(__maxOrMinMethods)) + { + if (method.IsOneOf(__maxOrMinWithSelectorMethods)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorItemParameter = selectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(selectorItemParameter, sourceExpression); + DeduceSerializers(node, selectorLambda.Body); + } + else + { + DeduceReturnsOneSourceItemSerializer(); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceOfTypeMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.OfType, QueryableMethod.OfType)) + { + var sourceExpression = arguments[0]; + var resultType = method.GetGenericArguments()[0]; + + if (IsNotKnown(node) && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var resultItemSerializer = sourceItemSerializer.GetDerivedTypeSerializer(resultType); + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceOrderByMethodSerializers() + { + if (method.IsOneOf( + EnumerableMethod.OrderBy, + EnumerableMethod.OrderByDescending, + EnumerableMethod.ThenBy, + EnumerableMethod.ThenByDescending, + QueryableMethod.OrderBy, + QueryableMethod.OrderByDescending, + QueryableMethod.ThenBy, + QueryableMethod.ThenByDescending)) + { + var sourceExpression = arguments[0]; + var keySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var keySelectorParameter = keySelectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeducePickMethodSerializers() + { + if (method.IsOneOf(__pickMethods)) + { + if (method.IsOneOf(__pickWithSortDefinitionMethods)) + { + var sortByExpression = arguments[1]; + if (IsNotKnown(sortByExpression)) + { + var ignoreSubTreeSerializer = IgnoreSubtreeSerializer.Create(sortByExpression.Type); + AddKnownSerializer(sortByExpression, ignoreSubTreeSerializer); + } + } + + var sourceExpression = arguments[0]; + if (IsKnown(sourceExpression, out var sourceSerializer)) + { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + + var selectorExpression = arguments[method.IsOneOf(__pickWithSortDefinitionMethods) ? 2 : 1]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, selectorExpression); + var selectorSourceItemParameter = selectorLambda.Parameters.Single(); + if (IsNotKnown(selectorSourceItemParameter)) + { + AddKnownSerializer(selectorSourceItemParameter, sourceItemSerializer); + } + } + + if (method.IsOneOf(__pickWithComputedNMethods)) + { + var keyExpression = arguments[method.IsOneOf(__pickWithSortDefinitionMethods) ? 3 : 2]; + if (IsKnown(keyExpression, out var keySerializer)) + { + var nExpression = arguments[method.IsOneOf(__pickWithSortDefinitionMethods) ? 4 : 3]; + var nLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, nExpression); + var nLambdaKeyParameter = nLambda.Parameters.Single(); + + if (IsNotKnown(nLambdaKeyParameter)) + { + AddKnownSerializer(nLambdaKeyParameter, keySerializer); + } + } + } + + if (IsNotKnown(node)) + { + var selectorExpressionIndex = method switch + { + _ when method.Is(EnumerableMethod.Bottom) => 2, + _ when method.Is(EnumerableMethod.BottomN) => 2, + _ when method.Is(EnumerableMethod.BottomNWithComputedN) => 2, + _ when method.Is(EnumerableMethod.FirstN) => 1, + _ when method.Is(EnumerableMethod.FirstNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.LastN) => 1, + _ when method.Is(EnumerableMethod.LastNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.MaxN) => 1, + _ when method.Is(EnumerableMethod.MaxNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.MinN) => 1, + _ when method.Is(EnumerableMethod.MinNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.Top) => 2, + _ when method.Is(EnumerableMethod.TopN) => 2, + _ when method.Is(EnumerableMethod.TopNWithComputedN) => 2, + _ => throw new ArgumentException($"Unrecognized method: {method.Name}.") + }; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[selectorExpressionIndex]); + + if (IsKnown(selectorLambda.Body, out var selectorItemSerializer)) + { + var nodeSerializer = method.IsOneOf(EnumerableMethod.Bottom, EnumerableMethod.Top) ? + selectorItemSerializer : + IEnumerableSerializer.Create(selectorItemSerializer); + AddKnownSerializer(node, nodeSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceParseMethodSerializers() + { + if (IsNotKnown(node)) + { + if (IsParseMethod(method)) + { + var nodeSerializer = GetParseResultSerializer(method.DeclaringType); + AddKnownSerializer(node, nodeSerializer); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + static bool IsParseMethod(MethodInfo method) + { + var parameters = method.GetParameters(); + return + method.IsPublic && + method.IsStatic && + method.ReturnType == method.DeclaringType && + parameters.Length == 1 && + parameters[0].ParameterType == typeof(string); + } + + static IBsonSerializer GetParseResultSerializer(Type declaringType) + { + return declaringType switch + { + _ when declaringType == typeof(DateTime) => DateTimeSerializer.Instance, + _ when declaringType == typeof(decimal) => DecimalSerializer.Instance, + _ when declaringType == typeof(double) => DoubleSerializer.Instance, + _ when declaringType == typeof(int) => Int32Serializer.Instance, + _ when declaringType == typeof(short) => Int64Serializer.Instance, + _ => UnknowableSerializer.Create(declaringType) + }; + } + } + + void DeducePowMethodSerializers() + { + if (method.IsOneOf(MathMethod.Pow)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRadiansToDegreesMethodSerializers() + { + if (method.Is(MongoDBMathMethod.RadiansToDegrees)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceReturnsBooleanSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, BooleanSerializer.Instance); + } + } + + void DeduceReturnsDateTimeSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, DateTimeSerializer.UtcInstance); + } + } + + void DeduceReturnsDecimalSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, DecimalSerializer.Instance); + } + } + + void DeduceReturnsDoubleSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, DoubleSerializer.Instance); + } + } + + void DeduceReturnsInt32Serializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, Int32Serializer.Instance); + } + } + + void DeduceReturnsInt64Serializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, Int64Serializer.Instance); + } + } + + void DeduceReturnsNullableDecimalSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, NullableSerializer.NullableDecimalInstance); + } + } + + void DeduceReturnsNullableDoubleSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, NullableSerializer.NullableDoubleInstance); + } + } + + void DeduceReturnsNullableInt32Serializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, NullableSerializer.NullableInt32Instance); + } + } + + void DeduceReturnsNullableInt64Serializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, NullableSerializer.NullableInt64Instance); + } + } + + void DeduceReturnsNullableSingleSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, NullableSerializer.NullableSingleInstance); + } + } + + void DeduceReturnsNumericSerializer() + { + if (IsNotKnown(node) && node.Type.IsNumeric()) + { + var numericSerializer = StandardSerializers.GetSerializer(node.Type); + AddKnownSerializer(node, numericSerializer); + } + } + + void DeduceReturnsNumericOrNullableNumericSerializer() + { + if (IsNotKnown(node) && node.Type.IsNumericOrNullableNumeric()) + { + var numericSerializer = StandardSerializers.GetSerializer(node.Type); + AddKnownSerializer(node, numericSerializer); + } + } + + void DeduceReturnsOneSourceItemSerializer() + { + var sourceExpression = arguments[0]; + + if (IsNotKnown(node) && IsKnown(sourceExpression, out var sourceSerializer)) + { + var nodeSerializer = sourceSerializer is IUnknowableSerializer ? + UnknowableSerializer.Create(node.Type) : + ArraySerializerHelper.GetItemSerializer(sourceSerializer); + AddKnownSerializer(node, nodeSerializer); + } + } + + void DeduceReturnsSingleSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, SingleSerializer.Instance); + } + } + + void DeduceReturnsStringSerializer() + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, StringSerializer.Instance); + } + } + + void DeduceReturnsTimeSpanSerializer(TimeSpanUnits units) + { + if (IsNotKnown(node)) + { + var resultSerializer = new TimeSpanSerializer(BsonType.Int64, units); + AddKnownSerializer(node, resultSerializer); + } + } + + void DeduceRangeMethodSerializers() + { + if (method.Is(EnumerableMethod.Range)) + { + var elementExpression = arguments[0]; + DeduceCollectionAndItemSerializers(node, elementExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRepeatMethodSerializers() + { + if (method.Is(EnumerableMethod.Repeat)) + { + var elementExpression = arguments[0]; + DeduceCollectionAndItemSerializers(node, elementExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceReverseMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Reverse, QueryableMethod.Reverse)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRoundMethodSerializers() + { + if (method.IsOneOf(MathMethod.RoundWithDecimal, MathMethod.RoundWithDecimalAndDecimals, MathMethod.RoundWithDouble, MathMethod.RoundWithDoubleAndDigits)) + { + if (IsNotKnown(node)) + { + var resultSerializer = StandardSerializers.GetSerializer(node.Type); + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSelectMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Select, QueryableMethod.Select)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorParameter, sourceExpression); + DeduceCollectionAndItemSerializers(node, selectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSelectManySerializers() + { + if (method.IsOneOf(__selectManyMethods)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(__selectManyWithResultSelectorMethods)) + { + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceParameter = selectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(selectorSourceParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, selectorLambda.Body); + } + + if (method.IsOneOf(__selectManyWithCollectionSelectorAndResultSelectorMethods)) + { + var collectionSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + + var collectionSelectorSourceItemParameter = collectionSelectorLambda.Parameters.Single(); + var resultSelectorSourceItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorCollectionItemParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(collectionSelectorSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(resultSelectorSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(resultSelectorCollectionItemParameter, collectionSelectorLambda.Body); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSequenceEqualMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.SequenceEqual, QueryableMethod.SequenceEqual)) + { + var source1Expression = arguments[0]; + var source2Expression = arguments[1]; + + DeduceCollectionAndCollectionSerializers(source1Expression, source2Expression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSetEqualsMethodSerializers() + { + if (IsSetEqualsMethod(method)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + static bool IsSetEqualsMethod(MethodInfo method) + { + var declaringType = method.DeclaringType; + var parameters = method.GetParameters(); + return + method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "SetEquals" && + parameters.Length == 1 && + parameters[0] is var otherParameter && + declaringType.ImplementsIEnumerable(out var declaringTypeItemType) && + otherParameter.ParameterType.ImplementsIEnumerable(out var otherTypeItemType) && + otherTypeItemType == declaringTypeItemType; + } + } + + void DeduceSetWindowFieldsMethodSerializers() + { + if (method.Is(EnumerableMethod.First)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceShiftMethodSerializers() + { + if (method.IsOneOf(WindowMethod.Shift, WindowMethod.ShiftWithDefaultValue)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceItemParameter = selectorLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(selectorSourceItemParameter, sourceExpression); + + if (IsNotKnown(node) && IsKnown(selectorLambda.Body, out var resultSerializer)) + { + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSinMethodSerializers() + { + if (method.Is(MathMethod.Sin)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSinhMethodSerializers() + { + if (method.Is(MathMethod.Sinh)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSplitMethodSerializers() + { + if (method.IsOneOf(__splitMethods)) + { + if (IsNotKnown(node)) + { + var nodeSerializer = ArraySerializer.Create(StringSerializer.Instance); + AddKnownSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSqrtMethodSerializers() + { + if (method.Is(MathMethod.Sqrt)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStandardDeviationMethodSerializers() + { + if (method.IsOneOf(__standardDeviationMethods)) + { + if (method.IsOneOf(__standardDeviationWithSelectorMethods)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorItemParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorItemParameter, sourceExpression); + } + + DeduceReturnsNumericOrNullableNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceEndsWithOrStartsWithMethodSerializers() + { + if (method.IsOneOf(__stringEndsWithOrStartsWithMethods)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStringInMethodSerializers() + { + if (method.IsOneOf(StringMethod.StringInWithEnumerable, StringMethod.StringInWithParams)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStrLenBytesMethodSerializers() + { + if (method.Is(StringMethod.StrLenBytes)) + { + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSubstringMethodSerializers() + { + if (method.IsOneOf(StringMethod.Substring, StringMethod.SubstringWithLength, StringMethod.SubstrBytes)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSubtractMethodSerializers() + { + if (method.IsOneOf(__subtractReturningDateTimeMethods)) + { + DeduceReturnsDateTimeSerializer(); + } + else if (method.IsOneOf(__subtractReturningInt64Methods)) + { + DeduceReturnsInt64Serializer(); + } + else if (method.IsOneOf(__subtractReturningTimeSpanWithMillisecondsUnitsMethods)) + { + var units = TimeSpanUnits.Milliseconds; + DeduceReturnsTimeSpanSerializer(units); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSumMethodSerializers() + { + if (method.IsOneOf(__sumMethods)) + { + if (method.IsOneOf(__sumWithSelectorMethods)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorParameter, sourceExpression); + } + + var returnType = node.Type; + switch (returnType) + { + case not null when returnType == typeof(decimal): DeduceReturnsDecimalSerializer(); break; + case not null when returnType == typeof(double): DeduceReturnsDoubleSerializer(); break; + case not null when returnType == typeof(int): DeduceReturnsInt32Serializer(); break; + case not null when returnType == typeof(long): DeduceReturnsInt64Serializer(); break; + case not null when returnType == typeof(float): DeduceReturnsSingleSerializer(); break; + case not null when returnType == typeof(decimal?): DeduceReturnsNullableDecimalSerializer(); break; + case not null when returnType == typeof(double?): DeduceReturnsNullableDoubleSerializer(); break; + case not null when returnType == typeof(int?): DeduceReturnsNullableInt32Serializer(); break; + case not null when returnType == typeof(long?): DeduceReturnsNullableInt64Serializer(); break; + case not null when returnType == typeof(float?): DeduceReturnsNullableSingleSerializer(); break; + + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSkipOrTakeMethodSerializers() + { + if (method.IsOneOf(__skipOrTakeMethods)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(__skipOrTakeWhileMethods)) + { + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceTanMethodSerializers() + { + if (method.Is(MathMethod.Tan)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceTanhMethodSerializers() + { + if (method.Is(MathMethod.Tanh)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceToArrayMethodSerializers() + { + if (IsToArrayMethod(out var sourceExpression)) + { + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsToArrayMethod(out Expression sourceExpression) + { + if (method.IsPublic && + method.Name == "ToArray" && + method.GetParameters().Length == (method.IsStatic ? 1 : 0)) + { + sourceExpression = method.IsStatic ? arguments[0] : node.Object; + return true; + } + + sourceExpression = null; + return false; + } + } + + void DeduceToListSerializers() + { + if (IsNotKnown(node)) + { + var source = method.IsStatic ? arguments[0] : node.Object; + if (IsKnown(source, out var sourceSerializer)) + { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + var resultSerializer = ListSerializer.Create(sourceItemSerializer); + AddKnownSerializer(node, resultSerializer); + } + } + } + + void DeduceToLowerOrToUpperSerializers() + { + if (method.IsOneOf(__toLowerOrToUpperMethods)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceToStringSerializers() + { + DeduceReturnsStringSerializer(); + } + + void DeduceTruncateSerializers() + { + if (method.IsOneOf(DateTimeMethod.Truncate, DateTimeMethod.TruncateWithBinSize, DateTimeMethod.TruncateWithBinSizeAndTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else if (method.IsOneOf(MathMethod.TruncateDecimal, MathMethod.TruncateDouble)) + { + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceUnionSerializers() + { + if (method.IsOneOf(EnumerableMethod.Union, QueryableMethod.Union)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceUnknownMethodSerializer() + { + DeduceUnknowableSerializer(node); + } + + void DeduceWeekSerializers() + { + if (method.IsOneOf(DateTimeMethod.Week, DateTimeMethod.WeekWithTimezone)) + { + if (IsNotKnown(node)) + { + AddKnownSerializer(node, Int32Serializer.Instance); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceWhereSerializers() + { + if (method.IsOneOf(__whereMethods)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceZipSerializers() + { + if (method.IsOneOf(EnumerableMethod.Zip, QueryableMethod.Zip)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var resultSelectorFirstParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorSecondParameter = resultSelectorLambda.Parameters[1]; + + if (IsNotKnown(resultSelectorFirstParameter) && IsKnown(firstExpression, out var firstSerializer)) + { + var firstItemSerializer = ArraySerializerHelper.GetItemSerializer(firstSerializer); + AddKnownSerializer(resultSelectorFirstParameter, firstItemSerializer); + } + + if (IsNotKnown(resultSelectorSecondParameter) && IsKnown(secondExpression, out var secondSerializer)) + { + var secondItemSerializer = ArraySerializerHelper.GetItemSerializer(secondSerializer); + AddKnownSerializer(resultSelectorSecondParameter, secondItemSerializer); + } + + if (IsNotKnown(node) && IsKnown(resultSelectorLambda.Body, out var resultItemSerializer)) + { + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddKnownSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + bool IsDictionaryContainsKeyMethod(out Expression keyExpression) + { + if (method.DeclaringType.Name.Contains("Dictionary") && + method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "ContainsKey" && + method.GetParameters().Length == 1) + { + keyExpression = arguments[0]; + return true; + } + + keyExpression = null; + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitNew.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitNew.cs new file mode 100644 index 00000000000..f189ea07e2e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitNew.cs @@ -0,0 +1,139 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitNew(NewExpression node) + { + var constructor = node.Constructor; + var arguments = node.Arguments; + + if (IsKnown(node, out var nodeSerializer) && + arguments.Any(IsNotKnown)) + { + if (!typeof(BsonValue).IsAssignableFrom(node.Type) && + nodeSerializer is IBsonDocumentSerializer) + { + var matchingMemberSerializationInfos = nodeSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(node, node.Constructor); + for (var i = 0; i < matchingMemberSerializationInfos.Count; i++) + { + var argument = arguments[i]; + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; + + if (IsNotKnown(argument)) + { + // arg => arg: matchingMemberSerializer + AddKnownSerializer(argument, matchingMemberSerializationInfo.Serializer); + } + } + } + } + + base.VisitNew(node); + + if (IsNotKnown(node)) + { + var knownSerializer = GetKnownSerializer(constructor); + if (knownSerializer != null) + { + AddKnownSerializer(node, knownSerializer); + } + } + + return node; + + IBsonSerializer GetKnownSerializer(ConstructorInfo constructor) + { + if (constructor == null) + { + return CreateNewExpressionSerializer(node, node, bindings: null); + } + else if (constructor.DeclaringType == typeof(BsonDocument)) + { + return BsonDocumentSerializer.Instance; + } + else if (constructor.DeclaringType == typeof(BsonValue)) + { + return BsonValueSerializer.Instance; + } + else if (constructor.DeclaringType == typeof(DateTime)) + { + return DateTimeSerializer.Instance; + } + else if (DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer) && + itemSerializer.IsKeyValuePairSerializer(out _, out _, out var keySerializer, out var valueSerializer)) + { + return DictionarySerializer.Create(keySerializer, valueSerializer); + } + } + else if (HashSetConstructor.IsWithCollectionConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + return HashSetSerializer.Create(itemSerializer); + } + } + else if (ListConstructor.IsWithCollectionConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + return ListSerializer.Create(itemSerializer); + } + } + else if (KeyValuePairConstructor.IsWithKeyAndValueConstructor(constructor)) + { + var key = arguments[0]; + var value = arguments[1]; + if (IsKnown(key, out var keySerializer) && + IsKnown(value, out var valueSerializer)) + { + return KeyValuePairSerializer.Create(BsonType.Document, keySerializer, valueSerializer); + } + } + else if (TupleOrValueTupleConstructor.IsTupleOrValueTupleConstructor(constructor)) + { + if (AllAreKnown(arguments, out var argumentSerializers)) + { + return TupleOrValueTupleSerializer.Create(constructor.DeclaringType, argumentSerializers); + } + } + else + { + return CreateNewExpressionSerializer(node, node, bindings: null); + } + + return null; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitNewArray.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitNewArray.cs new file mode 100644 index 00000000000..30420ae5121 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitNewArray.cs @@ -0,0 +1,145 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitNewArray(NewArrayExpression node) + { + DeduceNewArraySerializers(); + base.VisitNewArray(node); + DeduceNewArraySerializers(); + + return node; + + void DeduceNewArraySerializers() + { + switch (node.NodeType) + { + case ExpressionType.NewArrayBounds: + DeduceNewArrayBoundsSerializers(); + break; + + case ExpressionType.NewArrayInit: + DeduceNewArrayInitSerializers(); + break; + } + } + + void DeduceNewArrayBoundsSerializers() + { + throw new NotImplementedException(); + } + + void DeduceNewArrayInitSerializers() + { + var itemExpressions = node.Expressions; + + if (AnyIsNotKnown(itemExpressions) && IsKnown(node, out var arraySerializer)) + { + if (arraySerializer is IFixedSizeArraySerializer fixedSizeArraySerializer) + { + for (var i = 0; i < itemExpressions.Count; i++) + { + var itemExpression = itemExpressions[i]; + if (IsNotKnown(itemExpression)) + { + var itemSerializer = fixedSizeArraySerializer.GetItemSerializer(i); + AddKnownSerializer(itemExpression, itemSerializer); + } + } + } + else + { + var itemSerializer = arraySerializer.GetItemSerializer(); + foreach (var itemExpression in itemExpressions) + { + if (IsNotKnown(itemExpression)) + { + AddKnownSerializer(itemExpression, itemSerializer); + } + } + } + } + + if (AnyIsNotKnown(itemExpressions) && AnyIsKnown(itemExpressions, out var knownItemSerializer)) + { + var firstItemType = itemExpressions.First().Type; + if (itemExpressions.All(e => e.Type == firstItemType)) + { + foreach (var itemExpression in itemExpressions) + { + if (IsNotKnown(itemExpression)) + { + AddKnownSerializer(itemExpression, knownItemSerializer); + } + } + } + } + + if (IsNotKnown(node)) + { + if (AllAreKnown(itemExpressions, out var itemSerializers)) + { + if (AllItemSerializersAreEqual(itemSerializers, out var itemSerializer)) + { + arraySerializer = ArraySerializer.Create(itemSerializer); + } + else + { + var itemType = node.Type.GetElementType(); + arraySerializer = FixedSizeArraySerializer.Create(itemType, itemSerializers); + } + AddKnownSerializer(node, arraySerializer); + } + } + + static bool AllItemSerializersAreEqual(IReadOnlyList itemSerializers, out IBsonSerializer itemSerializer) + { + switch (itemSerializers.Count) + { + case 0: + itemSerializer = null; + return false; + case 1: + itemSerializer = itemSerializers[0]; + return true; + default: + var firstItemSerializer = itemSerializers[0]; + if (itemSerializers.Skip(1).All(s => s.Equals(firstItemSerializer))) + { + itemSerializer = firstItemSerializer; + return true; + } + else + { + itemSerializer = null; + return false; + } + } + } + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitTypeBinary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitTypeBinary.cs new file mode 100644 index 00000000000..d2df211c5b1 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitTypeBinary.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitTypeBinary(TypeBinaryExpression node) + { + base.VisitTypeBinary(node); + + DeduceBooleanSerializer(node); + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitUnary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitUnary.cs new file mode 100644 index 00000000000..88762da3c66 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitUnary.cs @@ -0,0 +1,312 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor +{ + protected override Expression VisitUnary(UnaryExpression node) + { + var unaryOperator = node.NodeType; + var operand = node.Operand; + + base.VisitUnary(node); + + switch (unaryOperator) + { + case ExpressionType.Negate: + DeduceNegateSerializers(); // TODO: fold into general case? + break; + + default: + DeduceUnaryOperatorSerializers(); + break; + } + + return node; + + void DeduceNegateSerializers() + { + DeduceSerializers(node, operand); + } + + void DeduceUnaryOperatorSerializers() + { + if (IsNotKnown(node)) + { + var resultSerializer = unaryOperator switch + { + ExpressionType.ArrayLength => Int32Serializer.Instance, + ExpressionType.Convert or ExpressionType.TypeAs => GetConvertSerializer(), + ExpressionType.Not => StandardSerializers.GetSerializer(node.Type), + ExpressionType.Quote => IgnoreNodeSerializer.Create(node.Type), + _ => null + }; + + if (resultSerializer != null) + { + AddKnownSerializer(node, resultSerializer); + } + } + } + + IBsonSerializer GetConvertSerializer() + { + var sourceType = operand.Type; + var targetType = node.Type; + + // handle double conversion (BsonValue)(object)x + if (targetType == typeof(BsonValue) && + operand is UnaryExpression unarySourceExpression && + unarySourceExpression.NodeType == ExpressionType.Convert && + unarySourceExpression.Type == typeof(object)) + { + operand = unarySourceExpression.Operand; + } + + if (IsKnown(operand, out var sourceSerializer)) + { + return GetTargetSerializer(node, sourceType, targetType, sourceSerializer); + } + + return null; + + static IBsonSerializer GetTargetSerializer(UnaryExpression node, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (targetType == sourceType) + { + return sourceSerializer; + } + + // handle conversionsn to BsonValue before any others + if (targetType == typeof(BsonValue)) + { + return GetConvertToBsonValueSerializer(node, sourceSerializer); + } + + // from Nullable must be handled before to Nullable + if (IsConvertFromNullableType(sourceType)) + { + return GetConvertFromNullableTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToNullableType(targetType, out var valueType)) + { + var valueSerializer = valueType == targetType ? sourceSerializer : GetTargetSerializer(node, sourceType, valueType, sourceSerializer); + return valueSerializer != null ? GetConvertToNullableTypeSerializer(node, sourceType, targetType, valueSerializer) : null; + } + + // from here on we know there are no longer any Nullable types involved + + if (sourceType == typeof(BsonValue)) + { + return GetConvertFromBsonValueSerializer(node, targetType); + } + + if (IsConvertEnumToUnderlyingType(sourceType, targetType)) + { + return GetConvertEnumToUnderlyingTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertUnderlyingTypeToEnum(sourceType, targetType)) + { + return GetConvertUnderlyingTypeToEnumSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertEnumToEnum(sourceType, targetType)) + { + return GetConvertEnumToEnumSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToBaseType(sourceType, targetType)) + { + return GetConvertToBaseTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToDerivedType(sourceType, targetType)) + { + return GetConvertToDerivedTypeSerializer(node, targetType, sourceSerializer); + } + + if (IsNumericConversion(sourceType, targetType)) + { + return GetNumericConversionSerializer(node, sourceType, targetType, sourceSerializer); + } + + return null; + } + + static IBsonSerializer GetConvertFromBsonValueSerializer(UnaryExpression expression, Type targetType) + { + return targetType switch + { + _ when targetType == typeof(string) => StringSerializer.Instance, + _ => throw new ExpressionNotSupportedException(expression, because: $"conversion from BsonValue to {targetType} is not supported") + }; + } + + static IBsonSerializer GetConvertToBaseTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var derivedTypeSerializer = sourceSerializer; + return DowncastingSerializer.Create(targetType, sourceType, derivedTypeSerializer); + } + + static IBsonSerializer GetConvertToDerivedTypeSerializer(UnaryExpression expression, Type targetType, IBsonSerializer sourceSerializer) + { + var derivedTypeSerializer = sourceSerializer.GetDerivedTypeSerializer(targetType); + return derivedTypeSerializer; + } + + static IBsonSerializer GetConvertToBsonValueSerializer(UnaryExpression expression, IBsonSerializer sourceSerializer) + { + return BsonValueSerializer.Instance; + } + + static IBsonSerializer GetConvertEnumToEnumSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (!sourceType.IsEnum) + { + throw new ExpressionNotSupportedException(expression, because: "source type is not an enum"); + } + if (!targetType.IsEnum) + { + throw new ExpressionNotSupportedException(expression, because: "target type is not an enum"); + } + + // if (sourceSerializer is IHasRepresentationSerializer sourceHasRepresentationSerializer && + // !SerializationHelper.IsNumericRepresentation(sourceHasRepresentationSerializer.Representation)) + // { + // throw new ExpressionNotSupportedException(expression, because: "source enum is not represented as a number"); + // } + + return EnumSerializer.Create(targetType); + } + + static IBsonSerializer GetConvertEnumToUnderlyingTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var enumSerializer = sourceSerializer; + return AsEnumUnderlyingTypeSerializer.Create(enumSerializer); + } + + static IBsonSerializer GetConvertFromNullableTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (sourceSerializer is not INullableSerializer nullableSourceSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"sourceSerializer type {sourceSerializer.GetType()} does not implement nameof(INullableSerializer)"); + } + + var sourceValueSerializer = nullableSourceSerializer.ValueSerializer; + var sourceValueType = sourceValueSerializer.ValueType; + + if (targetType.IsNullable(out var targetValueType)) + { + var targetValueSerializer = GetTargetSerializer(expression, sourceValueType, targetValueType, sourceValueSerializer); + return NullableSerializer.Create(targetValueSerializer); + } + else + { + return GetTargetSerializer(expression, sourceValueType, targetType, sourceValueSerializer); + } + } + + static IBsonSerializer GetConvertToNullableTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (sourceType.IsNullable()) + { + throw new ExpressionNotSupportedException(expression, because: "sourceType is already nullable"); + } + + if (targetType.IsNullable()) + { + return NullableSerializer.Create(sourceSerializer); + } + + throw new ExpressionNotSupportedException(expression, because: "targetType is not nullable"); + } + + static IBsonSerializer GetConvertUnderlyingTypeToEnumSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + IBsonSerializer targetSerializer; + if (sourceSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + { + targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; + } + else + { + targetSerializer = EnumSerializer.Create(targetType); + } + + return targetSerializer; + } + + static IBsonSerializer GetNumericConversionSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + return NumericConversionSerializer.Create(sourceType, targetType, sourceSerializer); + } + + static bool IsConvertEnumToEnum(Type sourceType, Type targetType) + { + return sourceType.IsEnum && targetType.IsEnum; + } + + static bool IsConvertEnumToUnderlyingType(Type sourceType, Type targetType) + { + return + sourceType.IsEnum(out var underlyingType) && + targetType == underlyingType; + } + + static bool IsConvertFromNullableType(Type sourceType) + { + return sourceType.IsNullable(); + } + + static bool IsConvertToBaseType(Type sourceType, Type targetType) + { + return sourceType.IsSubclassOf(targetType) || sourceType.Implements(targetType); + } + + static bool IsConvertToDerivedType(Type sourceType, Type targetType) + { + return sourceType.IsAssignableFrom(targetType); // targetType either derives from sourceType or implements sourceType interface + } + + static bool IsConvertToNullableType(Type targetType, out Type valueType) + { + return targetType.IsNullable(out valueType); + } + + static bool IsConvertUnderlyingTypeToEnum(Type sourceType, Type targetType) + { + return + targetType.IsEnum(out var underlyingType) && + sourceType == underlyingType; + } + + static bool IsNumericConversion(Type sourceType, Type targetType) + { + return sourceType.IsNumeric() && targetType.IsNumeric(); + } + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitor.cs new file mode 100644 index 00000000000..5b7e4dde971 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerFinderVisitor.cs @@ -0,0 +1,74 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal partial class KnownSerializerFinderVisitor : ExpressionVisitor +{ + private bool _isMakingProgress = true; + private readonly KnownSerializerMap _knownSerializers; + private int _oldKnownSerializersCount = 0; + private int _pass = 0; + private readonly ExpressionTranslationOptions _translationOptions; + private bool _useDefaultSerializerForConstants = false; // make as much progress as possible before setting this to true + + public KnownSerializerFinderVisitor(ExpressionTranslationOptions translationOptions, KnownSerializerMap knownSerializers) + { + _knownSerializers = knownSerializers; + _translationOptions = translationOptions; + } + + public int Pass => _pass; + + public bool IsMakingProgress => _isMakingProgress; + + public void EndPass() + { + var newKnownSerializersCount = _knownSerializers.Count; + if (newKnownSerializersCount == _oldKnownSerializersCount) + { + if (_useDefaultSerializerForConstants) + { + _isMakingProgress = false; + } + else + { + _useDefaultSerializerForConstants = true; + } + } + } + + public void StartPass() + { + _oldKnownSerializersCount = _knownSerializers.Count; + } + + public override Expression Visit(Expression node) + { + if (IsKnown(node, out var knownSerializer)) + { + if (knownSerializer is IIgnoreSubtreeSerializer or IUnknowableSerializer) + { + return node; // don't visit subtree + } + } + + return base.Visit(node); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerMap.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerMap.cs new file mode 100644 index 00000000000..062239db1e4 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/KnownSerializerMap.cs @@ -0,0 +1,105 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal class KnownSerializerMap +{ + private readonly Dictionary _map = new(); + + public int Count => _map.Count; + + public void AddSerializer(Expression node, IBsonSerializer serializer) + { + if (serializer.ValueType != node.Type && + node.Type.IsNullable(out var nodeNonNullableType) && + serializer.ValueType.IsNullable(out var serializerNonNullableType) && + serializer is INullableSerializer nullableSerializer) + { + if (nodeNonNullableType.IsEnum(out var targetEnumUnderlyingType) && targetEnumUnderlyingType == serializerNonNullableType) + { + var enumType = nodeNonNullableType; + var underlyingTypeSerializer = nullableSerializer.ValueSerializer; + var enumSerializer = AsUnderlyingTypeEnumSerializer.Create(enumType, underlyingTypeSerializer); + serializer = NullableSerializer.Create(enumSerializer); + } + else if (serializerNonNullableType.IsEnum(out var serializerUnderlyingType) && serializerUnderlyingType == nodeNonNullableType) + { + var enumSerializer = nullableSerializer.ValueSerializer; + var underlyingTypeSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); + serializer = NullableSerializer.Create(underlyingTypeSerializer); + } + } + + if (serializer.ValueType != node.Type) + { + if (node.Type.IsAssignableFrom(serializer.ValueType)) + { + serializer = DowncastingSerializer.Create(baseType: node.Type, derivedType: serializer.ValueType, derivedTypeSerializer: serializer); + } + else if (serializer.ValueType.IsAssignableFrom(node.Type)) + { + serializer = UpcastingSerializer.Create(baseType: serializer.ValueType, derivedType: node.Type, baseTypeSerializer: serializer); + } + else + { + throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expression value type {node.Type}", nameof(serializer)); + } + } + + if (_map.TryGetValue(node, out var existingSerializer)) + { + throw new ExpressionNotSupportedException( + node, + because: $"there are duplicate known serializers for expression '{node}': {serializer.GetType()} and {existingSerializer.GetType()}"); + } + + _map.Add(node, serializer); + } + + public IBsonSerializer GetSerializer(Expression node) + { + if (_map.TryGetValue(node, out var knownSerializer)) + { + return knownSerializer; + } + + throw new ExpressionNotSupportedException(node, because: "unable to determine which serializer to use"); + } + + public bool IsNotKnown(Expression node) + { + return !IsKnown(node); + } + + public bool IsKnown(Expression node) + { + return node != null && _map.ContainsKey(node); + } + + public bool IsKnown(Expression node, out IBsonSerializer serializer) + { + serializer = null; + return node != null && _map.TryGetValue(node, out serializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/UnknownSerializerFinder.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/UnknownSerializerFinder.cs new file mode 100644 index 00000000000..61f6a4a1394 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/KnownSerializerFinders/UnknownSerializerFinder.cs @@ -0,0 +1,63 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor; + +namespace MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; + +internal class UnknownSerializerFinder : ExpressionVisitor +{ + public static Expression FindExpressionWithUnknownSerializer(Expression expression, KnownSerializerMap knownSerializers) + { + var visitor = new UnknownSerializerFinder(knownSerializers); + visitor.Visit(expression); + return visitor._expressionWithUnknownSerializer; + } + + private Expression _expressionWithUnknownSerializer = null; + private readonly KnownSerializerMap _knownSerializers; + + public UnknownSerializerFinder(KnownSerializerMap knownSerializers) + { + _knownSerializers = knownSerializers; + } + + public Expression ExpressionWithUnknownSerialier => _expressionWithUnknownSerializer; + + public override Expression Visit(Expression node) + { + if (_knownSerializers.IsKnown(node, out var knownSerializer)) + { + if (knownSerializer is IIgnoreSubtreeSerializer or IUnknowableSerializer) + { + return node; // don't visit subtree + } + } + + base.Visit(node); + + if (_expressionWithUnknownSerializer == null && + node != null && + _knownSerializers.IsNotKnown(node)) + { + _expressionWithUnknownSerializer = node; + } + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs new file mode 100644 index 00000000000..de23e6e9d5e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs @@ -0,0 +1,24 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal static class BsonTypeExtensions +{ + public static bool IsNumeric(this BsonType bsonType) + => bsonType is BsonType.Decimal128 or BsonType.Double or BsonType.Int32 or BsonType.Int64; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs index 2b5c4a3a012..a2eed8cafd8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs @@ -35,7 +35,8 @@ public static LambdaExpression UnquoteLambdaIfQueryableMethod(MethodInfo method, Ensure.IsNotNull(method, nameof(method)); Ensure.IsNotNull(expression, nameof(expression)); - if (method.DeclaringType == typeof(Queryable)) + var declaringType = method.DeclaringType; + if (declaringType == typeof(Queryable) || declaringType == typeof(MongoQueryable)) { return UnquoteLambda(expression); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs new file mode 100644 index 00000000000..8a0bd4a7465 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs @@ -0,0 +1,113 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal static class IBsonSerializerExtensions +{ + public static bool CanBeAssignedTo(this IBsonSerializer sourceSerializer, IBsonSerializer targetSerializer) + { + if (sourceSerializer.Equals(targetSerializer)) + { + return true; + } + + if (sourceSerializer.ValueType.IsNumeric() && + targetSerializer.ValueType.IsNumeric() && + sourceSerializer.HasNumericRepresentation() && + targetSerializer.HasNumericRepresentation()) + { + return true; + } + + if (targetSerializer.ValueType.IsAssignableFrom(sourceSerializer.ValueType)) + { + return true; + } + + return false; + } + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer) + => ArraySerializerHelper.GetItemSerializer(serializer); + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer, int index) + { + if (serializer is IFixedSizeArraySerializer fixedSizeArraySerializer) + { + return fixedSizeArraySerializer.GetItemSerializer(index); + } + else + { + return serializer.GetItemSerializer(); + } + } + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer, Expression indexExpression, Expression containingExpression) + { + if (serializer is IFixedSizeArraySerializer fixedSizeArraySerializer) + { + var index = indexExpression.GetConstantValue(containingExpression); + return fixedSizeArraySerializer.GetItemSerializer(index); + } + else + { + return serializer.GetItemSerializer(); + } + } + + public static bool HasNumericRepresentation(this IBsonSerializer serializer) + { + return + serializer is IHasRepresentationSerializer hasRepresentationSerializer && + hasRepresentationSerializer.Representation.IsNumeric(); + } + + public static bool IsKeyValuePairSerializer( + this IBsonSerializer serializer, + out string keyElementName, + out string valueElementName, + out IBsonSerializer keySerializer, + out IBsonSerializer valueSerializer) + { + // TODO: add properties to IKeyValuePairSerializer to let us extract the needed information + // note: we can only verify the existence of "Key" and "Value" properties, but can't verify there are no others + if (serializer.ValueType is var valueType && + valueType.IsConstructedGenericType && + valueType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>) && + serializer is IBsonDocumentSerializer documentSerializer && + documentSerializer.TryGetMemberSerializationInfo("Key", out var keySerializationInfo) && + documentSerializer.TryGetMemberSerializationInfo("Value", out var valueSerializationInfo)) + { + keyElementName = keySerializationInfo.ElementName; + valueElementName = valueSerializationInfo.ElementName; + keySerializer = keySerializationInfo.Serializer; + valueSerializer = valueSerializationInfo.Serializer; + return true; + } + + keyElementName = null; + valueElementName = null; + keySerializer = null; + valueSerializer = null; + return false; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs index d7a61ab07d5..89310a45b05 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System.Collections.Generic; using System.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Misc @@ -38,6 +39,24 @@ public static bool Is(this MethodInfo method, MethodInfo comparand) return false; } + public static bool IsOneOf(this MethodInfo method, HashSet comparands) + { + if (comparands != null) + { + if (method.IsGenericMethod) + { + var methodDefinition = method.GetGenericMethodDefinition(); + return comparands.Contains(methodDefinition); + } + else + { + return comparands.Contains(method); + } + } + + return false; + } + public static bool IsOneOf(this MethodInfo method, MethodInfo comparand1, MethodInfo comparand2) { return method.Is(comparand1) || method.Is(comparand2); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs index 0b34b0bd7cf..c0c4a473eb2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs @@ -97,7 +97,7 @@ public static BsonType GetRepresentation(IBsonSerializer serializer) return GetRepresentation(downcastingSerializer.DerivedSerializer); } - if (serializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (serializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { return GetRepresentation(enumUnderlyingTypeSerializer.EnumSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs index ccb8f699740..0384dc7945a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs @@ -17,6 +17,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using MongoDB.Bson; namespace MongoDB.Driver.Linq.Linq3Implementation.Misc { @@ -136,6 +137,49 @@ public static bool ImplementsIList(this Type type, out Type itemType) return false; } + public static bool ImplementsIOrderedEnumerable(this Type type, out Type itemType) + { + if (TryGetIOrderedEnumerableGenericInterface(type, out var iOrderedEnumerableType)) + { + itemType = iOrderedEnumerableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + + public static bool ImplementsIOrderedQueryable(this Type type, out Type itemType) + { + if (TryGetIOrderedQueryableGenericInterface(type, out var iorderedQueryableType)) + { + itemType = iorderedQueryableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + + public static bool ImplementsIQueryable(this Type type, out Type itemType) + { + if (TryGetIQueryableGenericInterface(type, out var iqueryableType)) + { + itemType = iqueryableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + + public static bool ImplementsIQueryableOf(this Type type, Type itemType) + { + return + ImplementsIEnumerable(type, out var actualItemType) && + actualItemType == itemType; + } + public static bool Is(this Type type, Type comparand) { if (type == comparand) @@ -175,6 +219,16 @@ public static bool IsArray(this Type type, out Type itemType) return false; } + public static bool IsBoolean(this Type type) + { + return type == typeof(bool); + } + + public static bool IsBooleanOrNullableBoolean(this Type type) + { + return type == typeof(bool) || type.IsNullable(out var valueType) && IsBoolean(valueType); + } + public static bool IsEnum(this Type type, out Type underlyingType) { if (type.IsEnum) @@ -248,6 +302,28 @@ public static bool IsNullableOf(this Type type, Type valueType) return type.IsNullable(out var nullableValueType) && nullableValueType == valueType; } + public static bool IsNumeric(this Type type) + { + return Type.GetTypeCode(type) is + TypeCode.Byte or + TypeCode.Char or // TODO: should we really treat char as numeric? + TypeCode.Decimal or + TypeCode.Double or + TypeCode.Int16 or + TypeCode.Int32 or + TypeCode.Int64 or + TypeCode.SByte or + TypeCode.Single or + TypeCode.UInt16 or + TypeCode.UInt32 or + TypeCode.UInt64; + } + + public static bool IsNumericOrNullableNumeric(this Type type) + { + return IsNumeric(type) || type.IsNullable(out var valueType) && IsNumeric(valueType); + } + public static bool IsSameAsOrNullableOf(this Type type, Type valueType) { return type == valueType || type.IsNullableOf(valueType); @@ -332,5 +408,69 @@ public static bool TryGetIListGenericInterface(this Type type, out Type ilistGen ilistGenericInterface = null; return false; } + + public static bool TryGetIOrderedEnumerableGenericInterface(this Type type, out Type iOrderedEnumerableGenericInterface) + { + if (type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) + { + iOrderedEnumerableGenericInterface = type; + return true; + } + + foreach (var interfaceType in type.GetInterfaces()) + { + if (interfaceType.IsConstructedGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) + { + iOrderedEnumerableGenericInterface = interfaceType; + return true; + } + } + + iOrderedEnumerableGenericInterface = null; + return false; + } + + public static bool TryGetIOrderedQueryableGenericInterface(this Type type, out Type iorderedQueryableGenericInterface) + { + if (type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(IOrderedQueryable<>)) + { + iorderedQueryableGenericInterface = type; + return true; + } + + foreach (var interfaceType in type.GetInterfaces()) + { + if (interfaceType.IsConstructedGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IOrderedQueryable<>)) + { + iorderedQueryableGenericInterface = interfaceType; + return true; + } + } + + iorderedQueryableGenericInterface = null; + return false; + } + + public static bool TryGetIQueryableGenericInterface(this Type type, out Type iqueryableGenericInterface) + { + if (type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(IQueryable<>)) + { + iqueryableGenericInterface = type; + return true; + } + + foreach (var interfaceType in type.GetInterfaces()) + { + if (interfaceType.IsConstructedGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IQueryable<>)) + { + iqueryableGenericInterface = interfaceType; + return true; + } + } + + iqueryableGenericInterface = null; + return false; + } + } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs index fe96bacae36..a868f3a8508 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs @@ -41,7 +41,7 @@ internal class MongoQuery : MongoQuery, IOrderedQue public MongoQuery(MongoQueryProvider provider) { _provider = provider; - _expression = Expression.Constant(this); + _expression = Expression.Constant(this, typeof(IQueryable<>).MakeGenericType(typeof(TDocument))); } public MongoQuery(MongoQueryProvider provider, Expression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs index fd12fccd4c9..1e4e3b582c4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs @@ -22,14 +22,20 @@ internal static class BsonDocumentMethod { // private static fields private static readonly MethodInfo __addWithNameAndValue; + private static readonly MethodInfo __getItemWithIndex; + private static readonly MethodInfo __getItemWithName; // static constructor static BsonDocumentMethod() { __addWithNameAndValue = ReflectionInfo.Method((BsonDocument document, string name, BsonValue value) => document.Add(name, value)); + __getItemWithIndex = ReflectionInfo.Method((BsonDocument document, int index) => document[index]); + __getItemWithName = ReflectionInfo.Method((BsonDocument document, string name) => document[name]); } // public static properties public static MethodInfo AddWithNameAndValue => __addWithNameAndValue; + public static MethodInfo GetItemWithIndex => __getItemWithIndex; + public static MethodInfo GetItemWithName => __getItemWithName; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs index 0ae3e99ca4a..f0f01b79052 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs @@ -30,6 +30,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __aggregateWithSeedAndFunc; private static readonly MethodInfo __aggregateWithSeedFuncAndResultSelector; private static readonly MethodInfo __all; + private static readonly MethodInfo __allWithPredicate; private static readonly MethodInfo __any; private static readonly MethodInfo __anyWithPredicate; private static readonly MethodInfo __append; @@ -73,7 +74,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __firstOrDefault; private static readonly MethodInfo __firstOrDefaultWithPredicate; private static readonly MethodInfo __firstWithPredicate; - private static readonly MethodInfo __groupBy; + private static readonly MethodInfo __groupByWithKeySelector; private static readonly MethodInfo __groupByWithKeySelectorAndElementSelector; private static readonly MethodInfo __groupByWithKeySelectorAndResultSelector; private static readonly MethodInfo __groupByWithKeySelectorElementSelectorAndResultSelector; @@ -144,7 +145,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __repeat; private static readonly MethodInfo __reverse; private static readonly MethodInfo __select; - private static readonly MethodInfo __selectMany; + private static readonly MethodInfo __selectManyWithSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorAndResultSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorTakingIndexAndResultSelector; private static readonly MethodInfo __selectManyWithSelectorTakingIndex; @@ -197,6 +198,7 @@ static EnumerableMethod() __aggregateWithSeedAndFunc = ReflectionInfo.Method((IEnumerable source, object seed, Func func) => source.Aggregate(seed, func)); __aggregateWithSeedFuncAndResultSelector = ReflectionInfo.Method((IEnumerable source, object seed, Func func, Func resultSelector) => source.Aggregate(seed, func, resultSelector)); __all = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.All(predicate)); + __allWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.All(predicate)); __any = ReflectionInfo.Method((IEnumerable source) => source.Any()); __anyWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Any(predicate)); __append = ReflectionInfo.Method((IEnumerable source, object element) => source.Append(element)); @@ -240,7 +242,7 @@ static EnumerableMethod() __firstOrDefault = ReflectionInfo.Method((IEnumerable source) => source.FirstOrDefault()); __firstOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.FirstOrDefault(predicate)); __firstWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.First(predicate)); - __groupBy = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.GroupBy(keySelector)); + __groupByWithKeySelector = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.GroupBy(keySelector)); __groupByWithKeySelectorAndElementSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func elementSelector) => source.GroupBy(keySelector, elementSelector)); __groupByWithKeySelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func resultSelector) => source.GroupBy(keySelector, resultSelector)); __groupByWithKeySelectorElementSelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func elementSelector, Func, object> resultSelector) => source.GroupBy(keySelector, elementSelector, resultSelector)); @@ -311,7 +313,7 @@ static EnumerableMethod() __repeat = ReflectionInfo.Method((object element, int count) => Enumerable.Repeat(element, count)); __reverse = ReflectionInfo.Method((IEnumerable source) => source.Reverse()); __select = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Select(selector)); - __selectMany = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); + __selectManyWithSelector = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); __selectManyWithCollectionSelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func> collectionSelector, Func resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithCollectionSelectorTakingIndexAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func> collectionSelector, Func resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); @@ -363,6 +365,7 @@ static EnumerableMethod() public static MethodInfo AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; public static MethodInfo AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; public static MethodInfo All => __all; + public static MethodInfo AllWithPredicate => __allWithPredicate; public static MethodInfo Any => __any; public static MethodInfo AnyWithPredicate => __anyWithPredicate; public static MethodInfo Append => __append; @@ -406,7 +409,7 @@ static EnumerableMethod() public static MethodInfo FirstOrDefault => __firstOrDefault; public static MethodInfo FirstOrDefaultWithPredicate => __firstOrDefaultWithPredicate; public static MethodInfo FirstWithPredicate => __firstWithPredicate; - public static MethodInfo GroupBy => __groupBy; + public static MethodInfo GroupByWithKeySelector => __groupByWithKeySelector; public static MethodInfo GroupByWithKeySelectorAndElementSelector => __groupByWithKeySelectorAndElementSelector; public static MethodInfo GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; public static MethodInfo GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; @@ -477,7 +480,7 @@ static EnumerableMethod() public static MethodInfo Repeat => __repeat; public static MethodInfo Reverse => __reverse; public static MethodInfo Select => __select; - public static MethodInfo SelectMany => __selectMany; + public static MethodInfo SelectManyWithSelector => __selectManyWithSelector; public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs new file mode 100644 index 00000000000..3b062091f53 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class HashSetConstructor + { + public static bool IsWithCollectionConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(HashSet<>) && + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var itemType) && + itemType == declaringType.GenericTypeArguments[0]; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs new file mode 100644 index 00000000000..5ffa9e126f2 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs @@ -0,0 +1,44 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class KeyValuePairConstructor + { + public static bool IsWithKeyAndValueConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>) && + declaringType.GetGenericArguments() is var typeParameters && + typeParameters[0] is var keyType && + typeParameters[1] is var valueType && + parameters.Length == 2 && + parameters[0].ParameterType == keyType && + parameters[1].ParameterType == valueType; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs new file mode 100644 index 00000000000..21c731c7ceb --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class ListConstructor + { + public static bool IsWithCollectionConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(List<>) && + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var itemType) && + itemType == declaringType.GenericTypeArguments[0]; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs index c10550024c3..79aa21966e9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs @@ -65,6 +65,46 @@ internal static class MongoEnumerableMethod private static readonly MethodInfo __percentileNullableSingleWithSelector; private static readonly MethodInfo __percentileSingle; private static readonly MethodInfo __percentileSingleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationDecimal; + private static readonly MethodInfo __standardDeviationPopulationDecimalWithSelector; + private static readonly MethodInfo __standardDeviationPopulationDouble; + private static readonly MethodInfo __standardDeviationPopulationDoubleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationInt32; + private static readonly MethodInfo __standardDeviationPopulationInt32WithSelector; + private static readonly MethodInfo __standardDeviationPopulationInt64; + private static readonly MethodInfo __standardDeviationPopulationInt64WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableDecimal; + private static readonly MethodInfo __standardDeviationPopulationNullableDecimalWithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableDouble; + private static readonly MethodInfo __standardDeviationPopulationNullableDoubleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableInt32; + private static readonly MethodInfo __standardDeviationPopulationNullableInt32WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableInt64; + private static readonly MethodInfo __standardDeviationPopulationNullableInt64WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableSingle; + private static readonly MethodInfo __standardDeviationPopulationNullableSingleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationSingle; + private static readonly MethodInfo __standardDeviationPopulationSingleWithSelector; + private static readonly MethodInfo __standardDeviationSampleDecimal; + private static readonly MethodInfo __standardDeviationSampleDecimalWithSelector; + private static readonly MethodInfo __standardDeviationSampleDouble; + private static readonly MethodInfo __standardDeviationSampleDoubleWithSelector; + private static readonly MethodInfo __standardDeviationSampleInt32; + private static readonly MethodInfo __standardDeviationSampleInt32WithSelector; + private static readonly MethodInfo __standardDeviationSampleInt64; + private static readonly MethodInfo __standardDeviationSampleInt64WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableDecimal; + private static readonly MethodInfo __standardDeviationSampleNullableDecimalWithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableDouble; + private static readonly MethodInfo __standardDeviationSampleNullableDoubleWithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableInt32; + private static readonly MethodInfo __standardDeviationSampleNullableInt32WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableInt64; + private static readonly MethodInfo __standardDeviationSampleNullableInt64WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableSingle; + private static readonly MethodInfo __standardDeviationSampleNullableSingleWithSelector; + private static readonly MethodInfo __standardDeviationSampleSingle; + private static readonly MethodInfo __standardDeviationSampleSingleWithSelector; private static readonly MethodInfo __whereWithLimit; // static constructor @@ -113,6 +153,46 @@ static MongoEnumerableMethod() __percentileNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); __percentileSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); __percentileSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __standardDeviationPopulationDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationSampleDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); __whereWithLimit = ReflectionInfo.Method((IEnumerable source, Func predicate, int limit) => source.Where(predicate, limit)); } @@ -160,6 +240,46 @@ static MongoEnumerableMethod() public static MethodInfo PercentileNullableSingleWithSelector => __percentileNullableSingleWithSelector; public static MethodInfo PercentileSingle => __percentileSingle; public static MethodInfo PercentileSingleWithSelector => __percentileSingleWithSelector; + public static MethodInfo StandardDeviationPopulationDecimal => __standardDeviationPopulationDecimal; + public static MethodInfo StandardDeviationPopulationDecimalWithSelector => __standardDeviationPopulationDecimalWithSelector; + public static MethodInfo StandardDeviationPopulationDouble => __standardDeviationPopulationDouble; + public static MethodInfo StandardDeviationPopulationDoubleWithSelector => __standardDeviationPopulationDoubleWithSelector; + public static MethodInfo StandardDeviationPopulationInt32 => __standardDeviationPopulationInt32; + public static MethodInfo StandardDeviationPopulationInt32WithSelector => __standardDeviationPopulationInt32WithSelector; + public static MethodInfo StandardDeviationPopulationInt64 => __standardDeviationPopulationInt64; + public static MethodInfo StandardDeviationPopulationInt64WithSelector => __standardDeviationPopulationInt64WithSelector; + public static MethodInfo StandardDeviationPopulationNullableDecimal => __standardDeviationPopulationNullableDecimal; + public static MethodInfo StandardDeviationPopulationNullableDecimalWithSelector => __standardDeviationPopulationNullableDecimalWithSelector; + public static MethodInfo StandardDeviationPopulationNullableDouble => __standardDeviationPopulationNullableDouble; + public static MethodInfo StandardDeviationPopulationNullableDoubleWithSelector => __standardDeviationPopulationNullableDoubleWithSelector; + public static MethodInfo StandardDeviationPopulationNullableInt32 => __standardDeviationPopulationNullableInt32; + public static MethodInfo StandardDeviationPopulationNullableInt32WithSelector => __standardDeviationPopulationNullableInt32WithSelector; + public static MethodInfo StandardDeviationPopulationNullableInt64 => __standardDeviationPopulationNullableInt64; + public static MethodInfo StandardDeviationPopulationNullableInt64WithSelector => __standardDeviationPopulationNullableInt64WithSelector; + public static MethodInfo StandardDeviationPopulationNullableSingle => __standardDeviationPopulationNullableSingle; + public static MethodInfo StandardDeviationPopulationNullableSingleWithSelector => __standardDeviationPopulationNullableSingleWithSelector; + public static MethodInfo StandardDeviationPopulationSingle => __standardDeviationPopulationSingle; + public static MethodInfo StandardDeviationPopulationSingleWithSelector => __standardDeviationPopulationSingleWithSelector; + public static MethodInfo StandardDeviationSampleDecimal => __standardDeviationSampleDecimal; + public static MethodInfo StandardDeviationSampleDecimalWithSelector => __standardDeviationSampleDecimalWithSelector; + public static MethodInfo StandardDeviationSampleDouble => __standardDeviationSampleDouble; + public static MethodInfo StandardDeviationSampleDoubleWithSelector => __standardDeviationSampleDoubleWithSelector; + public static MethodInfo StandardDeviationSampleInt32 => __standardDeviationSampleInt32; + public static MethodInfo StandardDeviationSampleInt32WithSelector => __standardDeviationSampleInt32WithSelector; + public static MethodInfo StandardDeviationSampleInt64 => __standardDeviationSampleInt64; + public static MethodInfo StandardDeviationSampleInt64WithSelector => __standardDeviationSampleInt64WithSelector; + public static MethodInfo StandardDeviationSampleNullableDecimal => __standardDeviationSampleNullableDecimal; + public static MethodInfo StandardDeviationSampleNullableDecimalWithSelector => __standardDeviationSampleNullableDecimalWithSelector; + public static MethodInfo StandardDeviationSampleNullableDouble => __standardDeviationSampleNullableDouble; + public static MethodInfo StandardDeviationSampleNullableDoubleWithSelector => __standardDeviationSampleNullableDoubleWithSelector; + public static MethodInfo StandardDeviationSampleNullableInt32 => __standardDeviationSampleNullableInt32; + public static MethodInfo StandardDeviationSampleNullableInt32WithSelector => __standardDeviationSampleNullableInt32WithSelector; + public static MethodInfo StandardDeviationSampleNullableInt64 => __standardDeviationSampleNullableInt64; + public static MethodInfo StandardDeviationSampleNullableInt64WithSelector => __standardDeviationSampleNullableInt64WithSelector; + public static MethodInfo StandardDeviationSampleNullableSingle => __standardDeviationSampleNullableSingle; + public static MethodInfo StandardDeviationSampleNullableSingleWithSelector => __standardDeviationSampleNullableSingleWithSelector; + public static MethodInfo StandardDeviationSampleSingle => __standardDeviationSampleSingle; + public static MethodInfo StandardDeviationSampleSingleWithSelector => __standardDeviationSampleSingleWithSelector; public static MethodInfo WhereWithLimit => __whereWithLimit; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs index 17896da1313..b027077ff7d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs @@ -28,6 +28,7 @@ internal static class QueryableMethod private static readonly MethodInfo __aggregateWithSeedAndFunc; private static readonly MethodInfo __aggregateWithSeedFuncAndResultSelector; private static readonly MethodInfo __all; + private static readonly MethodInfo __allWithPredicate; private static readonly MethodInfo __any; private static readonly MethodInfo __anyWithPredicate; private static readonly MethodInfo __append; @@ -90,7 +91,7 @@ internal static class QueryableMethod private static readonly MethodInfo __prepend; private static readonly MethodInfo __reverse; private static readonly MethodInfo __select; - private static readonly MethodInfo __selectMany; + private static readonly MethodInfo __selectManyWithSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorAndResultSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorTakingIndexAndResultSelector; private static readonly MethodInfo __selectManyWithSelectorTakingIndex; @@ -138,6 +139,7 @@ static QueryableMethod() __aggregateWithSeedAndFunc = ReflectionInfo.Method((IQueryable source, object seed, Expression> func) => source.Aggregate(seed, func)); __aggregateWithSeedFuncAndResultSelector = ReflectionInfo.Method((IQueryable source, object seed, Expression> func, Expression> selector) => source.Aggregate(seed, func, selector)); __all = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.All(predicate)); + __allWithPredicate = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.All(predicate)); __any = ReflectionInfo.Method((IQueryable source) => source.Any()); __anyWithPredicate = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.Any(predicate)); __append = ReflectionInfo.Method((IQueryable source, object element) => source.Append(element)); @@ -200,7 +202,7 @@ static QueryableMethod() __prepend = ReflectionInfo.Method((IQueryable source, object element) => source.Prepend(element)); __reverse = ReflectionInfo.Method((IQueryable source) => source.Reverse()); __select = ReflectionInfo.Method((IQueryable source, Expression> selector) => source.Select(selector)); - __selectMany = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); + __selectManyWithSelector = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); __selectManyWithCollectionSelectorAndResultSelector = ReflectionInfo.Method((IQueryable source, Expression>> collectionSelector, Expression> resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithCollectionSelectorTakingIndexAndResultSelector = ReflectionInfo.Method((IQueryable source, Expression>> collectionSelector, Expression> resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); @@ -247,6 +249,7 @@ static QueryableMethod() public static MethodInfo AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; public static MethodInfo AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; public static MethodInfo All => __all; + public static MethodInfo AllWithPredicate => __allWithPredicate; public static MethodInfo Any => __any; public static MethodInfo AnyWithPredicate => __anyWithPredicate; public static MethodInfo Append => __append; @@ -291,7 +294,7 @@ static QueryableMethod() public static MethodInfo GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; public static MethodInfo GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; public static MethodInfo GroupJoin => __groupJoin; - public static MethodInfo Interset => __intersect; + public static MethodInfo Intersect => __intersect; public static MethodInfo Join => __join; public static MethodInfo Last => __last; public static MethodInfo LastOrDefault => __lastOrDefault; @@ -309,7 +312,7 @@ static QueryableMethod() public static MethodInfo Prepend => __prepend; public static MethodInfo Reverse => __reverse; public static MethodInfo Select => __select; - public static MethodInfo SelectMany => __selectMany; + public static MethodInfo SelectManyWithSelector => __selectManyWithSelector; public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs index 279e382733c..c1c113a20c9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs @@ -44,6 +44,7 @@ internal static class StringMethod private static readonly MethodInfo __endsWithWithString; private static readonly MethodInfo __endsWithWithStringAndComparisonType; private static readonly MethodInfo __endsWithWithStringAndIgnoreCaseAndCulture; + private static readonly MethodInfo __equalsWithComparisonType; private static readonly MethodInfo __getChars; private static readonly MethodInfo __indexOfAny; private static readonly MethodInfo __indexOfAnyWithStartIndex; @@ -72,6 +73,7 @@ internal static class StringMethod private static readonly MethodInfo __startsWithWithString; private static readonly MethodInfo __startsWithWithStringAndComparisonType; private static readonly MethodInfo __startsWithWithStringAndIgnoreCaseAndCulture; + private static readonly MethodInfo __staticEqualsWithComparisonType; private static readonly MethodInfo __stringInWithEnumerable; private static readonly MethodInfo __stringInWithParams; private static readonly MethodInfo __stringNinWithEnumerable; @@ -124,6 +126,7 @@ static StringMethod() __endsWithWithString = ReflectionInfo.Method((string s, string value) => s.EndsWith(value)); __endsWithWithStringAndComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.EndsWith(value, comparisonType)); __endsWithWithStringAndIgnoreCaseAndCulture = ReflectionInfo.Method((string s, string value, bool ignoreCase, CultureInfo culture) => s.EndsWith(value, ignoreCase, culture)); + __equalsWithComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.Equals(value, comparisonType)); __getChars = ReflectionInfo.Method((string s, int index) => s[index]); __indexOfAny = ReflectionInfo.Method((string s, char[] anyOf) => s.IndexOfAny(anyOf)); __indexOfAnyWithStartIndex = ReflectionInfo.Method((string s, char[] anyOf, int startIndex) => s.IndexOfAny(anyOf, startIndex)); @@ -151,6 +154,7 @@ static StringMethod() __startsWithWithString = ReflectionInfo.Method((string s, string value) => s.StartsWith(value)); __startsWithWithStringAndComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.StartsWith(value, comparisonType)); __startsWithWithStringAndIgnoreCaseAndCulture = ReflectionInfo.Method((string s, string value, bool ignoreCase, CultureInfo culture) => s.StartsWith(value, ignoreCase, culture)); + __staticEqualsWithComparisonType = ReflectionInfo.Method((string a, string b, StringComparison comparisonType) => string.Equals(a, b, comparisonType)); __stringInWithEnumerable = ReflectionInfo.Method((string s, IEnumerable values) => s.StringIn(values)); __stringInWithParams = ReflectionInfo.Method((string s, StringOrRegularExpression[] values) => s.StringIn(values)); __stringNinWithEnumerable = ReflectionInfo.Method((string s, IEnumerable values) => s.StringNin(values)); @@ -192,6 +196,7 @@ static StringMethod() public static MethodInfo EndsWithWithString => __endsWithWithString; public static MethodInfo EndsWithWithStringAndComparisonType => __endsWithWithStringAndComparisonType; public static MethodInfo EndsWithWithStringAndIgnoreCaseAndCulture => __endsWithWithStringAndIgnoreCaseAndCulture; + public static MethodInfo EqualsWithComparisonType => __equalsWithComparisonType; public static MethodInfo GetChars => __getChars; public static MethodInfo IndexOfAny => __indexOfAny; public static MethodInfo IndexOfAnyWithStartIndex => __indexOfAnyWithStartIndex; @@ -220,6 +225,7 @@ static StringMethod() public static MethodInfo StartsWithWithString => __startsWithWithString; public static MethodInfo StartsWithWithStringAndComparisonType => __startsWithWithStringAndComparisonType; public static MethodInfo StartsWithWithStringAndIgnoreCaseAndCulture => __startsWithWithStringAndIgnoreCaseAndCulture; + public static MethodInfo StaticEqualsWithComparisonType => __staticEqualsWithComparisonType; public static MethodInfo StringInWithEnumerable => __stringInWithEnumerable; public static MethodInfo StringInWithParams => __stringInWithParams; public static MethodInfo StringNinWithEnumerable => __stringNinWithEnumerable; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs new file mode 100644 index 00000000000..16d5d36b504 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs @@ -0,0 +1,31 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class TupleOrValueTupleConstructor +{ + public static bool IsTupleOrValueTupleConstructor(ConstructorInfo constructor) + { + return + constructor != null && + constructor.DeclaringType is var declaringType && + declaringType.Namespace == "System" && + (declaringType.Name.StartsWith("Tuple") || declaringType.Name.StartsWith("ValueTuple")); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs similarity index 63% rename from src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs rename to src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs index 816e5fc237f..068c94581b7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs @@ -20,24 +20,24 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers { - internal interface IEnumUnderlyingTypeSerializer + internal interface IAsEnumUnderlyingTypeSerializer { IBsonSerializer EnumSerializer { get; } } - internal class EnumUnderlyingTypeSerializer : StructSerializerBase, IEnumUnderlyingTypeSerializer + internal class AsEnumUnderlyingTypeSerializer : StructSerializerBase, IAsEnumUnderlyingTypeSerializer where TEnum : Enum - where TEnumUnderlyingType : struct + where TUnderlyingType : struct { // private fields private readonly IBsonSerializer _enumSerializer; // constructors - public EnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) + public AsEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) { - if (typeof(TEnumUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) + if (typeof(TUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) { - throw new ArgumentException($"{typeof(TEnumUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); + throw new ArgumentException($"{typeof(TUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); } _enumSerializer = Ensure.IsNotNull(enumSerializer, nameof(enumSerializer)); } @@ -46,13 +46,13 @@ public EnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) public IBsonSerializer EnumSerializer => _enumSerializer; // explicitly implemented properties - IBsonSerializer IEnumUnderlyingTypeSerializer.EnumSerializer => EnumSerializer; + IBsonSerializer IAsEnumUnderlyingTypeSerializer.EnumSerializer => EnumSerializer; // public methods - public override TEnumUnderlyingType Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + public override TUnderlyingType Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) { var enumValue = _enumSerializer.Deserialize(context); - return (TEnumUnderlyingType)(object)enumValue; + return (TUnderlyingType)(object)enumValue; } /// @@ -62,28 +62,28 @@ public override bool Equals(object obj) if (object.ReferenceEquals(this, obj)) { return true; } return base.Equals(obj) && - obj is EnumUnderlyingTypeSerializer other && + obj is AsEnumUnderlyingTypeSerializer other && object.Equals(_enumSerializer, other._enumSerializer); } /// public override int GetHashCode() => _enumSerializer.GetHashCode(); - public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TEnumUnderlyingType value) + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TUnderlyingType value) { var enumValue = (TEnum)(object)value; _enumSerializer.Serialize(context, enumValue); } } - internal static class EnumUnderlyingTypeSerializer + internal static class AsEnumUnderlyingTypeSerializer { public static IBsonSerializer Create(IBsonSerializer enumSerializer) { var enumType = enumSerializer.ValueType; var underlyingType = Enum.GetUnderlyingType(enumType); - var enumUnderlyingTypeSerializerType = typeof(EnumUnderlyingTypeSerializer<,>).MakeGenericType(enumType, underlyingType); - return (IBsonSerializer)Activator.CreateInstance(enumUnderlyingTypeSerializerType, enumSerializer); + var toEnumUnderlyingTypeSerializerType = typeof(AsEnumUnderlyingTypeSerializer<,>).MakeGenericType(enumType, underlyingType); + return (IBsonSerializer)Activator.CreateInstance(toEnumUnderlyingTypeSerializerType, enumSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs new file mode 100644 index 00000000000..42992e91851 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs @@ -0,0 +1,88 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Core.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers +{ + internal interface IAsUnderlyingTypeEnumSerializer + { + IBsonSerializer UnderlyingTypeSerializer { get; } + } + + internal class AsUnderlyingTypeEnumSerializer : SerializerBase, IAsUnderlyingTypeEnumSerializer + where TEnum : Enum + where TUnderlyingType : struct + { + // private fields + private readonly IBsonSerializer _underlyingTypeSerializer; + + // constructors + public AsUnderlyingTypeEnumSerializer(IBsonSerializer underlyingTypeSerializer) + { + if (typeof(TUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) + { + throw new ArgumentException($"{typeof(TUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); + } + _underlyingTypeSerializer = Ensure.IsNotNull(underlyingTypeSerializer, nameof(underlyingTypeSerializer)); + } + + // public properties + public IBsonSerializer UnderlyingTypeSerializer => _underlyingTypeSerializer; + + // explicitly implemented properties + IBsonSerializer IAsUnderlyingTypeEnumSerializer.UnderlyingTypeSerializer => UnderlyingTypeSerializer; + + // public methods + public override TEnum Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var underlyingTypeValue = _underlyingTypeSerializer.Deserialize(context); + return (TEnum)(object)underlyingTypeValue; + } + + /// + public override bool Equals(object obj) + { + if (object.ReferenceEquals(obj, null)) { return false; } + if (object.ReferenceEquals(this, obj)) { return true; } + return + base.Equals(obj) && + obj is AsUnderlyingTypeEnumSerializer other && + object.Equals(_underlyingTypeSerializer, other._underlyingTypeSerializer); + } + + /// + public override int GetHashCode() => _underlyingTypeSerializer.GetHashCode(); + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TEnum value) + { + var underlyingTypeValue = (TUnderlyingType)(object)value; + _underlyingTypeSerializer.Serialize(context, underlyingTypeValue); + } + } + + internal static class AsUnderlyingTypeEnumSerializer + { + public static IBsonSerializer Create(Type enumType, IBsonSerializer underlyingTypeSerializer) + { + var underlyingType = Enum.GetUnderlyingType(enumType); + var toUnderlyingTypeEnumSerializerType = typeof(AsUnderlyingTypeEnumSerializer<,>).MakeGenericType(enumType, underlyingType); + return (IBsonSerializer)Activator.CreateInstance(toUnderlyingTypeEnumSerializerType, underlyingTypeSerializer); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs new file mode 100644 index 00000000000..46f892815aa --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs @@ -0,0 +1,43 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class DictionarySerializer +{ + public static IBsonSerializer Create(IBsonSerializer keySerializer, IBsonSerializer valueSerializer) + { + var serializerType = typeof(DictionarySerializer<,>).MakeGenericType(keySerializer.ValueType, valueSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, keySerializer, valueSerializer); + } +} + +internal class DictionarySerializer : DictionaryInterfaceImplementerSerializer, TKey, TValue> +{ + public DictionarySerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer) + : base(DictionaryRepresentation.Document, keySerializer, valueSerializer) + { + } + + protected override ICollection> CreateAccumulator() => new Dictionary(); + + protected override DictionaryFinalizeAccumulator(ICollection> accumulator) => (Dictionary)accumulator; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/FixedSizeArraySerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/FixedSizeArraySerializer.cs new file mode 100644 index 00000000000..909bfbfa2e3 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/FixedSizeArraySerializer.cs @@ -0,0 +1,98 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal interface IFixedSizeArraySerializer +{ + IBsonSerializer GetItemSerializer(int index); +} + +internal static class FixedSizeArraySerializer +{ + public static IBsonSerializer Create(Type itemType, IEnumerable itemSerializers) + { + var serializerType = typeof(FixedSizeArraySerializer<>).MakeGenericType(itemType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializers); + } +} + +internal sealed class FixedSizeArraySerializer : SerializerBase, IFixedSizeArraySerializer +{ + private readonly IReadOnlyList _itemSerializers; + + public FixedSizeArraySerializer(IEnumerable itemSerializers) + { + var itemSerializersArray = itemSerializers.ToArray(); + foreach (var itemSerializer in itemSerializersArray) + { + if (!typeof(TItem).IsAssignableFrom(itemSerializer.ValueType)) + { + throw new ArgumentException($"Serializer class {itemSerializer.ValueType} value type is not assignable to item type {typeof(TItem).Name}"); + } + } + + _itemSerializers = itemSerializersArray; + } + + public override TItem[] Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + var array = new TItem[_itemSerializers.Count]; + + reader.ReadStartArray(); + var i = 0; + while (reader.ReadBsonType() != BsonType.EndOfDocument) + { + if (i < array.Length) + { + array[i] = (TItem)_itemSerializers[i].Deserialize(context); + i++; + } + } + if (i != array.Length) + { + throw new BsonSerializationException($"Expected {array.Length} array items but found {i}."); + } + reader.ReadEndArray(); + + return array; + } + + IBsonSerializer IFixedSizeArraySerializer.GetItemSerializer(int index) => _itemSerializers[index]; + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TItem[] value) + { + if (value.Length != _itemSerializers.Count) + { + throw new BsonSerializationException($"Expected array value to have {_itemSerializers.Count} items but found {value.Length}."); + } + + var writer = context.Writer; + writer.WriteStartArray(); + for (var i = 0; i < value.Length; i++) + { + _itemSerializers[i].Serialize(context, args, value[i]); + } + writer.WriteEndArray(); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs new file mode 100644 index 00000000000..87a47747e5f --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs @@ -0,0 +1,42 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class HashSetSerializer +{ + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var serializerType = typeof(HashSetSerializer<>).MakeGenericType(itemSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer); + } +} + +internal class HashSetSerializer : EnumerableInterfaceImplementerSerializerBase, T> +{ + public HashSetSerializer(IBsonSerializer itemSerializer) + : base(itemSerializer) + { + } + + protected override object CreateAccumulator() => new HashSet(); + + protected override HashSet FinalizeResult(object accumulator) => (HashSet)accumulator; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IBsonSerializerExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IBsonSerializerExtensions.cs new file mode 100644 index 00000000000..e52a05b3c2b --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IBsonSerializerExtensions.cs @@ -0,0 +1,71 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Support; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IBsonSerializerExtensions +{ + public static IReadOnlyList GetMatchingMemberSerializationInfosForConstructorParameters( + this IBsonSerializer serializer, + Expression expression, + ConstructorInfo constructorInfo) + { + if (serializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer type {serializer.GetType().Name} does not implement IBsonDocumentSerializer"); + } + + var matchingMemberSerializationInfos = new List(); + foreach (var constructorParameter in constructorInfo.GetParameters()) + { + var matchingMemberSerializationInfo = GetMatchingMemberSerializationInfo(expression, documentSerializer, constructorParameter.Name); + matchingMemberSerializationInfos.Add(matchingMemberSerializationInfo); + } + + return matchingMemberSerializationInfos; + + static BsonSerializationInfo GetMatchingMemberSerializationInfo( + Expression expression, + IBsonDocumentSerializer documentSerializer, + string constructorParameterName) + { + var possibleMatchingMembers = documentSerializer.ValueType.GetMembers().Where(m => m.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase)).ToArray(); + if (possibleMatchingMembers.Length == 0) + { + throw new ExpressionNotSupportedException(expression, because: $"no matching member found for constructor parameter: {constructorParameterName}"); + } + if (possibleMatchingMembers.Length > 1) + { + throw new ExpressionNotSupportedException(expression, because: $"multiple possible matching members found for constructor parameter: {constructorParameterName}"); + } + var matchingMemberName = possibleMatchingMembers[0].Name; + + if (!documentSerializer.TryGetMemberSerializationInfo(matchingMemberName, out var matchingMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer of type {documentSerializer.GetType().Name} did not provide serialization info for member {matchingMemberName}"); + } + + return matchingMemberSerializationInfo; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs new file mode 100644 index 00000000000..f03bf327711 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IEnumerableOrIQueryableSerializer +{ + public static IBsonSerializer Create(Type enumerableOrQueryableType, IBsonSerializer itemSerializer) + { + return enumerableOrQueryableType.ImplementsIQueryable(out _) ? + IQueryableSerializer.Create(itemSerializer) : + IEnumerableSerializer.Create(itemSerializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs new file mode 100644 index 00000000000..da44f92e218 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IOrderedEnumerableOrIOrderedQueryableSerializer +{ + public static IBsonSerializer Create(Type enumerableOrQueryableType, IBsonSerializer itemSerializer) + { + return enumerableOrQueryableType.ImplementsIOrderedQueryable(out _) ? + IOrderedQueryableSerializer.Create(itemSerializer) : + IOrderedEnumerableSerializer.Create(itemSerializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs index 2be9f49a1b3..b169febe181 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs @@ -24,7 +24,7 @@ internal interface ISetWindowFieldsPartitionSerializer IBsonSerializer InputSerializer { get; } } - internal class ISetWindowFieldsPartitionSerializer : IBsonSerializer>, ISetWindowFieldsPartitionSerializer + internal class ISetWindowFieldsPartitionSerializer : IBsonSerializer>, ISetWindowFieldsPartitionSerializer, IBsonArraySerializer { private readonly IBsonSerializer _inputSerializer; @@ -61,16 +61,20 @@ public void Serialize(BsonSerializationContext context, BsonSerializationArgs ar throw new InvalidOperationException("This serializer is not intended to be used."); } - public void Serialize(BsonSerializationContext context, BsonSerializationArgs args, object value) { throw new InvalidOperationException("This serializer is not intended to be used."); } - object IBsonSerializer.Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) { throw new InvalidOperationException("This serializer is not intended to be used."); } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo itemSerializationInfo) + { + itemSerializationInfo = new BsonSerializationInfo(null, _inputSerializer, _inputSerializer.ValueType); + return true; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs new file mode 100644 index 00000000000..23fb02f7db8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IgnoreNodeSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(IgnoreNodeSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal class IgnoreNodeSerializer : SerializerBase +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs new file mode 100644 index 00000000000..5476eb1e747 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs @@ -0,0 +1,37 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IgnoreSubtreeSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(IgnoreSubtreeSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal interface IIgnoreSubtreeSerializer +{ +} + +internal class IgnoreSubtreeSerializer : SerializerBase, IIgnoreSubtreeSerializer +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs new file mode 100644 index 00000000000..2a7044e7116 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs @@ -0,0 +1,42 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class ListSerializer +{ + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var serializerType = typeof(ListSerializer<>).MakeGenericType(itemSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer); + } +} + +internal class ListSerializer : EnumerableInterfaceImplementerSerializerBase, T> +{ + public ListSerializer(IBsonSerializer itemSerializer) + : base(itemSerializer) + { + } + + protected override object CreateAccumulator() => new List(); + + protected override List FinalizeResult(object accumulator) => (List)accumulator; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs new file mode 100644 index 00000000000..103d809afaf --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs @@ -0,0 +1,81 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class NumericConversionSerializer +{ + public static IBsonSerializer Create(Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var serializerType = typeof(NumericConversionSerializer<,>).MakeGenericType(sourceType, targetType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, sourceSerializer); + } +} + +internal class NumericConversionSerializer : SerializerBase, IHasRepresentationSerializer +{ + private readonly IBsonSerializer _sourceSerializer; + + public BsonType Representation + { + get + { + if (_sourceSerializer is not IHasRepresentationSerializer hasRepresentationSerializer) + { + throw new NotSupportedException($"Serializer class {_sourceSerializer.GetType().Name} does not implement IHasRepresentationSerializer."); + } + + return hasRepresentationSerializer.Representation; + } + } + + public NumericConversionSerializer(IBsonSerializer sourceSerializer) + { + _sourceSerializer = sourceSerializer; + } + + public override TTarget Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var sourceValue = _sourceSerializer.Deserialize(context); + return (TTarget)Convert(typeof(TSource), typeof(TTarget), sourceValue); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TTarget value) + { + var sourceValue = Convert(typeof(TTarget), typeof(TSource), value); + _sourceSerializer.Serialize(context, args, sourceValue); + } + + private object Convert(Type sourceType, Type targetType, object value) + { + return (Type.GetTypeCode(sourceType), Type.GetTypeCode(targetType)) switch + { + (TypeCode.Decimal, TypeCode.Double) => (object)(double)(decimal)value, + (TypeCode.Double, TypeCode.Decimal) => (object)(decimal)(double)value, + (TypeCode.Int16, TypeCode.Int32) => (object)(int)(short)value, + (TypeCode.Int16, TypeCode.Int64) => (object)(long)(short)value, + (TypeCode.Int32, TypeCode.Int16) => (object)(short)(int)value, + (TypeCode.Int32, TypeCode.Int64) => (object)(long)(int)value, + (TypeCode.Int64, TypeCode.Int16) => (object)(short)(long)value, + (TypeCode.Int64, TypeCode.Int32) => (object)(int)(long)value, + _ => throw new NotSupportedException($"Cannot convert {sourceType} to {targetType}."), + }; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs new file mode 100644 index 00000000000..84837a2d76c --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs @@ -0,0 +1,34 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class TupleOrValueTupleSerializer +{ + public static IBsonSerializer Create(Type type, IEnumerable itemSerializers) + { + return type.Name switch + { + _ when type.Name.StartsWith("Tuple") => TupleSerializer.Create(itemSerializers), + _ when type.Name.StartsWith("ValueTuple") => ValueTupleSerializer.Create(itemSerializers), + _ => throw new ArgumentException($"Unexpected type: {type.Name}", nameof(type)) + }; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs new file mode 100644 index 00000000000..e3e6583408b --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs @@ -0,0 +1,37 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class UnknowableSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(UnknowableSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal interface IUnknowableSerializer +{ +} + +internal class UnknowableSerializer : SerializerBase, IUnknowableSerializer +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs new file mode 100644 index 00000000000..e2843cb8602 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs @@ -0,0 +1,92 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers +{ + internal static class UpcastingSerializer + { + public static IBsonSerializer Create( + Type baseType, + Type derivedType, + IBsonSerializer baseTypeSerializer) + { + var upcastingSerializerType = typeof(UpcastingSerializer<,>).MakeGenericType(baseType, derivedType); + return (IBsonSerializer)Activator.CreateInstance(upcastingSerializerType, baseTypeSerializer); + } + } + + internal sealed class UpcastingSerializer : SerializerBase, IBsonArraySerializer, IBsonDocumentSerializer + where TDerived : TBase + { + private readonly IBsonSerializer _baseTypeSerializer; + + public UpcastingSerializer(IBsonSerializer baseTypeSerializer) + { + _baseTypeSerializer = baseTypeSerializer ?? throw new ArgumentNullException(nameof(baseTypeSerializer)); + } + + public Type BaseType => typeof(TBase); + + public IBsonSerializer BaseTypeSerializer => _baseTypeSerializer; + + public Type DerivedType => typeof(TDerived); + + public override TDerived Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + return (TDerived)_baseTypeSerializer.Deserialize(context); + } + + public override bool Equals(object obj) + { + if (object.ReferenceEquals(obj, null)) { return false; } + if (object.ReferenceEquals(this, obj)) { return true; } + return + base.Equals(obj) && + obj is UpcastingSerializer other && + object.Equals(_baseTypeSerializer, other._baseTypeSerializer); + } + + public override int GetHashCode() => 0; + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TDerived value) + { + _baseTypeSerializer.Serialize(context, value); + } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonArraySerializer arraySerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonArraySerializer."); + } + + return arraySerializer.TryGetItemSerializationInfo(out serializationInfo); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonDocumentSerializer."); + } + + return documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs index f3bb40aaf3a..c66f84b213e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs @@ -98,6 +98,20 @@ public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationI public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) { + if (_valueSerializer is IBsonDocumentSerializer documentSerializer) + { + if (documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo)) + { + serializationInfo = BsonSerializationInfo.CreateWithPath( + [_fieldName, serializationInfo.ElementName], + serializationInfo.Serializer, + serializationInfo.NominalType); + return true; + } + + return false; + } + throw new InvalidOperationException(); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs index 818a92fab7a..3462d1bcf3e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs @@ -14,8 +14,11 @@ */ using System.Linq.Expressions; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -30,7 +33,8 @@ public static TranslatedExpression Translate(TranslationContext context, BinaryE var indexExpression = expression.Right; var indexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, indexExpression); var ast = AstExpression.ArrayElemAt(arrayTranslation.Ast, indexTranslation.Ast); - var itemSerializer = ArraySerializerHelper.GetItemSerializer(arrayTranslation.Serializer); + var arraySerializer = arrayTranslation.Serializer; + var itemSerializer = arraySerializer.GetItemSerializer(indexExpression, arrayExpression); return new TranslatedExpression(expression, ast, itemSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs index 3881a1135a3..0b5e37352f0 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs @@ -18,11 +18,9 @@ using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; -using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators; -using MongoDB.Driver.Support; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index 7487627213d..ef4c2d8a6d5 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -23,12 +23,11 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static TranslatedExpression Translate(ConstantExpression constantExpression) + public static TranslatedExpression Translate(TranslationContext context, ConstantExpression constantExpression) { - var constantType = constantExpression.Type; - var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + var constantSerializer = context.KnownSerializers.GetSerializer(constantExpression); return Translate(constantExpression, constantSerializer); - } + } public static TranslatedExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs index 532e10c1609..90cf9d8c45d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs @@ -214,7 +214,7 @@ private static TranslatedExpression TranslateConvertEnumToEnum(UnaryExpression e private static TranslatedExpression TranslateConvertEnumToUnderlyingType(UnaryExpression expression, Type sourceType, Type targetType, TranslatedExpression sourceTranslation) { var enumSerializer = sourceTranslation.Serializer; - var targetSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer); + var targetSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); return new TranslatedExpression(expression, sourceTranslation.Ast, targetSerializer); } @@ -265,7 +265,7 @@ private static TranslatedExpression TranslateConvertUnderlyingTypeToEnum(UnaryEx var valueSerializer = sourceTranslation.Serializer; IBsonSerializer targetSerializer; - if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (valueSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index c2d8e0010e9..5eeb2857f9a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -67,7 +67,7 @@ public static TranslatedExpression Translate(TranslationContext context, Express case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate(context, (ConstantExpression)expression); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 20f7e81312c..b03a7135f8b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -13,16 +13,14 @@ * limitations under the License. */ -using System; using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -44,168 +42,63 @@ public static TranslatedExpression Translate( NewExpression newExpression, IReadOnlyList bindings) { + var knownSerializer = context.KnownSerializers.GetSerializer(expression); var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; - var computedFields = new List(); - var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); - if (constructorInfo != null && creatorMap != null) + var computedFields = new List(); + if (constructorInfo != null && constructorArguments.Count > 0) { - var constructorParameters = constructorInfo.GetParameters(); - var creatorMapParameters = creatorMap.Arguments?.ToArray(); - if (constructorParameters.Length > 0) + var matchingMemberSerializationInfos = knownSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(expression, constructorInfo); + + for (var i = 0; i < constructorArguments.Count; i++) { - if (creatorMapParameters == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); - } - if (creatorMapParameters.Length != constructorParameters.Length) - { - throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); - } + var argument = constructorArguments[i]; + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argument); + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; - for (var i = 0; i < creatorMapParameters.Length; i++) + if (!argumentTranslation.Serializer.CanBeAssignedTo(matchingMemberSerializationInfo.Serializer)) { - var creatorMapParameter = creatorMapParameters[i]; - var constructorArgumentExpression = constructorArguments[i]; - var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression); - var constructorArgumentType = constructorArgumentExpression.Type; - var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); - var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); - EnsureDefaultValue(memberMap); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); + throw new ExpressionNotSupportedException(argument, expression, because: "argument serializer is not equal to member serializer"); } - } - } - - foreach (var binding in bindings) - { - var memberAssignment = (MemberAssignment)binding; - var member = memberAssignment.Member; - var memberMap = FindMemberMap(expression, classMap, member.Name); - var valueExpression = memberAssignment.Expression; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); - } - - var ast = AstExpression.ComputedDocument(computedFields); - classMap.Freeze(); - var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); - var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); - - return new TranslatedExpression(expression, ast, serializer); - } - - private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) - { - BsonClassMap baseClassMap = null; - if (classType.BaseType != null) - { - baseClassMap = CreateClassMap(classType.BaseType, null, out _); - } - - var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); - var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); - var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); - if (constructorInfo != null) - { - creatorMap = classMap.MapConstructor(constructorInfo); - } - else - { - creatorMap = null; - } - - classMap.AutoMap(); - classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here - - return classMap; - } - - private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) - { - var memberType = memberMap.MemberType; - var memberSerializer = memberMap.GetSerializer(); - var sourceType = sourceSerializer.ValueType; - if (memberType != sourceType && - memberType.ImplementsIEnumerable(out var memberItemType) && - sourceType.ImplementsIEnumerable(out var sourceItemType) && - sourceItemType == memberItemType && - sourceSerializer is IBsonArraySerializer sourceArraySerializer && - sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && - memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) - { - var sourceItemSerializer = sourceItemSerializationInfo.Serializer; - return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); - } - - return sourceSerializer; - } - - private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) - { - var declaringClassMap = classMap; - while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) - { - declaringClassMap = declaringClassMap.BaseClassMap; - - if (declaringClassMap == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + var computedField = AstExpression.ComputedField(matchingMemberSerializationInfo.ElementName, argumentTranslation.Ast); + computedFields.Add(computedField); } } - foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + if (bindings.Count > 0) { - if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + if (knownSerializer is not IBsonDocumentSerializer documentSerializer) { - return memberMap; + throw new ExpressionNotSupportedException(expression, because: $"serializer type {knownSerializer.GetType()} does not implement IBsonDocumentSerializer"); } - } - return declaringClassMap.MapMember(creatorMapParameter); + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; - static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) - { - var memberInfo = memberMap.MemberInfo; - return - memberInfo.MemberType == creatorMapParameter.MemberType && - memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); - } - } + if (!documentSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"member {member.Name} was not found"); + } - private static void EnsureDefaultValue(BsonMemberMap memberMap) - { - if (memberMap.IsDefaultValueSpecified) - { - return; - } + var valueExpression = memberAssignment.Expression; + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; - memberMap.SetDefaultValue(defaultValue); - } + if (!valueTranslation.Serializer.CanBeAssignedTo(memberSerializationInfo.Serializer)) + { + throw new ExpressionNotSupportedException(valueExpression, expression, because: $"value serializer is not equal to serializer for member {member.Name}"); + } - private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) - { - foreach (var memberMap in classMap.DeclaredMemberMaps) - { - if (memberMap.MemberName == memberName) - { - return memberMap; + var computedField = AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast); + computedFields.Add(computedField); } } - if (classMap.BaseClassMap != null) - { - return FindMemberMap(expression, classMap.BaseClassMap, memberName); - } - - throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + var ast = AstExpression.ComputedDocument(computedFields); + return new TranslatedExpression(expression, ast, knownSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs index 9f6844b3031..568a2187f73 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs @@ -42,8 +42,9 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var valueExpression = arguments[0]; var optionsExpression = arguments[1]; - var (toBsonType, toSerializer) = TranslateToType(expression, toType); + var toBsonType = GetResultRepresentation(expression, toType); var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var toSerializer = context.KnownSerializers.GetSerializer(expression); var (subType, byteOrder, format, onErrorAst, onNullAst) = TranslateOptions(context, expression, optionsExpression, toSerializer); var ast = AstExpression.Convert(valueTranslation.Ast, toBsonType.Render(), subType, byteOrder, format, onErrorAst, onNullAst); @@ -143,39 +144,39 @@ IBsonSerializer toSerializer return (subType, byteOrder, format, onErrorTranslation?.Ast, onNullTranslation?.Ast); } - private static (BsonType ToBsonType, IBsonSerializer ToSerializer) TranslateToType(Expression expression, Type toType) + private static BsonType GetResultRepresentation(Expression expression, Type toType) { var isNullable = toType.IsNullable(); var valueType = isNullable ? Nullable.GetUnderlyingType(toType) : toType; - var (bsonType, valueSerializer) = (ValueTuple)(Type.GetTypeCode(valueType) switch + var representation = Type.GetTypeCode(valueType) switch { - TypeCode.Boolean => (BsonType.Boolean, BooleanSerializer.Instance), - TypeCode.Byte => (BsonType.Int32, ByteSerializer.Instance), - TypeCode.Char => (BsonType.String, StringSerializer.Instance), - TypeCode.DateTime => (BsonType.DateTime, DateTimeSerializer.Instance), - TypeCode.Decimal => (BsonType.Decimal128, DecimalSerializer.Instance), - TypeCode.Double => (BsonType.Double, DoubleSerializer.Instance), - TypeCode.Int16 => (BsonType.Int32, Int16Serializer.Instance), - TypeCode.Int32 => (BsonType.Int32, Int32Serializer.Instance), - TypeCode.Int64 => (BsonType.Int64, Int64Serializer.Instance), - TypeCode.SByte => (BsonType.Int32, SByteSerializer.Instance), - TypeCode.Single => (BsonType.Double, SingleSerializer.Instance), - TypeCode.String => (BsonType.String, StringSerializer.Instance), - TypeCode.UInt16 => (BsonType.Int32, UInt16Serializer.Instance), - TypeCode.UInt32 => (BsonType.Int64, Int32Serializer.Instance), - TypeCode.UInt64 => (BsonType.Decimal128, UInt64Serializer.Instance), - - _ when valueType == typeof(byte[]) => (BsonType.Binary, ByteArraySerializer.Instance), - _ when valueType == typeof(BsonBinaryData) => (BsonType.Binary, BsonBinaryDataSerializer.Instance), - _ when valueType == typeof(Decimal128) => (BsonType.Decimal128, Decimal128Serializer.Instance), - _ when valueType == typeof(Guid) => (BsonType.Binary, GuidSerializer.StandardInstance), - _ when valueType == typeof(ObjectId) => (BsonType.ObjectId, ObjectIdSerializer.Instance), + TypeCode.Boolean => BsonType.Boolean, + TypeCode.Byte => BsonType.Int32, + TypeCode.Char => BsonType.String, + TypeCode.DateTime => BsonType.DateTime, + TypeCode.Decimal => BsonType.Decimal128, + TypeCode.Double => BsonType.Double, + TypeCode.Int16 => BsonType.Int32, + TypeCode.Int32 => BsonType.Int32, + TypeCode.Int64 => BsonType.Int64, + TypeCode.SByte => BsonType.Int32, + TypeCode.Single => BsonType.Double, + TypeCode.String => BsonType.String, + TypeCode.UInt16 => BsonType.Int32, + TypeCode.UInt32 => BsonType.Int64, + TypeCode.UInt64 => BsonType.Decimal128, + + _ when valueType == typeof(byte[]) => BsonType.Binary, + _ when valueType == typeof(BsonBinaryData) => BsonType.Binary, + _ when valueType == typeof(Decimal128) => BsonType.Decimal128, + _ when valueType == typeof(Guid) => BsonType.Binary, + _ when valueType == typeof(ObjectId) => BsonType.ObjectId, _ => throw new ExpressionNotSupportedException(expression, because: $"{toType} is not a valid TTo for Convert") - }); + }; - return (bsonType, isNullable ? NullableSerializer.Create(valueSerializer) : valueSerializer); + return representation; } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs index c5519f5547d..82086889cda 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs @@ -27,7 +27,7 @@ internal static class IntersectMethodToAggregationExpressionTranslator private static readonly MethodInfo[] __intersectMethods = { EnumerableMethod.Intersect, - QueryableMethod.Interset + QueryableMethod.Intersect }; public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs index 2eb826a8770..8c24213c01d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs @@ -26,8 +26,8 @@ internal static class SelectManyMethodToAggregationExpressionTranslator { private static readonly MethodInfo[] __selectManyMethods = { - EnumerableMethod.SelectMany, - QueryableMethod.SelectMany + EnumerableMethod.SelectManyWithSelector, + QueryableMethod.SelectManyWithSelector }; public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs index f95a2361fdc..263bd9ac6a8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs @@ -20,6 +20,7 @@ using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { @@ -37,10 +38,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var listItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var listItemType = listItemSerializer.ValueType; - var listType = typeof(List<>).MakeGenericType(listItemType); - var listSerializerType = typeof(EnumerableInterfaceImplementerSerializer<,>).MakeGenericType(listType, listItemType); - var listSerializer = (IBsonSerializer)Activator.CreateInstance(listSerializerType, listItemSerializer); + var listSerializer = ListSerializer.Create(listItemSerializer); return new TranslatedExpression(expression, sourceTranslation.Ast, listSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs index 250c8658210..81342854a49 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs @@ -70,7 +70,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC @as: predicateSymbol.Var.Name, limitTranslation?.Ast); - var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); + var resultSerializer = context.KnownSerializers.GetSerializer(expression); return new TranslatedExpression(expression, ast, resultSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs index c5eba340536..4bd724c4091 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs @@ -19,6 +19,8 @@ using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -27,28 +29,14 @@ internal static class NewArrayInitExpressionToAggregationExpressionTranslator public static TranslatedExpression Translate(TranslationContext context, NewArrayExpression expression) { var items = new List(); - IBsonSerializer itemSerializer = null; foreach (var itemExpression in expression.Expressions) { var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression); items.Add(itemTranslation.Ast); - itemSerializer ??= itemTranslation.Serializer; - - // make sure all items are serialized using the same serializer - if (!itemTranslation.Serializer.Equals(itemSerializer)) - { - throw new ExpressionNotSupportedException(expression, because: "all items in the array must be serialized using the same serializer"); - } } - var ast = AstExpression.ComputedArray(items); - var arrayType = expression.Type; - var itemType = arrayType.GetElementType(); - itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null - var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); - var arraySerializer = (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); - + var arraySerializer = context.KnownSerializers.GetSerializer(expression); return new TranslatedExpression(expression, ast, arraySerializer); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs index aee174ac38d..af7b324c2f3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -39,34 +39,21 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression); var itemSerializer = ArraySerializerHelper.GetItemSerializer(collectionTranslation.Serializer); - IBsonSerializer keySerializer; - IBsonSerializer valueSerializer; AstExpression collectionTranslationAst; - if (itemSerializer is IBsonDocumentSerializer itemDocumentSerializer) + if (itemSerializer.IsKeyValuePairSerializer(out var keyElementName, out var valueElementName, out var keySerializer, out var valueSerializer)) { - if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo)) - { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Key member"); - } - keySerializer = keyMemberSerializationInfo.Serializer; - - if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) - { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Value member"); - } - valueSerializer = valueMemberSerializationInfo.Serializer; - - if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") + if (keyElementName == "k" && valueElementName == "v") { collectionTranslationAst = collectionTranslation.Ast; } else { + // map keyElementName and valueElementName to "k" and "v" var pairVar = AstExpression.Var("pair"); var computedDocumentAst = AstExpression.ComputedDocument([ - AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), - AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) + AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyElementName)), + AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueElementName)) ]); collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs index cfe4f67f6a8..5eaf71255b7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs @@ -13,13 +13,12 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson; -using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs index 692b3600ddd..486ae382721 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs @@ -24,6 +24,7 @@ public static TranslatedExpression Translate(TranslationContext context, UnaryEx { if (expression.NodeType == ExpressionType.Not) { + // TODO: check operand representation var operandTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, expression.Operand); var ast = expression.Type == typeof(bool) ? AstExpression.Not(operandTranslation.Ast) : AstExpression.BitNot(operandTranslation.Ast); return new TranslatedExpression(expression, ast, operandTranslation.Serializer); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index b96a193e323..255fc3de15e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -13,6 +13,8 @@ * limitations under the License. */ +using System; +using System.Linq; using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; @@ -31,7 +33,7 @@ public static ExecutableQuery> Translate TranslateScalar= 1) + { + var sourceParameter = parameters[0]; + var sourceParameterType = sourceParameter.ParameterType; + if (sourceParameterType.IsConstructedGenericType) + { + sourceParameterType = sourceParameterType.GetGenericTypeDefinition(); + } + + if (sourceParameterType == typeof(IQueryable) || + sourceParameterType == typeof(IQueryable<>) || + sourceParameterType == typeof(IOrderedQueryable) || + sourceParameterType == typeof(IOrderedQueryable<>)) + { + return GetUltimateSource(methodCallExpression.Arguments[0]); + } + } + + throw new ArgumentException($"No ultimate source found: {expression}."); } #endregion // private fields private readonly TranslationContextData _data; + private readonly KnownSerializerMap _knownSerializers; private readonly NameGenerator _nameGenerator; private readonly SymbolTable _symbolTable; private readonly ExpressionTranslationOptions _translationOptions; private TranslationContext( ExpressionTranslationOptions translationOptions, + KnownSerializerMap knownSerializers, TranslationContextData data, SymbolTable symbolTable, NameGenerator nameGenerator) { _translationOptions = translationOptions ?? new ExpressionTranslationOptions(); + _knownSerializers = Ensure.IsNotNull(knownSerializers, nameof(knownSerializers)); _data = data; // can be null _symbolTable = Ensure.IsNotNull(symbolTable, nameof(symbolTable)); _nameGenerator = Ensure.IsNotNull(nameGenerator, nameof(nameGenerator)); @@ -54,6 +140,7 @@ private TranslationContext( // public properties public TranslationContextData Data => _data; + public KnownSerializerMap KnownSerializers => _knownSerializers; public NameGenerator NameGenerator => _nameGenerator; public SymbolTable SymbolTable => _symbolTable; public ExpressionTranslationOptions TranslationOptions => _translationOptions; @@ -99,6 +186,11 @@ public Symbol CreateSymbolWithVarName(ParameterExpression parameter, string varN return CreateSymbol(parameter, name: parameterName, varName, serializer, isCurrent); } + public IBsonSerializer GetKnownSerializer(Expression parameter) + { + return _knownSerializers.GetSerializer(parameter); + } + public override string ToString() { return $"{{ SymbolTable : {_symbolTable} }}"; @@ -124,7 +216,7 @@ public TranslationContext WithSymbols(params Symbol[] newSymbols) public TranslationContext WithSymbolTable(SymbolTable symbolTable) { - return new TranslationContext(_translationOptions, _data, symbolTable, _nameGenerator); + return new TranslationContext(_translationOptions, _knownSerializers, _data, symbolTable, _nameGenerator); } } } diff --git a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs index 67ca25b4261..3c326b3fe6e 100644 --- a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs +++ b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs @@ -24,6 +24,7 @@ using MongoDB.Driver.Linq.Linq3Implementation; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages; +using MongoDB.Driver.Linq.Linq3Implementation.KnownSerializerFinders; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Translators; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; @@ -61,7 +62,8 @@ internal static BsonValue TranslateExpressionToAggregateExpression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions, contextData); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: sourceSerializer, translationOptions: translationOptions, data: contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -76,7 +78,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( { expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var body = RemovePossibleConvertToObject(expression.Body); @@ -106,7 +108,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, expression.Body); @@ -125,8 +127,8 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: elementSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element context = context.WithSingleSymbol(symbol); // @ is the only symbol visible inside an $elemMatch var filter = ExpressionToFilterTranslator.Translate(context, expression.Body, exprOk: false); @@ -142,7 +144,8 @@ internal static BsonDocument TranslateExpressionToFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +179,8 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec } expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var simplifier = forFind ? new AstFindProjectionSimplifier() : new AstSimplifier(); try @@ -215,8 +219,18 @@ internal static BsonDocument TranslateExpressionToSetStage( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - var context = TranslationContext.Create(translationOptions); // do not partially evaluate expression var parameter = expression.Parameters.Single(); + var body = expression.Body; + + var knownSerializers = new KnownSerializerMap(); + knownSerializers.AddSerializer(parameter, documentSerializer); + if (body.Type == typeof(TDocument)) + { + knownSerializers.AddSerializer(body, documentSerializer); + } + KnownSerializerFinder.FindKnownSerializers(expression, translationOptions, knownSerializers); + + var context = TranslationContext.Create(translationOptions, knownSerializers); // do not partially evaluate expression var symbol = context.CreateRootSymbol(parameter, documentSerializer); context = context.WithSymbol(symbol); var setStage = ExpressionToSetStageTranslator.Translate(context, documentSerializer, expression); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs index 035bba42f7e..fe4117ce430 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs @@ -17,6 +17,8 @@ using System.Collections.Generic; using System.Linq; using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.TestHelpers; using Xunit; @@ -79,7 +81,7 @@ public class C private class MyDTO { public DateTime timestamp { get; set; } - public decimal sqrt_calc { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal sqrt_calc { get; set; } } public sealed class ClassFixture : MongoCollectionFixture diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs index 524b72ff602..5177538f54c 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs @@ -43,6 +43,17 @@ from movieId in person.MovieIds join movie in movies.AsQueryable() on movieId equals movie.Id select new { person, movie }; + // equivalement method call syntax + // var queryable = people.AsQueryable() + // .SelectMany( + // person => person.MovieIds, + // (person, movieId) => new { person = person, movieId = movieId }) + // .Join( + // movies.AsQueryable(), + // transparentIdentifier => transparentIdentifier.movieId, + // movie => movie.Id, + // (transparentIdentifier, movie) => new { person = transparentIdentifier.person, movie = movie }); + var stages = Translate(people, queryable); AssertStages( stages, diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs new file mode 100644 index 00000000000..0224126f84d --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs @@ -0,0 +1,144 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4593Tests : LinqIntegrationTest +{ + public CSharp4593Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void First_example_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(r => r.Id); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result.Should().Be("a"); + } + + [Fact] + public void First_example_workaround_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(Builders.Projection.Include(o => o.Id)); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result["_id"].AsString.Should().Be("a"); + } + + [Fact] + public void Second_example_should_work() + { + var collection = Fixture.Entities; + var idsFilter = Builders.Filter.Eq(x => x.Id, 1); + + var aggregate = collection.Aggregate() + .Match(idsFilter) + .Project(e => new + { + _id = e.Id, + CampaignId = e.CampaignId, + Accepted = e.Status.Key == "Accepted" ? 1 : 0, + Rejected = e.Status.Key == "Rejected" ? 1 : 0, + }); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $match : { _id : 1 } }", + """ + { $project : + { + _id : "$_id", + CampaignId : "$CampaignId", + Accepted : { $cond : { if : { $eq : ["$Status.Key", "Accepted"] }, then : 1, else : 0 } }, + Rejected : { $cond : { if : { $eq : ["$Status.Key", "Rejected"] }, then : 1, else : 0 } } + } + } + """); + + var results = aggregate.ToList(); + results.Count.Should().Be(1); + results[0]._id.Should().Be(1); + results[0].CampaignId.Should().Be(11); + results[0].Accepted.Should().Be(1); + results[0].Rejected.Should().Be(0); + } + + public class Order + { + public string Id { get; set; } + public string RateBasisHistoryId { get; set; } + } + + public class Entity + { + public int Id { get; set; } + public int CampaignId { get; set; } + public Status Status { get; set; } + } + + public class Status + { + public string Key { get; set; } + } + + public sealed class ClassFixture : MongoDatabaseFixture + { + public IMongoCollection Orders { get; private set; } + public IMongoCollection Entities { get; private set; } + + protected override void InitializeFixture() + { + Orders = CreateCollection("orders"); + Orders.InsertMany( + [ + new Order { Id = "a", RateBasisHistoryId = "abc" } + ]); + + Entities = CreateCollection("entities"); + Entities.InsertMany( + [ + new Entity { Id = 1, CampaignId = 11, Status = new Status { Key = "Accepted" } }, + new Entity { Id = 2, CampaignId = 22, Status = new Status { Key = "Rejected" } } + ]); + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs index 2164f38e6a0..4225bef829f 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs @@ -355,7 +355,7 @@ public void Where_Document_item_with_int_using_call_to_get_item_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(int) }).GetGetMethod(), Expression.Constant(0)), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -379,7 +379,7 @@ public void Where_Document_item_with_int_using_MakeIndex_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(int) }), new Expression[] { Expression.Constant(0) }), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -418,7 +418,7 @@ public void Where_Document_item_with_string_using_call_to_get_item_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(string) }).GetGetMethod(), Expression.Constant("a")), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -442,7 +442,7 @@ public void Where_Document_item_with_string_using_MakeIndex_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(string) }), new Expression[] { Expression.Constant("a") }), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs new file mode 100644 index 00000000000..9f8f49eff4e --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs @@ -0,0 +1,68 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4819Tests : LinqIntegrationTest +{ + public CSharp4819Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void ReplaceWith_should_use_configured_element_name() + { + var collection = Fixture.Collection; + var stage = PipelineStageDefinitionBuilder + .ReplaceWith((User u) => new User { UserId = u.UserId }); + + var aggregate = collection.Aggregate() + .AppendStage(stage); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $replaceWith : { uuid : '$uuid' } }"); + + var result = aggregate.Single(); + result.Id.Should().Be(0); + result.UserId.Should().Be(Guid.Parse("00112233-4455-6677-8899-aabbccddeeff")); + } + + public class User + { + public int Id { get; set; } + [BsonElement("uuid")] + [BsonGuidRepresentation(GuidRepresentation.Standard)] + public Guid UserId { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new User { Id = 1, UserId = Guid.Parse("00112233-4455-6677-8899-aabbccddeeff") } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs new file mode 100644 index 00000000000..18be97f693c --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs @@ -0,0 +1,114 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4820Tests : LinqIntegrationTest +{ + public CSharp4820Tests(ClassFixture fixture) + : base(fixture) + { + } + + static CSharp4820Tests() + { + BsonClassMap.RegisterClassMap(cm => + { + cm.AutoMap(); + var readonlyCollectionMemberMap = cm.GetMemberMap(x => x.ReadOnlyCollection); + var readOnlyCollectionSerializer = readonlyCollectionMemberMap.GetSerializer(); + var bracketingCollectionSerializer = ((IChildSerializerConfigurable)readOnlyCollectionSerializer).WithChildSerializer(new StringBracketingSerializer()); + readonlyCollectionMemberMap.SetSerializer(bracketingCollectionSerializer); + }); + } + + [Fact] + public void Update_Set_with_List_should_work() + { + var values = new List() { "abc", "def" }; + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_should_throw() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_ToList_should_work() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values.ToList()); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + public class C + { + public int Id { get; set; } + public IReadOnlyCollection ReadOnlyCollection { get; set; } + } + + + private class StringBracketingSerializer : SerializerBase + { + public override string Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var bracketedValue = StringSerializer.Instance.Deserialize(context, args); + return bracketedValue.Substring(1, bracketedValue.Length - 2); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, string value) + { + var bracketedValue = "[" + value + "]"; + StringSerializer.Instance.Serialize(context, bracketedValue); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => null; + // [ + // new C { } + // ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs index 791ce3bcd75..e82194ef6cc 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs @@ -84,7 +84,7 @@ public void New_array_with_two_items_should_work() [Theory] [ParameterAttributeData] - public void New_array_with_two_items_with_different_serializers_should_throw( + public void New_array_with_two_items_with_different_serializers_should_work( [Values(false, true)] bool enableClientSideProjections) { RequireServer.Check().Supports(Feature.FindProjectionExpressions); @@ -94,21 +94,11 @@ public void New_array_with_two_items_with_different_serializers_should_throw( var queryable = collection.AsQueryable(translationOptions) .Select(x => new[] { x.X, x.Y }); - if (enableClientSideProjections) - { - var stages = Translate(collection, queryable, out var outputSerializer); - AssertStages(stages, "{ $project : { _snippets : ['$X', '$Y'], _id : 0 } }"); - outputSerializer.Should().BeAssignableTo(); - - var result = queryable.Single(); - result.Should().Equal(1, 2); - } - else - { - var exception = Record.Exception(() => Translate(collection, queryable)); - exception.Should().BeOfType(); - exception.Message.Should().Contain("all items in the array must be serialized using the same serializer"); - } + var stages = Translate(collection, queryable, out var outputSerializer); + AssertStages(stages, "{ $project : { _v : ['$X', '$Y'], _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(1, 2); } public class C diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs new file mode 100644 index 00000000000..a93e1b4f387 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs @@ -0,0 +1,75 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson.Serialization; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4967Tests : LinqIntegrationTest +{ + public CSharp4967Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Set_Nested_should_work() + { + var collection = Fixture.Collection; + var update = Builders.Update + .Pipeline(new EmptyPipelineDefinition() + .Set(c => new MyDocument + { + Nested = new MyNestedDocument + { + ValueCopy = c.Value, + }, + })); + + var renderedUpdate = update.Render(new(collection.DocumentSerializer, BsonSerializer.SerializerRegistry)).AsBsonArray; + renderedUpdate.Count.Should().Be(1); + renderedUpdate[0].Should().Be("{ $set : { Nested : { ValueCopy : '$Value' } } }"); + + collection.UpdateMany("{ }", update); + + var updatedDocument = collection.FindSync("{}").Single(); + updatedDocument.Nested.ValueCopy.Should().Be("Value"); + } + + public class MyDocument + { + public int Id { get; set; } + public string Value { get; set; } + public string AnotherValue { get; set; } + public MyNestedDocument Nested { get; set; } + } + + public class MyNestedDocument + { + public string ValueCopy { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new MyDocument { Id = 1, Value = "Value" } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs new file mode 100644 index 00000000000..d188c18fa1f --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -0,0 +1,225 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5435Tests : Linq3IntegrationTest + { + [Fact] + public void Test_set_ValueObject_Value_using_creator_map() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue(x.ValueObject == null ? 1 : x.ValueObject.Value + 1) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_Value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_to_derived_value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyDerivedValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1, + B = 42 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_X_using_constructor() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + X = new X(x.Y) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { X : { Y : '$Y' } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_A() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + A = new [] { 2, x.A[0] } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', 0] }] } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection.Database.GetCollection("test"), + BsonDocument.Parse("{ _id : 1 }"), + BsonDocument.Parse("{ _id : 2, X : null }"), + BsonDocument.Parse("{ _id : 3, X : 3 }")); + return collection; + } + + class MyDocument + { + [BsonRepresentation(MongoDB.Bson.BsonType.ObjectId)] + public string Id { get; set; } = ObjectId.GenerateNewId().ToString(); + + public MyValue ValueObject { get; set; } + + public long Long { get; set; } + + public X X { get; set; } + + public int Y { get; set; } + + [BsonRepresentation(BsonType.String)] + public int[] A { get; set; } + } + + class MyValue + { + [BsonConstructor] + public MyValue() { } + [BsonConstructor] + public MyValue(int value) { Value = value; } + public int Value { get; set; } + } + + class MyDerivedValue : MyValue + { + public int B { get; set; } + } + + [BsonSerializer(typeof(XSerializer))] + class X + { + public X(int y) + { + Y = y; + } + public int Y { get; } + } + + class XSerializer : SerializerBase, IBsonDocumentSerializer + { + public override X Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + reader.ReadStartArray(); + _ = reader.ReadName(); + var y = reader.ReadInt32(); + reader.ReadEndDocument(); + + return new X(y); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, X value) + { + var writer = context.Writer; + writer.WriteStartDocument(); + writer.WriteName("Y"); + writer.WriteInt32(value.Y); + writer.WriteEndDocument(); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + serializationInfo = memberName == "Y" ? new BsonSerializationInfo("Y", Int32Serializer.Instance, typeof(int)) : null; + return serializationInfo != null; + } + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs new file mode 100644 index 00000000000..30f3a73072a --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs @@ -0,0 +1,66 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5519Tests : LinqIntegrationTest +{ + public CSharp5519Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Array_constant_Any_should_serialize_array_correctly() + { + var collection = Fixture.Collection; + var array = new[] { E.A, E.B }; + + var find = collection.Find(x => array.Any(e => x.E == e)); + + var filter = TranslateFindFilter(collection, find); + filter.Should().Be("{ E : { $in : ['A', 'B'] } }"); + + var results = find.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + public class C + { + public int Id { get; set; } + [BsonRepresentation(BsonType.String)] public E E { get; set; } + } + + public enum E { A, B, C } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C { Id = 1, E = E.A }, + new C { Id = 2, E = E.B }, + new C { Id = 3, E = E.C } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs new file mode 100644 index 00000000000..3e1c9e64e82 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs @@ -0,0 +1,189 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5532Tests : LinqIntegrationTest +{ + private static readonly ObjectId id1 = ObjectId.Parse("111111111111111111111111"); + private static readonly ObjectId id2 = ObjectId.Parse("222222222222222222222222"); + private static readonly ObjectId id3 = ObjectId.Parse("333333333333333333333333"); + + public CSharp5532Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Filter_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find(x => x.Parts.Any(a => a.Refs.Any(b => jobIds.Contains(b.id)))); + + var filter = TranslateFindFilter(collection, find); + + filter.Should().Be("{ Parts : { $elemMatch : { Refs : { $elemMatch : { _id : { $in : [ObjectId('222222222222222222222222')] } } } } } }"); + } + + [Fact] + public void Projection_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find("{}") + .Project(chain => + new + { + chain.Parts + .First(p => p.Refs.Any(j => jobIds.Contains(j.id))) + .Refs.First(j => jobIds.Contains(j.id)).id + });; + + var projectionTranslation = TranslateFindProjection(collection, find); + + projectionTranslation.Should().Be( + """ + { + _id : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : "$Parts", + as : "p", + cond : + { + $anyElementTrue : + { + $map : + { + input : "$$p.Refs", + as : "j", + in : { $in : ["$$j._id", [{ "$oid" : "222222222222222222222222" }]] } + } + } + } + } + }, + 0 + ] + } + }, + in : "$$this.Refs" + } + }, + as : "j", + cond : { $in : ['$$j._id', [{ "$oid" : "222222222222222222222222" }]] } + } + }, + 0 + ] + } + }, + in : "$$this._id" + } + } + } + """); + } + + public class Document + { + [BsonId] + [BsonRepresentation(BsonType.ObjectId)] + public string id { get; set; } + } + + public class Chain : Document + { + public ICollection Parts { get; set; } = new List(); + } + + public class Unit + { + public ICollection Refs { get; set; } + + public Unit() + { + Refs = new List(); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new Chain + { + id = "0102030405060708090a0b0c", + Parts = new List() + { + new() + { + Refs = new List() + { + new() + { + id = id1.ToString(), + }, + new() + { + id = id2.ToString(), + }, + new() + { + id = id3.ToString(), + }, + } + } + } + } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs similarity index 68% rename from tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs rename to tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs index f6f7ace6d48..de10de29a04 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs @@ -22,7 +22,7 @@ namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Serializers { - public class EnumUnderlyingTypeSerializerTests + public class AsEnumUnderlyingTypeSerializerTests { private static readonly IBsonSerializer __enumSerializer1 = new ESerializer1(); private static readonly IBsonSerializer __enumSerializer2 = new ESerializer2(); @@ -30,8 +30,8 @@ public class EnumUnderlyingTypeSerializerTests [Fact] public void Equals_derived_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new DerivedFromEnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new DerivedFromAsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(y); @@ -41,7 +41,7 @@ public void Equals_derived_should_return_false() [Fact] public void Equals_null_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(null); @@ -51,7 +51,7 @@ public void Equals_null_should_return_false() [Fact] public void Equals_object_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var y = new object(); var result = x.Equals(y); @@ -62,7 +62,7 @@ public void Equals_object_should_return_false() [Fact] public void Equals_self_should_return_true() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(x); @@ -72,8 +72,8 @@ public void Equals_self_should_return_true() [Fact] public void Equals_with_equal_fields_should_return_true() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(y); @@ -83,8 +83,8 @@ public void Equals_with_equal_fields_should_return_true() [Fact] public void Equals_with_not_equal_field_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new EnumUnderlyingTypeSerializer(__enumSerializer2); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new AsEnumUnderlyingTypeSerializer(__enumSerializer2); var result = x.Equals(y); @@ -94,18 +94,18 @@ public void Equals_with_not_equal_field_should_return_false() [Fact] public void GetHashCode_should_return_zero() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.GetHashCode(); result.Should().Be(0); } - internal class DerivedFromEnumUnderlyingTypeSerializer : EnumUnderlyingTypeSerializer + internal class DerivedFromAsEnumUnderlyingTypeSerializer : AsEnumUnderlyingTypeSerializer where TEnum : Enum where TEnumUnderlyingType : struct { - public DerivedFromEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) : base(enumSerializer) + public DerivedFromAsEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) : base(enumSerializer) { } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs index 10f3f2a5d14..08f071902ab 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs @@ -31,8 +31,8 @@ public class ModuloComparisonExpressionToFilterTranslatorTests [Fact] public void Translate_should_return_expected_result_with_byte_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Byte % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Byte % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -44,8 +44,8 @@ public void Translate_should_return_expected_result_with_byte_arguments() [Fact] public void Translate_should_return_expected_result_with_decimal_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Decimal % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Decimal % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -57,8 +57,8 @@ public void Translate_should_return_expected_result_with_decimal_arguments() [Fact] public void Translate_should_return_expected_result_with_double_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Double % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Double % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -70,8 +70,8 @@ public void Translate_should_return_expected_result_with_double_arguments() [Fact] public void Translate_should_return_expected_result_with_float_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Float % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Float % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -83,8 +83,8 @@ public void Translate_should_return_expected_result_with_float_arguments() [Fact] public void Translate_should_return_expected_result_with_int16_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int16 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int16 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -96,8 +96,8 @@ public void Translate_should_return_expected_result_with_int16_arguments() [Fact] public void Translate_should_return_expected_result_with_int32_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int32 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int32 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -109,8 +109,8 @@ public void Translate_should_return_expected_result_with_int32_arguments() [Fact] public void Translate_should_return_expected_result_with_int64_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int64 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int64 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -122,8 +122,8 @@ public void Translate_should_return_expected_result_with_int64_arguments() [Fact] public void Translate_should_return_expected_result_with_sbyte_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.SByte % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.SByte % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -135,8 +135,8 @@ public void Translate_should_return_expected_result_with_sbyte_arguments() [Fact] public void Translate_should_return_expected_result_with_uint16_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt16 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt16 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -148,8 +148,8 @@ public void Translate_should_return_expected_result_with_uint16_arguments() [Fact] public void Translate_should_return_expected_result_with_uint32_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt32 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt32 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -161,8 +161,8 @@ public void Translate_should_return_expected_result_with_uint32_arguments() [Fact] public void Translate_should_return_expected_result_with_uint64_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt64 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt64 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -180,19 +180,19 @@ private void Assert(AstFilter result, string path, BsonValue divisor, BsonValue modFilterOperation.Remainder.Should().Be(remainder); } - private TranslationContext CreateContext(ParameterExpression parameter) + private TranslationContext CreateContext(LambdaExpression lambda) { + var parameter = lambda.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(parameter.Type); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(lambda, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); return context.WithSymbol(symbol); } - private (ParameterExpression, BinaryExpression) CreateExpression(Expression> lambda) + private (LambdaExpression, BinaryExpression) CreateExpression(Expression> lambda) { - var parameter = lambda.Parameters.Single(); var expression = (BinaryExpression)lambda.Body; - return (parameter, expression); + return (lambda, expression); } private class C diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index a8f7428079b..361100e0240 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -641,7 +641,7 @@ private ProjectedResult Group(Expression Project(Expression var query = __collection.AsQueryable().Select(projector); var provider = (MongoQueryProvider)query.Provider; + var inputSerializer = (IBsonSerializer)provider.PipelineInputSerializer; + var serializerRegistry = provider.Collection.Settings.SerializerRegistry; var translationOptions = new ExpressionTranslationOptions { EnableClientSideProjections = false }; - var executableQuery = ExpressionToExecutableQueryTranslator.Translate(provider, query.Expression, translationOptions); - var projection = executableQuery.Pipeline.Ast.Stages.First().Render()["$project"].AsBsonDocument; + var renderedProjection = LinqProviderAdapter.TranslateExpressionToProjection( + projector, + inputSerializer, + serializerRegistry, + translationOptions); + + var projection = renderedProjection.Document; var value = query.Take(1).FirstOrDefault(); return new ProjectedResult diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs index cd49af1955f..1f3c8925d8e 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs @@ -1184,7 +1184,7 @@ private void Assert(Expression> expression, int var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(expression, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, expression.Body); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index 0869d70822e..df920d70e6a 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1152,9 +1152,9 @@ public List Assert(IMongoCollection collection, { filter = (Expression>)PartialEvaluator.EvaluatePartially(filter); - var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single(); - var context = TranslationContext.Create(translationOptions: null); + var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); + var context = TranslationContext.Create(filter, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, filter.Body);