-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[API Proposal]: Generic overloads of existing TensorPrimitives methods #94553
Comments
Tagging subscribers to this area: @dotnet/area-system-numerics Issue DetailsBackground and motivationFor .NET 8, we added the new TensorPrimitives type, with methods dedicated to handling
This issue covers (1). API ProposalExactly the same signatures as on TensorPrimitives in .NET 8, with an overload that takes a namespace System.Numerics.Tensors;
public static class TensorPrimitives
{
public static void Abs<T>(ReadOnlySpan<T> x, Span<T> destination) where T : INumber<T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void Add<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static void Add<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : INumber<T>;
public static void ConvertToHalf(ReadOnlySpan<float> source, Span<Half> destination);
public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destination);
public static void Cosh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
public static T CosineSimilarity<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IFloatingPointIeee754<T>;
public static T Distance<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>, IRootFunctions<T>;
public static void Divide<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
public static void Divide<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
public static T Dot<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>;
public static void Exp<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
public static void Log2<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
public static void Log<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
public static T MaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void MaxMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static T Max<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void Max<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static T MinMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void MinMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static T Min<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void Min<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void Multiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
public static void Multiply<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : INumberBase<T>;
public static void Negate<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IUnaryNegationOperators<T, T>;
public static T Norm<T>(ReadOnlySpan<T> x) where T : IRootFunctions<T>;
public static T ProductOfDifferences<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>;
public static T ProductOfSums<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>;
public static T Product<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static void Sigmoid<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
public static void Sinh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
public static void SoftMax<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
public static void Subtract<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
public static void Subtract<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
public static T SumOfMagnitudes<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static T SumOfSquares<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static T Sum<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static void Tanh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IFloatingPointIeee754<T>;
}
API Usagedouble[] values1 = ..., values2 = ...;
double similarity = TensorPrimitives.CosineSimilarity(values1, values2); Alternative DesignsNo response RisksNo response
|
A little bit torn here. But I think restricting to the smallest subset is sufficient and we can look into this again as we implement and before shipping to determine if it needs to be tweaked. I do think there are smaller subsets then you've picked for a number of them. I expect in practice Given that, we can then relax that in the future when a feature like However, some of the functionality like Likewise something such as |
I think the following are the APIs with the minimal set of interfaces needed on each. We could require that // IAdditionOperators<T, T, T>
public static void Add<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IAdditionOperators<T, T, T>;
public static void Add<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IAdditionOperators<T, T, T>;
// IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>
public static T Sum<T>(ReadOnlySpan<T> x) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>;
// IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>
public static T Dot<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>;
public static T SumOfSquares<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
// IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>
public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
// IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>
public static T ProductOfSums<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
// IDivisionOperators<T, T, T>
public static void Divide<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
public static void Divide<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
// IExponentialFunctions<T> -- Implies INumberBase<T>, IFloatingPointConstants<T>
public static void Exp<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExponentialFunctions<T>;
public static void SoftMax<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExponentialFunctions<T>;
public static void Sigmoid<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExponentialFunctions<T>;
// IHyperbolicFunctions<T> -- Implies INumberBase<T>, IFloatingPointConstants<T>
public static void Cosh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunctions<T>;
public static void Sinh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunctions<T>;
public static void Tanh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunctions<T>;
// ILogarithmicFunctions<T> -- Implies INumberBase<T>, IFloatingPointConstants<T>
public static void Log<T>(ReadOnlySpan<T> x, Span<T> destination) where T : ILogarithmicFunctions<T>;
public static void Log2<T>(ReadOnlySpan<T> x, Span<T> destination) where T : ILogarithmicFunctions<T>;
// IMultiplyOperators<T, T, T>
public static void Multiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IMultiplyOperators<T, T, T>;
public static void Multiply<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IMultiplyOperators<T, T, T>;
// IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>
public static T Product<T>(ReadOnlySpan<T> x) where T : IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
// INumber<T>
public static int IndexOfMax<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static int IndexOfMin<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static T Max<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void Max<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static T Min<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void Min<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
// INumberBase<T>
public static void Abs<T>(ReadOnlySpan<T> x, Span<T> destination) where T : INumberBase<T>;
public static int IndexOfMaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static int IndexOfMinMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static T MaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static void MaxMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
public static T MinMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static void MinMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
public static T SumOfMagnitudes<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
// IRootFunctions<T> -- Implies INumberBase<T>, IFloatingPointConstants<T>
public static T CosineSimilarity<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IRootFunctions<T>;
public static T Distance<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IRootFunctions<T>;
public static T Norm<T>(ReadOnlySpan<T> x) where T : IRootFunctions<T>;
// ISubtractionOperators<T, T, T>
public static void Subtract<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
public static void Subtract<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
// ISubtractionOperators<T, T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>
public static T ProductOfDifferences<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>;
// IUnaryNegationOperators<T, T>
public static void Negate<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IUnaryNegationOperators<T, T>;
// Other -- supporting arbitrary floating-point types is possible, but very difficult
public static void ConvertToHalf(ReadOnlySpan<double> source, Span<Half> destination);
public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<double> destination); |
Yeah, I wrote before I updated most of them to use INumber, so my comment was inconsistent. That said, unless we change how things are implemented, we can't make many of them as fine-grained as you've outlined. For example, you've noted Add as only needing I can go through and make them the minimum possible with today's implementation and then we can possibly revise the implementation subsequently if we want to bring it down even further. |
It would really just need But, the notably the need for |
Ok.
Right, that's what I meant by "unless we change how things are implemented" and "we can possibly revise the implementation subsequently if we want to bring it down even further". I don't think it's worth adding more code here to reduce the constraints right now. |
Looks good as proposed (taking all of the constraints on faith) namespace System.Numerics.Tensors;
public static partial class TensorPrimitives
{
public static void Abs<T>(ReadOnlySpan<T> x, Span<T> destination) where T : INumberBase<T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void AddMultiply<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void Add<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>;
public static void Add<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>;
public static void Cosh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunctions<T>;
public static T CosineSimilarity<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IRootFunctions<T>;
public static T Distance<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>, IRootFunctions<T>;
public static void Divide<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
public static void Divide<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
public static T Dot<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
public static void Exp<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExponentialFunctions<T>;
public static int IndexOfMax<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static int IndexOfMaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static int IndexOfMin<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static int IndexOfMinMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void Log2<T>(ReadOnlySpan<T> x, Span<T> destination) where T : ILogarithmicFunctions<T>;
public static void Log<T>(ReadOnlySpan<T> x, Span<T> destination) where T : ILogarithmicFunctions<T>;
public static T MaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static void MaxMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
public static T Max<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void Max<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static T MinMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static void MinMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
public static T Min<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static void Min<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void MultiplyAdd<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
public static void Multiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
public static void Multiply<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
public static void Negate<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IUnaryNegationOperators<T, T>;
public static T Norm<T>(ReadOnlySpan<T> x) where T : IRootFunctions<T>;
public static T ProductOfDifferences<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : ISubtractionOperators<T, T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
public static T ProductOfSums<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
public static T Product<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static void Sigmoid<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExponentialFunction<T>;
public static void Sinh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunction<T>;
public static void SoftMax<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExpoentialFunction<T>;
public static void Subtract<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
public static void Subtract<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
public static T SumOfMagnitudes<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
public static T SumOfSquares<T>(ReadOnlySpan<T> x) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>;
public static T Sum<T>(ReadOnlySpan<T> x) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>;
public static void Tanh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunctions<T>;
} |
Everything here has been done, except: public static int IndexOfMax<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static int IndexOfMaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static int IndexOfMin<T>(ReadOnlySpan<T> x) where T : INumber<T>;
public static int IndexOfMinMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>; Those still need to be added. |
Background and motivation
For .NET 8, we added the new TensorPrimitives type, with methods dedicated to handling
float
. For post-.NET 8, we're planning to augment this in three ways (#93286):This issue covers (1).
API Proposal
Exactly the same signatures as on TensorPrimitives in .NET 8, with an overload that takes a
T
instead offloat
.INumber<T>
orIFloatingPointIeee754<T>
for simplicity / consistency / future flexibility?API Usage
Alternative Designs
No response
Risks
No response
The text was updated successfully, but these errors were encountered: