Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing CommandHandler(s) #27

Open
mohammadtaherri opened this issue Jul 8, 2024 · 5 comments
Open

Testing CommandHandler(s) #27

mohammadtaherri opened this issue Jul 8, 2024 · 5 comments

Comments

@mohammadtaherri
Copy link

mohammadtaherri commented Jul 8, 2024

Hi Vladimir.

How to test (integration test) a CommandHandler when it is internal ?

1)We can make it public (just for the matter of testing!!!!).

2)Or we can test controller's action methods. If we test controller's method then our tests cover the Messages too (it is good), but on the other hand, we get involved in working with Json in tests and we know that output json format is changed frequently. Besides that Messages uses an IServiceProvider and it is hard to create an instance of it in tests.

What do you suggest?

public sealed class EnrollCommand : ICommand 
    {
        public long Id { get; }
        public string Course { get; }
        public string Grade { get; }

        public EnrollCommand(long id, string course, string grade)
        {
            Id = id;
            Course = course;
            Grade = grade;
        }

        internal sealed class EnrollCommandHandler : ICommandHandler<EnrollCommand>
        {
            private readonly SessionFactory _sessionFactory;

            public EnrollCommandHandler(SessionFactory sessionFactory)
            {
                _sessionFactory = sessionFactory;
            }

            public Result Handle(EnrollCommand command)
            {
                var unitOfWork = new UnitOfWork(_sessionFactory);
                var courseRepository = new CourseRepository(unitOfWork);
                var studentRepository = new StudentRepository(unitOfWork);
                Student student = studentRepository.GetById(command.Id);
                if (student == null)
                    return Result.Fail($"No student found with Id '{command.Id}'");

                Course course = courseRepository.GetByName(command.Course);
                if (course == null)
                    return Result.Fail($"Course is incorrect: '{command.Course}'");

                bool success = Enum.TryParse(command.Grade, out Grade grade);
                if (!success)
                    return Result.Fail($"Grade is incorrect: '{command.Grade}'");

                student.Enroll(course, grade);

                unitOfWork.Commit();

                return Result.Ok();
            }
        }
    }
    public sealed class StudentController : BaseController
    {
        private readonly Messages _messages;

        public StudentController(Messages messages)
        {
            _messages = messages;
        }
        
        [HttpPost("{id}/enrollments")]
        public IActionResult Enroll(long id, [FromBody] StudentEnrollmentDto dto)
        {
            Result result = _messages.Dispatch(new EnrollCommand(id, dto.Course, dto.Grade));
            return FromResult(result);
        }
        
        }
        public sealed class Messages
    {
        private readonly IServiceProvider _provider;

        public Messages(IServiceProvider provider)
        {
            _provider = provider;
        }

        public Result Dispatch(ICommand command)
        {
            Type type = typeof(ICommandHandler<>);
            Type[] typeArgs = { command.GetType() };
            Type handlerType = type.MakeGenericType(typeArgs);

            dynamic handler = _provider.GetService(handlerType);
            Result result = handler.Handle((dynamic)command);

            return result;
        }

        public T Dispatch<T>(IQuery<T> query)
        {
            Type type = typeof(IQueryHandler<,>);
            Type[] typeArgs = { query.GetType(), typeof(T) };
            Type handlerType = type.MakeGenericType(typeArgs);

            dynamic handler = _provider.GetService(handlerType);
            T result = handler.Handle((dynamic)query);

            return result;
        }
    }
@vkhorikov
Copy link
Owner

Sorry for the late reply here.

Good trade-off analysis here. I actually don't make command handlers internal in my projects. This was an idea I showed in the course but I stopped using it in practice precisely because it hinders testability. There's no issue with making the command handler private as it only exposes the "services" layer.

You could test controllers too, btw. The problem with it is that controllers often involve some cross-cutting concerns and testing them in each test isn't ideal.

@mohammadtaherri
Copy link
Author

mohammadtaherri commented Aug 11, 2024

Thanks Vladimir.

One more question.

