Skip to content

Commit

Permalink
Created async overload for clausules that use a predicate (#228)
Browse files Browse the repository at this point in the history
* Created async overload for clausules that use a predicate

* Made GetCustomStruct public (build error)

---------

Co-authored-by: Danny Bos <[email protected]>
Co-authored-by: Steve Smith <[email protected]>
  • Loading branch information
3 people authored Feb 17, 2023
1 parent 6d41101 commit d54c66a
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 35 deletions.
21 changes: 21 additions & 0 deletions src/GuardClauses/GuardAgainstExpressionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Threading.Tasks;

namespace Ardalis.GuardClauses
{
Expand All @@ -23,5 +24,25 @@ public static T AgainstExpression<T>(this IGuardClause guardClause, Func<T, bool

return input;
}

/// <summary>
/// Throws an <see cref="ArgumentException" /> if <paramref name="func"/> evaluates to false for given <paramref name="input"/>
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="func"></param>
/// <param name="guardClause"></param>
/// <param name="input"></param>
/// <param name="message"></param>
/// <returns><paramref name="input"/> if the <paramref name="func"/> evaluates to true </returns>
/// <exception cref="ArgumentException"></exception>
public static async Task<T> AgainstExpressionAsync<T>([JetBrainsNotNull] this IGuardClause guardClause, [JetBrainsNotNull] Func<T, Task<bool>> func, T input, string message) where T : struct
{
if (!await func(input))
{
throw new ArgumentException(message);
}

return input;
}
}
}
22 changes: 22 additions & 0 deletions src/GuardClauses/GuardAgainstInvalidFormatExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Text.RegularExpressions;
using System.Threading.Tasks;

namespace Ardalis.GuardClauses
{
Expand Down Expand Up @@ -50,5 +51,26 @@ public static T InvalidInput<T>(this IGuardClause guardClause, T input, string p

return input;
}

/// <summary>
/// Throws an <see cref="ArgumentException" /> if <paramref name="input"/> doesn't satisfy the <paramref name="predicate"/> function.
/// </summary>
/// <param name="guardClause"></param>
/// <param name="input"></param>
/// <param name="parameterName"></param>
/// <param name="predicate"></param>
/// <param name="message">Optional. Custom error message</param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public static async Task<T> InvalidInputAsync<T>([JetBrainsNotNull] this IGuardClause guardClause, [JetBrainsNotNull] T input, [JetBrainsNotNull][JetBrainsInvokerParameterName] string parameterName, Func<T, Task<bool>> predicate, string? message = null)
{
if (!await predicate(input))
{
throw new ArgumentException(message ?? $"Input {parameterName} did not satisfy the options", parameterName);
}

return input;
}
}
}
147 changes: 112 additions & 35 deletions test/GuardClauses.UnitTests/GuardAgainstOutOfRangeForInvalidInput.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading.Tasks;
using Ardalis.GuardClauses;
using Microsoft.VisualBasic;
using Xunit;
Expand All @@ -16,43 +17,88 @@ public void DoesNothingGivenInRangeValue<T>(T input, Func<T, bool> func)
Guard.Against.InvalidInput(input, nameof(input), func);
}

[Theory]
[ClassData(typeof(IncorrectClassData))]
public void ThrowsGivenOutOfRangeValue<T>(T input, Func<T, bool> func)
{
Assert.Throws<ArgumentException>(() => Guard.Against.InvalidInput(input, nameof(input), func));
}
[Theory]
[ClassData(typeof(CorrectAsyncClassData))]
public async Task DoesNothingGivenInRangeValueAsync<T>(T input, Func<T, Task<bool>> func)
{
await Guard.Against.InvalidInputAsync(input, nameof(input), func);
}

[Theory]
[ClassData(typeof(CorrectClassData))]
public void ReturnsExpectedValueGivenInRangeValue<T>(T input, Func<T, bool> func)
{
var result = Guard.Against.InvalidInput(input, nameof(input), func);
Assert.Equal(input, result);
}
[Theory]
[ClassData(typeof(IncorrectClassData))]
public void ThrowsGivenOutOfRangeValue<T>(T input, Func<T, bool> func)
{
Assert.Throws<ArgumentException>(() => Guard.Against.InvalidInput(input, nameof(input), func));
}

[Theory]
[InlineData(null, "Input parameterName did not satisfy the options (Parameter 'parameterName')")]
[InlineData("Evaluation failed", "Evaluation failed (Parameter 'parameterName')")]
public void ErrorMessageMatchesExpected(string customMessage, string expectedMessage)
{
var exception = Assert.Throws<ArgumentException>(() => Guard.Against.InvalidInput(10, "parameterName", x => x > 20, customMessage));
Assert.NotNull(exception);
Assert.NotNull(exception.Message);
Assert.Equal(expectedMessage, exception.Message);
}
[Theory]
[ClassData(typeof(IncorrectAsyncClassData))]
public async Task ThrowsGivenOutOfRangeValueAsync<T>(T input, Func<T, Task<bool>> func)
{
await Assert.ThrowsAsync<ArgumentException>(async () => await Guard.Against.InvalidInputAsync(input, nameof(input), func));
}

[Theory]
[InlineData(null, null)]
[InlineData(null, "Please provide correct value")]
[InlineData("SomeParameter", null)]
[InlineData("SomeOtherParameter", "Value must be correct")]
public void ExceptionParamNameMatchesExpected(string expectedParamName, string customMessage)
{
var exception = Assert.Throws<ArgumentException>(() => Guard.Against.InvalidInput(10, expectedParamName, x => x > 20, customMessage));
Assert.NotNull(exception);
Assert.Equal(expectedParamName, exception.ParamName);
}
[Theory]
[ClassData(typeof(CorrectClassData))]
public void ReturnsExpectedValueGivenInRangeValue<T>(T input, Func<T, bool> func)
{
var result = Guard.Against.InvalidInput(input, nameof(input), func);
Assert.Equal(input, result);
}

