diff --git a/.autover/changes/5bdfbea6-1f68-4d45-b380-4b705ed18fe2.json b/.autover/changes/5bdfbea6-1f68-4d45-b380-4b705ed18fe2.json new file mode 100644 index 00000000..ab7d1a54 --- /dev/null +++ b/.autover/changes/5bdfbea6-1f68-4d45-b380-4b705ed18fe2.json @@ -0,0 +1,11 @@ +{ + "Projects": [ + { + "Name": "AWS.Messaging", + "Type": "Minor", + "ChangelogMessages": [ + "Added subscriber middleware with optional error handler to override result or retry execution." + ] + } + ] +} \ No newline at end of file diff --git a/AWS.Messaging.lutconfig b/AWS.Messaging.lutconfig new file mode 100644 index 00000000..596a8603 --- /dev/null +++ b/AWS.Messaging.lutconfig @@ -0,0 +1,6 @@ + + + true + true + 180000 + \ No newline at end of file diff --git a/sampleapps/SubscriberService/Middleware/SampleMiddleware.cs b/sampleapps/SubscriberService/Middleware/SampleMiddleware.cs new file mode 100644 index 00000000..4596e23f --- /dev/null +++ b/sampleapps/SubscriberService/Middleware/SampleMiddleware.cs @@ -0,0 +1,16 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using AWS.Messaging; + +namespace SubscriberService.Middleware; + +public class SampleMiddleware : IMiddleware +{ + public Task InvokeAsync(MessageEnvelope messageEnvelope, RequestDelegate next, CancellationToken token = default) + { + // This middleware does not do anything, but exists to demonstrate how to implement a middleware + + return next(); + } +} diff --git a/sampleapps/SubscriberService/Program.cs b/sampleapps/SubscriberService/Program.cs index 4bb45ce2..3b3933e0 100644 --- a/sampleapps/SubscriberService/Program.cs +++ b/sampleapps/SubscriberService/Program.cs @@ -11,6 +11,7 @@ using OpenTelemetry.Resources; using OpenTelemetry.Trace; using SubscriberService.MessageHandlers; +using SubscriberService.Middleware; using SubscriberService.Models; await Host.CreateDefaultBuilder(args) @@ -34,6 +35,8 @@ await Host.CreateDefaultBuilder(args) if (string.IsNullOrEmpty(mpfQueueUrl)) throw new InvalidOperationException("Missing required configuration parameter 'AWS:Resources:MPFQueueUrl'."); + builder.AddMiddleware(); + builder.AddSQSPoller(mpfQueueUrl); builder.AddMessageHandler("chatMessage"); diff --git a/src/AWS.Messaging/Configuration/IMessageBusBuilder.cs b/src/AWS.Messaging/Configuration/IMessageBusBuilder.cs index b377f874..b48f4031 100644 --- a/src/AWS.Messaging/Configuration/IMessageBusBuilder.cs +++ b/src/AWS.Messaging/Configuration/IMessageBusBuilder.cs @@ -53,6 +53,17 @@ public interface IMessageBusBuilder IMessageBusBuilder AddMessageHandler<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] THandler, TMessage>(string? messageTypeIdentifier = null) where THandler : IMessageHandler; + /// + /// Adds a middleware to the subscriber message bus pipeline. + /// + /// + /// Middleware will be executed in the order in which it is added. + /// + /// The type that implements + /// The lifetime of the middleware. + IMessageBusBuilder AddMiddleware<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] TMiddleware>(ServiceLifetime serviceLifetime = ServiceLifetime.Singleton) + where TMiddleware : class, IMiddleware; + /// /// Adds an SQS queue to poll for messages. /// diff --git a/src/AWS.Messaging/Configuration/IMessageConfiguration.cs b/src/AWS.Messaging/Configuration/IMessageConfiguration.cs index 3e34389c..415004b7 100644 --- a/src/AWS.Messaging/Configuration/IMessageConfiguration.cs +++ b/src/AWS.Messaging/Configuration/IMessageConfiguration.cs @@ -46,6 +46,11 @@ public interface IMessageConfiguration /// The containing routing info for the specified message type. SubscriberMapping? GetSubscriberMapping(string messageTypeIdentifier); + /// + /// Maps the middleware types to be used in order of execution. + /// + IList SubscriberMiddleware { get; } + /// /// List of configurations for subscriber to poll for messages from an AWS service endpoint. /// diff --git a/src/AWS.Messaging/Configuration/MessageBusBuilder.cs b/src/AWS.Messaging/Configuration/MessageBusBuilder.cs index a87a6cca..ce4c50f8 100644 --- a/src/AWS.Messaging/Configuration/MessageBusBuilder.cs +++ b/src/AWS.Messaging/Configuration/MessageBusBuilder.cs @@ -2,6 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using System.Reflection; using AWS.Messaging.Configuration.Internal; using AWS.Messaging.Publishers; using AWS.Messaging.Publishers.EventBridge; @@ -12,13 +15,11 @@ using AWS.Messaging.Services.Backoff; using AWS.Messaging.Services.Backoff.Policies; using AWS.Messaging.Telemetry; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; -using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; namespace AWS.Messaging.Configuration; @@ -99,16 +100,33 @@ private IMessageBusBuilder AddPublisher([DynamicallyAccessedMembers(DynamicallyA public IMessageBusBuilder AddMessageHandler<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] THandler, TMessage>(string? messageTypeIdentifier = null) where THandler : IMessageHandler { - return AddMessageHandler(typeof(THandler), typeof(TMessage), () => new MessageEnvelope(), messageTypeIdentifier); + var subscriberMapping = SubscriberMapping.Create(messageTypeIdentifier); + _messageConfiguration.SubscriberMappings.Add(subscriberMapping); + return this; } - private IMessageBusBuilder AddMessageHandler([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] Type handlerType, Type messageType, Func envelopeFactory, string? messageTypeIdentifier = null) + private IMessageBusBuilder AddMessageHandler([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] Type handlerType, Type messageType, Func envelopeFactory, HandlerInvokerDelegate handlerInvokerDelegate, string? messageTypeIdentifier = null) { - var subscriberMapping = new SubscriberMapping(handlerType, messageType, envelopeFactory, messageTypeIdentifier); + var subscriberMapping = new SubscriberMapping(handlerType, messageType, envelopeFactory, handlerInvokerDelegate, messageTypeIdentifier); _messageConfiguration.SubscriberMappings.Add(subscriberMapping); return this; } + public IMessageBusBuilder AddMessageErrorHandler<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] T>(ServiceLifetime serviceLifetime = ServiceLifetime.Singleton) + where T: IMessageErrorHandler + { + AddAdditionalService(new ServiceDescriptor(typeof(IMessageErrorHandler), typeof(T), serviceLifetime)); + return this; + } + + public IMessageBusBuilder AddMiddleware<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] TMiddleware>(ServiceLifetime serviceLifetime = ServiceLifetime.Singleton) + where TMiddleware : class, IMiddleware + { + var subscriberMiddleware = SubscriberMiddleware.Create(serviceLifetime); + _messageConfiguration.SubscriberMiddleware.Add(subscriberMiddleware); + return this; + } + /// public IMessageBusBuilder AddSQSPoller(string queueUrl, Action? options = null) { @@ -223,12 +241,13 @@ public IMessageBusBuilder LoadConfigurationFromSettings(IConfiguration configura var handlerType = GetTypeFromAssemblies(callingAssembly, messageHandler.HandlerType) ?? throw new InvalidAppSettingsConfigurationException($"Unable to find the provided message handler type '{messageHandler.HandlerType}'."); + var messageEnvelopeType = typeof(MessageEnvelope<>).MakeGenericType(messageType); + // This func is not Native AOT compatible but the method in general is marked // as not being Native AOT compatible due to loading dynamic types. So this // func not being Native AOT compatible is okay. MessageEnvelope envelopeFactory() { - var messageEnvelopeType = typeof(MessageEnvelope<>).MakeGenericType(messageType); var envelope = Activator.CreateInstance(messageEnvelopeType); if (envelope == null || envelope is not MessageEnvelope) { @@ -238,7 +257,8 @@ MessageEnvelope envelopeFactory() return (MessageEnvelope)envelope; } - AddMessageHandler(handlerType, messageType, envelopeFactory, messageHandler.MessageTypeIdentifier); + var handlerInoker = BuildHandlerInvoker(messageType, messageEnvelopeType); + AddMessageHandler(handlerType, messageType, envelopeFactory, handlerInoker, messageHandler.MessageTypeIdentifier); } } @@ -283,6 +303,29 @@ MessageEnvelope envelopeFactory() } return this; + + // This is not Native AOT compatible but the method in general is marked + // as not being Native AOT compatible due to loading dynamic types. So this + // func not being Native AOT compatible is okay. + static HandlerInvokerDelegate BuildHandlerInvoker(Type messageType, Type messageEnvelopeType) + { + var invokerParam = Expression.Parameter(typeof(HandlerInvoker), "invoker"); + var envelopeParam = Expression.Parameter(typeof(MessageEnvelope), "envelope"); + var mappingParam = Expression.Parameter(typeof(SubscriberMapping), "mapping"); + var tokenParam = Expression.Parameter(typeof(CancellationToken), "token"); + + // invoker.InvokeAsync( (MessageEnvelope) envelope, mapping, token ) + var genericMethodDef = typeof(HandlerInvoker) + .GetMethods(BindingFlags.Public | BindingFlags.Instance) + .First(m => m.Name == nameof(HandlerInvoker.InvokeAsync) && m.IsGenericMethodDefinition) + .GetGenericMethodDefinition(); + + var closedMethod = genericMethodDef.MakeGenericMethod(messageType); + var typedEnvelope = Expression.Convert(envelopeParam, messageEnvelopeType); + var call = Expression.Call(invokerParam, closedMethod, typedEnvelope, mappingParam, tokenParam); + var lambda = Expression.Lambda(call, invokerParam, envelopeParam, mappingParam, tokenParam); + return lambda.Compile(); + } } [RequiresUnreferencedCode("This method requires loading types dynamically as defined in the configuration system.")] @@ -378,6 +421,11 @@ internal void Build() { _serviceCollection.TryAddScoped(subscriberMapping.HandlerType); } + + foreach (var subscriberMiddleware in _messageConfiguration.SubscriberMiddleware) + { + _serviceCollection.TryAdd(new ServiceDescriptor(subscriberMiddleware.Type, subscriberMiddleware.Type, subscriberMiddleware.ServiceLifetime)); + } } if (_messageConfiguration.MessagePollerConfigurations.Any()) diff --git a/src/AWS.Messaging/Configuration/MessageConfiguration.cs b/src/AWS.Messaging/Configuration/MessageConfiguration.cs index db93d17f..4184c440 100644 --- a/src/AWS.Messaging/Configuration/MessageConfiguration.cs +++ b/src/AWS.Messaging/Configuration/MessageConfiguration.cs @@ -39,6 +39,9 @@ public class MessageConfiguration : IMessageConfiguration return subscriberMapping; } + /// + public IList SubscriberMiddleware { get; } = new List(); + /// public IList MessagePollerConfigurations { get; set; } = new List(); diff --git a/src/AWS.Messaging/Configuration/SubscriberMapping.cs b/src/AWS.Messaging/Configuration/SubscriberMapping.cs index 86af401a..63625a28 100644 --- a/src/AWS.Messaging/Configuration/SubscriberMapping.cs +++ b/src/AWS.Messaging/Configuration/SubscriberMapping.cs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 using System.Diagnostics.CodeAnalysis; +using AWS.Messaging.Services; namespace AWS.Messaging.Configuration; @@ -17,6 +18,9 @@ public class SubscriberMapping /// public Type MessageType { get; } + /// + public HandlerInvokerDelegate HandlerInvoker { get; } + /// public string MessageTypeIdentifier { get; } @@ -32,9 +36,10 @@ public class SubscriberMapping /// The type that implements /// The type that will be message data will deserialized into /// Func for creating + /// Delegate to invoke handler of . /// Optional message type identifier. If not set the full name of the is used. - internal SubscriberMapping([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] Type handlerType, Type messageType, Func envelopeFactory, string? messageTypeIdentifier = null) + internal SubscriberMapping([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] Type handlerType, Type messageType, Func envelopeFactory, HandlerInvokerDelegate handlerInvoker, string? messageTypeIdentifier = null) { HandlerType = handlerType; MessageType = messageType; @@ -44,6 +49,7 @@ internal SubscriberMapping([DynamicallyAccessedMembers(DynamicallyAccessedMember messageType.FullName ?? throw new InvalidMessageTypeException("Unable to retrieve the Full Name of the provided Message Type."); MessageEnvelopeFactory = envelopeFactory; + HandlerInvoker = handlerInvoker; } /// @@ -61,6 +67,13 @@ static MessageEnvelope envelopeFactory() return new MessageEnvelope(); } - return new SubscriberMapping(typeof(THandler), typeof(TMessage), envelopeFactory, messageTypeIdentifier); + static Task handlerInvoker(HandlerInvoker invoker, MessageEnvelope messageEnvelope, SubscriberMapping subscriberMapping, CancellationToken token = default) + { + return invoker.InvokeAsync((MessageEnvelope)messageEnvelope, subscriberMapping, token); + } + + return new SubscriberMapping(typeof(THandler), typeof(TMessage), envelopeFactory, handlerInvoker, messageTypeIdentifier); } } + +public delegate Task HandlerInvokerDelegate(HandlerInvoker invoker, MessageEnvelope messageEnvelope, SubscriberMapping subscriberMapping, CancellationToken token = default); diff --git a/src/AWS.Messaging/Configuration/SubscriberMiddleware.cs b/src/AWS.Messaging/Configuration/SubscriberMiddleware.cs new file mode 100644 index 00000000..100e8189 --- /dev/null +++ b/src/AWS.Messaging/Configuration/SubscriberMiddleware.cs @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.DependencyInjection; + +namespace AWS.Messaging.Configuration; + +/// +/// Tracks the to be processed by the implementation and its . +/// +public class SubscriberMiddleware +{ + /// + /// Constructs an instance of + /// + /// The type that implements . + /// The lifetime of the middleware. + internal SubscriberMiddleware([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] Type type, ServiceLifetime serviceLifetime) + { + Type = type; + ServiceLifetime = serviceLifetime; + } + + /// + /// Type that implements . + /// + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] + public Type Type { get; } + + /// + /// Service lifetime of the middleware. + /// + public ServiceLifetime ServiceLifetime { get; } + + /// + /// Creates a SubscriberMiddleware from the generic parameters for the middleware. + /// + /// The type that implements + /// The lifetime of the middleware. + /// + public static SubscriberMiddleware Create<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] TMiddleware>(ServiceLifetime serviceLifetime = ServiceLifetime.Singleton) + where TMiddleware : class, IMiddleware + { + return new SubscriberMiddleware(typeof(TMiddleware), serviceLifetime); + } +} diff --git a/src/AWS.Messaging/IMessageErrorHandler.cs b/src/AWS.Messaging/IMessageErrorHandler.cs new file mode 100644 index 00000000..f3fd380f --- /dev/null +++ b/src/AWS.Messaging/IMessageErrorHandler.cs @@ -0,0 +1,35 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +namespace AWS.Messaging; + +public interface IMessageErrorHandler +{ + /// + /// Handles errors that occur during message processing. + /// + /// The message being processed. + /// raised while processing message. + /// Number of attempts made at processing this message + /// + /// + public ValueTask OnHandleError(MessageEnvelope messageEnvelope, Exception exception, int attempts, CancellationToken token); +} + +public enum MessageErrorHandlerResponse +{ + /// + /// Failed response. + /// + Failed, + + /// + /// Retry the message processing in the same process. + /// + Retry, + + /// + /// Success response. + /// + Success +} diff --git a/src/AWS.Messaging/IMiddleware.cs b/src/AWS.Messaging/IMiddleware.cs new file mode 100644 index 00000000..07b56072 --- /dev/null +++ b/src/AWS.Messaging/IMiddleware.cs @@ -0,0 +1,30 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +namespace AWS.Messaging; + +/// +/// This interface is implemented by the users of this library for each layer of middleware that should be processed. +/// +public interface IMiddleware +{ + /// + /// Processes a message through the middleware pipeline, invoking the next middleware or the message handler. + /// + /// The message read from the message source wrapped around a message envelope containing message metadata. + /// Delegate to execute the next layer of middleware. When no further middleware remains, the delegate will execute the message handler. + /// The optional cancellation token. + /// + /// The status of the processed message. For example whether the message was successfully processed. + /// Default implementations should return the result returned from the next delegate. + /// + Task InvokeAsync(MessageEnvelope messageEnvelope, RequestDelegate next, CancellationToken token = default); +} + +/// +/// The delegate used to invoke the next middleware layer or the message handler. +/// +/// +/// The status of the processed message. For example whether the message was successfully processed. +/// +public delegate Task RequestDelegate(); diff --git a/src/AWS.Messaging/Services/HandlerInvoker.cs b/src/AWS.Messaging/Services/HandlerInvoker.cs index 3144a297..18e49c66 100644 --- a/src/AWS.Messaging/Services/HandlerInvoker.cs +++ b/src/AWS.Messaging/Services/HandlerInvoker.cs @@ -1,7 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -using System.Collections.Concurrent; using System.Reflection; using AWS.Messaging.Configuration; using AWS.Messaging.Telemetry; @@ -16,12 +15,7 @@ public class HandlerInvoker : IHandlerInvoker private readonly IServiceProvider _serviceProvider; private readonly ILogger _logger; private readonly ITelemetryFactory _telemetryFactory; - - /// - /// Caches the of the - /// method that will be invoked with the message envelope for each handler - /// - private readonly ConcurrentDictionary _handlerMethods = new(); + private readonly IMessageConfiguration _messageConfiguration; /// /// Constructs an instance of HandlerInvoker @@ -29,18 +23,27 @@ public class HandlerInvoker : IHandlerInvoker /// Service provider used to resolve handler objects /// Logger for debugging information /// Factory for telemetry data + /// Messaging configuration holding middleware configuration public HandlerInvoker( IServiceProvider serviceProvider, ILogger logger, - ITelemetryFactory telemetryFactory) + ITelemetryFactory telemetryFactory, + IMessageConfiguration messageConfiguration) { _serviceProvider = serviceProvider; _logger = logger; _telemetryFactory = telemetryFactory; + _messageConfiguration = messageConfiguration; + } + + public Task InvokeAsync(MessageEnvelope messageEnvelope, SubscriberMapping subscriberMapping, CancellationToken token = default) + { + // redirect to the generic version of InvokeAsync + return subscriberMapping.HandlerInvoker(this, messageEnvelope, subscriberMapping, token); } /// - public async Task InvokeAsync(MessageEnvelope messageEnvelope, SubscriberMapping subscriberMapping, CancellationToken token = default) + public async Task InvokeAsync(MessageEnvelope messageEnvelope, SubscriberMapping subscriberMapping, CancellationToken token = default) { using (var trace = _telemetryFactory.Trace("Processing message", messageEnvelope)) { @@ -54,67 +57,74 @@ public async Task InvokeAsync(MessageEnvelope messageEnvel trace.AddMetadata(TelemetryKeys.SqsMessageId, messageEnvelope.SQSMetadata.MessageID); } - await using (var scope = _serviceProvider.CreateAsyncScope()) + var attempt = 0; + while (true) { - object handler; + await using var scope = _serviceProvider.CreateAsyncScope(); try { - handler = scope.ServiceProvider.GetRequiredService(subscriberMapping.HandlerType); - } - catch (Exception e) - { - _logger.LogError("Unable to resolve a handler for {HandlerType} while handling message ID {MessageEnvelopeId}.", subscriberMapping.HandlerType, messageEnvelope.Id); - throw new InvalidMessageHandlerSignatureException($"Unable to resolve a handler for {subscriberMapping.HandlerType} " + - $"while handling message ID {messageEnvelope.Id}.", e); - } - var method = _handlerMethods.GetOrAdd(subscriberMapping.MessageType, x => - { - return subscriberMapping.HandlerType.GetMethod( // Look up the method on the handler type with: - nameof(IMessageHandler.HandleAsync), // name "HandleAsync" - new Type[] { messageEnvelope.GetType(), typeof(CancellationToken) }); // parameters (MessageEnvelope, CancellationToken) - }); + IMessageHandler handler; + try + { + handler = (IMessageHandler)scope.ServiceProvider.GetRequiredService(subscriberMapping.HandlerType); + } + catch (Exception ex) + { + _logger.LogError(ex, "Unable to resolve a handler for {HandlerType} while handling message ID {MessageEnvelopeId}.", subscriberMapping.HandlerType, messageEnvelope.Id); + throw new InvalidMessageHandlerSignatureException($"Unable to resolve a handler for {subscriberMapping.HandlerType} while handling message ID {messageEnvelope.Id}.", ex); + } - if (method == null) - { - _logger.LogError("Unable to resolve a compatible HandleAsync method for {HandlerType} while handling message ID {MessageEnvelopeId}.", subscriberMapping.HandlerType, messageEnvelope.Id); - throw new InvalidMessageHandlerSignatureException($"Unable to resolve a compatible HandleAsync method for {subscriberMapping.HandlerType} while handling message ID {messageEnvelope.Id}."); + var middlewares = _messageConfiguration.SubscriberMiddleware.Select(type => (IMiddleware)scope.ServiceProvider.GetRequiredService(type.Type)!).ToList(); + return await ExecutePipelineAsync(messageEnvelope, middlewares, handler, token).ConfigureAwait(false); } - - try + catch (Exception ex) when (ex is not InvalidMessageHandlerSignatureException) { - var task = method.Invoke(handler, new object[] { messageEnvelope, token }) as Task; - - if (task == null) + if (ex is TargetInvocationException targetInvocationException) { - _logger.LogError("Unexpected return type for the HandleAsync method on {HandlerType} while handling message ID {MessageEnvelopeId}. Expected {ExpectedType}", subscriberMapping.HandlerType, messageEnvelope.Id, nameof(Task)); - throw new InvalidMessageHandlerSignatureException($"Unexpected return type for the HandleAsync method on {subscriberMapping.HandlerType} while handling message ID {messageEnvelope.Id}. Expected {nameof(Task)}"); + // Since we are invoking HandleAsync via reflection, we need to unwrap the TargetInvocationException + // containing application exceptions that happened inside the IMessageHandler + if (targetInvocationException.InnerException != null) + { + ex = targetInvocationException.InnerException; + } + } + else + { + trace.AddException(ex, false); } - return await task; - } - // Since we are invoking HandleAsync via reflection, we need to unwrap the TargetInvocationException - // containing application exceptions that happened inside the IMessageHandler - catch (TargetInvocationException ex) - { - trace.AddException(ex, false); + _logger.LogError(ex, "An unexpected exception occurred while handling message ID {MessageId}.", messageEnvelope.Id); - if (ex.InnerException != null) + try { - _logger.LogError(ex.InnerException, "A handler exception occurred while handling message ID {MessageId}.", messageEnvelope.Id); - return MessageProcessStatus.Failed(); + var retryHandler = scope.ServiceProvider.GetService(); + if (retryHandler != null) + { + switch (await retryHandler!.OnHandleError(messageEnvelope, ex, ++attempt, token)) + { + case MessageErrorHandlerResponse.Failed: + _logger.LogError(ex, "An unexpected exception occurred while determining if message ID {MessageId} should be retried.", messageEnvelope.Id); + return MessageProcessStatus.Failed(); + + case MessageErrorHandlerResponse.Success: + return MessageProcessStatus.Success(); + + case MessageErrorHandlerResponse.Retry: + trace.AddMetadata(TelemetryKeys.Retry, attempt); + continue; + + default: + throw new NotImplementedException(); + } + } } - else + catch (Exception retryException) { - _logger.LogError(ex, "An unexpected exception occurred while handling message ID {MessageId}.", messageEnvelope.Id); + _logger.LogError(retryException, "An unexpected exception occurred while determining if message ID {MessageId} should be retried.", messageEnvelope.Id); return MessageProcessStatus.Failed(); } - } - catch (Exception ex) - { - trace.AddException(ex, false); - _logger.LogError(ex, "An unexpected exception occurred while handling message ID {MessageId}.", messageEnvelope.Id); return MessageProcessStatus.Failed(); } } @@ -126,4 +136,18 @@ public async Task InvokeAsync(MessageEnvelope messageEnvel } } } + + private static async Task ExecutePipelineAsync(MessageEnvelope messageEnvelope, List middlewares, IMessageHandler handler, CancellationToken token) + { + RequestDelegate next = () => handler.HandleAsync(messageEnvelope, token); + + for (var i = middlewares.Count - 1; i >= 0; i--) + { + var capturedNext = next; + var middleware = middlewares[i]; + next = () => middleware.InvokeAsync(messageEnvelope, capturedNext, token); + } + + return await next().ConfigureAwait(false); + } } diff --git a/src/AWS.Messaging/Telemetry/TelemetryKeys.cs b/src/AWS.Messaging/Telemetry/TelemetryKeys.cs index 315c49a4..05f2b715 100644 --- a/src/AWS.Messaging/Telemetry/TelemetryKeys.cs +++ b/src/AWS.Messaging/Telemetry/TelemetryKeys.cs @@ -24,4 +24,5 @@ public static class TelemetryKeys internal const string MessageId = "aws.messaging.messageId"; internal const string PublishTargetType = "aws.messaging.publishTargetType"; internal const string HandlerType = "aws.messaging.handlerType"; + internal const string Retry = "aws.messaging.retry"; } diff --git a/test/AWS.Messaging.Tests.Common/AWSUtilities.cs b/test/AWS.Messaging.Tests.Common/AWSUtilities.cs index 19e36ac3..0dcef8ca 100644 --- a/test/AWS.Messaging.Tests.Common/AWSUtilities.cs +++ b/test/AWS.Messaging.Tests.Common/AWSUtilities.cs @@ -118,7 +118,7 @@ public async static Task CreateBucketWithDeploymentZipAsync(IAmazonS3 s3Client, { // Create bucket if it doesn't exist var listBucketsResponse = await s3Client.ListBucketsAsync(); - if (listBucketsResponse.Buckets.Find((bucket) => bucket.BucketName == bucketName) == null) + if (listBucketsResponse.Buckets?.Find((bucket) => bucket.BucketName == bucketName) == null) { var putBucketRequest = new PutBucketRequest { diff --git a/test/AWS.Messaging.UnitTests/HandlerInvokerTests.cs b/test/AWS.Messaging.UnitTests/HandlerInvokerTests.cs index fb70b9f3..eec349a6 100644 --- a/test/AWS.Messaging.UnitTests/HandlerInvokerTests.cs +++ b/test/AWS.Messaging.UnitTests/HandlerInvokerTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using AWS.Messaging.Configuration; using AWS.Messaging.Services; @@ -39,7 +40,8 @@ public async Task HandlerInvoker_HappyPath() var handlerInvoker = new HandlerInvoker( serviceProvider, new NullLogger(), - new DefaultTelemetryFactory(serviceProvider)); + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); var envelope = new MessageEnvelope(); var subscriberMapping = SubscriberMapping.Create(); @@ -67,7 +69,8 @@ public async Task HandlerInvoker_DualHandler_InvokesCorrectMethod() var handlerInvoker = new HandlerInvoker( serviceProvider, new NullLogger(), - new DefaultTelemetryFactory(serviceProvider)); + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); // Assert that ChatMessage is routed to the right handler method, which always succeeds var chatEnvelope = new MessageEnvelope(); @@ -105,7 +108,8 @@ public async Task HandlerInvoker_UnwrapsTargetInvocationException() var handlerInvoker = new HandlerInvoker( serviceProvider, mockLogger.Object, - new DefaultTelemetryFactory(serviceProvider)); + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); var envelope = new MessageEnvelope() { Id = "123" @@ -114,7 +118,7 @@ public async Task HandlerInvoker_UnwrapsTargetInvocationException() await handlerInvoker.InvokeAsync(envelope, subscriberMapping); - mockLogger.VerifyLogError(typeof(CustomHandlerException), "A handler exception occurred while handling message ID 123."); + mockLogger.VerifyLogError(typeof(CustomHandlerException), "An unexpected exception occurred while handling message ID 123."); } /// @@ -137,7 +141,8 @@ public async Task HandlerInvoker_VerifyHandlersAreRetrievedAsScopedDependencies( var handlerInvoker = new HandlerInvoker( serviceProvider, new NullLogger(), - new DefaultTelemetryFactory(serviceProvider)); + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); // ACT and ASSERT - Invoke the GreetingHandler multiple times and verify that a new instance of IGreeter is created each time. var envelope = new MessageEnvelope(); @@ -174,7 +179,8 @@ public async Task HandlerInvoker_VerifyHandlersFatalErrorWhenDIFails() var handlerInvoker = new HandlerInvoker( serviceProvider, new NullLogger(), - new DefaultTelemetryFactory(serviceProvider)); + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); var envelope = new MessageEnvelope(); var subscriberMapping = SubscriberMapping.Create(); @@ -204,7 +210,8 @@ public async Task HandlerInvoker_ScopeDispose() var handlerInvoker = new HandlerInvoker( serviceProvider, new NullLogger(), - new DefaultTelemetryFactory(serviceProvider)); + new DefaultTelemetryFactory(serviceProvider), + serviceProvider.GetRequiredService()); var envelope = new MessageEnvelope(); var subscriberMapping = SubscriberMapping.Create(); @@ -213,4 +220,260 @@ public async Task HandlerInvoker_ScopeDispose() Assert.Equal(1, ChatMessageHandlerWithDisposableServices.TestDisposableService.CallCount); Assert.Equal(1, ChatMessageHandlerWithDisposableServices.TestDisposableServiceAsync.CallCount); } + + /// + /// Tests that middleware is executed in the order of registration. + /// + [Fact] + public async Task Middleware_IsExecutedInOrderOfRegistration() + { + var serviceCollection = new ServiceCollection() + .AddAWSMessageBus(builder => + { + builder.AddMessageHandler, ChatMessage>("sqsQueueUrl"); + + builder.AddMiddleware(); + builder.AddMiddleware(); + builder.AddMiddleware(); + }); + + var middlewareTracker = new SubscriberMiddlewareModels.MiddlewareTracker(); + serviceCollection.AddSingleton(middlewareTracker); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var handlerInvoker = new HandlerInvoker( + serviceProvider, + new NullLogger(), + new DefaultTelemetryFactory(serviceProvider), + serviceProvider.GetRequiredService()); + + var envelope = new MessageEnvelope(); + var subscriberMapping = SubscriberMapping.Create, ChatMessage>(); + var messageProcessStatus = await handlerInvoker.InvokeAsync(envelope, subscriberMapping); + + Assert.Equal(MessageProcessStatus.Success(), messageProcessStatus); + + Assert.Equal(4, middlewareTracker.Executed.Count); + Assert.Equal(typeof(SubscriberMiddlewareModels.A), middlewareTracker.Executed[0]); + Assert.Equal(typeof(SubscriberMiddlewareModels.B), middlewareTracker.Executed[1]); + Assert.Equal(typeof(SubscriberMiddlewareModels.C), middlewareTracker.Executed[2]); + Assert.Equal(typeof(SubscriberMiddlewareModels.SuccessMessageHandler), middlewareTracker.Executed[3]); + } + + /// + /// Tests that middleware can propagate the message process status. + /// + [Fact] + public async Task Middleware_MessageProcessStatusIsPropagated() + { + var serviceCollection = new ServiceCollection() + .AddAWSMessageBus(builder => + { + builder.AddMessageHandler, ChatMessage>("sqsQueueUrl"); + + builder.AddMiddleware(); + builder.AddMiddleware(); + builder.AddMiddleware(); + }); + + var middlewareTracker = new SubscriberMiddlewareModels.MiddlewareTracker(); + serviceCollection.AddSingleton(middlewareTracker); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var handlerInvoker = new HandlerInvoker( + serviceProvider, + new NullLogger(), + new DefaultTelemetryFactory(serviceProvider), + serviceProvider.GetRequiredService()); + + var envelope = new MessageEnvelope(); + var subscriberMapping = SubscriberMapping.Create, ChatMessage>(); + var messageProcessStatus = await handlerInvoker.InvokeAsync(envelope, subscriberMapping); + + Assert.Equal(MessageProcessStatus.Failed(), messageProcessStatus); + + Assert.Equal(4, middlewareTracker.Executed.Count); + Assert.Equal(typeof(SubscriberMiddlewareModels.A), middlewareTracker.Executed[0]); + Assert.Equal(typeof(SubscriberMiddlewareModels.B), middlewareTracker.Executed[1]); + Assert.Equal(typeof(SubscriberMiddlewareModels.C), middlewareTracker.Executed[2]); + Assert.Equal(typeof(SubscriberMiddlewareModels.FailMessageHandler), middlewareTracker.Executed[3]); + } + + /// + /// Tests that handlers that do not throw exceptions do not trigger the message error handler. + /// + [Fact] + public async Task MessageErrorHandler_HandlerWithoutException_DoesNotExecuteMessageErrorHandler() + { + var mockMessageErrorHandler = new Mock(); + + var serviceCollection = new ServiceCollection() + .AddAWSMessageBus(builder => + { + builder.AddMessageHandler("sqsQueueUrl"); + builder.AddAdditionalService(ServiceDescriptor.Singleton(mockMessageErrorHandler.Object)); + }); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var handlerInvoker = new HandlerInvoker( + serviceProvider, + new NullLogger(), + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); + + var envelope = new MessageEnvelope(); + var subscriberMapping = SubscriberMapping.Create(); + var messageProcessStatus = await handlerInvoker.InvokeAsync(envelope, subscriberMapping); + + Assert.Equal(MessageProcessStatus.Success(), messageProcessStatus); + mockMessageErrorHandler.Verify(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); + } + + /// + /// Tests that when the message error handler returns a failed response, the handler does not retry and that response is returned by the invoker. + /// + [Fact] + public async Task MessageErrorHandler_WithFailed_DoesNotRetry() + { + var mockMessageErrorHandler = new Mock(); + mockMessageErrorHandler.Setup(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Failed)); + + var serviceCollection = new ServiceCollection() + .AddAWSMessageBus(builder => + { + builder.AddMessageHandler("sqsQueueUrl"); + builder.AddAdditionalService(ServiceDescriptor.Singleton(mockMessageErrorHandler.Object)); + }); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var handlerInvoker = new HandlerInvoker( + serviceProvider, + new NullLogger(), + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); + + var envelope = new MessageEnvelope(); + var subscriberMapping = SubscriberMapping.Create(); + var messageProcessStatus = await handlerInvoker.InvokeAsync(envelope, subscriberMapping); + + Assert.Equal(MessageProcessStatus.Failed(), messageProcessStatus); + mockMessageErrorHandler.Verify(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); + } + + /// + /// Tests that when the message error handler returns a success response, the handler does not retry and that response is returned by the invoker. + /// + [Fact] + public async Task MessageErrorHandler_WithSuccess_DoesNotRetry() + { + var mockMessageErrorHandler = new Mock(); + mockMessageErrorHandler.Setup(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Success)); + + var serviceCollection = new ServiceCollection() + .AddAWSMessageBus(builder => + { + builder.AddMessageHandler("sqsQueueUrl"); + builder.AddAdditionalService(ServiceDescriptor.Singleton(mockMessageErrorHandler.Object)); + }); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var handlerInvoker = new HandlerInvoker( + serviceProvider, + new NullLogger(), + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); + + var envelope = new MessageEnvelope(); + var subscriberMapping = SubscriberMapping.Create(); + var messageProcessStatus = await handlerInvoker.InvokeAsync(envelope, subscriberMapping); + + Assert.Equal(MessageProcessStatus.Success(), messageProcessStatus); + mockMessageErrorHandler.Verify(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); + } + + /// + /// Tests that when the message error handler returns a retry response, the pipeline is re-executed. + /// + [Fact] + public async Task MessageErrorHandler_WithRetry_ReExecutesPipeline() + { + var mockMessageErrorHandler = new Mock(); + mockMessageErrorHandler.SetupSequence(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Retry)) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Retry)) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Success)); + + var serviceCollection = new ServiceCollection() + .AddAWSMessageBus(builder => + { + builder.AddMessageHandler("sqsQueueUrl"); + builder.AddAdditionalService(ServiceDescriptor.Singleton(mockMessageErrorHandler.Object)); + }); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var handlerInvoker = new HandlerInvoker( + serviceProvider, + new NullLogger(), + new DefaultTelemetryFactory(serviceProvider), + new MessageConfiguration()); + + var envelope = new MessageEnvelope(); + var subscriberMapping = SubscriberMapping.Create(); + var messageProcessStatus = await handlerInvoker.InvokeAsync(envelope, subscriberMapping); + + Assert.Equal(MessageProcessStatus.Success(), messageProcessStatus); + mockMessageErrorHandler.Verify(x => x.OnHandleError(It.IsAny>(), It.IsAny(), 1, It.IsAny()), Times.Once); + mockMessageErrorHandler.Verify(x => x.OnHandleError(It.IsAny>(), It.IsAny(), 2, It.IsAny()), Times.Once); + mockMessageErrorHandler.Verify(x => x.OnHandleError(It.IsAny>(), It.IsAny(), 3, It.IsAny()), Times.Once); + } + + /// + /// Tests that when the message error handler returns a retry response, the pipeline is executed in a new DI scope. + /// + [Fact] + public async Task MessageErrorHandler_Retry_UsesNewScope() + { + var mockMessageErrorHandler = new Mock(); + mockMessageErrorHandler.SetupSequence(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Retry)) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Retry)) + .Returns(ValueTask.FromResult(MessageErrorHandlerResponse.Success)); + + var serviceCollection = new ServiceCollection() + .AddAWSMessageBus(builder => + { + builder.AddMessageHandler(); + builder.AddAdditionalService(ServiceDescriptor.Singleton(mockMessageErrorHandler.Object)); + }); + + serviceCollection.AddScoped(); + serviceCollection.AddScoped(); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + ChatExceptionHandlerAndDisposableServices.TestDisposableService.CallCount = 0; + ChatExceptionHandlerAndDisposableServices.TestDisposableServiceAsync.CallCount = 0; + + var handlerInvoker = new HandlerInvoker( + serviceProvider, + new NullLogger(), + new DefaultTelemetryFactory(serviceProvider), + serviceProvider.GetRequiredService()); + + var envelope = new MessageEnvelope(); + var subscriberMapping = SubscriberMapping.Create(); + await handlerInvoker.InvokeAsync(envelope, subscriberMapping); + + mockMessageErrorHandler.Verify(x => x.OnHandleError(It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Exactly(3)); + Assert.Equal(3, ChatExceptionHandlerAndDisposableServices.TestDisposableService.CallCount); + Assert.Equal(3, ChatExceptionHandlerAndDisposableServices.TestDisposableServiceAsync.CallCount); + } } diff --git a/test/AWS.Messaging.UnitTests/MessageHandlers/Handlers.cs b/test/AWS.Messaging.UnitTests/MessageHandlers/Handlers.cs index 717f97b0..609a9ba2 100644 --- a/test/AWS.Messaging.UnitTests/MessageHandlers/Handlers.cs +++ b/test/AWS.Messaging.UnitTests/MessageHandlers/Handlers.cs @@ -74,6 +74,37 @@ public void Dispose() } } +public class ChatExceptionHandlerAndDisposableServices : IMessageHandler +{ + public ChatExceptionHandlerAndDisposableServices(TestDisposableServiceAsync testDisposableServiceAsync, TestDisposableService testDisposable) + { + } + + public Task HandleAsync(MessageEnvelope messageEnvelope, CancellationToken token = default) + { + throw new CustomHandlerException($"Unable to process message {messageEnvelope.Id}"); + } + + public class TestDisposableServiceAsync : IAsyncDisposable + { + public static long CallCount { get; set; } = 0; + public ValueTask DisposeAsync() + { + CallCount++; + return ValueTask.CompletedTask; + } + } + + public class TestDisposableService : IDisposable + { + public static long CallCount { get; set; } = 0; + public void Dispose() + { + CallCount++; + } + } +} + public class PlainTextHandler : IMessageHandler { public Task HandleAsync(MessageEnvelope messageEnvelope, CancellationToken token = default) diff --git a/test/AWS.Messaging.UnitTests/Models/SubscriberMiddlewareModels.cs b/test/AWS.Messaging.UnitTests/Models/SubscriberMiddlewareModels.cs new file mode 100644 index 00000000..6ed12f52 --- /dev/null +++ b/test/AWS.Messaging.UnitTests/Models/SubscriberMiddlewareModels.cs @@ -0,0 +1,103 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace AWS.Messaging.UnitTests.Models; + +public static class SubscriberMiddlewareModels +{ + public class MiddlewareTracker + { + private readonly List _executed = []; + + public IReadOnlyList Executed => _executed.AsReadOnly(); + + public void Add(object middleware) + { + _executed.Add(middleware.GetType()); + } + } + + public class SuccessMessageHandler : IMessageHandler + { + private readonly MiddlewareTracker _tracker; + + public SuccessMessageHandler(MiddlewareTracker tracker) + { + _tracker = tracker; + } + + public Task HandleAsync(MessageEnvelope messageEnvelope, CancellationToken token = default) + { + _tracker.Add(this); + return Task.FromResult(MessageProcessStatus.Success()); + } + } + + public class FailMessageHandler : IMessageHandler + { + private readonly MiddlewareTracker _tracker; + + public FailMessageHandler(MiddlewareTracker tracker) + { + _tracker = tracker; + } + + public Task HandleAsync(MessageEnvelope messageEnvelope, CancellationToken token = default) + { + _tracker.Add(this); + return Task.FromResult(MessageProcessStatus.Failed()); + } + } + + public abstract class TrackedMiddleware : IMiddleware + { + private readonly MiddlewareTracker _tracker; + + protected TrackedMiddleware(MiddlewareTracker tracker) + { + _tracker = tracker; + } + + public virtual Task InvokeAsync(MessageEnvelope messageEnvelope, RequestDelegate next, CancellationToken cancellationToken = default) + { + _tracker.Add(this); + return next(); + } + } + + public class A : TrackedMiddleware + { + public A(MiddlewareTracker tracker) : base(tracker) { } + } + + public class B : TrackedMiddleware + { + public B(MiddlewareTracker tracker) : base(tracker) { } + } + + public class C : TrackedMiddleware + { + public C(MiddlewareTracker tracker) : base(tracker) { } + } + + public class Error : IMiddleware + { + private readonly MiddlewareTracker _tracker; + + public Error(MiddlewareTracker tracker) + { + _tracker = tracker; + } + + public Task InvokeAsync(MessageEnvelope messageEnvelope, RequestDelegate next, CancellationToken cancellationToken = default) + { + _tracker.Add(this); + throw new Exception("Error in middleware"); + } + } +} diff --git a/test/AWS.Messaging.UnitTests/OpenTelemetryTests.cs b/test/AWS.Messaging.UnitTests/OpenTelemetryTests.cs index 4ec20467..3f66f518 100644 --- a/test/AWS.Messaging.UnitTests/OpenTelemetryTests.cs +++ b/test/AWS.Messaging.UnitTests/OpenTelemetryTests.cs @@ -87,7 +87,8 @@ public OpenTelemetryTests() _handler = new HandlerInvoker( _serviceProvider, new NullLogger(), - new DefaultTelemetryFactory(_serviceProvider)); + new DefaultTelemetryFactory(_serviceProvider), + new MessageConfiguration()); } ///