A ComnandHandler does not return any thing because it is a command. It's obvious. But suppose we have a Command for creating (or adding) a new Entity (for example AddNewCustomerCommand). Now to test this Command, we create a AddNewCustomerCommandhandler and a AddNewCustomerCommand (in arrange section) and call Handle (in act section). Now to verify in the assert section, we must query on the database and get a cutomer by id and check it. But we have no Id to get the customer from the database because the CommandHandler returns nothing to us.

Can we make CommandHandler like this?

public interface ICommand
{
}

public interface ICommandHandler<TCommand>
    where TCommand : ICommand
{
    Result<Guid, Error> Handle(TCommand command);
}

In this case the Commandhandler returns an Id and we can use it to verify the correctness of it in tests. We can also return that Id to the user.

Is this approach reasonable? And if not, How to verify CommandHandlers in test when we have no Id.

@vkhorikov
Copy link
Owner

Command handlers can and should return a response.

I think it'd be easier for me to just leave some code examples. These are from real projects where I employed this pattern:

public sealed class CreateSessionHandler : IRequestHandler<CreateSessionRequest, CreateSessionResponse>
{
    private readonly UserRepository _repository;
    private readonly UserContext _userContext;
    private readonly TestPackageCache _testPackageCache;
    private readonly ExposureCache _exposureCache;


    public CreateSessionHandler(
        UserRepository repository,
        UserContext userContext,
        TestPackageCache testPackageCache,
        ExposureCache exposureCache)
    {
        _repository = repository;
        _userContext = userContext;
        _testPackageCache = testPackageCache;
        _exposureCache = exposureCache;
    }


    public Result<CreateSessionResponse> Handle(CreateSessionRequest request)
    {
        Maybe<SectionOrder> sectionOrderOrNothing = SectionOrder.FromName(request.SectionOrder);
        if (sectionOrderOrNothing.HasNoValue)
            return Errors.Sessions.InvalidSectionOrder(request.SectionOrder);

        SectionOrder sectionOrder = sectionOrderOrNothing.Value;

        User user = GetOrCreateUser(_userContext.ExternalId);
        Session session = user.AddSession(
            _testPackageCache.GetAllTestPackages(null),
            _exposureCache.GetExposure(),
            sectionOrder,
            request.AppointmentId);

        _repository.Save(user);
        _repository.Save(session);
        _repository.UpdateItemExposure(session.NewExposedItems);

        return new CreateSessionResponse(session.Id);
    }


    private User GetOrCreateUser(Guid externalId)
    {
        Maybe<User> existingUser = _repository.GetByExternalId(externalId);

        if (existingUser.HasValue)
            return existingUser.Value;

        return new User(externalId);
    }
}


public sealed class CreateSessionRequest : IRequest<CreateSessionResponse>
{
    public string SectionOrder { get; set; }
    public string AppointmentId { get; set; }
    public string[] AccommodationCodes { get; set; } // ToDo
}


public sealed class CreateSessionResponse : IResponse
{
    public long Id { get; set; }


    private CreateSessionResponse()
    {
    }


    public CreateSessionResponse(long id)
        : this()
    {
        Id = id;
    }
}

public interface IRequest<TResponse>
    where TResponse : IResponse
{
}

public interface IResponse
{
}

// For requests that don't require any response:

public sealed class FinishSessionRequest : IRequest<Unit>
{
}

public sealed class Unit : IResponse
{
    public static readonly Unit Value = new Unit();
}


public interface IRequestHandler<TRequest, TResponse>
    where TRequest : IRequest<TResponse>
    where TResponse : IResponse
{
    Result<TResponse> Handle(TRequest request);
}

[ApiController]
public abstract class BaseController : ControllerBase
{
    private const int NumberOfRetries = 3;

    private readonly SessionFactory _sessionFactory;
    private readonly RequestDispatcher _requestDispatcher;
    private readonly ILogger _logger;


    protected BaseController(
        SessionFactory sessionFactory,
        RequestDispatcher requestDispatcher,
        ILogger logger)
    {
        _sessionFactory = sessionFactory;
        _requestDispatcher = requestDispatcher;
        _logger = logger;
    }