[Theory]
[ClassData(typeof(CorrectAsyncClassData))]
public async Task ReturnsExpectedValueGivenInRangeValueAsync<T>(T input, Func<T, Task<bool>> func)
{
var result = await Guard.Against.InvalidInputAsync(input, nameof(input), func);
Assert.Equal(input, result);
}

[Theory]
[InlineData(null, "Input parameterName did not satisfy the options (Parameter 'parameterName')")]
[InlineData("Evaluation failed", "Evaluation failed (Parameter 'parameterName')")]
public void ErrorMessageMatchesExpected(string customMessage, string expectedMessage)
{
var exception = Assert.Throws<ArgumentException>(() => Guard.Against.InvalidInput(10, "parameterName", x => x > 20, customMessage));
Assert.NotNull(exception);
Assert.NotNull(exception.Message);
Assert.Equal(expectedMessage, exception.Message);
}

[Theory]
[InlineData(null, "Input parameterName did not satisfy the options (Parameter 'parameterName')")]
[InlineData("Evaluation failed", "Evaluation failed (Parameter 'parameterName')")]
public async Task ErrorMessageMatchesExpectedAsync(string customMessage, string expectedMessage)
{
var exception = await Assert.ThrowsAsync<ArgumentException>(async () => await Guard.Against.InvalidInputAsync(10, "parameterName", x => Task.FromResult(x > 20), customMessage));
Assert.NotNull(exception);
Assert.NotNull(exception.Message);
Assert.Equal(expectedMessage, exception.Message);
}

[Theory]
[InlineData(null, null)]
[InlineData(null, "Please provide correct value")]
[InlineData("SomeParameter", null)]
[InlineData("SomeOtherParameter", "Value must be correct")]
public void ExceptionParamNameMatchesExpected(string expectedParamName, string customMessage)
{
var exception = Assert.Throws<ArgumentException>(() => Guard.Against.InvalidInput(10, expectedParamName, x => x > 20, customMessage));
Assert.NotNull(exception);
Assert.Equal(expectedParamName, exception.ParamName);
}

[Theory]
[InlineData(null, null)]
[InlineData(null, "Please provide correct value")]
[InlineData("SomeParameter", null)]
[InlineData("SomeOtherParameter", "Value must be correct")]
public async Task ExceptionParamNameMatchesExpectedAsync(string expectedParamName, string customMessage)
{
var exception = await Assert.ThrowsAsync<ArgumentException>(async () => await Guard.Against.InvalidInputAsync(10, expectedParamName, x => Task.FromResult(x > 20), customMessage));
Assert.NotNull(exception);
Assert.Equal(expectedParamName, exception.ParamName);
}

// TODO: Test decimal types outside of ClassData
// See: https://github.com/xunit/xunit/issues/2298
Expand All @@ -72,6 +118,21 @@ public IEnumerator<object[]> GetEnumerator()
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}

public class CorrectAsyncClassData : IEnumerable<object[]> {
public IEnumerator<object[]> GetEnumerator()
{
yield return new object[] { 20, (Func<int, Task<bool>>)((x) => Task.FromResult(x > 10)) };
yield return new object[] { DateAndTime.Now, (Func<DateTime, Task<bool>>)((x) => Task.FromResult(x > DateTime.MinValue)) };
yield return new object[] { 20.0f, (Func<float, Task<bool>>)((x) => Task.FromResult(x > 10.0f)) };
//yield return new object[] { 20.0m, (Func<decimal, Task<bool>>)((x) => Task.FromResult(x > 10.0m)) };
yield return new object[] { 20.0, (Func<double, Task<bool>>)((x) => Task.FromResult(x > 10.0)) };
yield return new object[] { long.MaxValue, (Func<long, Task<bool>>)((x) => Task.FromResult(x > 1)) };
yield return new object[] { short.MaxValue, (Func<short, Task<bool>>)((x) => Task.FromResult(x > 1)) };
yield return new object[] { "abcd", (Func<string, Task<bool>>)((x) => Task.FromResult(x == x.ToLower())) };
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}

public class IncorrectClassData : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
Expand All @@ -85,6 +146,22 @@ public IEnumerator<object[]> GetEnumerator()
yield return new object[] { short.MaxValue, (Func<short, bool>)((x) => x < 1) };
yield return new object[] { "abcd", (Func<string, bool>)((x) => x == x.ToUpper()) };
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}

public class IncorrectAsyncClassData : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
{
yield return new object[] { 20, (Func<int, Task<bool>>)((x) => Task.FromResult(x < 10)) };
yield return new object[] { DateAndTime.Now, (Func<DateTime, Task<bool>>)((x) => Task.FromResult(x > DateTime.MaxValue)) };
yield return new object[] { 20.0f, (Func<float, Task<bool>>)((x) => Task.FromResult(x > 30.0f)) };
//yield return new object[] { 20.0m, (Func<decimal, bool>)((x) => x > 30.0m)) };
yield return new object[] { 20.0, (Func<double, Task<bool>>)((x) => Task.FromResult(x > 30.0)) };
yield return new object[] { long.MaxValue, (Func<long, Task<bool>>)((x) => Task.FromResult(x < 1)) };
yield return new object[] { short.MaxValue, (Func<short, Task<bool>>)((x) => Task.FromResult(x < 1)) };
yield return new object[] { "abcd", (Func<string, Task<bool>>)((x) => Task.FromResult(x == x.ToUpper())) };
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
}

0 comments on commit d54c66a

Please sign in to comment.