    protected abstract bool RequiresPlatformAuth();
    protected abstract bool RequiresAdminPlatform();
    protected abstract bool RequiresUserAuth();
    protected abstract bool RetriesAreEnabled();
    protected abstract bool RequiresSession();


    public async Task<Unit> Handle<TRequest, TResponse>()
        where TRequest : IRequest<TResponse>, new()
        where TResponse : IResponse
    {
        Maybe<string> configError = CheckConfig();
        if (configError.HasValue)
        {
            _logger.LogError("Config error: " + configError.Value);
            return await Error(Errors.Common.InternalServerError());
        }

        var request = await ReadRequest<TRequest>();
        int numberOfRetries = RetriesAreEnabled() ? NumberOfRetries : 1;

        for (int i = 0;; i++)
        {
            try
            {
                return await HandleCore<TRequest, TResponse>(request);
            }
            catch (Exception exception)
            {
                _logger.LogError(exception, $"Unhandled error in the application after retry #{i + 1}");
                _logger.LogInformation($"Request body:\r\n{CustomJsonSerializer.Serialize(request, true)}");

                if (i >= numberOfRetries)
                {
                    _logger.LogError(exception, $"Retries are exhausted (total number of retries: {numberOfRetries})");
                    return await Error(Errors.Common.InternalServerError());
                }
            }
        }
    }


    private Maybe<string> CheckConfig()
    {
        if (RequiresPlatformAuth() == false && RequiresUserAuth())
            return "Must require platform auth if user auth is required";

        if (RequiresPlatformAuth() == false && RequiresAdminPlatform())
            return "Must require platform auth if admin platform is required";

        if (RequiresSession() && RequiresUserAuth() == false)
            return "Must require user auth if session is required";

        return null;
    }


    public async Task<Unit> HandleCore<TRequest, TResponse>(TRequest request)
        where TRequest : IRequest<TResponse>, new()
        where TResponse : IResponse
    {
        using (var unitOfWork = new UnitOfWork(_sessionFactory))
        {
            UserContext userContext = GetUserContext(unitOfWork);

            if (IsAuthenticated(userContext) == false)
                return await Error(Errors.Common.AuthFailed());

            Result<Maybe<Session>> currentSession = GetCurrentSession(userContext, unitOfWork);
            if (currentSession.IsFailure)
                return await Error(currentSession.Error);

            Result<TResponse> result = Execute<TRequest, TResponse>(request, userContext, currentSession.Value, unitOfWork);

            _logger.LogInformation(GetLogMessage(request, result));

            if (result.IsFailure)
                return await Error(result.Error);

            return await Ok(result.Value);
        }
    }


    private string GetLogMessage<TResponse>(IRequest<TResponse> request, Result<TResponse> result)
        where TResponse : IResponse
    {
        StringBuilder sb = new StringBuilder();
        sb.Append($"Request body:\r\n{CustomJsonSerializer.Serialize(request, true)}\r\n");

        if (result.IsSuccess)
        {
            sb.Append($"Response body:\r\n{CustomJsonSerializer.Serialize(result.Value, true)}\r\n");
        }

        return sb.ToString();
    }


    private Result<TResponse> Execute<TRequest, TResponse>(
        TRequest request,
        UserContext userContext,
        Maybe<Session> currentSession,
        UnitOfWork unitOfWork)
        where TRequest : IRequest<TResponse>, new()
        where TResponse : IResponse
    {
        Result<TResponse> result = _requestDispatcher.Dispatch<TRequest, TResponse>(request, userContext, currentSession, unitOfWork);

        if (result.IsSuccess)
        {
            unitOfWork.Commit();
        }

        return result;
    }


    private Result<Maybe<Session>> GetCurrentSession(UserContext userContext, UnitOfWork unitOfWork)
    {
        if (RequiresSession() == false)
            return Maybe<Session>.None;
        
        if (userContext.SessionIdIsSpecified == false)
            return Errors.Sessions.InvalidSessionId();

        var repository = new UserRepository(unitOfWork);
        Maybe<User> userOrNothing = repository.GetByExternalId(userContext.ExternalId);
        if (userOrNothing.HasNoValue)
            return Errors.Sessions.InvalidUserId(userContext.ExternalId);

        User user = userOrNothing.Value;

        Maybe<Session> sessionOrNothing = user.GetSession(userContext.SessionId);
        if (sessionOrNothing.HasNoValue)
            return Errors.Sessions.InvalidSessionId(userContext.SessionId);

        return Maybe<Session>.From(sessionOrNothing.Value);
    }


    private bool IsAuthenticated(UserContext userContext)
    {
        if (RequiresPlatformAuth() == false)
            return true;

        if (userContext.IsPlatformAuthenticated == false)
            return false;

        if (RequiresAdminPlatform() && userContext.Platform.IsAdmin() == false)
            return false;

        if (RequiresUserAuth() == false)
            return true;

        return userContext.UserIdIsValid;
    }


    private async Task<TRequest> ReadRequest<TRequest>()
        where TRequest : new()
    {
        using (var reader = new StreamReader(Request.Body))
        {
            string json = await reader.ReadToEndAsync();
            var request = CustomJsonSerializer.Deserialize<TRequest>(json);
            return request ?? new TRequest();
        }
    }


    private UserContext GetUserContext(UnitOfWork unitOfWork)
    {
        Guid? userId = GetExternalUserId(Request);
        Maybe<Platform> platform = GetPlatform(Request, unitOfWork);
        long? sessionId = GetSessionId(Request);

        if (userId == null || platform.HasNoValue)
            return new UserContext(null, platform, sessionId);

        return new UserContext(userId.Value, platform, sessionId);
    }


    private static long? GetSessionId(HttpRequest request)
    {
        string path = request.Path;
        Match match = Regex.Match(path, @"\/sessions\/(\d*?)(?:\/|$)", RegexOptions.IgnoreCase);

        if (match.Groups.Count < 2)
            return null;

        string value = match.Groups[1].Value;

        if (string.IsNullOrWhiteSpace(value))
            return null;

        if (!long.TryParse(value, out long result))
            return null;

        return result;
    }


    private Maybe<Platform> GetPlatform(HttpRequest request, UnitOfWork unitOfWork)
    {
        var repository = new PlatformRepository(unitOfWork);

        string platformKey = request.Headers["Authentication"].FirstOrDefault();

        if (!Guid.TryParse(platformKey, out Guid key))
            return null;

        return repository.GetByKey(key);
    }


    private static Guid? GetExternalUserId(HttpRequest request)
    {
        string path = request.Path;
        Match match = Regex.Match(path, @"\/users\/(.*?)(?:\/|$)", RegexOptions.IgnoreCase);

        if (match.Groups.Count < 2)
            return null;

        string value = match.Groups[1].Value;

        if (string.IsNullOrWhiteSpace(value))
            return null;

        if (!Guid.TryParse(value, out Guid result))
            return null;

        return result;
    }


    private async Task<Unit> Ok<T>(T result)
    {
        await WriteResponse(HttpStatusCode.OK, ResponseEnvelope.Ok(result));

        return Unit.Value;
    }


    private async Task<Unit> Error(Error error, Maybe<string> additionalLogInfo = default)
    {
        string info = additionalLogInfo.GetValueOrDefault("Error");
        _logger.LogInformation($"{error}: {info}");

        await WriteResponse(GetStatusCode(error), ResponseEnvelope.Error(error));

        return Unit.Value;
    }


    private static HttpStatusCode GetStatusCode(Error error)
    {
        if (error == Errors.Common.AuthFailed())
            return HttpStatusCode.Unauthorized;

        if (error == Errors.Common.InternalServerError())
            return HttpStatusCode.InternalServerError;

        return HttpStatusCode.BadRequest;
    }

    private async Task WriteResponse(HttpStatusCode statusCode, ResponseEnvelope responseEnvelope)
    {
        Response.StatusCode = (int)statusCode;
        await Response.WriteAsJsonAsync(responseEnvelope);
    }
}


[Route("api/users/{userId}")]
public class UserController : BaseController
{
    public UserController(SessionFactory sessionFactory, RequestDispatcher requestDispatcher, ILogger<UserController> logger)
        : base(sessionFactory, requestDispatcher, logger)
    {
    }


    protected override bool RequiresPlatformAuth() => true;
    protected override bool RequiresAdminPlatform() => false;
    protected override bool RequiresUserAuth() => true;
    protected override bool RetriesAreEnabled() => true;
    protected override bool RequiresSession() => false;


    [HttpPost("sessions")]
    public async Task CreateSession() => await Handle<CreateSessionRequest, CreateSessionResponse>();
}


public sealed class RequestDispatcher
{
    private readonly IServiceProvider _provider;
    private readonly Dictionary<Type, HandlerInfo> _handlerMapping;


    public RequestDispatcher(IServiceProvider provider)
    {
        _provider = provider;
        _handlerMapping = BuildHandlerMapping();
    }


    private static Dictionary<Type, HandlerInfo> BuildHandlerMapping()
    {
        Type[] handlerTypes = Assembly.GetExecutingAssembly()
            .GetTypes()
            .Where(x => x.GetInterfaces().Any(y => IsHandlerInterface(y)))
            .Where(x => x.IsAbstract == false)
            .ToArray();

        var result = new Dictionary<Type, HandlerInfo>();

        foreach (Type handlerType in handlerTypes)
        {
            ConstructorInfo ctor = handlerType.GetConstructors().Single();
            ParameterInfo[] parameterInfos = ctor.GetParameters().ToArray();
            Type requestType = handlerType
                .GetInterfaces()
                .Single(y => IsHandlerInterface(y))
                .GenericTypeArguments[0];

            var handlerInfo = new HandlerInfo(ctor, parameterInfos);
            result.Add(requestType, handlerInfo);
        }

        return result;
    }


    private static bool IsHandlerInterface(Type type)
    {
        if (!type.IsGenericType)
            return false;

        Type typeDefinition = type.GetGenericTypeDefinition();

        return typeDefinition == typeof(IRequestHandler<,>);
    }


    public Result<TResponse> Dispatch<TRequest, TResponse>(
        TRequest request,
        UserContext userContext,
        Maybe<Session> currentSession,
        UnitOfWork unitOfWork)
        where TRequest : IRequest<TResponse>
        where TResponse : IResponse
    {
        (ConstructorInfo ctor, ParameterInfo[] parameterInfos) = _handlerMapping[request.GetType()];

        object[] parameters = GetParameters(parameterInfos, userContext, currentSession, unitOfWork);
        var handler = (IRequestHandler<TRequest, TResponse>)ctor.Invoke(parameters);

        return handler.Handle(request);
    }


    private object[] GetParameters(ParameterInfo[] parameterInfos, UserContext userContext, Maybe<Session> currentSession, UnitOfWork unitOfWork)
    {
        object[] result = new object[parameterInfos.Length];

        for (int i = 0; i < parameterInfos.Length; i++)
        {
            result[i] = GetParameter(parameterInfos[i], userContext, currentSession, unitOfWork);
        }

        return result;
    }


    private object GetParameter(ParameterInfo parameterInfo, UserContext userContext, Maybe<Session> currentSession, UnitOfWork unitOfWork)
    {
        Type parameterType = parameterInfo.ParameterType;

        if (parameterType == typeof(UserContext))
            return userContext;

        if (parameterType == typeof(Session))
            return currentSession.GetValueOrThrow("Controller requires current session");

        if (IsParameterRepository(parameterType))
            return CreateRepository(parameterType, unitOfWork);

        object service = _provider.GetService(parameterType);
        if (service == null)
            throw new ArgumentException($"Type '{parameterType}' not found");

        return service;
    }


    private bool IsParameterRepository(Type parameterType)
    {
        Type baseType = parameterType.BaseType;

        if (baseType == null)
            return false;

        if (baseType.IsGenericType == false)
            return false;

        return baseType.GetGenericTypeDefinition() == typeof(Repository<>);
    }


    private object CreateRepository(Type repositoryType, UnitOfWork unitOfWork)
    {
        return Activator.CreateInstance(repositoryType, unitOfWork);
    }


    private sealed record HandlerInfo(ConstructorInfo Ctor, ParameterInfo[] ParameterInfos);
}

public sealed class CreateSessionTests : IntegrationTests
{
    [Fact]
    public void Can_create_for_new_user()
    {
        TestPackage activePackage = GetActivePackage();
        var externalId = Guid.NewGuid();
        var request = new CreateSessionRequest
        {
            SectionOrder = "Q_V_DI"
        };

        Result<CreateSessionResponse> result = InvokeCreateSession(request, externalId);

        CreateSessionResponse response = result.ShouldBeOK();
        using (DB db = CreateDB())
        {
            db.GetAllUsers().Length.Should().Be(1);

            User user = db.GetUser(externalId);
            user.ExternalId.Should().Be(externalId);
            user.Sessions.Count.Should().Be(1);

            Session session = user.Sessions.Single();
            session.Id.Should().Be(response.Id);
            session.User.Should().Be(user);
            session.TestPackageId.Should().Be(activePackage.Id);
            session.Sections.Quant.IsStarted.Should().BeTrue();
            session.Sections.Verbal.IsStarted.Should().BeFalse();
            session.Sections.DI.IsStarted.Should().BeFalse();
            session.Sections.QuantScriptId.Should().BeGreaterThan(0);
            session.Sections.VerbalScriptId.Should().BeGreaterThan(0);
            session.Sections.DIScriptId.Should().BeGreaterThan(0);
            session.Sections.Order.Should().Be(SectionOrder.Q_V_DI);
            session.StartDateTime.ShouldBeAroundNow();
            session.FinishDateTime.Should().BeNull();
            session.AppointmentId.ShouldNotHaveValue();

            (ItemExposure itemExposure, TestPackageExposure packageExposure) = GetExposure();
            long leadingItemId = session.Sections.Quant.SelectedItems.LeadingItemId.Value;
            long nextCorrectItemId = session.Sections.Quant.SelectedItems.NextCorrectItemId.Value;
            long nextIncorrectItemId = session.Sections.Quant.SelectedItems.NextIncorrectItemId.Value;
            itemExposure.GetExposure(GetItem(leadingItemId, activePackage), activePackage).Should().Be(1);
            itemExposure.GetExposure(GetItem(nextCorrectItemId, activePackage), activePackage).Should().Be(0);
            itemExposure.GetExposure(GetItem(nextIncorrectItemId, activePackage), activePackage).Should().Be(0);
            packageExposure.GetExposure(session.TestPackageId).Should().Be(1);
        }
    }
}

Commands here are requests, command handlers are request handlers. Note that auth here is simple, most projects require OAuth2 integration (this wasn't a requirement for that project).

You can use the code as-is, it's been tested quite heavily. One thing I'd change in a new project: all these RequiresSession() etc method exposed at the controller level -- they should really be at the request handler level instead.

@mohammadtaherri
Copy link
Author

mohammadtaherri commented Aug 12, 2024

Thanks.

Then you say that Commands (Command Handlers) in CQRS can return a response and we can define commands like this:

public interface ICommand<TResult>
{
}

public interface ICommandHandler<TCommand, TResult>
    where TCommand : ICommand<TResult>
{
    Result<TResult, Error> Handle(TCommand command);
}

Right?

So we can conclude that command methods in CQS principle don't have any side-effects and don't return any thing (they are void). But Command services in CQRS can return response and in this regard they aren't similar to command methods in CQS.

As a result Command and Query interfaces (and CommandHandler and QueryHandler interfaces) are identical:

public interface ICommand<TResult>
{
}

public interface ICommandHandler<TCommand, TResult>
    where TCommand : ICommand<TResult>
{
    Result<TResult, Error> Handle(TCommand command);
}

public interface IQuery<TResult>
{
}

public interface IQueryHandler<TQuery, TResult>
    where TQuery : IQuery<TResult>
{
    Result<TResult, Error> Handle(TQuery query);
}

And the only difference between them is that queries shouldn't have any side-effects while commands should.

Right?

@vkhorikov
Copy link
Owner

Yes, the unification of commands and queries is party why I called those handlers request handlers.

The CQS principle unfortunately is rarely possible to follow when it comes to the API interface. It's not a "hard" rule to follow, so I wouldn't worry about that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants