diff --git a/src/Examples/AlexNet.cs b/src/Examples/AlexNet.cs index 17b08b7c3..aa893255e 100644 --- a/src/Examples/AlexNet.cs +++ b/src/Examples/AlexNet.cs @@ -19,17 +19,17 @@ public AlexNet(string name, int numClasses, torch.Device device = null) : base(n features = Sequential( ("c1", Conv2d(3, 64, kernelSize: 3, stride: 2, padding: 1)), ("r1", ReLU(inplace: true)), - ("mp1", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("mp1", MaxPool2d(kernel_size: new long[] { 2, 2 })), ("c2", Conv2d(64, 192, kernelSize: 3, padding: 1)), ("r2", ReLU(inplace: true)), - ("mp2", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("mp2", MaxPool2d(kernel_size: new long[] { 2, 2 })), ("c3", Conv2d(192, 384, kernelSize: 3, padding: 1)), ("r3", ReLU(inplace: true)), ("c4", Conv2d(384, 256, kernelSize: 3, padding: 1)), ("r4", ReLU(inplace: true)), ("c5", Conv2d(256, 256, kernelSize: 3, padding: 1)), ("r5", ReLU(inplace: true)), - ("mp3", MaxPool2d(kernelSize: new long[] { 2, 2 }))); + ("mp3", MaxPool2d(kernel_size: new long[] { 2, 2 }))); avgPool = AdaptiveAvgPool2d(new long[] { 2, 2 }); diff --git a/src/Examples/MNIST.cs b/src/Examples/MNIST.cs index 73c8e69b0..d3059406a 100644 --- a/src/Examples/MNIST.cs +++ b/src/Examples/MNIST.cs @@ -105,7 +105,7 @@ internal class Model : Module // These don't have any parameters, so the only reason to instantiate // them is performance, since they will be used over and over. - private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 }); + private Module pool1 = MaxPool2d(kernel_size: new long[] { 2, 2 }); private Module relu1 = ReLU(); private Module relu2 = ReLU(); diff --git a/src/Examples/VGG.cs b/src/Examples/VGG.cs index 0c12eca4e..033bf6491 100644 --- a/src/Examples/VGG.cs +++ b/src/Examples/VGG.cs @@ -38,7 +38,7 @@ public VGG(string name, int numClasses, Device device = null) : base(name) for (var i = 0; i < channels.Length; i++) { if (channels[i] == 0) { - modules.Add(($"MaxPool2d-{i}a", MaxPool2d(kernelSize: 2, stride: 2))); + modules.Add(($"MaxPool2d-{i}a", MaxPool2d(kernel_size: 2, stride: 2))); } else { modules.Add(($"conv2d-{i}a", Conv2d(in_channels, channels[i], kernelSize: 3, padding: 1))); modules.Add(($"bnrm2d-{i}a", BatchNorm2d(channels[i]))); diff --git a/src/FSharp.Examples/AlexNet.fs b/src/FSharp.Examples/AlexNet.fs index 9a1e3fbe1..5604f0923 100644 --- a/src/FSharp.Examples/AlexNet.fs +++ b/src/FSharp.Examples/AlexNet.fs @@ -49,17 +49,17 @@ type Model(name,device:torch.Device) as this = let features = Sequential(("c1", Conv2d(3L, 64L, kernelSize=3L, stride=2L, padding=1L) :> Module), ("r1", ReLU(inplace=true) :> Module), - ("mp1", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), + ("mp1", MaxPool2d(kernel_size=[|2L; 2L|]) :> Module), ("c2", Conv2d(64L, 192L, kernelSize=3L, padding=1L) :> Module), ("r2", ReLU(inplace=true) :> Module), - ("mp2", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), + ("mp2", MaxPool2d(kernel_size=[|2L; 2L|]) :> Module), ("c3", Conv2d(192L, 384L, kernelSize=3L, padding=1L) :> Module), ("r3", ReLU(inplace=true) :> Module), ("c4", Conv2d(384L, 256L, kernelSize=3L, padding=1L) :> Module), ("r4", ReLU(inplace=true) :> Module), ("c5", Conv2d(256L, 256L, kernelSize=3L, padding=1L) :> Module), ("r5", ReLU(inplace=true) :> Module), - ("mp3", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), + ("mp3", MaxPool2d(kernel_size=[|2L; 2L|]) :> Module), ("avg", AdaptiveAvgPool2d([|2L; 2L|]) :> Module)) let classifier = Sequential(("d1", Dropout() :> Module), diff --git a/src/FSharp.Examples/MNIST.fs b/src/FSharp.Examples/MNIST.fs index d042f6f97..6967ebdc6 100644 --- a/src/FSharp.Examples/MNIST.fs +++ b/src/FSharp.Examples/MNIST.fs @@ -51,7 +51,7 @@ type Model(name,device:torch.Device) as this = let fc1 = Linear(9216L, 128L) let fc2 = Linear(128L, 10L) - let pool1 = MaxPool2d(kernelSize=[|2L; 2L|]) + let pool1 = MaxPool2d(kernel_size=[|2L; 2L|]) let relu = ReLU() diff --git a/src/Native/LibTorchSharp/THSActivation.cpp b/src/Native/LibTorchSharp/THSActivation.cpp index 21b2e14a9..c89beaab6 100644 --- a/src/Native/LibTorchSharp/THSActivation.cpp +++ b/src/Native/LibTorchSharp/THSActivation.cpp @@ -2,330 +2,3 @@ #include "THSNN.h" #include - -NNModule THSNN_CELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::CELUOptions().alpha(alpha).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_CELU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ELUOptions().alpha(alpha).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ELU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_GELU_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_GELU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_GLU_ctor(const int64_t dim, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::GLUOptions().dim(dim); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_GLU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Hardshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::HardshrinkOptions(lambda); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Hardshrink_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Hardtanh_ctor(const double min_val, const double max_val, const bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::HardtanhOptions() - .min_val(min_val) - .max_val(max_val) - .inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Hardtanh_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - - -NNModule THSNN_LeakyReLU_ctor(const double negative_sloope, const bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::LeakyReLUOptions().negative_slope(negative_sloope).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_LeakyReLU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_LogSoftmax_ctor(int64_t dim, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::LogSoftmaxOptions(dim); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_LogSoftmax_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Mish_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_Mish_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_PReLU_ctor(const int64_t nparams, const double init, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::PReLUOptions().num_parameters(nparams).init(init); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_PReLU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_PReLU_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_PReLU_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - -NNModule THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReLUOptions(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReLU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_RReLU_ctor(const double lower, const double upper, const bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::RReLUOptions().lower(lower).upper(upper).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_RReLU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReLU6_ctor(bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReLU6Options(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReLU6_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_SELU_ctor(bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::SELUOptions(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_SELU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Sigmoid_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_Sigmoid_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_SiLU_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_SiLU_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Softmax2d_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_Softmax2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Softmax_ctor(const int64_t dim, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::SoftmaxOptions(dim); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Softmax_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Softmin_ctor(const int64_t dim, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::SoftminOptions(dim); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Softmin_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Softplus_ctor(const double beta, const double threshold, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::SoftplusOptions().beta(beta).threshold(threshold); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Softplus_forward(const NNModule module, const Tensor tensor) { - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Softshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::SoftshrinkOptions().lambda(lambda); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Softshrink_forward(const NNModule module, const Tensor tensor) { - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Softsign_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_Softsign_forward(const NNModule module, const Tensor tensor) { - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Tanh_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_Tanh_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Tanhshrink_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_Tanhshrink_forward(const NNModule module, const Tensor tensor) { - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Threshold_ctor(const double threshold, const double value, const bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ThresholdOptions(threshold, value).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Threshold_forward(const NNModule module, const Tensor tensor) { - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - diff --git a/src/Native/LibTorchSharp/THSConvolution.cpp b/src/Native/LibTorchSharp/THSConvolution.cpp index e57602dee..621f8935c 100644 --- a/src/Native/LibTorchSharp/THSConvolution.cpp +++ b/src/Native/LibTorchSharp/THSConvolution.cpp @@ -3,623 +3,6 @@ #include - - -NNModule THSNN_AvgPool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, - bool ceil_mode, bool count_include_pad, int64_t divisor_override, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AvgPool1dOptions(at::ArrayRef(kernelSize, 1)).ceil_mode(ceil_mode).count_include_pad(count_include_pad); - if (stride) - opts = opts.stride(at::ArrayRef(stride, 1)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, 1)); - if (divisor_override > 0) - opts = opts.divisor_override(divisor_override); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AvgPool1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, - bool ceil_mode, bool count_include_pad, int64_t divisor_override, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AvgPool2dOptions(at::ArrayRef(kernelSize, kernelSizeLength)).ceil_mode(ceil_mode).count_include_pad(count_include_pad); - if (stride) - opts = opts.stride(at::ArrayRef(stride, strideLength)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, paddingLength)); - if (divisor_override > 0) - opts = opts.divisor_override(divisor_override); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AvgPool2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AvgPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, - bool ceil_mode, bool count_include_pad, int64_t divisor_override, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AvgPool3dOptions(at::ArrayRef(kernelSize, kernelSizeLength)).ceil_mode(ceil_mode).count_include_pad(count_include_pad); - if (stride) - opts = opts.stride(at::ArrayRef(stride, strideLength)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, paddingLength)); - if (divisor_override > 0) - opts = opts.divisor_override(divisor_override); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AvgPool3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AdaptiveAvgPool1d_ctor(const int64_t* kernelSize, const int kernelSizeLength, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AdaptiveAvgPool1dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AdaptiveAvgPool1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AdaptiveAvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AdaptiveAvgPool2dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AdaptiveAvgPool2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AdaptiveAvgPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AdaptiveAvgPool3dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AdaptiveAvgPool3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AdaptiveMaxPool1d_ctor(const int64_t* kernelSize, const int kernelSizeLength, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AdaptiveMaxPool1dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AdaptiveMaxPool1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AdaptiveMaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AdaptiveMaxPool2dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AdaptiveMaxPool2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AdaptiveMaxPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AdaptiveMaxPool3dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AdaptiveMaxPool3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_LPPool1d_ctor(double norm_type, const int64_t* kernelSize, const int64_t* stride, const bool ceil_mode, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::LPPool1dOptions(norm_type, at::ArrayRef(kernelSize, 1)).ceil_mode(ceil_mode); - if (stride) - opts = opts.stride(at::ArrayRef(stride, 1)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_LPPool1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_LPPool2d_ctor(double norm_type, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const bool ceil_mode, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::LPPool2dOptions(norm_type, at::ArrayRef(kernelSize, kernelSizeLength)).ceil_mode(ceil_mode); - if (stride) - opts = opts.stride(at::ArrayRef(stride, strideLength)); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_LPPool2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_MaxPool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, const int64_t* dilation, bool ceil_mode, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::MaxPool1dOptions(at::ArrayRef(kernelSize, 1)).ceil_mode(ceil_mode); - if (stride) - opts = opts.stride(at::ArrayRef(stride, 1)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, 1)); - if (dilation) - opts = opts.dilation(at::ArrayRef(dilation, 1)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_MaxPool1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_MaxPool1d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices) -{ - std::tuple res; - CATCH(res = (*module)->as()->forward_with_indices(*tensor);); - *indices = ResultTensor(std::get<1>(res)); - return ResultTensor(std::get<0>(res)); -} - -NNModule THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, bool ceil_mode, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::MaxPool2dOptions(at::ArrayRef(kernelSize, kernelSizeLength)).ceil_mode(ceil_mode); - if (stride) - opts = opts.stride(at::ArrayRef(stride, strideLength)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, paddingLength)); - if (dilation) - opts = opts.dilation(at::ArrayRef(dilation, dilationLength)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_MaxPool2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_MaxPool2d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices) -{ - std::tuple res; - CATCH(res = (*module)->as()->forward_with_indices(*tensor);); - *indices = ResultTensor(std::get<1>(res)); - return ResultTensor(std::get<0>(res)); -} - -NNModule THSNN_MaxPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, bool ceil_mode, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::MaxPool3dOptions(at::ArrayRef(kernelSize, kernelSizeLength)).ceil_mode(ceil_mode); - if (stride) - opts = opts.stride(at::ArrayRef(stride, strideLength)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, paddingLength)); - if (dilation) - opts = opts.dilation(at::ArrayRef(dilation, dilationLength)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_MaxPool3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_MaxPool3d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices) -{ - std::tuple res; - CATCH(res = (*module)->as()->forward_with_indices(*tensor);); - *indices = ResultTensor(std::get<1>(res)); - return ResultTensor(std::get<0>(res)); -} - -NNModule THSNN_MaxUnpool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::MaxUnpool1dOptions(at::ArrayRef(kernelSize, 1)); - if (stride) - opts = opts.stride(at::ArrayRef(stride, 1)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, 1)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_MaxUnpool1d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize) -{ - if (outputSize != nullptr) { - std::vector outSize; - outSize.push_back(*outputSize); - - CATCH_TENSOR((*module)->as()->forward(*tensor, *indices, outSize)); - } - else { - CATCH_TENSOR((*module)->as()->forward(*tensor, *indices)); - } -} - -NNModule THSNN_MaxUnpool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::MaxUnpool2dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - if (stride) - opts = opts.stride(at::ArrayRef(stride, strideLength)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, paddingLength)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_MaxUnpool2d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize, const int outputSizeLength) -{ - if (outputSize != nullptr) { - std::vector outSize; - for (auto i = 0L; i < outputSizeLength; i++) { - outSize.push_back(outputSize[i]); - } - - CATCH_TENSOR((*module)->as()->forward(*tensor, *indices, outSize)); - } - else { - CATCH_TENSOR((*module)->as()->forward(*tensor, *indices)); - } -} - -NNModule THSNN_MaxUnpool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::MaxUnpool3dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - if (stride) - opts = opts.stride(at::ArrayRef(stride, strideLength)); - if (padding) - opts = opts.padding(at::ArrayRef(padding, paddingLength)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_MaxUnpool3d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize, const int outputSizeLength) -{ - if (outputSize != nullptr) { - std::vector outSize; - for (auto i = 0L; i < outputSizeLength; i++) { - outSize.push_back(outputSize[i]); - } - - CATCH_TENSOR((*module)->as()->forward(*tensor, *indices, outSize)); - } - else { - CATCH_TENSOR((*module)->as()->forward(*tensor, *indices)); - } -} - - -NNModule THSNN_FractionalMaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::FractionalMaxPool2dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - if (outputSize) - opts = opts.output_size(at::ArrayRef(outputSize, outputSizeLength)); - if (outputRatio) - opts = opts.output_ratio(at::ArrayRef(outputRatio, outputRatioLength)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_FractionalMaxPool2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_FractionalMaxPool2d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices) -{ - std::tuple res; - CATCH(res = (*module)->as()->forward_with_indices(*tensor);); - *indices = ResultTensor(std::get<1>(res)); - return ResultTensor(std::get<0>(res)); -} - -NNModule THSNN_FractionalMaxPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::FractionalMaxPool3dOptions(at::ArrayRef(kernelSize, kernelSizeLength)); - if (outputSize) - opts = opts.output_size(at::ArrayRef(outputSize, outputSizeLength)); - if (outputRatio) - opts = opts.output_ratio(at::ArrayRef(outputRatio, outputRatioLength)); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_FractionalMaxPool3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_FractionalMaxPool3d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices) -{ - std::tuple res; - CATCH(res = (*module)->as()->forward_with_indices(*tensor);); - *indices = ResultTensor(std::get<1>(res)); - return ResultTensor(std::get<0>(res)); -} - -NNModule THSNN_ZeroPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ZeroPad2dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ZeroPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ZeroPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ZeroPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad1d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad1dOptions(padding, value); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ConstantPad1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad2d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad2dOptions(padding, value); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ConstantPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad3d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad3dOptions(padding, value); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ConstantPad3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad1d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad1dOptions({ padding_left, padding_right }, value); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ConstantPad2d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }, value); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ConstantPad3d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad3dOptions({ padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back }, value); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReplicationPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad1dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReplicationPad1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReplicationPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad2dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReplicationPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReplicationPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad3dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - - -Tensor THSNN_ReplicationPad3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReplicationPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad1dOptions({ padding_left, padding_right }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReplicationPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReplicationPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad3dOptions({ padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReflectionPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad1dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReflectionPad1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReflectionPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad2dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReflectionPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReflectionPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad3dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReflectionPad3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReflectionPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad1dOptions({ padding_left, padding_right }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReflectionPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReflectionPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad3dOptions({ padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back }); - res = create_module(opts, outAsAnyModule); - ); -} - - template void ApplyPaddingMode(T& opts, const int64_t padding) { diff --git a/src/Native/LibTorchSharp/THSNN.cpp b/src/Native/LibTorchSharp/THSNN.cpp index 12b6a461a..d1e6297e0 100644 --- a/src/Native/LibTorchSharp/THSNN.cpp +++ b/src/Native/LibTorchSharp/THSNN.cpp @@ -3,53 +3,6 @@ #include - -NNModule THSNN_Identity_ctor(NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - res = create_module(outAsAnyModule); - ); -} - -Tensor THSNN_Identity_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool bias, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::LinearOptions(input_size, output_size).bias(bias); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Linear_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_Linear_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_Linear_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_Linear_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_Linear_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - Tensor THSNN_functional_linear(const Tensor input, const Tensor weights, const Tensor bias) { CATCH_TENSOR(bias == nullptr ? @@ -64,40 +17,6 @@ Tensor THSNN_functional_bilinear(const Tensor input1, const Tensor input2, const torch::nn::functional::bilinear(*input1, *input2, *weights, *bias)); } -NNModule THSNN_Bilinear_ctor(const int64_t input_size_1, const int64_t input_size_2, const int64_t output_size, const bool bias, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::BilinearOptions(input_size_1, input_size_2, output_size).bias(bias); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Bilinear_forward(const NNModule module, const Tensor x1, const Tensor x2) -{ - CATCH_TENSOR((*module)->as()->forward(*x1, *x2)); -} - -Tensor THSNN_Bilinear_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_Bilinear_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_Bilinear_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_Bilinear_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - Tensor THSNN_dropout(const Tensor input, const double p, bool training, bool inplace) { auto opts = torch::nn::functional::DropoutFuncOptions() @@ -148,113 +67,16 @@ Tensor THSNN_feature_alpha_dropout(const Tensor input, const double p, bool trai CATCH_TENSOR(torch::nn::functional::feature_alpha_dropout(*input, opts)); } -NNModule THSNN_Dropout_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::DropoutOptions(probability).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Dropout_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_AlphaDropout_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::AlphaDropoutOptions(probability).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_AlphaDropout_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Dropout1d_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - // Creating a Dropout2d instance here is done on purpose. There's no torch::nn::Dropout1d - auto opts = torch::nn::Dropout2dOptions(probability).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Dropout1d_forward(const NNModule module, const Tensor tensor) -{ - auto drop1d = (*module)->as(); - CATCH_TENSOR(drop1d->options.inplace() - ? drop1d->forward((*tensor).unsqueeze_(-1)).squeeze_(-1) - : drop1d->forward((*tensor).unsqueeze(-1)).squeeze(-1)); -} - - -NNModule THSNN_Dropout2d_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::Dropout2dOptions(probability).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Dropout2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Dropout3d_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::Dropout3dOptions(probability).inplace(inplace); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Dropout3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_FeatureAlphaDropout_ctor(double probability, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::FeatureAlphaDropoutOptions(probability); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_FeatureAlphaDropout_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_PixelShuffle_ctor(const int64_t upscale_factor, NNAnyModule* outAsAnyModule) +Tensor THSNN_pixel_shuffle(const Tensor tensor, const int64_t upscale_factor) { - CATCH_RETURN_NNModule( - auto opts = torch::nn::PixelShuffleOptions(upscale_factor); - res = create_module(opts, outAsAnyModule); - ); + auto opts = torch::nn::functional::PixelShuffleFuncOptions(upscale_factor); + CATCH_TENSOR(torch::nn::functional::pixel_shuffle(*tensor, opts)); } -Tensor THSNN_PixelShuffle_forward(const NNModule module, const Tensor tensor) +Tensor THSNN_pixel_unshuffle(const Tensor tensor, const int64_t downscale_factor) { - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_PixelUnshuffle_ctor(const int64_t downscale_factor, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::PixelUnshuffleOptions(downscale_factor); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_PixelUnshuffle_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); + auto opts = torch::nn::functional::PixelUnshuffleFuncOptions(downscale_factor); + CATCH_TENSOR(torch::nn::functional::pixel_unshuffle(*tensor, opts)); } template @@ -289,38 +111,6 @@ void ApplyInterpolateMode(T& opts, const int8_t mode) opts = opts.mode(torch::kArea); } -NNModule THSNN_Upsample_ctor(const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, NNAnyModule* outAsAnyModule) -{ - auto opts = torch::nn::UpsampleOptions(); - // align_corners -- 0=None, 1=true, 2=false - if (align_corners != 0) - opts.align_corners(align_corners == 1); - ApplyUpsampleMode(opts, mode); - - CATCH_RETURN_NNModule( - if (size_len > 0) { - std::vector sizes; - for (int i = 0; i < size_len; ++i) { - sizes.push_back(size[i]); - } - opts.size(sizes); - } - if (scale_factor_len > 0) { - std::vector scales; - for (int i = 0; i < scale_factor_len; ++i) { - scales.push_back(scale_factor[i]); - } - opts.scale_factor(scales); - } - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Upsample_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - template void ApplyPadMode(T& opts, const int64_t padding) { @@ -743,54 +533,9 @@ Tensor THSNN_TransformerDecoder_forward(const NNModule module, const Tensor tg ); } -NNModule THSNN_Flatten_ctor(const int64_t start_dim, const int64_t end_dim, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::FlattenOptions() - .start_dim(start_dim) - .end_dim(end_dim); - - res = create_module(opts, outAsAnyModule); - ); -} -Tensor THSNN_Flatten_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_Unflatten_ctor(const int64_t dim, const int64_t* shape, const int64_t shape_len, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - std::vector sizes; - for (int64_t i = 0; i < shape_len; ++i) - { - sizes.push_back(shape[i]); - } - auto opts = torch::nn::UnflattenOptions(dim, sizes); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Unflatten_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_CosineSimilarity_ctor(const int64_t dim, double eps, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::CosineSimilarityOptions() - .dim(dim) - .eps(eps); - - res = create_module(opts, outAsAnyModule); - ); - -} - -Tensor THSNN_CosineSimilarity_forward(const NNModule module, const Tensor input1, const Tensor input2) +Tensor THSNN_cosine_similarity(const Tensor input1, const Tensor input2, int64_t dim, double eps) { - CATCH_TENSOR((*module)->as()->forward(*input1, *input2)); + CATCH_TENSOR(torch::nn::functional::cosine_similarity(*input1, *input2, torch::nn::functional::CosineSimilarityFuncOptions().dim(dim).eps(eps))); } NNModule THSNN_PairwiseDistance_ctor(double p, double eps, bool keep_dim, NNAnyModule* outAsAnyModule) diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h index 78b39a3a4..5003874ac 100644 --- a/src/Native/LibTorchSharp/THSNN.h +++ b/src/Native/LibTorchSharp/THSNN.h @@ -37,99 +37,6 @@ EXPORT_API(void) THSNN_AnyModule_dispose(const NNAnyModule module); EXPORT_API(NNModule) THSNN_custom_module(const char* name, Tensor(*forward)(Tensor), NNAnyModule* outAsAnyModule); -// Pooling - -EXPORT_API(NNModule) THSNN_MaxPool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, const int64_t* dilation, bool ceil_mode, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_MaxPool1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_MaxPool1d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor *indices); - -EXPORT_API(NNModule) THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, bool ceil_mode, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_MaxPool2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_MaxPool2d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices); - -EXPORT_API(NNModule) THSNN_MaxPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, bool ceil_mode, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_MaxPool3d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_MaxPool3d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices); - -EXPORT_API(NNModule) THSNN_FractionalMaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_FractionalMaxPool2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_FractionalMaxPool2d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices); - -EXPORT_API(NNModule) THSNN_FractionalMaxPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_FractionalMaxPool3d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_FractionalMaxPool3d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices); - -EXPORT_API(NNModule) THSNN_MaxUnpool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_MaxUnpool1d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize); - -EXPORT_API(NNModule) THSNN_MaxUnpool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_MaxUnpool2d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize, const int outputSizeLength); - -EXPORT_API(NNModule) THSNN_MaxUnpool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_MaxUnpool3d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize, const int outputSizeLength); - -EXPORT_API(NNModule) THSNN_AdaptiveAvgPool1d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AdaptiveAvgPool1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_AdaptiveAvgPool2d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AdaptiveAvgPool2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_AdaptiveAvgPool3d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AdaptiveAvgPool3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_AdaptiveMaxPool1d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AdaptiveMaxPool1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_AdaptiveMaxPool2d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AdaptiveMaxPool2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_AdaptiveMaxPool3d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AdaptiveMaxPool3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_AvgPool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, bool ceil_mode, bool count_include_pad, int64_t divisor_override, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AvgPool1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, bool count_include_pad, int64_t divisor_override, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AvgPool2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_AvgPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, bool count_include_pad, int64_t divisor_override, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AvgPool3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_LPPool1d_ctor(double norm_type, const int64_t* kernelSize, const int64_t* stride, bool ceil_mode, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_LPPool1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_LPPool2d_ctor(double norm_type, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, bool ceil_mode, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_LPPool2d_forward(const NNModule module, const Tensor tensor); - -// Padding - -EXPORT_API(NNModule) THSNN_ZeroPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ZeroPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ZeroPad2d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_ConstantPad1d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ConstantPad1d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ConstantPad1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ConstantPad2d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ConstantPad2d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ConstantPad2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ConstantPad3d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ConstantPad3d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ConstantPad3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_ReplicationPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReplicationPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReplicationPad1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReplicationPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReplicationPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReplicationPad2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReplicationPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReplicationPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReplicationPad3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_ReflectionPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReflectionPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReflectionPad1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReflectionPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReflectionPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReflectionPad2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReflectionPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReflectionPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReflectionPad3d_forward(const NNModule module, const Tensor tensor); - // Convolution EXPORT_API(NNModule) THSNN_Conv1d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule); @@ -176,119 +83,6 @@ EXPORT_API(void) THSNN_ConvTranspose3d_set_bias(const NNModule module, const // Normalization -EXPORT_API(NNModule) THSNN_BatchNorm1d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_BatchNorm1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_BatchNorm2d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_BatchNorm2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_BatchNorm3d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_BatchNorm3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(Tensor) THSNN_BatchNorm1d_bias(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm1d_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_BatchNorm1d_weight(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm1d_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_BatchNorm2d_bias(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm2d_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_BatchNorm2d_weight(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm2d_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_BatchNorm3d_bias(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm3d_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_BatchNorm3d_weight(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm3d_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(void) THSNN_BatchNorm1d_reset_stats(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm2d_reset_stats(const NNModule module); -EXPORT_API(void) THSNN_BatchNorm3d_reset_stats(const NNModule module); - -EXPORT_API(Tensor) THSNN_BatchNorm1d_get_mean(const NNModule module); -EXPORT_API(Tensor) THSNN_BatchNorm2d_get_mean(const NNModule module); -EXPORT_API(Tensor) THSNN_BatchNorm3d_get_mean(const NNModule module); - -EXPORT_API(void) THSNN_BatchNorm1d_set_mean(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_BatchNorm2d_set_mean(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_BatchNorm3d_set_mean(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_BatchNorm1d_get_var(const NNModule module); -EXPORT_API(Tensor) THSNN_BatchNorm2d_get_var(const NNModule module); -EXPORT_API(Tensor) THSNN_BatchNorm3d_get_var(const NNModule module); - -EXPORT_API(void) THSNN_BatchNorm1d_set_var(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_BatchNorm2d_set_var(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_BatchNorm3d_set_var(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_BatchNorm1d_get_batches(const NNModule module); -EXPORT_API(Tensor) THSNN_BatchNorm2d_get_batches(const NNModule module); -EXPORT_API(Tensor) THSNN_BatchNorm3d_get_batches(const NNModule module); - -EXPORT_API(NNModule) THSNN_InstanceNorm1d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_InstanceNorm1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_InstanceNorm2d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_InstanceNorm2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_InstanceNorm3d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_InstanceNorm3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(Tensor) THSNN_InstanceNorm1d_bias(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm1d_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_InstanceNorm1d_weight(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm1d_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_InstanceNorm2d_bias(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm2d_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_InstanceNorm2d_weight(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm2d_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_InstanceNorm3d_bias(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm3d_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_InstanceNorm3d_weight(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm3d_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(void) THSNN_InstanceNorm1d_reset_stats(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm2d_reset_stats(const NNModule module); -EXPORT_API(void) THSNN_InstanceNorm3d_reset_stats(const NNModule module); - -EXPORT_API(Tensor) THSNN_InstanceNorm1d_get_mean(const NNModule module); -EXPORT_API(Tensor) THSNN_InstanceNorm2d_get_mean(const NNModule module); -EXPORT_API(Tensor) THSNN_InstanceNorm3d_get_mean(const NNModule module); - -EXPORT_API(void) THSNN_InstanceNorm1d_set_mean(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_InstanceNorm2d_set_mean(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_InstanceNorm3d_set_mean(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_InstanceNorm1d_get_var(const NNModule module); -EXPORT_API(Tensor) THSNN_InstanceNorm2d_get_var(const NNModule module); -EXPORT_API(Tensor) THSNN_InstanceNorm3d_get_var(const NNModule module); - -EXPORT_API(void) THSNN_InstanceNorm1d_set_var(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_InstanceNorm2d_set_var(const NNModule module, const Tensor weight); -EXPORT_API(void) THSNN_InstanceNorm3d_set_var(const NNModule module, const Tensor weight); - -EXPORT_API(Tensor) THSNN_InstanceNorm1d_get_batches(const NNModule module); -EXPORT_API(Tensor) THSNN_InstanceNorm2d_get_batches(const NNModule module); -EXPORT_API(Tensor) THSNN_InstanceNorm3d_get_batches(const NNModule module); - - - -EXPORT_API(NNModule) THSNN_LayerNorm_ctor(const int64_t* norm_shape, const int64_t norm_shape_len, const double eps, const bool elementwise_affine, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_LayerNorm_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(Tensor) THSNN_LayerNorm_bias(const NNModule module); -EXPORT_API(void) THSNN_LayerNorm_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_LayerNorm_weight(const NNModule module); -EXPORT_API(void) THSNN_LayerNorm_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(NNModule) THSNN_GroupNorm_ctor(const int64_t num_groups, const int64_t num_channels, const double eps, const bool affine, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_GroupNorm_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(Tensor) THSNN_GroupNorm_bias(const NNModule module); -EXPORT_API(void) THSNN_GroupNorm_set_bias(const NNModule module, const Tensor bias); -EXPORT_API(Tensor) THSNN_GroupNorm_weight(const NNModule module); -EXPORT_API(void) THSNN_GroupNorm_set_weight(const NNModule module, const Tensor weight); - -EXPORT_API(NNModule) THSNN_LocalResponseNorm_ctor(const int64_t size, const double alpha, const double beta, const double k, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_LocalResponseNorm_forward(const NNModule module, const Tensor tensor); - EXPORT_API(Tensor) THSNN_batch_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps); EXPORT_API(Tensor) THSNN_group_norm(const Tensor input, int64_t num_groups, const Tensor weight, const Tensor bias, const double eps); EXPORT_API(Tensor) THSNN_instance_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool use_input_stats, const double momentum, const double eps); @@ -296,22 +90,6 @@ EXPORT_API(Tensor) THSNN_layer_norm(const Tensor input, const int64_t* normali EXPORT_API(Tensor) THSNN_local_response_norm(const Tensor input, const int64_t size, const double alpha, const double beta, const double k); // Dropout - -EXPORT_API(NNModule) THSNN_Dropout_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Dropout_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Dropout1d_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Dropout1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Dropout2d_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Dropout2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Dropout3d_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Dropout3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_AlphaDropout_ctor(double probability, bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_AlphaDropout_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_FeatureAlphaDropout_ctor(double probability, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_FeatureAlphaDropout_forward(const NNModule module, const Tensor tensor); - EXPORT_API(Tensor) THSNN_dropout(const Tensor input, const double p, bool training, bool inplace); EXPORT_API(Tensor) THSNN_dropout2d(const Tensor input, const double p, bool training, bool inplace); EXPORT_API(Tensor) THSNN_dropout3d(const Tensor input, const double p, bool training, bool inplace); @@ -325,33 +103,13 @@ EXPORT_API(Tensor) THSNN_unfold(const Tensor input, const int64_t kernel1, const // Linear -EXPORT_API(NNModule) THSNN_Identity_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Identity_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool with_bias, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Linear_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_Linear_bias(const NNModule module); -EXPORT_API(void) THSNN_Linear_set_bias(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_Linear_weight(const NNModule module); -EXPORT_API(void) THSNN_Linear_set_weight(const NNModule module, const Tensor tensor); - EXPORT_API(Tensor) THSNN_functional_linear(const Tensor input, const Tensor weights, const Tensor bias); EXPORT_API(Tensor) THSNN_functional_bilinear(const Tensor input1, const Tensor input2, const Tensor weights, const Tensor bias); -EXPORT_API(NNModule) THSNN_Bilinear_ctor(const int64_t input_size_1, const int64_t input_size_2, const int64_t output_size, const bool with_bias, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Bilinear_forward(const NNModule module, const Tensor x1, const Tensor x2); -EXPORT_API(Tensor) THSNN_Bilinear_bias(const NNModule module); -EXPORT_API(void) THSNN_Bilinear_set_bias(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_Bilinear_weight(const NNModule module); -EXPORT_API(void) THSNN_Bilinear_set_weight(const NNModule module, const Tensor tensor); - // Vision -- Modules -EXPORT_API(NNModule) THSNN_PixelShuffle_ctor(const int64_t upscale_factor, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_PixelShuffle_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_PixelUnshuffle_ctor(const int64_t downscale_factor, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_PixelUnshuffle_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Upsample_ctor(const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Upsample_forward(const NNModule module, const Tensor tensor); +EXPORT_API(Tensor) THSNN_pixel_shuffle(const Tensor tensor, const int64_t upscale_factor); +EXPORT_API(Tensor) THSNN_pixel_unshuffle(const Tensor tensor, const int64_t downscale_fasctor); // Vision -- Functions @@ -360,61 +118,6 @@ EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, co EXPORT_API(Tensor) THSNN_grid_sample(const Tensor input, const Tensor grid, const int8_t mode, const int8_t padding_mode, const int8_t align_corners); EXPORT_API(Tensor) THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size_len, const bool align_corners); -// Activation functions - -EXPORT_API(NNModule) THSNN_CELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_CELU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ELU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_GELU_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_GELU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_GLU_ctor(const int64_t dim, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_GLU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Hardshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Hardshrink_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Hardtanh_ctor(const double min_val, const double max_val, const bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Hardtanh_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_LeakyReLU_ctor(const double negative_sloope, const bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_LeakyReLU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Mish_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Mish_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_PReLU_ctor(const int64_t nparams, const double init, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_PReLU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_PReLU_weight(const NNModule module); -EXPORT_API(void) THSNN_PReLU_set_weight(const NNModule module, const Tensor weight); -EXPORT_API(NNModule) THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReLU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReLU6_ctor(bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReLU6_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_RReLU_ctor(const double lower, const double upper, const bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_RReLU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_LogSoftmax_ctor(int64_t dim, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_LogSoftmax_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_SELU_ctor(bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_SELU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Sigmoid_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Sigmoid_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_SiLU_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_SiLU_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Softmax_ctor(const int64_t dim, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Softmax_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Softmax2d_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Softmax2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Softmin_ctor(const int64_t dim, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Softmin_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Softplus_ctor(const double beta, const double threshold, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Softplus_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Softshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Softshrink_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Softsign_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Softsign_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Tanh_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Tanh_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Tanhshrink_ctor(NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Tanhshrink_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Threshold_ctor(const double threshold, const double value, const bool inplace, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Threshold_forward(const NNModule module, const Tensor tensor); - // Sparse EXPORT_API(NNModule) THSNN_Embedding_ctor(const int64_t num_embeddings, const int64_t embedding_dims, const int64_t padding_idx, bool has_pi, const double max_norm, const bool has_mn, const double norm_type, const bool scale_grad_by_freq, const bool sparse, NNAnyModule* outAsAnyModule); @@ -564,14 +267,7 @@ EXPORT_API(void) THSNN_SGD_set_lr(const Optimizer optimizer, const double lr); EXPORT_API(Tensor) THSNN_one_hot(const Tensor self, const int64_t num_classes); -EXPORT_API(NNModule) THSNN_Flatten_ctor(const int64_t start_dim, const int64_t end_dim, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Flatten_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_Unflatten_ctor(const int64_t dim, const int64_t* shape, const int64_t shape_len, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Unflatten_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_CosineSimilarity_ctor(const int64_t dim, double eps, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_CosineSimilarity_forward(const NNModule module, const Tensor input1, const Tensor input2); +EXPORT_API(Tensor) THSNN_cosine_similarity(const Tensor input1, const Tensor input2, int64_t dim, double eps); EXPORT_API(NNModule) THSNN_PairwiseDistance_ctor(double p, double eps, bool keep_dim, NNAnyModule* outAsAnyModule); EXPORT_API(Tensor) THSNN_PairwiseDistance_forward(const NNModule module, const Tensor input1, const Tensor input2); diff --git a/src/Native/LibTorchSharp/THSNormalization.cpp b/src/Native/LibTorchSharp/THSNormalization.cpp index 6d0e6d97e..c94db9896 100644 --- a/src/Native/LibTorchSharp/THSNormalization.cpp +++ b/src/Native/LibTorchSharp/THSNormalization.cpp @@ -3,608 +3,15 @@ #include -NNModule THSNN_BatchNorm1d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::BatchNorm1dOptions(features) - .eps(eps) - .momentum(momentum) - .affine(affine) - .track_running_stats(track_running_stats); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_BatchNorm1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_BatchNorm2d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::BatchNorm2dOptions(features) - .eps(eps) - .momentum(momentum) - .affine(affine) - .track_running_stats(track_running_stats); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_BatchNorm2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_BatchNorm3d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::BatchNorm3dOptions(features) - .eps(eps) - .momentum(momentum) - .affine(affine) - .track_running_stats(track_running_stats); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_BatchNorm3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_GroupNorm_ctor(const int64_t num_groups, const int64_t num_channels, const double eps, const bool affine, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::GroupNormOptions(num_groups, num_channels).eps(eps).affine(affine); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_GroupNorm_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_GroupNorm_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_GroupNorm_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_GroupNorm_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_GroupNorm_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - - -NNModule THSNN_InstanceNorm1d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::InstanceNorm1dOptions(features) - .eps(eps) - .momentum(momentum) - .affine(affine) - .track_running_stats(track_running_stats); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_InstanceNorm1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_InstanceNorm2d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::InstanceNorm2dOptions(features) - .eps(eps) - .momentum(momentum) - .affine(affine) - .track_running_stats(track_running_stats); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_InstanceNorm2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_InstanceNorm3d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::InstanceNorm3dOptions(features) - .eps(eps) - .momentum(momentum) - .affine(affine) - .track_running_stats(track_running_stats); - - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_InstanceNorm3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_LayerNorm_ctor(const int64_t* norm_shape, const int64_t norm_shape_len, const double eps, const bool elementwise_affine, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - std::vector normalized_shape; - for (int64_t i = 0; i < norm_shape_len; ++i) - { - normalized_shape.push_back(norm_shape[i]); - } - auto opts = torch::nn::LayerNormOptions(normalized_shape).eps(eps).elementwise_affine(elementwise_affine); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_LayerNorm_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_LayerNorm_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_LayerNorm_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_LayerNorm_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_LayerNorm_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - - -NNModule THSNN_LocalResponseNorm_ctor(const int64_t size, const double alpha, const double beta, const double k, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::LocalResponseNormOptions(size) - .alpha(alpha) - .beta(beta) - .k(k); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_LocalResponseNorm_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -void THSNN_BatchNorm1d_reset_stats(const NNModule module) -{ - CATCH((*module)->as()->reset_running_stats();); -} - -Tensor THSNN_BatchNorm1d_get_mean(const NNModule module) -{ - CATCH( - auto m = (*module)->as()->running_mean; - return m.defined() ? ResultTensor(m) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_BatchNorm1d_get_var(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->running_var; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_BatchNorm1d_get_batches(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->num_batches_tracked; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -void THSNN_BatchNorm1d_set_mean(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_mean = *bias; - ); -} - -void THSNN_BatchNorm1d_set_var(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_var = *bias; - ); -} - -Tensor THSNN_BatchNorm1d_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_BatchNorm1d_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_BatchNorm1d_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_BatchNorm1d_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - -void THSNN_BatchNorm2d_reset_stats(const NNModule module) -{ - CATCH((*module)->as()->reset_running_stats();); -} - -Tensor THSNN_BatchNorm2d_get_mean(const NNModule module) -{ - CATCH( - auto m = (*module)->as()->running_mean; - return m.defined() ? ResultTensor(m) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_BatchNorm2d_get_var(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->running_var; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_BatchNorm2d_get_batches(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->num_batches_tracked; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -void THSNN_BatchNorm2d_set_mean(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_mean = *bias; - ); -} - -void THSNN_BatchNorm2d_set_var(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_var = *bias; - ); -} - -Tensor THSNN_BatchNorm2d_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_BatchNorm2d_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_BatchNorm2d_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_BatchNorm2d_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - -void THSNN_BatchNorm3d_reset_stats(const NNModule module) -{ - CATCH((*module)->as()->reset_running_stats();); -} - -Tensor THSNN_BatchNorm3d_get_mean(const NNModule module) -{ - CATCH( - auto m = (*module)->as()->running_mean; - return m.defined() ? ResultTensor(m) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_BatchNorm3d_get_var(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->running_var; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_BatchNorm3d_get_batches(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->num_batches_tracked; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -void THSNN_BatchNorm3d_set_mean(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_mean = *bias; - ); -} - -void THSNN_BatchNorm3d_set_var(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_var = *bias; - ); -} - -Tensor THSNN_BatchNorm3d_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_BatchNorm3d_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_BatchNorm3d_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_BatchNorm3d_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - -void THSNN_InstanceNorm1d_reset_stats(const NNModule module) -{ - CATCH((*module)->as()->reset_running_stats();); -} - -Tensor THSNN_InstanceNorm1d_get_mean(const NNModule module) -{ - CATCH( - auto m = (*module)->as()->running_mean; - return m.defined() ? ResultTensor(m) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_InstanceNorm1d_get_var(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->running_var; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_InstanceNorm1d_get_batches(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->num_batches_tracked; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -void THSNN_InstanceNorm1d_set_mean(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_mean = *bias; - ); -} - -void THSNN_InstanceNorm1d_set_var(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_var = *bias; - ); -} - -Tensor THSNN_InstanceNorm1d_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_InstanceNorm1d_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_InstanceNorm1d_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_InstanceNorm1d_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - -void THSNN_InstanceNorm2d_reset_stats(const NNModule module) -{ - CATCH((*module)->as()->reset_running_stats();); -} - -Tensor THSNN_InstanceNorm2d_get_mean(const NNModule module) -{ - CATCH( - auto m = (*module)->as()->running_mean; - return m.defined() ? ResultTensor(m) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_InstanceNorm2d_get_var(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->running_var; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_InstanceNorm2d_get_batches(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->num_batches_tracked; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -void THSNN_InstanceNorm2d_set_mean(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_mean = *bias; - ); -} - -void THSNN_InstanceNorm2d_set_var(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_var = *bias; - ); -} - -Tensor THSNN_InstanceNorm2d_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_InstanceNorm2d_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_InstanceNorm2d_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_InstanceNorm2d_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - -void THSNN_InstanceNorm3d_reset_stats(const NNModule module) -{ - CATCH((*module)->as()->reset_running_stats();); -} - -Tensor THSNN_InstanceNorm3d_get_mean(const NNModule module) -{ - CATCH( - auto m = (*module)->as()->running_mean; - return m.defined() ? ResultTensor(m) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_InstanceNorm3d_get_var(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->running_var; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -Tensor THSNN_InstanceNorm3d_get_batches(const NNModule module) -{ - CATCH( - auto v = (*module)->as()->num_batches_tracked; - return v.defined() ? ResultTensor(v) : nullptr; - ); - return nullptr; -} - -void THSNN_InstanceNorm3d_set_mean(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_mean = *bias; - ); -} - -void THSNN_InstanceNorm3d_set_var(const NNModule module, const Tensor bias) -{ - CATCH( - (*module)->as()->running_var = *bias; - ); -} - -Tensor THSNN_InstanceNorm3d_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_InstanceNorm3d_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_InstanceNorm3d_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_InstanceNorm3d_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - Tensor THSNN_batch_norm(const Tensor input, Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps) { - auto opts = torch::nn::functional::BatchNormFuncOptions() - .training(training) - .momentum(momentum) - .eps(eps); - if (weight != nullptr) opts.weight(*weight); - if (bias != nullptr) opts.bias(*bias); - CATCH_TENSOR(torch::nn::functional::batch_norm(*input, *running_mean, *running_var, opts)); + c10::optional w, b, rm, rv; + if (weight != nullptr) w.emplace(*weight); + if (bias != nullptr) b.emplace(*bias); + if (running_mean != nullptr) rm.emplace(*running_mean); + if (running_var != nullptr) rv.emplace(*running_var); + + CATCH_TENSOR(torch::batch_norm(*input, w, b, rm, rv, training, momentum, eps, false)); } Tensor THSNN_group_norm(const Tensor input, const int64_t num_groups, const Tensor weight, const Tensor bias, const double eps) diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp index 2bdc96a83..6f0e035d9 100644 --- a/src/Native/LibTorchSharp/THSTensor.cpp +++ b/src/Native/LibTorchSharp/THSTensor.cpp @@ -200,14 +200,29 @@ Tensor THSTensor_cat(const Tensor* tensors, const int length, const int64_t dim) CATCH_TENSOR(torch::cat(toTensors((torch::Tensor**)tensors, length), dim)); } -Tensor THSTensor_celu(const Tensor tensor) +Tensor THSTensor_celu(const Tensor tensor, const Scalar alpha) { - CATCH_TENSOR(torch::celu(*tensor)); + CATCH_TENSOR(torch::celu(*tensor, *alpha)); } -void THSTensor_celu_(const Tensor tensor) +void THSTensor_celu_(const Tensor tensor, const Scalar alpha) { - CATCH(torch::celu_(*tensor);); + CATCH(torch::celu_(*tensor, *alpha);); +} + +Tensor THSTensor_glu(const Tensor tensor, const int64_t dim) +{ + CATCH_TENSOR(torch::glu(*tensor, dim)); +} + +Tensor THSTensor_hardshrink(const Tensor tensor, const Scalar lambda) +{ + CATCH_TENSOR(torch::hardshrink(*tensor, *lambda)); +} + +Tensor THSTensor_softshrink(const Tensor tensor, const Scalar lambda) +{ + CATCH_TENSOR(torch::softshrink(*tensor, *lambda)); } void THSTensor_chunk(const Tensor tensor, Tensor* (*allocator)(size_t length), const int64_t chunks, const int64_t dim) @@ -561,6 +576,11 @@ Tensor THSTensor_gelu(const Tensor tensor) CATCH_TENSOR(torch::gelu(*tensor)); } +Tensor THSTensor_gelu_(const Tensor tensor) +{ + CATCH_TENSOR(torch::gelu_(*tensor)); +} + Tensor THSTensor_get1(const Tensor tensor, int64_t index) { CATCH_TENSOR((*tensor)[index]); @@ -1181,6 +1201,17 @@ void THSTensor_relu6_(const Tensor tensor) CATCH(torch::nn::functional::relu6(*tensor, torch::nn::functional::ReLU6FuncOptions().inplace(true));); } + +Tensor THSTensor_rrelu(const Tensor tensor, const double lower, const double upper) +{ + CATCH_TENSOR(torch::rrelu(*tensor, lower, upper)); +} + +void THSTensor_rrelu_(const Tensor tensor, const double lower, const double upper) +{ + CATCH(torch::rrelu_(*tensor, lower, upper);); +} + Tensor THSTensor_renorm(const Tensor tensor, const float p, const int64_t dim, const float maxnorm) { CATCH_TENSOR(tensor->renorm(p, dim, maxnorm)); @@ -1359,9 +1390,9 @@ Tensor THSTensor_slice(const Tensor tensor, int64_t dim, int64_t start, int64_t CATCH_TENSOR(tensor->slice(dim, start, finish, step)); } -Tensor THSTensor_softplus(const Tensor tensor) +Tensor THSTensor_softplus(const Tensor tensor, const Scalar beta, const Scalar threshold) { - CATCH_TENSOR(torch::softplus(*tensor)); + CATCH_TENSOR(torch::softplus(*tensor, *beta, *threshold)); } Tensor THSTensor_sort(const Tensor tensor, const int64_t dim, const bool descending, const bool stable, Tensor* indices) @@ -1873,6 +1904,17 @@ void THSTensor_transpose_(const Tensor tensor, const int64_t dim1, const int64_t CATCH(tensor->transpose_(dim1, dim2);); } +Tensor THSTensor_threshold(const Tensor tensor, const Scalar threshold, const Scalar value) +{ + CATCH_TENSOR(torch::threshold(*tensor, *threshold, *value)); +} + +void THSTensor_threshold_(const Tensor tensor, const Scalar threshold, const Scalar value) +{ + CATCH(torch::threshold_(*tensor, *threshold, *value);); +} + + Tensor THSTensor_view(const Tensor tensor, const int64_t* shape, const int length) { CATCH_TENSOR(tensor->view(at::ArrayRef(shape, length))); diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h index 6af55912b..5e60909c0 100644 --- a/src/Native/LibTorchSharp/THSTensor.h +++ b/src/Native/LibTorchSharp/THSTensor.h @@ -23,6 +23,20 @@ EXPORT_API(Tensor) THSTensor_adaptive_avg_pool3d(const Tensor tensor, const int6 EXPORT_API(Tensor) THSTensor_adaptive_avg_pool3d_backward_out(const Tensor grad_input, const Tensor grad_output, const Tensor tensor); +EXPORT_API(Tensor) THSTensor_adaptive_max_pool1d(const Tensor tensor, const int64_t* outputSize, const int outputSizeLength, Tensor* indices); + +EXPORT_API(Tensor) THSTensor_adaptive_max_pool2d(const Tensor tensor, const int64_t* outputSize, const int outputSizeLength, Tensor* indices); + +EXPORT_API(Tensor) THSTensor_adaptive_max_pool3d(const Tensor tensor, const int64_t* outputSize, const int outputSizeLength, Tensor* indices); + +EXPORT_API(Tensor) THSTensor_fractional_max_pool2d(const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, Tensor* indices); + +EXPORT_API(Tensor) THSTensor_fractional_max_pool3d(const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, Tensor* indices); + +EXPORT_API(Tensor) THSTensor_lp_pool1d(const Tensor tensor, const double norm_type, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const bool ceil_mode); + +EXPORT_API(Tensor) THSTensor_lp_pool2d(const Tensor tensor, const double norm_type, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const bool ceil_mode); + EXPORT_API(Tensor) THSTensor_add(const Tensor left, const Tensor right, const Scalar alpha); EXPORT_API(void) THSTensor_add_(const Tensor left, const Tensor right, const Scalar alpha); @@ -139,7 +153,8 @@ EXPORT_API(Tensor) THSTensor_avg_pool2d( const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, - bool count_include_pad); + bool count_include_pad, + const int64_t divisor_override); EXPORT_API(Tensor) THSTensor_avg_pool2d_backward( const Tensor grad_output, @@ -157,7 +172,8 @@ EXPORT_API(Tensor) THSTensor_avg_pool3d( const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, - bool count_include_pad); + bool count_include_pad, + const int64_t divisor_override); EXPORT_API(Tensor) THSTensor_avg_pool3d_backward( const Tensor grad_output, @@ -235,9 +251,13 @@ EXPORT_API(Tensor) THSTensor_ceil(const Tensor tensor); EXPORT_API(void) THSTensor_ceil_(const Tensor tensor); -EXPORT_API(Tensor) THSTensor_celu(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_celu(const Tensor tensor, const Scalar alpha); + +EXPORT_API(void) THSTensor_celu_(const Tensor tensor, const Scalar alpha); -EXPORT_API(void) THSTensor_celu_(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_hardshrink(const Tensor tensor, const Scalar lambda); + +EXPORT_API(Tensor) THSTensor_softshrink(const Tensor tensor, const Scalar lambda); EXPORT_API(Tensor) THSTensor_cholesky(const Tensor tensor, const bool upper); @@ -555,6 +575,9 @@ EXPORT_API(Tensor) THSTensor_ge_scalar(const Tensor left, const Scalar right); EXPORT_API(void) THSTensor_ge_scalar_(const Tensor left, const Scalar right); EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor); + +EXPORT_API(Tensor) THSTensor_glu(const Tensor tensor, const int64_t dim); EXPORT_API(Tensor) THSTensor_get1(const Tensor tensor, int64_t index); @@ -792,68 +815,53 @@ EXPORT_API(void) THSTensor_max_along_dimension(const Tensor tensor, Tensor* (*al EXPORT_API(Tensor) THSTensor_max_elementwise(const Tensor tensor, const Tensor other); -EXPORT_API(Tensor) THSTensor_max_pool1d( - const Tensor tensor, - const int64_t* kernelSize, const int kernelSizeLength, - const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength, - const int64_t* dilation, const int dilationLength, - bool ceil_mode); - -EXPORT_API(Tensor) THSTensor_max_pool2d( +EXPORT_API(Tensor) THSTensor_max_pool1d_with_indices( const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, - bool ceil_mode); + bool ceil_mode, Tensor *indices); -EXPORT_API(Tensor) THSTensor_max_pool3d( +EXPORT_API(Tensor) THSTensor_max_pool2d_with_indices( const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, - bool ceil_mode); + bool ceil_mode, Tensor* indices); -EXPORT_API(void) THSTensor_max_pool1d_with_indices( +EXPORT_API(Tensor) THSTensor_max_pool3d_with_indices( const Tensor tensor, - Tensor* (*allocator)(size_t length), const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, - bool ceil_mode); + bool ceil_mode, Tensor* indices); -EXPORT_API(void) THSTensor_max_pool2d_with_indices( +EXPORT_API(Tensor) THSTensor_max_unpool1d( const Tensor tensor, - Tensor* (*allocator)(size_t length), + const Tensor indices, const int64_t* kernelSize, const int kernelSizeLength, - const int64_t* stride, const int strideLength, + const int64_t* outputSize, const int outputSizeLength, const int64_t* padding, const int paddingLength, - const int64_t* dilation, const int dilationLength, - bool ceil_mode); + const int64_t* stride, const int strideLength); -EXPORT_API(void) THSTensor_max_pool3d_with_indices( +EXPORT_API(Tensor) THSTensor_max_unpool2d( const Tensor tensor, - Tensor* (*allocator)(size_t length), + const Tensor indices, const int64_t* kernelSize, const int kernelSizeLength, - const int64_t* stride, const int strideLength, + const int64_t* outputSize, const int outputSizeLength, const int64_t* padding, const int paddingLength, - const int64_t* dilation, const int dilationLength, - bool ceil_mode); + const int64_t* stride, const int strideLength); -EXPORT_API(Tensor) THSTensor_maxunpool2d( - const Tensor tensor, - const Tensor indices, - const int64_t* outputSize, const int outputSizeLength); - -EXPORT_API(Tensor) THSTensor_maxunpool3d( +EXPORT_API(Tensor) THSTensor_max_unpool3d( const Tensor tensor, const Tensor indices, + const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, - const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength); + const int64_t* padding, const int paddingLength, + const int64_t* stride, const int strideLength); EXPORT_API(Tensor) THSTensor_mean(const Tensor tensor); @@ -1076,17 +1084,17 @@ EXPORT_API(Tensor) THSTensor_ravel(const Tensor tensor); EXPORT_API(Tensor) THSTensor_real(const Tensor tensor); EXPORT_API(Tensor) THSTensor_reciprocal(const Tensor tensor); - EXPORT_API(void) THSTensor_reciprocal_(const Tensor tensor); EXPORT_API(Tensor) THSTensor_relu(const Tensor tensor); - EXPORT_API(void) THSTensor_relu_(const Tensor tensor); EXPORT_API(Tensor) THSTensor_relu6(const Tensor tensor); - EXPORT_API(void) THSTensor_relu6_(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_rrelu(const Tensor tensor, const double lower, const double upper); +EXPORT_API(void) THSTensor_rrelu_(const Tensor tensor, const double lower, const double upper); + EXPORT_API(Tensor) THSTensor_repeat(const Tensor tensor, const int64_t* sizes, const int length); EXPORT_API(Tensor) THSTensor_repeat_interleave(const Tensor tensor, const Tensor repeats, const int64_t dim, const int64_t output_size); @@ -1158,7 +1166,7 @@ EXPORT_API(Tensor) THSTensor_sinh(const Tensor tensor); EXPORT_API(void) THSTensor_sinh_(const Tensor tensor); -EXPORT_API(Tensor) THSTensor_softplus(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_softplus(const Tensor tensor, const Scalar beta, const Scalar threshold); EXPORT_API(Tensor) THSTensor_sort(const Tensor tensor, const int64_t dim, const bool descending, const bool stable, Tensor* indices); @@ -1313,9 +1321,11 @@ EXPORT_API(Tensor) THSTensor_tril_indices(const int64_t row, const int64_t col, EXPORT_API(Tensor) THSTensor_triu_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index); EXPORT_API(Tensor) THSTensor_transpose(const Tensor tensor, const int64_t dim1, const int64_t dim2); - EXPORT_API(void) THSTensor_transpose_(const Tensor tensor, const int64_t dim1, const int64_t dim2); +EXPORT_API(Tensor) THSTensor_threshold(const Tensor tensor, const Scalar threshold, const Scalar value); +EXPORT_API(void) THSTensor_threshold_(const Tensor tensor, const Scalar threshold, const Scalar value); + EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_x(const Tensor y, const Tensor x, int64_t dim); EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_dx(const Tensor y, const double dx, int64_t dim); diff --git a/src/Native/LibTorchSharp/THSTensorConv.cpp b/src/Native/LibTorchSharp/THSTensorConv.cpp index 10daa3e72..2783371e5 100644 --- a/src/Native/LibTorchSharp/THSTensorConv.cpp +++ b/src/Native/LibTorchSharp/THSTensorConv.cpp @@ -44,6 +44,110 @@ Tensor THSTensor_adaptive_avg_pool3d_backward_out( *tensor)); } +Tensor THSTensor_adaptive_max_pool1d(const Tensor tensor, const int64_t* outputSize, const int outputSizeLength, Tensor *indices) +{ + Tensor output = nullptr; + *indices = nullptr; + CATCH( + auto result = torch::adaptive_max_pool1d(*tensor, at::ArrayRef(outputSize, outputSizeLength)); + output = new torch::Tensor(std::get<0>(result)); + *indices = new torch::Tensor(std::get<1>(result)); + ); + return output; +} + +Tensor THSTensor_adaptive_max_pool2d(const Tensor tensor, const int64_t* outputSize, const int outputSizeLength, Tensor* indices) +{ + Tensor output = nullptr; + *indices = nullptr; + CATCH( + auto result = torch::adaptive_max_pool2d(*tensor, at::ArrayRef(outputSize, outputSizeLength)); + output = new torch::Tensor(std::get<0>(result)); + *indices = new torch::Tensor(std::get<1>(result)); + ); + return output; +} + +Tensor THSTensor_adaptive_max_pool3d(const Tensor tensor, const int64_t* outputSize, const int outputSizeLength, Tensor* indices) +{ + Tensor output = nullptr; + *indices = nullptr; + CATCH( + auto result = torch::adaptive_max_pool3d(*tensor, at::ArrayRef(outputSize, outputSizeLength)); + output = new torch::Tensor(std::get<0>(result)); + *indices = new torch::Tensor(std::get<1>(result)); + ); + return output; +} + +Tensor THSTensor_fractional_max_pool2d(const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, Tensor* indices) +{ + Tensor output = nullptr; + *indices = nullptr; + auto opts = torch::nn::functional::FractionalMaxPool2dFuncOptions(at::ArrayRef(kernelSize, kernelSizeLength)); + if (outputSizeLength > 0) + opts = opts.output_size(at::ArrayRef(outputSize, outputSizeLength)); + if (outputRatioLength > 0) + opts = opts.output_ratio(at::ArrayRef(outputRatio, outputRatioLength)); + + CATCH( + auto result = torch::nn::functional::fractional_max_pool2d_with_indices(*tensor, opts); + output = new torch::Tensor(std::get<0>(result)); + *indices = new torch::Tensor(std::get<1>(result)); + ); + return output; +} + +Tensor THSTensor_fractional_max_pool3d(const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, Tensor* indices) +{ + Tensor output = nullptr; + *indices = nullptr; + auto opts = torch::nn::functional::FractionalMaxPool3dFuncOptions(at::ArrayRef(kernelSize, kernelSizeLength)); + if (outputSizeLength > 0) + opts = opts.output_size(at::ArrayRef(outputSize, outputSizeLength)); + if (outputRatioLength > 0) + opts = opts.output_ratio(at::ArrayRef(outputRatio, outputRatioLength)); + + CATCH( + auto result = torch::nn::functional::fractional_max_pool3d_with_indices(*tensor, opts); + output = new torch::Tensor(std::get<0>(result)); + *indices = new torch::Tensor(std::get<1>(result)); + ); + return output; +} + +Tensor THSTensor_lp_pool1d( + const Tensor tensor, + const double norm_type, + const int64_t* kernelSize, + const int kernelSizeLength, + const int64_t* stride, + const int strideLength, + const bool ceil_mode) +{ + auto opts = torch::nn::functional::LPPool1dFuncOptions(norm_type, at::ArrayRef(kernelSize, kernelSizeLength)).ceil_mode(ceil_mode); + if (strideLength > 0) + opts = opts.stride(at::ArrayRef(stride, strideLength)); + opts.ceil_mode(); + CATCH_TENSOR(torch::nn::functional::lp_pool1d(*tensor, opts)); +} + +Tensor THSTensor_lp_pool2d( + const Tensor tensor, + const double norm_type, + const int64_t* kernelSize, + const int kernelSizeLength, + const int64_t* stride, + const int strideLength, + const bool ceil_mode) +{ + auto opts = torch::nn::functional::LPPool2dFuncOptions(norm_type, at::ArrayRef(kernelSize, kernelSizeLength)).ceil_mode(ceil_mode); + if (strideLength > 0) + opts = opts.stride(at::ArrayRef(stride, strideLength)); + opts.ceil_mode(); + CATCH_TENSOR(torch::nn::functional::lp_pool2d(*tensor, opts)); +} + Tensor THSTensor_avg_pool1d( const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, @@ -67,7 +171,8 @@ Tensor THSTensor_avg_pool2d( const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, - bool count_include_pad) + bool count_include_pad, + const int64_t divisor_override) { CATCH_TENSOR(torch::avg_pool2d( *tensor, @@ -75,7 +180,8 @@ Tensor THSTensor_avg_pool2d( at::ArrayRef(stride, strideLength), at::ArrayRef(padding, paddingLength), ceil_mode, - count_include_pad)); + count_include_pad, + (divisor_override == 0 ? c10::nullopt : c10::optional(divisor_override)))); } Tensor THSTensor_avg_pool2d_backward( @@ -96,7 +202,7 @@ Tensor THSTensor_avg_pool2d_backward( at::ArrayRef(padding, paddingLength), ceil_mode, count_include_pad, - (divisor_override == 0 ? c10::optional() : c10::optional(divisor_override)))); + (divisor_override == 0 ? c10::nullopt : c10::optional(divisor_override)))); } Tensor THSTensor_avg_pool3d( @@ -105,7 +211,8 @@ Tensor THSTensor_avg_pool3d( const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, - bool count_include_pad) + bool count_include_pad, + const int64_t divisor_override) { CATCH_TENSOR(torch::avg_pool3d( *tensor, @@ -113,7 +220,8 @@ Tensor THSTensor_avg_pool3d( at::ArrayRef(stride, strideLength), at::ArrayRef(padding, paddingLength), ceil_mode, - count_include_pad)); + count_include_pad, + (divisor_override == 0 ? c10::nullopt : c10::optional(divisor_override)))); } Tensor THSTensor_avg_pool3d_backward( @@ -232,33 +340,16 @@ Tensor THSTensor_conv3d( groups)); } - -Tensor THSTensor_max_pool1d( - const Tensor tensor, - const int64_t* kernelSize, const int kernelSizeLength, - const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength, - const int64_t* dilation, const int dilationLength, - bool ceil_mode) -{ - CATCH_TENSOR(torch::max_pool1d( - *tensor, - at::ArrayRef(kernelSize, kernelSizeLength), - at::ArrayRef(stride, strideLength), - at::ArrayRef(padding, paddingLength), - at::ArrayRef(dilation, dilationLength), - ceil_mode)); -} - -void THSTensor_max_pool1d_with_indices( +Tensor THSTensor_max_pool1d_with_indices( const Tensor tensor, - Tensor* (*allocator)(size_t length), const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, - bool ceil_mode) + bool ceil_mode, Tensor *indices) { + Tensor output = nullptr; + *indices = nullptr; CATCH( auto res = torch::max_pool1d_with_indices( *tensor, @@ -268,38 +359,22 @@ void THSTensor_max_pool1d_with_indices( at::ArrayRef(dilation, dilationLength), ceil_mode); - Tensor * result = allocator(2); - result[0] = new torch::Tensor(std::get<0>(res)); - result[1] = new torch::Tensor(std::get<1>(res)); + output = new torch::Tensor(std::get<0>(res)); + *indices = new torch::Tensor(std::get<1>(res)); ) + return output; } -Tensor THSTensor_max_pool2d( +Tensor THSTensor_max_pool2d_with_indices( const Tensor tensor, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, - bool ceil_mode) -{ - CATCH_TENSOR(torch::max_pool2d( - *tensor, - at::ArrayRef(kernelSize, kernelSizeLength), - at::ArrayRef(stride, strideLength), - at::ArrayRef(padding, paddingLength), - at::ArrayRef(dilation, dilationLength), - ceil_mode)); -} - -void THSTensor_max_pool2d_with_indices( - const Tensor tensor, - Tensor* (*allocator)(size_t length), - const int64_t* kernelSize, const int kernelSizeLength, - const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength, - const int64_t* dilation, const int dilationLength, - bool ceil_mode) + bool ceil_mode, Tensor* indices) { + Tensor output = nullptr; + *indices = nullptr; CATCH( auto res = torch::max_pool2d_with_indices( *tensor, @@ -308,38 +383,22 @@ void THSTensor_max_pool2d_with_indices( at::ArrayRef(padding, paddingLength), at::ArrayRef(dilation, dilationLength), ceil_mode); - Tensor * result = allocator(2); - result[0] = new torch::Tensor(std::get<0>(res)); - result[1] = new torch::Tensor(std::get<1>(res)); + output = new torch::Tensor(std::get<0>(res)); + *indices = new torch::Tensor(std::get<1>(res)); ) + return output; } -Tensor THSTensor_max_pool3d( - const Tensor tensor, - const int64_t* kernelSize, const int kernelSizeLength, - const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength, - const int64_t* dilation, const int dilationLength, - bool ceil_mode) -{ - CATCH_TENSOR(torch::max_pool3d( - *tensor, - at::ArrayRef(kernelSize, kernelSizeLength), - at::ArrayRef(stride, strideLength), - at::ArrayRef(padding, paddingLength), - at::ArrayRef(dilation, dilationLength), - ceil_mode)); -} - -void THSTensor_max_pool3d_with_indices( +Tensor THSTensor_max_pool3d_with_indices( const Tensor tensor, - Tensor* (*allocator)(size_t length), const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, - bool ceil_mode) + bool ceil_mode, Tensor* indices) { + Tensor output = nullptr; + *indices = nullptr; CATCH( auto res = torch::max_pool3d_with_indices( *tensor, @@ -348,36 +407,70 @@ void THSTensor_max_pool3d_with_indices( at::ArrayRef(padding, paddingLength), at::ArrayRef(dilation, dilationLength), ceil_mode); - Tensor * result = allocator(2); - result[0] = new torch::Tensor(std::get<0>(res)); - result[1] = new torch::Tensor(std::get<1>(res)); + output = new torch::Tensor(std::get<0>(res)); + *indices = new torch::Tensor(std::get<1>(res)); ) + return output; } -Tensor THSTensor_maxunpool2d( +Tensor THSTensor_max_unpool1d( const Tensor tensor, const Tensor indices, - const int64_t* outputSize, const int outputSizeLength) + const int64_t* kernelSize, const int kernelSizeLength, + const int64_t* outputSize, const int outputSizeLength, + const int64_t* padding, const int paddingLength, + const int64_t* stride, const int strideLength) { - CATCH_TENSOR(torch::max_unpool2d( - *tensor, - *indices, - at::ArrayRef(outputSize, outputSizeLength))); + + auto opts = torch::nn::functional::MaxUnpool1dFuncOptions(at::IntArrayRef(kernelSize, kernelSizeLength)); + if (outputSizeLength > 0) + opts = opts.output_size(std::vector(outputSize, outputSize + outputSizeLength)); + if (paddingLength > 0) + opts = opts.padding(at::IntArrayRef(padding, paddingLength)); + if (paddingLength > 0) + opts = opts.stride(at::IntArrayRef(stride, strideLength)); + + CATCH_TENSOR(torch::nn::functional::max_unpool1d(*tensor, *indices, opts)); } -Tensor THSTensor_maxunpool3d( + +Tensor THSTensor_max_unpool2d( const Tensor tensor, const Tensor indices, + const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, - const int64_t* stride, const int strideLength, - const int64_t* padding, const int paddingLength) + const int64_t* padding, const int paddingLength, + const int64_t* stride, const int strideLength) { - CATCH_TENSOR(torch::max_unpool3d( - *tensor, - *indices, - at::ArrayRef(outputSize, outputSizeLength), - at::ArrayRef(stride, strideLength), - at::ArrayRef(padding, paddingLength))); + + auto opts = torch::nn::functional::MaxUnpool2dFuncOptions(at::IntArrayRef(kernelSize, kernelSizeLength)); + if (outputSizeLength > 0) + opts = opts.output_size(std::vector(outputSize, outputSize + outputSizeLength)); + if (paddingLength > 0) + opts = opts.padding(at::IntArrayRef(padding, paddingLength)); + if (paddingLength > 0) + opts = opts.stride(at::IntArrayRef(stride, strideLength)); + + CATCH_TENSOR(torch::nn::functional::max_unpool2d(*tensor, *indices, opts)); +} + +Tensor THSTensor_max_unpool3d( + const Tensor tensor, + const Tensor indices, + const int64_t* kernelSize, const int kernelSizeLength, + const int64_t* outputSize, const int outputSizeLength, + const int64_t* padding, const int paddingLength, + const int64_t* stride, const int strideLength) +{ + auto opts = torch::nn::functional::MaxUnpool3dFuncOptions(at::IntArrayRef(kernelSize, kernelSizeLength)); + if (outputSizeLength > 0) + opts = opts.output_size(std::vector(outputSize, outputSize + outputSizeLength)); + if (paddingLength > 0) + opts = opts.padding(at::IntArrayRef(padding, paddingLength)); + if (paddingLength > 0) + opts = opts.stride(at::IntArrayRef(stride, strideLength)); + + CATCH_TENSOR(torch::nn::functional::max_unpool3d(*tensor, *indices, opts)); } diff --git a/src/TorchSharp/NN/Activation/CELU.cs b/src/TorchSharp/NN/Activation/CELU.cs index 0af37f6b6..685919099 100644 --- a/src/TorchSharp/NN/Activation/CELU.cs +++ b/src/TorchSharp/NN/Activation/CELU.cs @@ -12,27 +12,21 @@ namespace Modules /// /// This class is used to represent a CELU module. /// - public sealed class CELU : torch.nn.Module + public sealed class CELU : ParamLessModule { - internal CELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal CELU(double alpha, bool inplace) : base(nameof(CELU)) { - var res = THSNN_CELU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.alpha = alpha; + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(CELU).Name; + return torch.nn.functional.celu(tensor, alpha, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public double alpha {get; set;} + public bool inplace {get; set; } } } @@ -48,9 +42,7 @@ public static partial class nn /// public static CELU CELU(double alpha = 1.0, bool inplace = false) { - var handle = THSNN_CELU_ctor(alpha, inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new CELU(handle, boxedHandle); + return new CELU(alpha, inplace); } public static partial class functional @@ -64,9 +56,7 @@ public static partial class functional /// public static Tensor celu(Tensor x, double alpha, bool inplace = false) { - using (var m = nn.CELU(alpha, inplace)) { - return m.call(x); - } + return inplace ? x.celu_(alpha).alias() : x.celu(alpha); } } } diff --git a/src/TorchSharp/NN/Activation/ELU.cs b/src/TorchSharp/NN/Activation/ELU.cs index 365280f87..078e29e2f 100644 --- a/src/TorchSharp/NN/Activation/ELU.cs +++ b/src/TorchSharp/NN/Activation/ELU.cs @@ -12,27 +12,22 @@ namespace Modules /// /// This class is used to represent a ELU module. /// - public sealed class ELU : torch.nn.Module + public sealed class ELU : ParamLessModule { - internal ELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ELU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + internal ELU(double alpha, bool inplace) : base(nameof(ELU)) + { + this.alpha = alpha; + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(ELU).Name; + return torch.nn.functional.elu(tensor, alpha, inplace); } + + public double alpha {get; set;} - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set;} } } @@ -48,9 +43,7 @@ public static partial class nn /// public static ELU ELU(double alpha = 1.0, bool inplace = false) { - var handle = THSNN_ELU_ctor(alpha, inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ELU(handle, boxedHandle); + return new ELU(alpha, inplace); } public static partial class functional @@ -64,9 +57,7 @@ public static partial class functional /// public static Tensor elu(Tensor x, double alpha, bool inplace = false) { - using (var m = nn.ELU(alpha, inplace)) { - return m.call(x); - } + return inplace ? x.elu_(alpha).alias() : x.elu(alpha); } } } diff --git a/src/TorchSharp/NN/Activation/GELU.cs b/src/TorchSharp/NN/Activation/GELU.cs index 71dda326e..7ccb08c8c 100644 --- a/src/TorchSharp/NN/Activation/GELU.cs +++ b/src/TorchSharp/NN/Activation/GELU.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a GELU module. /// - public sealed class GELU : torch.nn.Module + public sealed class GELU : ParamLessModule { - internal GELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal GELU(bool inplace) : base(nameof(GELU)) { - var res = THSNN_GELU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(GELU).Name; + return torch.nn.functional.gelu(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } @@ -43,12 +35,10 @@ public static partial class nn /// /// Gaussian Error Linear Units /// - /// - public static GELU GELU() + /// Do the operation in-place. Default: False + public static GELU GELU(bool inplace = false) { - var handle = THSNN_GELU_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new GELU(handle, boxedHandle); + return new GELU(inplace); } public static partial class functional @@ -57,12 +47,10 @@ public static partial class functional /// Gaussian Error Linear Units /// /// The input tensor - /// - public static Tensor gelu(Tensor x) + /// Do the operation in-place. Default: False + public static Tensor gelu(Tensor x, bool inplace = false) { - using (var m = nn.GELU()) { - return m.call(x); - } + return inplace ? x.gelu_().alias() : x.gelu(); } } } diff --git a/src/TorchSharp/NN/Activation/GLU.cs b/src/TorchSharp/NN/Activation/GLU.cs index 15e717272..4fd759208 100644 --- a/src/TorchSharp/NN/Activation/GLU.cs +++ b/src/TorchSharp/NN/Activation/GLU.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a GLU (gated linear unit) module. /// - public sealed class GLU : torch.nn.Module + public sealed class GLU : ParamLessModule { - internal GLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) - { - var res = THSNN_GLU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + internal GLU(long dim) : base(nameof(GLU)) + { + this.dim = dim; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(GLU).Name; + return torch.nn.functional.glu(tensor, dim); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long dim {get; set;} } } @@ -47,9 +39,7 @@ public static partial class nn /// public static GLU GLU(long dim = -1) { - var handle = THSNN_GLU_ctor(dim, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new GLU(handle, boxedHandle); + return new GLU(dim); } public static partial class functional @@ -62,9 +52,7 @@ public static partial class functional /// public static Tensor glu(Tensor input, long dim = -1) { - using (var m = nn.GLU(dim)) { - return m.call(input); - } + return input.glu(dim); } } } diff --git a/src/TorchSharp/NN/Activation/Hardshrink.cs b/src/TorchSharp/NN/Activation/Hardshrink.cs index b3f2684d3..57b692b68 100644 --- a/src/TorchSharp/NN/Activation/Hardshrink.cs +++ b/src/TorchSharp/NN/Activation/Hardshrink.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Hardshrink module. /// - public sealed class Hardshrink : torch.nn.Module + public sealed class Hardshrink : ParamLessModule { - internal Hardshrink(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) - { - var res = THSNN_Hardshrink_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + internal Hardshrink(double lambda = 0.5) : base(nameof(Hardshrink)) + { + this.lambda = lambda; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Hardshrink).Name; + return torch.nn.functional.hardshrink(tensor, lambda); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public double lambda {get; set; } } } @@ -47,9 +39,7 @@ public static partial class nn /// public static Hardshrink Hardshrink(double lambda = 0.5) { - var handle = THSNN_Hardshrink_ctor(lambda, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Hardshrink(handle, boxedHandle); + return new Hardshrink(lambda); } public static partial class functional @@ -60,11 +50,12 @@ public static partial class functional /// The input tensor /// The λ value for the Hardshrink formulation. Default: 0.5 /// - public static Tensor Hardshrink(Tensor x, double lambda = 0.5) + public static Tensor hardshrink(Tensor x, double lambda = 0.5) { - using (var m = nn.Hardshrink(lambda)) { - return m.call(x); - } + using var sc = (Scalar)lambda; + var result = THSTensor_hardshrink(x.Handle, sc.Handle); + if (result == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(result); } } } diff --git a/src/TorchSharp/NN/Activation/Hardsigmoid.cs b/src/TorchSharp/NN/Activation/Hardsigmoid.cs index c4f354cf2..74ef6af56 100644 --- a/src/TorchSharp/NN/Activation/Hardsigmoid.cs +++ b/src/TorchSharp/NN/Activation/Hardsigmoid.cs @@ -11,30 +11,19 @@ namespace Modules /// /// This class is used to represent a Hardsigmoid module. /// - public sealed class Hardsigmoid : torch.nn.Module + public sealed class Hardsigmoid : ParamLessModule { - private readonly bool inplace; - - internal Hardsigmoid(bool inplace = false) : base(nameof(Hardsigmoid)) + internal Hardsigmoid(bool inplace) : base(nameof(Hardsigmoid)) { this.inplace = inplace; } public override Tensor forward(Tensor tensor) { - return torch.nn.functional.hardsigmoid(tensor, this.inplace); - } - - public override string GetName() - { - return typeof(Hardsigmoid).Name; + return torch.nn.functional.hardsigmoid(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } diff --git a/src/TorchSharp/NN/Activation/Hardswish.cs b/src/TorchSharp/NN/Activation/Hardswish.cs index 6f8b401da..c9d39107a 100644 --- a/src/TorchSharp/NN/Activation/Hardswish.cs +++ b/src/TorchSharp/NN/Activation/Hardswish.cs @@ -11,9 +11,9 @@ namespace Modules /// /// This class is used to represent a Hardswish module. /// - public sealed class Hardswish : torch.nn.Module + public sealed class Hardswish : ParamLessModule { - private readonly bool inplace; + public bool inplace { get; set;} internal Hardswish(bool inplace = false) : base(nameof(Hardswish)) { @@ -24,17 +24,6 @@ public override Tensor forward(Tensor tensor) { return torch.nn.functional.hardswish(tensor, this.inplace); } - - public override string GetName() - { - return typeof(Hardswish).Name; - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/Hardtanh.cs b/src/TorchSharp/NN/Activation/Hardtanh.cs index f7236a72a..b245fec9d 100644 --- a/src/TorchSharp/NN/Activation/Hardtanh.cs +++ b/src/TorchSharp/NN/Activation/Hardtanh.cs @@ -12,15 +12,18 @@ namespace Modules /// /// This class is used to represent a Hardtanh module. /// - public sealed class Hardtanh : torch.nn.Module + public sealed class Hardtanh : ParamLessModule { - internal Hardtanh(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Hardtanh(double min_val = -1.0, double max_val = 1.0, bool inplace = false) : base(nameof(Hardtanh)) + { + this.min_val = min_val; + this.max_val = max_val; + this.inplace = inplace; + } public override Tensor forward(Tensor tensor) { - var res = THSNN_Hardtanh_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.hardtanh(tensor, min_val, max_val, inplace); } public override string GetName() @@ -28,11 +31,9 @@ public override string GetName() return typeof(Hardtanh).Name; } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public double min_val { get; set; } + public double max_val { get; set; } + public bool inplace {get; set; } } } @@ -49,9 +50,7 @@ public static partial class nn /// public static Hardtanh Hardtanh(double min_val = -1.0, double max_val = 1.0, bool inplace = false) { - var handle = THSNN_Hardtanh_ctor(min_val, max_val, inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Hardtanh(handle, boxedHandle); + return new Hardtanh(min_val, max_val, inplace); } public static partial class functional @@ -64,7 +63,7 @@ public static partial class functional /// Maximum value of the linear region range. /// Do the operation in-place /// - public static Tensor Hardtanh(Tensor x, double min_val = -1.0, double max_val = 1.0, bool inplace = false) + public static Tensor hardtanh(Tensor x, double min_val = -1.0, double max_val = 1.0, bool inplace = false) { return inplace ? x.hardtanh_(min_val, max_val).alias() : x.hardtanh(min_val, max_val); } diff --git a/src/TorchSharp/NN/Activation/LeakyReLU.cs b/src/TorchSharp/NN/Activation/LeakyReLU.cs index 51ec4ae31..f40866190 100644 --- a/src/TorchSharp/NN/Activation/LeakyReLU.cs +++ b/src/TorchSharp/NN/Activation/LeakyReLU.cs @@ -12,27 +12,21 @@ namespace Modules /// /// This class is used to represent a LeakyReLU module. /// - public sealed class LeakyReLU : torch.nn.Module + public sealed class LeakyReLU : ParamLessModule { - internal LeakyReLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal LeakyReLU(double negative_slope, bool inplace) : base(nameof(LeakyReLU)) { - var res = THSNN_LeakyReLU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; + this.negative_slope = negative_slope; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(LeakyReLU).Name; + return torch.nn.functional.leaky_relu(tensor, negative_slope, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } + public double negative_slope {get; set;} } } @@ -48,9 +42,7 @@ public static partial class nn /// public static LeakyReLU LeakyReLU(double negative_slope = 0.01, bool inplace = false) { - var handle = THSNN_LeakyReLU_ctor(negative_slope, inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new LeakyReLU(handle, boxedHandle); + return new LeakyReLU(negative_slope, inplace); } public static partial class functional @@ -64,9 +56,7 @@ public static partial class functional /// public static Tensor leaky_relu(Tensor input, double negative_slope = 0.01, bool inplace = false) { - using (var m = nn.LeakyReLU(negative_slope, inplace)) { - return m.call(input); - } + return inplace ? input.leaky_relu_(negative_slope).alias() : input.leaky_relu(negative_slope); } } } diff --git a/src/TorchSharp/NN/Activation/LogSigmoid.cs b/src/TorchSharp/NN/Activation/LogSigmoid.cs index 802784356..d9aaf1acd 100644 --- a/src/TorchSharp/NN/Activation/LogSigmoid.cs +++ b/src/TorchSharp/NN/Activation/LogSigmoid.cs @@ -12,25 +12,16 @@ namespace Modules /// /// This class is used to represent a LogSigmoid module. /// - public sealed class LogSigmoid : torch.nn.Module + public sealed class LogSigmoid : ParamLessModule { - internal LogSigmoid() : base(nameof(LogSigmoid)) { } - - public override Tensor forward(Tensor tensor) + internal LogSigmoid() : base(nameof(LogSigmoid)) { - return tensor.log_sigmoid(); } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(LogSigmoid).Name; + return torch.nn.functional.logsigmoid(tensor); } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; } } public static partial class torch diff --git a/src/TorchSharp/NN/Activation/LogSoftMax.cs b/src/TorchSharp/NN/Activation/LogSoftMax.cs index 302627c38..c4889174a 100644 --- a/src/TorchSharp/NN/Activation/LogSoftMax.cs +++ b/src/TorchSharp/NN/Activation/LogSoftMax.cs @@ -12,24 +12,19 @@ namespace Modules /// /// This class is used to represent a log softmax module. /// - public sealed class LogSoftmax : torch.nn.Module + public sealed class LogSoftmax : ParamLessModule { - internal LogSoftmax(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal LogSoftmax(long dim) : base(nameof(LogSoftmax)) { + this.dim = dim; } public override Tensor forward(Tensor tensor) { - var res = THSNN_LogSoftmax_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.log_softmax(tensor, dim); } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + + public long dim { get; set; } } } @@ -39,9 +34,7 @@ public static partial class nn { public static LogSoftmax LogSoftmax(long dim) { - var handle = THSNN_LogSoftmax_ctor(dim, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new LogSoftmax(handle, boxedHandle); + return new LogSoftmax(dim); } public static partial class functional diff --git a/src/TorchSharp/NN/Activation/Mish.cs b/src/TorchSharp/NN/Activation/Mish.cs index eb57c9914..cc00fe288 100644 --- a/src/TorchSharp/NN/Activation/Mish.cs +++ b/src/TorchSharp/NN/Activation/Mish.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Mish module. /// - public sealed class Mish : torch.nn.Module + public sealed class Mish : ParamLessModule { - internal Mish(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal Mish(bool inplace) : base(nameof(Mish)) { - var res = THSNN_Mish_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Mish).Name; + return torch.nn.functional.mish(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } @@ -43,12 +35,10 @@ public static partial class nn /// /// A Self Regularized Non-Monotonic Neural Activation Function. /// - /// - public static Mish Mish() + /// Do the operation in-place. Default: False + public static Mish Mish(bool inplace = false) { - var handle = THSNN_Mish_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Mish(handle, boxedHandle); + return new Mish(inplace); } public static partial class functional @@ -57,12 +47,12 @@ public static partial class functional /// A Self Regularized Non-Monotonic Neural Activation Function. /// /// The input tensor - /// - public static Tensor Mish(Tensor x) + /// Do the operation in-place. Default: False + public static Tensor mish(Tensor x, bool inplace = false) { - using (var m = nn.Mish()) { - return m.call(x); - } + using var t1 = softplus(x); + using var t2 = t1.tanh(); + return inplace ? x.mul_(t2).alias() : x.mul(t2); } } } diff --git a/src/TorchSharp/NN/Activation/PReLU.cs b/src/TorchSharp/NN/Activation/PReLU.cs index 3c8d666f5..2b48b4a6b 100644 --- a/src/TorchSharp/NN/Activation/PReLU.cs +++ b/src/TorchSharp/NN/Activation/PReLU.cs @@ -7,6 +7,7 @@ namespace TorchSharp { using Modules; + using TorchSharp.Utils; namespace Modules { @@ -15,13 +16,20 @@ namespace Modules /// public sealed class PReLU : torch.nn.Module { - internal PReLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal PReLU(long num_parameters, double init, Device? device = null, ScalarType? dtype = null) : base(nameof(PReLU)) + { + this.init = init; + this.num_parameters = num_parameters; + + var w = torch.empty(num_parameters, device:device, dtype:dtype); + w.fill_(init); + + this.weight = new Parameter(w); + } public override Tensor forward(Tensor tensor) { - var res = THSNN_PReLU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.prelu(tensor, weight); } public override string GetName() @@ -29,19 +37,35 @@ public override string GetName() return typeof(PReLU).Name; } - public Parameter? weight { - get { - var res = THSNN_PReLU_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } + public Parameter weight { + get => _weight!; set { - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_PReLU_set_weight(handle, value!.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); + if (value is null) throw new ArgumentNullException(nameof(weight)); + if (value.Handle != _weight?.Handle) { + _weight?.Dispose(); + _weight = (value.DetachFromDisposeScope() as Parameter)!; + ConditionallyRegisterParameter(nameof(weight), _weight); + } + } + } + + public long num_parameters { + get; private set; + } + + public double init { + get; private set; + } + + protected override void Dispose(bool disposing) + { + if (disposing) { + _weight?.Dispose(); } } + + [ComponentName(Name = nameof(weight))] + private Parameter? _weight; } } @@ -61,9 +85,7 @@ public static partial class nn /// The desired floating point or complex dtype of the parameters and buffers in this module public static PReLU PReLU(long num_parameters, double init = 0.25, Device? device = null, ScalarType? dtype = null) { - var handle = THSNN_PReLU_ctor(num_parameters, init, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new PReLU(handle, boxedHandle).MoveModule(device, dtype); + return new PReLU(num_parameters, init).MoveModule(device, dtype); } public static partial class functional diff --git a/src/TorchSharp/NN/Activation/RReLU.cs b/src/TorchSharp/NN/Activation/RReLU.cs index 289664874..1aea1e23c 100644 --- a/src/TorchSharp/NN/Activation/RReLU.cs +++ b/src/TorchSharp/NN/Activation/RReLU.cs @@ -12,27 +12,23 @@ namespace Modules /// /// This class is used to represent a RReLU module. /// - public sealed class RReLU : torch.nn.Module + public sealed class RReLU : ParamLessModule { - internal RReLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal RReLU(double lower, double upper, bool inplace) : base(nameof(RReLU)) { - var res = THSNN_RReLU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.lower = lower; + this.upper = upper; + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(RReLU).Name; + return torch.nn.functional.rrelu(tensor, lower, upper, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public double lower {get; set;} + public double upper {get; set;} + public bool inplace {get; set;} } } @@ -49,9 +45,7 @@ public static partial class nn /// public static RReLU RReLU(double lower = one_eighth, double upper = one_third, bool inplace = false) { - var handle = THSNN_RReLU_ctor(lower, upper, inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new RReLU(handle, boxedHandle); + return new RReLU(lower, upper, inplace); } private const double one_eighth = 1.0 / 8.0; @@ -67,11 +61,9 @@ public static partial class functional /// Upper bound of the uniform distribution. Default: 1/3 /// Do the operation in-place. Default: False /// - public static Tensor rrelu(Tensor x, double lower, double upper, bool inplace = false) + public static Tensor rrelu(Tensor x, double lower = one_eighth, double upper = one_third, bool inplace = false) { - using (var m = nn.RReLU(lower, upper, inplace)) { - return m.call(x); - } + return inplace ? x.rrelu_(lower, upper).alias() : x.rrelu(lower, upper); } } } diff --git a/src/TorchSharp/NN/Activation/ReLU6.cs b/src/TorchSharp/NN/Activation/ReLU6.cs index 1b89d60da..3ebb2c67d 100644 --- a/src/TorchSharp/NN/Activation/ReLU6.cs +++ b/src/TorchSharp/NN/Activation/ReLU6.cs @@ -14,27 +14,20 @@ namespace Modules /// /// This class is used to represent a ReLU6 module. /// - public sealed class ReLU6 : torch.nn.Module + public sealed class ReLU6 : ParamLessModule { - internal ReLU6(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal ReLU6(bool inplace) : base(nameof(ReLU6)) { - var res = NativeMethods.THSNN_ReLU6_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + + public override Tensor forward(Tensor tensor) { - return typeof(ReLU6).Name; + return torch.nn.functional.relu6(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } @@ -51,9 +44,7 @@ public static partial class nn /// public static ReLU6 ReLU6(bool inplace = false) { - var handle = NativeMethods.THSNN_ReLU6_ctor(inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReLU6(handle, boxedHandle); + return new ReLU6(inplace); } public static partial class functional @@ -68,9 +59,7 @@ public static partial class functional /// public static Tensor relu6(Tensor x, bool inplace = false) { - using (var m = nn.ReLU6(inplace)) { - return m.call(x); - } + return inplace ? x.relu6_().alias() : x.relu6(); } } } diff --git a/src/TorchSharp/NN/Activation/ReLu.cs b/src/TorchSharp/NN/Activation/ReLu.cs index 050a03235..68a16ea04 100644 --- a/src/TorchSharp/NN/Activation/ReLu.cs +++ b/src/TorchSharp/NN/Activation/ReLu.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a ReLU module. /// - public sealed class ReLU : torch.nn.Module + public sealed class ReLU : ParamLessModule { - internal ReLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal ReLU(bool inplace) : base(nameof(ReLU)) { - var res = THSNN_ReLU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(ReLU).Name; + return torch.nn.functional.relu(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } public static partial class torch @@ -46,9 +38,7 @@ public static partial class nn /// public static ReLU ReLU(bool inplace = false) { - var handle = THSNN_ReLU_ctor(inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReLU(handle, boxedHandle); + return new ReLU(inplace); } public static partial class functional diff --git a/src/TorchSharp/NN/Activation/SELU.cs b/src/TorchSharp/NN/Activation/SELU.cs index 774ab5a24..b75278bd3 100644 --- a/src/TorchSharp/NN/Activation/SELU.cs +++ b/src/TorchSharp/NN/Activation/SELU.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a SELU module. /// - public sealed class SELU : torch.nn.Module + public sealed class SELU : ParamLessModule { - internal SELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal SELU(bool inplace) : base(nameof(SELU)) { - var res = THSNN_SELU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(SELU).Name; + return torch.nn.functional.selu(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } @@ -47,9 +39,7 @@ public static partial class nn /// public static SELU SELU(bool inplace = false) { - var handle = THSNN_SELU_ctor(inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new SELU(handle, boxedHandle); + return new SELU(inplace); } public static partial class functional diff --git a/src/TorchSharp/NN/Activation/SiLU.cs b/src/TorchSharp/NN/Activation/SiLU.cs index 051675f17..5abbc74ea 100644 --- a/src/TorchSharp/NN/Activation/SiLU.cs +++ b/src/TorchSharp/NN/Activation/SiLU.cs @@ -12,15 +12,16 @@ namespace Modules /// /// This class is used to represent a SiLU module. /// - public sealed class SiLU : torch.nn.Module + public sealed class SiLU : ParamLessModule { - internal SiLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal SiLU(bool inplace) : base(nameof(SiLU)) + { + this.inplace = inplace; + } public override Tensor forward(Tensor tensor) { - var res = THSNN_SiLU_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.silu(tensor, inplace); } public override string GetName() @@ -28,11 +29,7 @@ public override string GetName() return typeof(SiLU).Name; } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } public static partial class torch @@ -42,13 +39,9 @@ public static partial class nn /// /// Sigmoid-Weighted Linear Unit /// - /// - /// The native libreary does not take an 'inplace' option, even though the PyTorch documentation mentions the parameter. - public static SiLU SiLU() + public static SiLU SiLU(bool inplace = false) { - var handle = THSNN_SiLU_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new SiLU(handle, boxedHandle); + return new SiLU(inplace); } public static partial class functional diff --git a/src/TorchSharp/NN/Activation/Sigmoid.cs b/src/TorchSharp/NN/Activation/Sigmoid.cs index 1513f06f5..4981c814a 100644 --- a/src/TorchSharp/NN/Activation/Sigmoid.cs +++ b/src/TorchSharp/NN/Activation/Sigmoid.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Sigmoid module. /// - public sealed class Sigmoid : torch.nn.Module + public sealed class Sigmoid : ParamLessModule { - internal Sigmoid(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal Sigmoid(bool inplace) : base(nameof(Sigmoid)) { - var res = THSNN_Sigmoid_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Sigmoid).Name; + return torch.nn.functional.sigmoid(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } public static partial class torch @@ -42,12 +34,11 @@ public static partial class nn /// /// Sigmoid activation /// + /// Do the operation in-place. Default: False /// - public static Sigmoid Sigmoid() + public static Sigmoid Sigmoid(bool inplace = false) { - var handle = THSNN_Sigmoid_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Sigmoid(handle, boxedHandle); + return new Sigmoid(inplace); } public static partial class functional @@ -56,10 +47,11 @@ public static partial class functional /// Sigmoid activation /// /// The input tensor + /// Do the operation in-place. Default: False /// - public static Tensor sigmoid(Tensor x) + public static Tensor sigmoid(Tensor x, bool inplace = false) { - return x.sigmoid(); + return inplace ? x.sigmoid_().alias() : x.sigmoid(); } } } diff --git a/src/TorchSharp/NN/Activation/Softmax.cs b/src/TorchSharp/NN/Activation/Softmax.cs index dc4fea3fe..32aeee6f4 100644 --- a/src/TorchSharp/NN/Activation/Softmax.cs +++ b/src/TorchSharp/NN/Activation/Softmax.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Softmax module. /// - public sealed class Softmax : torch.nn.Module + public sealed class Softmax : ParamLessModule { - internal Softmax(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) - { - var res = THSNN_Softmax_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + internal Softmax(long dim) : base(nameof(Softmax)) + { + this.dim = dim; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Softmax).Name; + return torch.nn.functional.softmax(tensor, dim); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long dim {get; set;} } } @@ -47,9 +39,7 @@ public static partial class nn /// public static Softmax Softmax(long dim) { - var handle = THSNN_Softmax_ctor(dim, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Softmax(handle, boxedHandle); + return new Softmax(dim); } public static partial class functional @@ -60,7 +50,8 @@ public static partial class functional /// The input tensor /// A dimension along which softmax will be computed. /// The desired data type of returned tensor. - public static Tensor softmax(Tensor input, long dim, ScalarType? dtype = null) => torch.special.softmax(input, dim, dtype); + public static Tensor softmax(Tensor input, long dim, ScalarType? dtype = null) => + torch.special.softmax(input, dim, dtype); } } } diff --git a/src/TorchSharp/NN/Activation/Softmax2d.cs b/src/TorchSharp/NN/Activation/Softmax2d.cs index 58014fdc9..ed3977d26 100644 --- a/src/TorchSharp/NN/Activation/Softmax2d.cs +++ b/src/TorchSharp/NN/Activation/Softmax2d.cs @@ -12,27 +12,14 @@ namespace Modules /// /// This class is used to represent a Softmax2d module. /// - public sealed class Softmax2d : torch.nn.Module + public sealed class Softmax2d : ParamLessModule { - internal Softmax2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Softmax2d() : base(nameof(Softmax2d)) { } public override Tensor forward(Tensor tensor) { - var res = THSNN_Softmax2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.softmax2d(tensor); } - - public override string GetName() - { - return typeof(Softmax2d).Name; - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; } } public static partial class torch @@ -45,9 +32,7 @@ public static partial class nn /// public static Softmax2d Softmax2d() { - var handle = THSNN_Softmax2d_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Softmax2d(handle, boxedHandle); + return new Softmax2d(); } public static partial class functional @@ -59,9 +44,7 @@ public static partial class functional /// public static Tensor softmax2d(Tensor x) { - using (var m = nn.Softmax2d()) { - return m.call(x); - } + return torch.nn.functional.softmax(x, -3); } } } diff --git a/src/TorchSharp/NN/Activation/Softmin.cs b/src/TorchSharp/NN/Activation/Softmin.cs index 9c3c4eba7..be2b91761 100644 --- a/src/TorchSharp/NN/Activation/Softmin.cs +++ b/src/TorchSharp/NN/Activation/Softmin.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Softmin module. /// - public sealed class Softmin : torch.nn.Module + public sealed class Softmin : ParamLessModule { - internal Softmin(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) - { - var res = THSNN_Softmin_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + internal Softmin(long dim) : base(nameof(Softmin)) + { + this.dim = dim; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Softmin).Name; + return torch.nn.functional.softmin(tensor, dim); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long dim {get; set;} } } @@ -47,9 +39,7 @@ public static partial class nn /// public static Softmin Softmin(long dim) { - var handle = THSNN_Softmin_ctor(dim, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Softmin(handle, boxedHandle); + return new Softmin(dim); } public static partial class functional @@ -62,9 +52,8 @@ public static partial class functional /// public static Tensor softmin(Tensor x, long dim) { - using (var m = nn.Softmin(dim)) { - return m.call(x); - } + using var minus_x = -x; + return softmax(minus_x, dim); } } } diff --git a/src/TorchSharp/NN/Activation/Softplus.cs b/src/TorchSharp/NN/Activation/Softplus.cs index d814b1042..a8651aab3 100644 --- a/src/TorchSharp/NN/Activation/Softplus.cs +++ b/src/TorchSharp/NN/Activation/Softplus.cs @@ -12,27 +12,21 @@ namespace Modules /// /// This class is used to represent a Softplus module. /// - public sealed class Softplus : torch.nn.Module + public sealed class Softplus : ParamLessModule { - internal Softplus(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal Softplus(int beta = 1, int threshold = 20) : base(nameof(Softplus)) { - var res = THSNN_Softplus_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.beta = beta; + this.threshold = threshold; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Softplus).Name; + return torch.nn.functional.softplus(tensor, beta, threshold); } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + + public int beta {get; set;} + public int threshold {get; set;} } } @@ -46,11 +40,9 @@ public static partial class nn /// The β value for the Softplus formulation. /// Values above this revert to a linear function /// - public static Softplus Softplus(double beta = 1.0, double threshold = 20.0) + public static Softplus Softplus(int beta = 1, int threshold = 20) { - var handle = THSNN_Softplus_ctor(beta, threshold, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Softplus(handle, boxedHandle); + return new Softplus(beta, threshold); } public static partial class functional @@ -62,11 +54,9 @@ public static partial class functional /// The β value for the Softplus formulation. /// Values above this revert to a linear function /// - public static Tensor softplus(Tensor x, double beta = 1.0, double threshold = 20.0) + public static Tensor softplus(Tensor x, int beta = 1, int threshold = 20) { - using (var m = nn.Softplus(beta, threshold)) { - return m.call(x); - } + return x.softplus(beta, threshold); } } } diff --git a/src/TorchSharp/NN/Activation/Softshrink.cs b/src/TorchSharp/NN/Activation/Softshrink.cs index de32e4dfb..beff46550 100644 --- a/src/TorchSharp/NN/Activation/Softshrink.cs +++ b/src/TorchSharp/NN/Activation/Softshrink.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Softshrink module. /// - public sealed class Softshrink : torch.nn.Module + public sealed class Softshrink : ParamLessModule { - internal Softshrink(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) - { - var res = THSNN_Softshrink_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + internal Softshrink(double lambda = 0.5) : base(nameof(Softshrink)) + { + this.lambda = lambda; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Softshrink).Name; + return torch.nn.functional.softshrink(tensor, lambda); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public double lambda {get; set; } } } @@ -47,9 +39,7 @@ public static partial class nn /// public static Softshrink Softshrink(double lambda = 0.5) { - var handle = THSNN_Softshrink_ctor(lambda, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Softshrink(handle, boxedHandle); + return new Softshrink(lambda); } public static partial class functional @@ -60,11 +50,12 @@ public static partial class functional /// The input tensor /// The λ value for the Softshrink formulation. Default: 0.5 /// - public static Tensor Softshrink(Tensor x, double lambda = 0.5) + public static Tensor softshrink(Tensor x, double lambda = 0.5) { - using (var m = nn.Softshrink(lambda)) { - return m.call(x); - } + using var sc = (Scalar)lambda; + var result = THSTensor_softshrink(x.Handle, sc.Handle); + if (result == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(result); } } } diff --git a/src/TorchSharp/NN/Activation/Softsign.cs b/src/TorchSharp/NN/Activation/Softsign.cs index 6dbbffe96..9ac20d39a 100644 --- a/src/TorchSharp/NN/Activation/Softsign.cs +++ b/src/TorchSharp/NN/Activation/Softsign.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Softsign module. /// - public sealed class Softsign : torch.nn.Module + public sealed class Softsign : ParamLessModule { - internal Softsign(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal Softsign(bool inplace) : base(nameof(Softsign)) { - var res = THSNN_Softsign_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Softsign).Name; + return torch.nn.functional.softsign(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } @@ -43,12 +35,10 @@ public static partial class nn /// /// Softsign /// - /// - public static Softsign Softsign() + /// Do the operation in-place. Default: False + public static Softsign Softsign(bool inplace = false) { - var handle = THSNN_Softsign_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Softsign(handle, boxedHandle); + return new Softsign(inplace); } public static partial class functional @@ -57,12 +47,12 @@ public static partial class functional /// Softsign /// /// The input tensor - /// - public static Tensor Softsign(Tensor x) + /// Do the operation in-place. Default: False + public static Tensor softsign(Tensor x, bool inplace = false) { - using (var m = nn.Softsign()) { - return m.call(x); - } + using var abs = x.abs(); + using var y = 1 + abs; + return inplace ? x.div_(y).alias() : x.div(y); } } } diff --git a/src/TorchSharp/NN/Activation/Tanh.cs b/src/TorchSharp/NN/Activation/Tanh.cs index 9ff611792..34bfb6e5a 100644 --- a/src/TorchSharp/NN/Activation/Tanh.cs +++ b/src/TorchSharp/NN/Activation/Tanh.cs @@ -12,15 +12,16 @@ namespace Modules /// /// This class is used to represent a Tanh module. /// - public sealed class Tanh : torch.nn.Module + public sealed class Tanh : ParamLessModule { - internal Tanh(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Tanh(bool inplace) : base(nameof(Tanh)) + { + this.inplace = inplace; + } public override Tensor forward(Tensor tensor) { - var res = THSNN_Tanh_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.tanh(tensor, inplace); } public override string GetName() @@ -28,11 +29,7 @@ public override string GetName() return typeof(Tanh).Name; } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } @@ -44,11 +41,9 @@ public static partial class nn /// Tanh activation /// /// - public static Tanh Tanh() + public static Tanh Tanh(bool inplace = false) { - var handle = THSNN_Tanh_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tanh(handle, boxedHandle); + return new Tanh(inplace); } public static partial class functional diff --git a/src/TorchSharp/NN/Activation/Tanhshrink.cs b/src/TorchSharp/NN/Activation/Tanhshrink.cs index 371200871..5e07d6c4c 100644 --- a/src/TorchSharp/NN/Activation/Tanhshrink.cs +++ b/src/TorchSharp/NN/Activation/Tanhshrink.cs @@ -12,27 +12,19 @@ namespace Modules /// /// This class is used to represent a Tanhshrink module. /// - public sealed class Tanhshrink : torch.nn.Module + public sealed class Tanhshrink : ParamLessModule { - internal Tanhshrink(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public override Tensor forward(Tensor tensor) + internal Tanhshrink(bool inplace) : base(nameof(Tanhshrink)) { - var res = THSNN_Tanhshrink_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + this.inplace = inplace; } - public override string GetName() + public override Tensor forward(Tensor tensor) { - return typeof(Tanhshrink).Name; + return torch.nn.functional.tanhshrink(tensor, inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set; } } } @@ -43,12 +35,10 @@ public static partial class nn /// /// Tanhshrink /// - /// - public static Tanhshrink Tanhshrink() + /// Do the operation in-place. Default: False + public static Tanhshrink Tanhshrink(bool inplace = false) { - var handle = THSNN_Tanhshrink_ctor(out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tanhshrink(handle, boxedHandle); + return new Tanhshrink(inplace); } public static partial class functional @@ -57,12 +47,11 @@ public static partial class functional /// Tanhshrink /// /// The input tensor - /// - public static Tensor Tanhshrink(Tensor x) + /// Do the operation in-place. Default: False + public static Tensor tanhshrink(Tensor x, bool inplace = false) { - using (var m = nn.Tanhshrink()) { - return m.call(x); - } + using var tanh_x = x.tanh(); + return inplace ? x.sub_(tanh_x).alias() : x.sub(tanh_x); } } } diff --git a/src/TorchSharp/NN/Activation/Threshold.cs b/src/TorchSharp/NN/Activation/Threshold.cs index cfd9ea1c7..a3139fdbb 100644 --- a/src/TorchSharp/NN/Activation/Threshold.cs +++ b/src/TorchSharp/NN/Activation/Threshold.cs @@ -12,27 +12,25 @@ namespace Modules /// /// This class is used to represent a Threshold module. /// - public sealed class Threshold : torch.nn.Module + public sealed class Threshold : ParamLessModule { - internal Threshold(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Threshold(double threshold, double value, bool inplace) : base(nameof(Threshold)) + { + this.inplace = inplace; + this.threshold = threshold; + this.value = value; + } public override Tensor forward(Tensor tensor) { - var res = THSNN_Threshold_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.threshold(tensor, threshold, value, inplace); } + + public double threshold {get; set;} - public override string GetName() - { - return typeof(Threshold).Name; - } + public double value {get; set;} - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace {get; set;} } } @@ -49,9 +47,7 @@ public static partial class nn /// public static Threshold Threshold(double threshold, double value, bool inplace = false) { - var handle = THSNN_Threshold_ctor(threshold, value, inplace, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Threshold(handle, boxedHandle); + return new Threshold(threshold, value, inplace); } public static partial class functional @@ -64,11 +60,9 @@ public static partial class functional /// The value to replace with /// Do the operation in-place /// - public static Tensor Threshold(Tensor x, double threshold, double value, bool inplace = false) + public static Tensor threshold(Tensor x, double threshold, double value, bool inplace = false) { - using (var m = nn.Threshold(threshold, value, inplace)) { - return m.call(x); - } + return inplace ? x.threshold_(threshold, value).alias() : x.threshold(threshold, value); } } } diff --git a/src/TorchSharp/NN/AlphaDropout.cs b/src/TorchSharp/NN/AlphaDropout.cs index ba9916c85..663c8994b 100644 --- a/src/TorchSharp/NN/AlphaDropout.cs +++ b/src/TorchSharp/NN/AlphaDropout.cs @@ -17,7 +17,7 @@ namespace Modules /// The elements to masked are randomized on every forward call, and scaled and shifted to maintain zero mean and unit standard deviation. /// During evaluation the module simply computes an identity function. /// - public sealed class AlphaDropout : torch.nn.Module + public sealed class AlphaDropout : ParamLessModule { internal AlphaDropout(double p = 0.5, bool inplace = false) : base(nameof(Dropout1d)) { @@ -35,16 +35,9 @@ public override Tensor forward(Tensor tensor) return torch.nn.functional.alpha_dropout(tensor, this.p, this.training, inplace); } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - - private bool inplace; - private double p; - } + public bool inplace { get; set; } + public double p { get; set;} + } } public static partial class torch diff --git a/src/TorchSharp/NN/Bilinear.cs b/src/TorchSharp/NN/Bilinear.cs index 8ba4efebb..8d281fc8c 100644 --- a/src/TorchSharp/NN/Bilinear.cs +++ b/src/TorchSharp/NN/Bilinear.cs @@ -8,49 +8,75 @@ namespace TorchSharp { using Modules; + using TorchSharp.Utils; namespace Modules { public sealed class Bilinear : Module { - internal Bilinear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + const string WeightComponentName = nameof(weight); + const string BiasComponentName = nameof(bias); + + internal Bilinear(long in1_features, long in2_features, long out_features, bool hasBias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Bilinear)) + { + this.in1_features = in1_features; + this.in2_features = in2_features; + this.out_features = out_features; + + weight = torch.empty(out_features, in1_features, in2_features, device: device, dtype: dtype).AsParameter(); + var bound = 1 / Math.Sqrt(weight!.shape[1]); + + init.uniform_(_weight, -bound, bound); + + if (hasBias) { + bias = torch.empty(out_features, device: device, dtype: dtype).AsParameter(); + init.uniform_(_bias, -bound, bound); + } + //NOTE: it's important not to call 'RegisterComponents' here. + } public override Tensor forward(Tensor input1, Tensor input2) { - var res = THSNN_Bilinear_forward(handle, input1.Handle, input2.Handle); - if (res == IntPtr.Zero) { CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.bilinear(input1, input2, _weight!, _bias); } - public Parameter? bias { - get { - var res = THSNN_Bilinear_bias(handle); - if (res == IntPtr.Zero) { CheckForErrors(); } - return ((res == IntPtr.Zero) ? null : new Parameter(res)); + protected override void Dispose(bool disposing) + { + if (disposing) { + _weight?.Dispose(); + _bias?.Dispose(); } + } + + public Parameter? bias { + get => _bias; set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_Bilinear_set_bias(handle, value?.Handle ?? IntPtr.Zero); - CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + _bias?.Dispose(); + _bias = value?.DetachFromDisposeScope() as Parameter; + ConditionallyRegisterParameter(BiasComponentName, _bias); } } - public Parameter? weight { - get { - var res = THSNN_Bilinear_weight(handle); - if (res == IntPtr.Zero) { CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } + public Parameter weight { + get => _weight!; set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_Bilinear_set_weight(handle, value?.Handle ?? IntPtr.Zero); - CheckForErrors(); - ConditionallyRegisterParameter("weight", value); + if (value is null) throw new ArgumentNullException(nameof(weight)); + if (value.Handle != _weight?.Handle) { + _weight?.Dispose(); + _weight = (value.DetachFromDisposeScope() as Parameter)!; + ConditionallyRegisterParameter(WeightComponentName, _weight); + } } } + + [ComponentName(Name = BiasComponentName)] + private Parameter? _bias; + [ComponentName(Name = WeightComponentName)] + private Parameter? _weight; + + public long in1_features { get; set; } + public long in2_features { get; set; } + public long out_features { get; set; } } } @@ -62,19 +88,16 @@ public static partial class nn /// /// Applies a bilinear transformation to the incoming data /// - /// size of each first input sample - /// size of each second input sample - /// size of each output sample + /// size of each first input sample + /// size of each second input sample + /// size of each output sample /// If set to false, the layer will not learn an additive bias /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static Bilinear Bilinear(long in1Features, long in2Features, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) + public static Bilinear Bilinear(long in1_features, long in2_features, long out_features, bool hasBias = true, Device? device = null, ScalarType? dtype = null) { - var res = THSNN_Bilinear_ctor(in1Features, in2Features, outputSize, hasBias, out var boxedHandle); - if (res == IntPtr.Zero) { CheckForErrors(); } - - return new Bilinear(res, boxedHandle).MoveModule(device, dtype); + return new Bilinear(in1_features, in2_features, out_features, hasBias, device, dtype); } public static partial class functional @@ -92,7 +115,7 @@ public static Tensor bilinear(Tensor input1, Tensor input2, Tensor weight, Tenso { IntPtr bPtr = bias?.Handle ?? IntPtr.Zero; var res = THSNN_functional_bilinear(input1.Handle, input2.Handle, weight.Handle, bPtr); - if (res == IntPtr.Zero) { CheckForErrors(); } + if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } } diff --git a/src/TorchSharp/NN/CosineSimilarity.cs b/src/TorchSharp/NN/CosineSimilarity.cs index b4c4802ae..eab4e964e 100644 --- a/src/TorchSharp/NN/CosineSimilarity.cs +++ b/src/TorchSharp/NN/CosineSimilarity.cs @@ -12,18 +12,21 @@ namespace Modules /// /// A cosine similarity module. /// - public sealed class CosineSimilarity : torch.nn.Module + public sealed class CosineSimilarity : ParamLessModule { - internal CosineSimilarity(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal CosineSimilarity(long dim = 1, double eps = 1e-8) : base(nameof(CosineSimilarity)) { + this.dim = dim; + this.eps = eps; } public override Tensor forward(Tensor input1, Tensor input2) { - var res = THSNN_CosineSimilarity_forward(handle, input1.Handle, input2.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.cosine_similarity(input1, input2, this.dim, this.eps); } + + public long dim { get; set; } + public double eps { get; set; } } } @@ -39,9 +42,7 @@ public static partial class nn /// public static CosineSimilarity CosineSimilarity(long dim = 1, double eps = 1e-8) { - var handle = THSNN_CosineSimilarity_ctor(dim, eps, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new CosineSimilarity(handle, boxedHandle); + return new CosineSimilarity(dim, eps); } public static partial class functional @@ -56,9 +57,9 @@ public static partial class functional /// public static Tensor cosine_similarity(Tensor x1, Tensor x2, long dim = 1, double eps = 1e-8) { - using (var f = nn.CosineSimilarity(dim, eps)) { - return f.call(x1, x2); - } + var res = THSNN_cosine_similarity(x1.Handle, x2.Handle, dim, eps); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); } } } diff --git a/src/TorchSharp/NN/Dropout.cs b/src/TorchSharp/NN/Dropout.cs index b2d31dbae..b8e351a80 100644 --- a/src/TorchSharp/NN/Dropout.cs +++ b/src/TorchSharp/NN/Dropout.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module. /// - public sealed class Dropout : torch.nn.Module + public sealed class Dropout : ParamLessModule { internal Dropout(double p = 0.5, bool inplace = false) : base(nameof(Dropout)) { @@ -30,14 +30,8 @@ public override Tensor forward(Tensor tensor) return torch.nn.functional.dropout(tensor, this.p, this.training, this.inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - - private bool inplace; - private double p; + public bool inplace { get; set; } + public double p { get; set;} } } diff --git a/src/TorchSharp/NN/Dropout1d.cs b/src/TorchSharp/NN/Dropout1d.cs index e06eecafa..70a51f96f 100644 --- a/src/TorchSharp/NN/Dropout1d.cs +++ b/src/TorchSharp/NN/Dropout1d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Dropout2d module. /// - public sealed class Dropout1d : torch.nn.Module + public sealed class Dropout1d : ParamLessModule { internal Dropout1d(double p = 0.5, bool inplace = false) : base(nameof(Dropout1d)) { @@ -27,14 +27,8 @@ public override Tensor forward(Tensor tensor) return torch.nn.functional.dropout1d(tensor, this.p, this.training, this.inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - - private bool inplace; - private double p; + public bool inplace { get; set; } + public double p { get; set;} } } diff --git a/src/TorchSharp/NN/Dropout2d.cs b/src/TorchSharp/NN/Dropout2d.cs index 363cb40d5..9005c745a 100644 --- a/src/TorchSharp/NN/Dropout2d.cs +++ b/src/TorchSharp/NN/Dropout2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Dropout2d module. /// - public sealed class Dropout2d : torch.nn.Module + public sealed class Dropout2d : ParamLessModule { internal Dropout2d(double p = 0.5, bool inplace = false) : base(nameof(Dropout2d)) { @@ -22,19 +22,11 @@ internal Dropout2d(double p = 0.5, bool inplace = false) : base(nameof(Dropout2d public override Tensor forward(Tensor input) { - var res = THSNN_dropout2d(input.Handle, p, this.training, inplace); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.dropout2d(input, this.p, this.training, this.inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - - private bool inplace; - private double p; + public bool inplace { get; set; } + public double p { get; set;} } } diff --git a/src/TorchSharp/NN/Dropout3d.cs b/src/TorchSharp/NN/Dropout3d.cs index 8c70d2f79..4f447a149 100644 --- a/src/TorchSharp/NN/Dropout3d.cs +++ b/src/TorchSharp/NN/Dropout3d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Dropout3d module. /// - public sealed class Dropout3d : nn.Module + public sealed class Dropout3d : ParamLessModule { internal Dropout3d(double p = 0.5, bool inplace = false) : base(nameof(Dropout3d)) { @@ -22,19 +22,11 @@ internal Dropout3d(double p = 0.5, bool inplace = false) : base(nameof(Dropout3d public override Tensor forward(Tensor input) { - var res = THSNN_dropout3d(input.Handle, p, this.training, inplace); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.dropout3d(input, this.p, this.training, this.inplace); } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - - private bool inplace; - private double p; + + public bool inplace { get; set; } + public double p { get; set;} } } diff --git a/src/TorchSharp/NN/FeatureDropout.cs b/src/TorchSharp/NN/FeatureDropout.cs index 0e016b385..16712c2ab 100644 --- a/src/TorchSharp/NN/FeatureDropout.cs +++ b/src/TorchSharp/NN/FeatureDropout.cs @@ -12,24 +12,21 @@ namespace Modules /// /// This class is used to represent a dropout module for 2d/3d convolutational layers. /// - public sealed class FeatureAlphaDropout : torch.nn.Module + public sealed class FeatureAlphaDropout : ParamLessModule { - internal FeatureAlphaDropout(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal FeatureAlphaDropout(double p = 0.5, bool inplace = false) : base(nameof(FeatureAlphaDropout)) { + this.p = p; + this.inplace = inplace; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_FeatureAlphaDropout_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.feature_alpha_dropout(input, this.p, this.training, this.inplace); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public bool inplace { get; set; } + public double p { get; set; } } } @@ -44,11 +41,10 @@ public static partial class nn /// randomized on every forward call, and scaled and shifted to maintain zero mean and unit variance. /// /// Dropout probability of a channel to be zeroed. Default: 0.5 - public static FeatureAlphaDropout FeatureAlphaDropout(double p = 0.5) + /// If set to true, will do this operation in-place. Default: false + public static FeatureAlphaDropout FeatureAlphaDropout(double p = 0.5, bool inplace = false) { - var handle = THSNN_FeatureAlphaDropout_ctor(p, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new FeatureAlphaDropout(handle, boxedHandle); + return new FeatureAlphaDropout(p, inplace); } public static partial class functional diff --git a/src/TorchSharp/NN/Flatten.cs b/src/TorchSharp/NN/Flatten.cs index b1568938c..1301870b1 100644 --- a/src/TorchSharp/NN/Flatten.cs +++ b/src/TorchSharp/NN/Flatten.cs @@ -10,26 +10,23 @@ namespace TorchSharp namespace Modules { /// - /// This class is used to represent a dropout module for 2d/3d convolutational layers. + /// This class is used to represent a flattening of the input tensors. /// - public sealed class Flatten : torch.nn.Module + public sealed class Flatten : ParamLessModule { - internal Flatten(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal Flatten(long start_dim = 1, long end_dim = -1) : base(nameof(Flatten)) { + this.start_dim = start_dim; + this.end_dim = end_dim; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_Flatten_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return input.flatten(start_dim, end_dim); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long start_dim { get; set; } + public long end_dim { get; set; } } } @@ -40,14 +37,12 @@ public static partial class nn /// /// Flattens a contiguous range of dims into a tensor. For use with Sequential. /// - /// First dim to flatten (default = 1). - /// Last dim to flatten (default = -1). + /// First dim to flatten (default = 1). + /// Last dim to flatten (default = -1). /// - public static Flatten Flatten(long startDim = 1, long endDim = -1) + public static Flatten Flatten(long start_dim = 1, long end_dim = -1) { - var handle = THSNN_Flatten_ctor(startDim, endDim, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Flatten(handle, boxedHandle); + return new Flatten(start_dim, end_dim); } } } diff --git a/src/TorchSharp/NN/Fold.cs b/src/TorchSharp/NN/Fold.cs index 5c4de0ff0..9d1c0a9b6 100644 --- a/src/TorchSharp/NN/Fold.cs +++ b/src/TorchSharp/NN/Fold.cs @@ -11,12 +11,12 @@ namespace TorchSharp namespace Modules { - public sealed class Fold : torch.nn.Module + public sealed class Fold : ParamLessModule { internal Fold((long, long) output_size, (long, long) kernel_size, (long, long) dilation, (long, long) padding, (long, long) stride) : base(nameof(Fold)) { - this.outputSize = output_size; - this.kernelSize = kernel_size; + this.output_size = output_size; + this.kernel_size = kernel_size; this.dilation = dilation; this.padding = padding; this.stride = stride; @@ -24,14 +24,14 @@ internal Fold((long, long) output_size, (long, long) kernel_size, (long, long) d public override Tensor forward(Tensor tensor) { - return torch.nn.functional.fold(tensor, outputSize , kernelSize, dilation, padding, stride); + return torch.nn.functional.fold(tensor, output_size , kernel_size, dilation, padding, stride); } - private (long, long) outputSize; - private (long, long) kernelSize; - private (long, long) dilation; - private (long, long) padding; - private (long, long) stride; + public (long, long) output_size { get; set; } + public (long, long) kernel_size { get; set; } + public (long, long) dilation { get; set; } + public (long, long) padding { get; set; } + public (long, long) stride { get; set; } } } diff --git a/src/TorchSharp/NN/Identity.cs b/src/TorchSharp/NN/Identity.cs index 7296a52a1..fd0c26760 100644 --- a/src/TorchSharp/NN/Identity.cs +++ b/src/TorchSharp/NN/Identity.cs @@ -10,22 +10,14 @@ namespace TorchSharp namespace Modules { - public sealed class Identity : torch.nn.Module + public sealed class Identity : ParamLessModule { - internal Identity(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Identity() : base(nameof(Identity)) { } public override Tensor forward(Tensor tensor) { - var res = THSNN_Identity_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return tensor.alias(); } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; } } @@ -39,9 +31,7 @@ public static partial class nn /// The same tensor as is input. public static Identity Identity() { - var res = THSNN_Identity_ctor(out var boxedHandle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Identity(res, boxedHandle); + return new Identity(); } } } diff --git a/src/TorchSharp/NN/Linear.cs b/src/TorchSharp/NN/Linear.cs index 4595582d7..c9adb78cd 100644 --- a/src/TorchSharp/NN/Linear.cs +++ b/src/TorchSharp/NN/Linear.cs @@ -8,49 +8,75 @@ namespace TorchSharp { using Modules; + using TorchSharp.Utils; namespace Modules { public sealed class Linear : torch.nn.Module { - internal Linear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + const string WeightComponentName = nameof(weight); + const string BiasComponentName = nameof(bias); + + internal Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Linear)) { + this.in_features = inputSize; + this.out_features = outputSize; + + weight = torch.empty(outputSize, inputSize, device: device, dtype: dtype).AsParameter(); + init.kaiming_uniform_(weight, a: _sqrt5); + + if (hasBias) { + bias = torch.empty(outputSize, device: device, dtype: dtype).AsParameter(); + var (fanIn, _) = init.CalculateFanInAndFanOut(weight); + var bound = fanIn > 0 ? 1 / Math.Sqrt(fanIn) : 0; + init.uniform_(_bias, -bound, bound); + } + //NOTE: it's important not to call 'RegisterComponents' here. } public override Tensor forward(Tensor tensor) { - var res = THSNN_Linear_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.linear(tensor, _weight!, _bias); } - public Parameter? bias { - get { - var res = THSNN_Linear_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return ((res == IntPtr.Zero) ? null : new Parameter(res)); + protected override void Dispose(bool disposing) + { + if (disposing) { + _weight?.Dispose(); + _bias?.Dispose(); } + } + + public Parameter? bias { + get => _bias; set { - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_Linear_set_bias(handle, value?.Handle ?? IntPtr.Zero); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + _bias?.Dispose(); + _bias = value?.DetachFromDisposeScope() as Parameter; + ConditionallyRegisterParameter(BiasComponentName, _bias); } } - public Parameter? weight { - get { - var res = THSNN_Linear_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } + public Parameter weight { + get => _weight!; set { - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_Linear_set_weight(handle, value!.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); + if (value is null) throw new ArgumentNullException(nameof(weight)); + if (value.Handle != _weight?.Handle) { + _weight?.Dispose(); + _weight = (value.DetachFromDisposeScope() as Parameter)!; + ConditionallyRegisterParameter(WeightComponentName, _weight); + } } } + + [ComponentName(Name = BiasComponentName)] + private Parameter? _bias; + [ComponentName(Name = WeightComponentName)] + private Parameter? _weight; + + public long in_features { get; set; } + public long out_features { get; set; } + + private static readonly double _sqrt5 = Math.Sqrt(5); } } @@ -68,10 +94,7 @@ public static partial class nn /// The desired floating point or complex dtype of the parameters and buffers in this module public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) { - var res = THSNN_Linear_ctor(inputSize, outputSize, hasBias, out var boxedHandle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - - return new Linear(res, boxedHandle).MoveModule(device, dtype); + return new Linear(inputSize, outputSize, hasBias, device, dtype); } public static partial class functional diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 6d6cca212..e4fbabcf3 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -127,10 +127,10 @@ protected virtual void Dispose(bool disposing) if (disposing && !handle.IsInvalid) { foreach (var (_, p) in named_buffers(false)) { - p.Dispose(); + p.DetachFromDisposeScope().Dispose(); } foreach (var (_, b) in named_parameters(false)) { - b.Dispose(); + b.DetachFromDisposeScope().Dispose(); } foreach (var (_, m) in named_modules()) { @@ -785,6 +785,8 @@ protected void ConditionallyRegisterBuffer(string name, Tensor value, bool persi public virtual string GetName() { + if (!string.IsNullOrEmpty(this.name)) return this.name; + var res = THSNN_Module_name(handle); CheckForErrors(); return res; diff --git a/src/TorchSharp/NN/Normalization/BatchNorm.cs b/src/TorchSharp/NN/Normalization/BatchNorm.cs new file mode 100644 index 000000000..398eae63c --- /dev/null +++ b/src/TorchSharp/NN/Normalization/BatchNorm.cs @@ -0,0 +1,53 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.PInvoke.NativeMethods; +#nullable enable +namespace TorchSharp +{ + using System.Globalization; + using System.Transactions; + using Modules; + using TorchSharp.Utils; + using F = TorchSharp.torch.nn.functional; + + namespace Modules + { + public abstract class BatchNorm : NormBase + { + public BatchNorm(long num_features, + double eps, + double momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype, + string name) : base(num_features, eps, momentum, affine, track_running_stats, device, dtype, name) + { + } + + public override Tensor forward(Tensor input) + { + ValidateInputDimensions(input); + + double exponential_average_factor = (this.momentum is null) ? 0.0 : this.momentum.Value; + + if (training && track_running_stats) + { + if (num_batches_tracked is not null) + { + num_batches_tracked.add_(1); + exponential_average_factor = (this.momentum is null) ? (1.0 / (double)num_batches_tracked) : momentum.Value; + } + } + + var bn_training = training ? true : running_mean is null && running_var is null; + var pr = !training || track_running_stats; + + return F.batch_norm(input, pr ? running_mean : null, pr ? running_var : null, weight, bias, bn_training, exponential_average_factor, eps); + } + + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/NN/Normalization/BatchNorm1D.cs b/src/TorchSharp/NN/Normalization/BatchNorm1D.cs index a28bb9057..3633f82b8 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm1D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm1D.cs @@ -13,92 +13,22 @@ namespace Modules /// /// This class is used to represent a BatchNorm1D module. /// - public sealed class BatchNorm1d : torch.nn.Module + public sealed class BatchNorm1d : BatchNorm { - internal BatchNorm1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal BatchNorm1d(long num_features, + double eps, + double momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype) : base(num_features, eps, momentum, affine, track_running_stats, device, dtype, nameof(BatchNorm1d)) { } - public override Tensor forward(Tensor tensor) + protected override void ValidateInputDimensions(Tensor input) { - if (tensor.Dimensions < 2 || tensor.Dimensions > 3) throw new ArgumentException($"Invalid number of dimensions for BatchNorm argument: {tensor.Dimensions}"); - var res = THSNN_BatchNorm1d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - public Parameter? bias { - get { - var res = THSNN_BatchNorm1d_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_BatchNorm1d_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); - } - } - - public Parameter? weight { - get { - var res = THSNN_BatchNorm1d_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_BatchNorm1d_set_weight(handle, value is null ? IntPtr.Zero : value.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); - } - } - - public Tensor? running_mean { - get { - var res = THSNN_BatchNorm1d_get_mean(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_mean cannot be set to 'null'"); - THSNN_BatchNorm1d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_mean", value); - } - } - - public Tensor? running_var { - get { - var res = THSNN_BatchNorm1d_get_var(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_var cannot be set to 'null'"); - THSNN_BatchNorm1d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_var", value); - } - } - - public Tensor? num_batches_tracked { - get { - var res = THSNN_BatchNorm1d_get_batches(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - } - - public void reset_running_stats() - { - THSNN_BatchNorm1d_reset_stats(handle); - torch.CheckForErrors(); + if (input.ndim != 2 && input.ndim != 3) + throw new ArgumentException($"expected 2D or 3D input, but got {input.ndim}D input."); } } } @@ -110,7 +40,7 @@ public static partial class nn /// /// Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . /// - /// C from an expected input of size (N,C,L) or LL from input of size (N, L) + /// C from an expected input of size (N,C,L) or LL from input of size (N, L) /// A value added to the denominator for numerical stability. Default: 1e-5 /// The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1 /// A boolean value that when set to True, this module has learnable affine parameters. Default: true @@ -120,13 +50,9 @@ public static partial class nn /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static BatchNorm1d BatchNorm1d(long features, double eps = 1e-05, double momentum = 0.1, bool affine = true, bool track_running_stats = true, Device? device = null, ScalarType? dtype = null) + public static BatchNorm1d BatchNorm1d(long num_features, double eps = 1e-05, double momentum = 0.1, bool affine = true, bool track_running_stats = true, Device? device = null, ScalarType? dtype = null) { - unsafe { - var handle = THSNN_BatchNorm1d_ctor(features, eps, momentum, affine, track_running_stats, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new BatchNorm1d(handle, boxedHandle).MoveModule(device, dtype); - } + return new BatchNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype); } } } diff --git a/src/TorchSharp/NN/Normalization/BatchNorm2D.cs b/src/TorchSharp/NN/Normalization/BatchNorm2D.cs index 391b8a6eb..051605f30 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm2D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm2D.cs @@ -13,92 +13,22 @@ namespace Modules /// /// This class is used to represent a BatchNorm2D module. /// - public sealed class BatchNorm2d : torch.nn.Module + public sealed class BatchNorm2d : BatchNorm { - internal BatchNorm2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal BatchNorm2d(long num_features, + double eps, + double momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype) : base(num_features, eps, momentum, affine, track_running_stats, device, dtype, nameof(BatchNorm1d)) { } - public override Tensor forward(Tensor tensor) + protected override void ValidateInputDimensions(Tensor input) { - if (tensor.Dimensions != 4) throw new ArgumentException($"Invalid number of dimensions for BatchNorm argument: {tensor.Dimensions}"); - var res = THSNN_BatchNorm2d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - public Parameter? bias { - get { - var res = THSNN_BatchNorm2d_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_BatchNorm2d_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); - } - } - - public Parameter? weight { - get { - var res = THSNN_BatchNorm2d_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_BatchNorm2d_set_weight(handle, value is null ? IntPtr.Zero : value.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); - } - } - - public Tensor? running_mean { - get { - var res = THSNN_BatchNorm2d_get_mean(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_mean cannot be set to 'null'"); - THSNN_BatchNorm2d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_mean", value); - } - } - - public Tensor? running_var { - get { - var res = THSNN_BatchNorm2d_get_var(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_var cannot be set to 'null'"); - THSNN_BatchNorm2d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_var", value); - } - } - - public Tensor? num_batches_tracked { - get { - var res = THSNN_BatchNorm2d_get_batches(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - } - - public void reset_running_stats() - { - THSNN_BatchNorm2d_reset_stats(handle); - torch.CheckForErrors(); + if (input.ndim != 4) + throw new ArgumentException($"expected 4D input, but got {input.ndim}D input."); } } } @@ -110,7 +40,7 @@ public static partial class nn /// /// Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. /// - /// C from an expected input of size (N,C,H,W) + /// C from an expected input of size (N,C,H,W) /// A value added to the denominator for numerical stability. Default: 1e-5 /// The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1 /// A boolean value that when set to True, this module has learnable affine parameters. Default: true @@ -120,13 +50,9 @@ public static partial class nn /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static BatchNorm2d BatchNorm2d(long features, double eps = 1e-05, double momentum = 0.1, bool affine = true, bool track_running_stats = true, Device? device = null, ScalarType? dtype = null) + public static BatchNorm2d BatchNorm2d(long num_features, double eps = 1e-05, double momentum = 0.1, bool affine = true, bool track_running_stats = true, Device? device = null, ScalarType? dtype = null) { - unsafe { - var handle = THSNN_BatchNorm2d_ctor(features, eps, momentum, affine, track_running_stats, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new BatchNorm2d(handle, boxedHandle).MoveModule(device, dtype); - } + return new BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype); } } } diff --git a/src/TorchSharp/NN/Normalization/BatchNorm3D.cs b/src/TorchSharp/NN/Normalization/BatchNorm3D.cs index 4af5f9f60..f434073d9 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm3D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm3D.cs @@ -13,92 +13,22 @@ namespace Modules /// /// This class is used to represent a BatchNorm3D module. /// - public sealed class BatchNorm3d : torch.nn.Module + public sealed class BatchNorm3d : BatchNorm { - internal BatchNorm3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal BatchNorm3d(long num_features, + double eps, + double momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype) : base(num_features, eps, momentum, affine, track_running_stats, device, dtype, nameof(BatchNorm1d)) { } - public override Tensor forward(Tensor tensor) + protected override void ValidateInputDimensions(Tensor input) { - if (tensor.Dimensions != 5) throw new ArgumentException($"Invalid number of dimensions for BatchNorm argument: {tensor.Dimensions}"); - var res = THSNN_BatchNorm3d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - public Parameter? bias { - get { - var res = THSNN_BatchNorm3d_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_BatchNorm3d_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); - } - } - - public Parameter? weight { - get { - var res = THSNN_BatchNorm3d_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_BatchNorm3d_set_weight(handle, value is null ? IntPtr.Zero : value.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); - } - } - - public Tensor? running_mean { - get { - var res = THSNN_BatchNorm3d_get_mean(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_mean cannot be set to 'null'"); - THSNN_BatchNorm3d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_mean", value); - } - } - - public Tensor? running_var { - get { - var res = THSNN_BatchNorm3d_get_var(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_var cannot be set to 'null'"); - THSNN_BatchNorm3d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_var", value); - } - } - - public Tensor? num_batches_tracked { - get { - var res = THSNN_BatchNorm3d_get_batches(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - } - - public void reset_running_stats() - { - THSNN_BatchNorm3d_reset_stats(handle); - torch.CheckForErrors(); + if (input.ndim != 5) + throw new ArgumentException($"expected 4D input, but got {input.ndim}D input."); } } } @@ -110,7 +40,7 @@ public static partial class nn /// /// Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. /// - /// C from an expected input of size (N,C,D,H,W) + /// C from an expected input of size (N,C,D,H,W) /// A value added to the denominator for numerical stability. Default: 1e-5 /// The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1 /// A boolean value that when set to True, this module has learnable affine parameters. Default: true @@ -120,13 +50,9 @@ public static partial class nn /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static BatchNorm3d BatchNorm3d(long features, double eps = 1e-05, double momentum = 0.1, bool affine = true, bool track_running_stats = true, Device? device = null, ScalarType? dtype = null) + public static BatchNorm3d BatchNorm3d(long num_features, double eps = 1e-05, double momentum = 0.1, bool affine = true, bool track_running_stats = true, Device? device = null, ScalarType? dtype = null) { - unsafe { - var handle = THSNN_BatchNorm3d_ctor(features, eps, momentum, affine, track_running_stats, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new BatchNorm3d(handle, boxedHandle).MoveModule(device, dtype); - } + return new BatchNorm3d(num_features, eps, momentum, affine, track_running_stats, device, dtype); } } } diff --git a/src/TorchSharp/NN/Normalization/Functional.cs b/src/TorchSharp/NN/Normalization/Functional.cs index 2f8bcd1e4..0fdbf1c54 100644 --- a/src/TorchSharp/NN/Normalization/Functional.cs +++ b/src/TorchSharp/NN/Normalization/Functional.cs @@ -17,8 +17,8 @@ public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor runnin { var res = THSNN_batch_norm( input.Handle, - running_mean.Handle, - running_var.Handle, + running_mean is not null ? running_mean.Handle : IntPtr.Zero, + running_var is not null ? running_var.Handle : IntPtr.Zero, weight is not null ? weight.Handle : IntPtr.Zero, bias is not null ? bias.Handle : IntPtr.Zero, training, @@ -84,16 +84,6 @@ public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor we return new Tensor(res); } - /// - /// Applies Local Normalization. - /// - public static Tensor local_response_norm(Tensor input, long size, double alpha = 0.0001, double beta = 0.75, double k = 1.0) - { - var res = THSNN_local_response_norm(input.Handle, size, alpha, beta, k); - if (res == IntPtr.Zero) - torch.CheckForErrors(); - return new Tensor(res); - } } } } diff --git a/src/TorchSharp/NN/Normalization/GroupNorm.cs b/src/TorchSharp/NN/Normalization/GroupNorm.cs index e63b5c8c7..e6bfcd991 100644 --- a/src/TorchSharp/NN/Normalization/GroupNorm.cs +++ b/src/TorchSharp/NN/Normalization/GroupNorm.cs @@ -1,12 +1,15 @@ // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. using System; using static TorchSharp.torch; +using static TorchSharp.torch.nn; using static TorchSharp.PInvoke.NativeMethods; #nullable enable namespace TorchSharp { using Modules; + using TorchSharp.Utils; + using F = TorchSharp.torch.nn.functional; namespace Modules { @@ -16,47 +19,60 @@ namespace Modules /// public sealed class GroupNorm : torch.nn.Module { - internal GroupNorm(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal GroupNorm(long num_groups, long num_channels, double eps, bool affine, Device? device, ScalarType? dtype) : base(nameof(GroupNorm)) { + this.eps = eps; + this.affine = affine; + this.num_groups = num_groups; + + if (affine) { + weight = Parameter(torch.empty(num_channels, dtype, device)); + this.bias = Parameter(torch.empty(num_channels, dtype, device)); + } } public override Tensor forward(Tensor tensor) { - if (tensor.Dimensions < 3) throw new ArgumentException($"Invalid number of dimensions for GroupNorm argument: {tensor.Dimensions}"); - var res = THSNN_GroupNorm_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + if (tensor.Dimensions < 3) + throw new ArgumentException($"Invalid number of dimensions for GroupNorm argument: {tensor.Dimensions}"); + return F.group_norm(tensor, num_groups, weight, bias, eps); + } + + protected override void Dispose(bool disposing) + { + _weight?.Dispose(); + _bias?.Dispose(); + base.Dispose(disposing); } public Parameter? bias { - get { - var res = THSNN_GroupNorm_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } + get => _bias; set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_GroupNorm_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + _bias?.Dispose(); + _bias = value?.DetachFromDisposeScope() as Parameter; + ConditionallyRegisterParameter(nameof(bias), _bias); } } - public Parameter? weight { - get { - var res = THSNN_GroupNorm_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } + public Parameter weight { + get => _weight!; set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_GroupNorm_set_weight(handle, value is null ? IntPtr.Zero : value.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); + if (value is null) throw new ArgumentNullException(nameof(weight)); + if (value.Handle != _weight?.Handle) { + _weight?.Dispose(); + _weight = (value.DetachFromDisposeScope() as Parameter)!; + ConditionallyRegisterParameter(nameof(weight), _weight); + } } } + + [ComponentName(Name = nameof(bias))] + private Parameter? _bias; + [ComponentName(Name = nameof(weight))] + private Parameter? _weight; + public long num_groups { get; set; } + public double eps { get; set; } + public bool affine { get; set; } } } @@ -76,11 +92,7 @@ public static partial class nn /// public static GroupNorm GroupNorm(long num_groups, long num_channels, double eps = 1e-05, bool affine = true, Device? device = null, ScalarType? dtype = null) { - unsafe { - var handle = THSNN_GroupNorm_ctor(num_groups, num_channels, eps, affine, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new GroupNorm(handle, boxedHandle).MoveModule(device, dtype); - } + return new GroupNorm(num_groups, num_channels, eps, affine, device, dtype); } } } diff --git a/src/TorchSharp/NN/Normalization/InstanceNorm.cs b/src/TorchSharp/NN/Normalization/InstanceNorm.cs new file mode 100644 index 000000000..43ecd9023 --- /dev/null +++ b/src/TorchSharp/NN/Normalization/InstanceNorm.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.PInvoke.NativeMethods; +#nullable enable +namespace TorchSharp +{ + using System.Globalization; + using System.Transactions; + using Modules; + using TorchSharp.Utils; + using F = TorchSharp.torch.nn.functional; + + namespace Modules + { + public abstract class InstanceNorm : NormBase + { + public InstanceNorm(long num_features, + double eps, + double? momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype, + string name) : base(num_features, eps, momentum.HasValue ? momentum : 0.1, affine, track_running_stats, device, dtype, name) + { + } + + protected abstract long GetNumberOfBatchDimensions(); + + public override Tensor forward(Tensor input) + { + ValidateInputDimensions(input); + + var feature_dim = (int)(input.ndim - GetNumberOfBatchDimensions()); + + if (input.size((int)feature_dim) != num_features) { + throw new ArgumentException($"expected input's size at dim={feature_dim} to match num_features ({this.num_features}), but got: {input.size(feature_dim)}."); + } + + if (feature_dim == 0) { + using var t0 = input.unsqueeze(0); + return ApplyInstanceNorm(t0).squeeze_(0); + } + else { + return ApplyInstanceNorm(input); + } + } + + private Tensor ApplyInstanceNorm(Tensor input) + { + return F.instance_norm(input, running_mean, running_var, weight, bias, training || !track_running_stats, momentum.HasValue ? momentum.Value : 0.1, eps); + } + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs b/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs index f9fb5836c..10040c349 100644 --- a/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs +++ b/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs @@ -14,92 +14,24 @@ namespace Modules /// /// This class is used to represent a InstanceNorm1D module. /// - public sealed class InstanceNorm1d : torch.nn.Module + public sealed class InstanceNorm1d : InstanceNorm { - internal InstanceNorm1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal InstanceNorm1d(long num_features, + double eps, + double momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype) : base(num_features, eps, momentum, affine, track_running_stats, device, dtype, nameof(InstanceNorm1d)) { } - public override Tensor forward(Tensor tensor) - { - if (tensor.Dimensions < 2 || tensor.Dimensions > 3) throw new ArgumentException($"Invalid number of dimensions for InstanceNorm argument: {tensor.Dimensions}"); - var res = THSNN_InstanceNorm1d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - public Parameter? bias { - get { - var res = THSNN_InstanceNorm1d_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_InstanceNorm1d_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); - } - } - - public Parameter? weight { - get { - var res = THSNN_InstanceNorm1d_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_InstanceNorm1d_set_weight(handle, value is null ? IntPtr.Zero : value.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); - } - } - - public Tensor? running_mean { - get { - var res = THSNN_InstanceNorm1d_get_mean(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_mean cannot be set to 'null'"); - THSNN_InstanceNorm1d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_mean", value); - } - } - - public Tensor? running_var { - get { - var res = THSNN_InstanceNorm1d_get_var(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_var cannot be set to 'null'"); - THSNN_InstanceNorm1d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_var", value); - } - } - - public Tensor? num_batches_tracked { - get { - var res = THSNN_InstanceNorm1d_get_batches(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - } + protected override long GetNumberOfBatchDimensions() => 2; - public void reset_running_stats() + protected override void ValidateInputDimensions(Tensor input) { - THSNN_InstanceNorm1d_reset_stats(handle); - torch.CheckForErrors(); + if (input.ndim != 2 && input.ndim != 3) + throw new ArgumentException($"expected 2D or 3D input, but got {input.ndim}D input."); } } } @@ -111,7 +43,7 @@ public static partial class nn /// /// Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization. /// - /// C from an expected input of size (N,C,L) or LL from input of size (N, L) + /// C from an expected input of size (N,C,L) or LL from input of size (N, L) /// A value added to the denominator for numerical stability. Default: 1e-5 /// The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1 /// A boolean value that when set to True, this module has learnable affine parameters. Default: true @@ -121,13 +53,9 @@ public static partial class nn /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static InstanceNorm1d InstanceNorm1d(long features, double eps = 1e-05, double momentum = 0.1, bool affine = false, bool track_running_stats = false, Device? device = null, ScalarType? dtype = null) + public static InstanceNorm1d InstanceNorm1d(long num_features, double eps = 1e-05, double momentum = 0.1, bool affine = false, bool track_running_stats = false, Device? device = null, ScalarType? dtype = null) { - unsafe { - var handle = THSNN_InstanceNorm1d_ctor(features, eps, momentum, affine, track_running_stats, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new InstanceNorm1d(handle, boxedHandle).MoveModule(device, dtype); - } + return new InstanceNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype); } } } diff --git a/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs b/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs index 9a7b35d1d..7e5c6bd78 100644 --- a/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs +++ b/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs @@ -14,92 +14,24 @@ namespace Modules /// /// This class is used to represent a InstanceNorm2D module. /// - public sealed class InstanceNorm2d : torch.nn.Module + public sealed class InstanceNorm2d : InstanceNorm { - internal InstanceNorm2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal InstanceNorm2d(long num_features, + double eps, + double momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype) : base(num_features, eps, momentum, affine, track_running_stats, device, dtype, nameof(InstanceNorm1d)) { } - public override Tensor forward(Tensor tensor) - { - if (tensor.Dimensions != 4) throw new ArgumentException($"Invalid number of dimensions for InstanceNorm argument: {tensor.Dimensions}"); - var res = THSNN_InstanceNorm2d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - public Parameter? bias { - get { - var res = THSNN_InstanceNorm2d_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_InstanceNorm2d_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); - } - } - - public Parameter? weight { - get { - var res = THSNN_InstanceNorm2d_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_InstanceNorm2d_set_weight(handle, value is null ? IntPtr.Zero : value.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); - } - } - - public Tensor? running_mean { - get { - var res = THSNN_InstanceNorm2d_get_mean(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_mean cannot be set to 'null'"); - THSNN_InstanceNorm2d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_mean", value); - } - } - - public Tensor? running_var { - get { - var res = THSNN_InstanceNorm2d_get_var(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_var cannot be set to 'null'"); - THSNN_InstanceNorm2d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_var", value); - } - } - - public Tensor? num_batches_tracked { - get { - var res = THSNN_InstanceNorm2d_get_batches(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - } + protected override long GetNumberOfBatchDimensions() => 3; - public void reset_running_stats() + protected override void ValidateInputDimensions(Tensor input) { - THSNN_InstanceNorm2d_reset_stats(handle); - torch.CheckForErrors(); + if (input.ndim != 3 && input.ndim != 4) + throw new ArgumentException($"expected 3D or 4D input, but got {input.ndim}D input."); } } } @@ -111,7 +43,7 @@ public static partial class nn /// /// Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization. /// - /// C from an expected input of size (N,C,H,W) + /// C from an expected input of size (N,C,H,W) /// A value added to the denominator for numerical stability. Default: 1e-5 /// The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1 /// A boolean value that when set to True, this module has learnable affine parameters. Default: true @@ -121,13 +53,9 @@ public static partial class nn /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static InstanceNorm2d InstanceNorm2d(long features, double eps = 1e-05, double momentum = 0.1, bool affine = false, bool track_running_stats = false, Device? device = null, ScalarType? dtype = null) + public static InstanceNorm2d InstanceNorm2d(long num_features, double eps = 1e-05, double momentum = 0.1, bool affine = false, bool track_running_stats = false, Device? device = null, ScalarType? dtype = null) { - unsafe { - var handle = THSNN_InstanceNorm2d_ctor(features, eps, momentum, affine, track_running_stats, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new InstanceNorm2d(handle, boxedHandle).MoveModule(device, dtype); - } + return new InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype); } } } diff --git a/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs b/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs index e74cbc278..99ca44a15 100644 --- a/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs +++ b/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs @@ -14,92 +14,24 @@ namespace Modules /// /// This class is used to represent a InstanceNorm3D module. /// - public sealed class InstanceNorm3d : torch.nn.Module + public sealed class InstanceNorm3d : InstanceNorm { - internal InstanceNorm3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal InstanceNorm3d(long num_features, + double eps, + double momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype) : base(num_features, eps, momentum, affine, track_running_stats, device, dtype, nameof(InstanceNorm3d)) { } - public override Tensor forward(Tensor tensor) - { - if (tensor.Dimensions != 5) throw new ArgumentException($"Invalid number of dimensions for InstanceNorm argument: {tensor.Dimensions}"); - var res = THSNN_InstanceNorm3d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - public Parameter? bias { - get { - var res = THSNN_InstanceNorm3d_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'"); - THSNN_InstanceNorm3d_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); - } - } - - public Parameter? weight { - get { - var res = THSNN_InstanceNorm3d_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'"); - THSNN_InstanceNorm3d_set_weight(handle, value is null ? IntPtr.Zero : value.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); - } - } - - public Tensor? running_mean { - get { - var res = THSNN_InstanceNorm3d_get_mean(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_mean cannot be set to 'null'"); - THSNN_InstanceNorm3d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_mean", value); - } - } - - public Tensor? running_var { - get { - var res = THSNN_InstanceNorm3d_get_var(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - set { - // Please ignore, for now, that the litorch call thinks you *can* set it to null. - if (value is null) throw new ArgumentNullException("running_var cannot be set to 'null'"); - THSNN_InstanceNorm3d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); - torch.CheckForErrors(); - ConditionallyRegisterBuffer("running_var", value); - } - } - - public Tensor? num_batches_tracked { - get { - var res = THSNN_InstanceNorm3d_get_batches(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } - return new Tensor(res); - } - } + protected override long GetNumberOfBatchDimensions() => 4; - public void reset_running_stats() + protected override void ValidateInputDimensions(Tensor input) { - THSNN_InstanceNorm3d_reset_stats(handle); - torch.CheckForErrors(); + if (input.ndim != 4 && input.ndim != 5) + throw new ArgumentException($"expected 4D or 4D input, but got {input.ndim}D input."); } } } @@ -111,7 +43,7 @@ public static partial class nn /// /// Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization. /// - /// C from an expected input of size (N,C,D,H,W) + /// C from an expected input of size (N,C,D,H,W) /// A value added to the denominator for numerical stability. Default: 1e-5 /// The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1 /// A boolean value that when set to True, this module has learnable affine parameters. Default: true @@ -121,13 +53,9 @@ public static partial class nn /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static InstanceNorm3d InstanceNorm3d(long features, double eps = 1e-05, double momentum = 0.1, bool affine = false, bool track_running_stats = false, Device? device = null, ScalarType? dtype = null) + public static InstanceNorm3d InstanceNorm3d(long num_features, double eps = 1e-05, double momentum = 0.1, bool affine = false, bool track_running_stats = false, Device? device = null, ScalarType? dtype = null) { - unsafe { - var handle = THSNN_InstanceNorm3d_ctor(features, eps, momentum, affine, track_running_stats, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new InstanceNorm3d(handle, boxedHandle).MoveModule(device, dtype); - } + return new InstanceNorm3d(num_features, eps, momentum, affine, track_running_stats, device, dtype); } } } diff --git a/src/TorchSharp/NN/Normalization/LayerNorm.cs b/src/TorchSharp/NN/Normalization/LayerNorm.cs index 7010e754e..7ae96ee71 100644 --- a/src/TorchSharp/NN/Normalization/LayerNorm.cs +++ b/src/TorchSharp/NN/Normalization/LayerNorm.cs @@ -18,13 +18,15 @@ namespace Modules /// public sealed class LayerNorm : torch.nn.Module { - private long[] _normalized_shape; - private double _eps; + const string WeightComponentName = nameof(weight); + const string BiasComponentName = nameof(bias); internal LayerNorm(long[] normalized_shape, double eps, bool elementwise_affine, bool bias, Device? device, ScalarType? dtype) : base(nameof(LayerNorm)) { - _normalized_shape = normalized_shape; - _eps = eps; + this.normalized_shape = normalized_shape; + this.eps = eps; + this.elementwise_affine = elementwise_affine; + if (elementwise_affine) { weight = Parameter(torch.empty(normalized_shape, dtype, device)); @@ -34,11 +36,10 @@ internal LayerNorm(long[] normalized_shape, double eps, bool elementwise_affine, } } - - reset_parameters(elementwise_affine); + reset_parameters(); } - private void reset_parameters(bool elementwise_affine) + public void reset_parameters() { if (elementwise_affine) { @@ -52,7 +53,14 @@ private void reset_parameters(bool elementwise_affine) public override Tensor forward(Tensor tensor) { - return F.layer_norm(tensor, _normalized_shape, weight, bias, _eps); + return F.layer_norm(tensor, normalized_shape, weight, bias, eps); + } + + protected override void Dispose(bool disposing) + { + _weight?.Dispose(); + _bias?.Dispose(); + base.Dispose(disposing); } public Parameter? bias { @@ -60,7 +68,7 @@ public Parameter? bias { set { _bias?.Dispose(); _bias = value?.DetachFromDisposeScope() as Parameter; - ConditionallyRegisterParameter(nameof(bias), _bias); + ConditionallyRegisterParameter(BiasComponentName, _bias); } } @@ -71,22 +79,20 @@ public Parameter weight { if (value.Handle != _weight?.Handle) { _weight?.Dispose(); _weight = (value.DetachFromDisposeScope() as Parameter)!; - ConditionallyRegisterParameter(nameof(weight), _weight); + ConditionallyRegisterParameter(WeightComponentName, _weight); } } } - [ComponentName(Name = "bias")] + [ComponentName(Name = BiasComponentName)] private Parameter? _bias; - [ComponentName(Name = "weight")] + [ComponentName(Name = WeightComponentName)] private Parameter? _weight; - protected override void Dispose(bool disposing) - { - _weight?.Dispose(); - _bias?.Dispose(); - base.Dispose(disposing); - } + + public long[] normalized_shape { get; set; } + public double eps { get; set; } + public bool elementwise_affine { get; set; } } } diff --git a/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs b/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs index 5fc5f07b7..6adad3d0c 100644 --- a/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs +++ b/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs @@ -12,19 +12,25 @@ namespace Modules /// /// This class is used to represent a LocalResponseNorm module. /// - public sealed class LocalResponseNorm : torch.nn.Module + public sealed class LocalResponseNorm : ParamLessModule { - internal LocalResponseNorm(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal LocalResponseNorm(long size, double alpha = 0.0001, double beta = 0.75, double k = 1.0) : base(nameof(LocalResponseNorm)) { + this.size = size; + this.alpha = alpha; + this.beta = beta; + this.k = k; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - if (tensor.Dimensions < 3) throw new ArgumentException($"Invalid number of dimensions for LocalResponseNorm argument: {tensor.Dimensions}"); - var res = THSNN_LocalResponseNorm_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.local_response_norm(input, this.size, this.alpha, this.beta, this.k); } + + public long size { get; set; } + public double alpha { get; set; } + public double beta { get; set; } + public double k { get; set; } } } @@ -37,10 +43,24 @@ public static partial class nn /// public static LocalResponseNorm LocalResponseNorm(long size, double alpha = 0.0001, double beta = 0.75, double k = 1.0) { - unsafe { - var handle = THSNN_LocalResponseNorm_ctor(size, alpha, beta, k, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new LocalResponseNorm(handle, boxedHandle); + return new LocalResponseNorm(size, alpha, beta, k); + } + + public static partial class functional + { + + /// + /// Applies local response normalization over an input signal. + /// The input signal is composed of several input planes, where channels occupy the second dimension. + /// Applies normalization across channels. + /// + public static Tensor local_response_norm(Tensor input, long size, double alpha = 0.0001, double beta = 0.75, double k = 1.0) + { + if (input.Dimensions < 3) throw new ArgumentException($"Invalid number of dimensions for LocalResponseNorm argument: {input.Dimensions}"); + var res = THSNN_local_response_norm(input.Handle, size, alpha, beta, k); + if (res == IntPtr.Zero) + torch.CheckForErrors(); + return new Tensor(res); } } } diff --git a/src/TorchSharp/NN/Normalization/NormBase.cs b/src/TorchSharp/NN/Normalization/NormBase.cs new file mode 100644 index 000000000..13338b0d5 --- /dev/null +++ b/src/TorchSharp/NN/Normalization/NormBase.cs @@ -0,0 +1,145 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.PInvoke.NativeMethods; +#nullable enable +namespace TorchSharp +{ + using Modules; + using TorchSharp.Utils; + using F = TorchSharp.torch.nn.functional; + + namespace Modules + { + public abstract class NormBase : torch.nn.Module + { + public NormBase(long num_features, + double eps, + double? momentum, + bool affine, + bool track_running_stats, + Device? device, + ScalarType? dtype, + string name) : base(name) + { + this.num_features = num_features; + this.eps = eps; + this.momentum = momentum; + this.affine = affine; + this.track_running_stats = track_running_stats; + + if (affine) { + this.weight = Parameter(torch.empty(num_features, dtype, device)); + this.bias = Parameter(torch.empty(num_features, dtype, device)); + } + + if (track_running_stats) { + this.running_mean = torch.zeros(num_features, dtype, device); + this.running_var = torch.ones(num_features, dtype, device); + this.num_batches_tracked = torch.tensor(0L, dtype, device); + } + reset_parameters(); + } + + private void ResetRunningStats() + { + if (track_running_stats){ + init.zeros_(this._running_mean); + init.ones_(this._running_var); + init.zeros_(this._num_batches_tracked); + } + } + + public void reset_parameters() { + ResetRunningStats(); + if (affine) { + init.ones_(this._weight); + init.zeros_(this._bias); + } + } + + protected abstract void ValidateInputDimensions(Tensor input); + + protected override void Dispose(bool disposing) + { + _weight?.Dispose(); + _bias?.Dispose(); + base.Dispose(disposing); + } + + public Parameter? bias { + get => _bias; + set { + _bias?.Dispose(); + _bias = value?.DetachFromDisposeScope() as Parameter; + ConditionallyRegisterParameter(nameof(bias), _bias); + } + } + + public Parameter weight { + get => _weight!; + set { + if (value is null) throw new ArgumentNullException(nameof(weight)); + if (value.Handle != _weight?.Handle) { + _weight?.Dispose(); + _weight = (value.DetachFromDisposeScope() as Parameter)!; + ConditionallyRegisterParameter(nameof(weight), _weight); + } + } + } + + public Tensor? running_mean { + get => _running_mean; + set { + _running_mean?.Dispose(); + _running_mean = value?.DetachFromDisposeScope(); + ConditionallyRegisterBuffer(nameof(running_mean), _running_mean); + } + } + + public Tensor? running_var { + get => _running_var; + set { + _running_var?.Dispose(); + _running_var = value?.DetachFromDisposeScope(); + ConditionallyRegisterBuffer(nameof(running_var), _running_var); + } + } + + public Tensor? num_batches_tracked { + get => _num_batches_tracked; + set { + _num_batches_tracked?.Dispose(); + _num_batches_tracked = value?.DetachFromDisposeScope(); + ConditionallyRegisterBuffer(nameof(num_batches_tracked), _num_batches_tracked); + } + } + + public long num_features { get; private set; } + + public double eps { get; set; } + + public double? momentum { get; set; } + + public bool affine { get; private set; } + + public bool track_running_stats { get; private set; } + + [ComponentName(Name = nameof(bias))] + private Parameter? _bias; + + [ComponentName(Name = nameof(weight))] + private Parameter? _weight; + + [ComponentName(Name = nameof(running_mean))] + private Tensor? _running_mean; + + [ComponentName(Name = nameof(running_var))] + private Tensor? _running_var; + + [ComponentName(Name = nameof(num_batches_tracked))] + private Tensor? _num_batches_tracked; + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/NN/Padding/ConstantPad1d.cs b/src/TorchSharp/NN/Padding/ConstantPad1d.cs index ad6771e7b..ec905b4b7 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad1d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad1d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ConstantPad1d module. /// - public sealed class ConstantPad1d : torch.nn.Module + public sealed class ConstantPad1d : PadBase { - internal ConstantPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ConstantPad1d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ConstantPad1d(double value, params long[] padding) : base(nameof(ConstantPad1d), PaddingModes.Constant, value, padding) { } } } @@ -48,9 +30,7 @@ public static partial class nn /// public static ConstantPad1d ConstantPad1d(long padding, double value) { - var handle = THSNN_ConstantPad1d_ctor(value, padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad1d(handle, boxedHandle); + return new ConstantPad1d(value, padding, padding); } /// @@ -61,9 +41,7 @@ public static ConstantPad1d ConstantPad1d(long padding, double value) /// public static ConstantPad1d ConstantPad1d((long, long) padding, double value) { - var handle = THSNN_ConstantPad1d_ctor_tuple(value, padding.Item1, padding.Item2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad1d(handle, boxedHandle); + return new ConstantPad1d(value, padding.Item1, padding.Item2); } } } diff --git a/src/TorchSharp/NN/Padding/ConstantPad2d.cs b/src/TorchSharp/NN/Padding/ConstantPad2d.cs index 7d54b7bc6..9bc47b2be 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad2d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ConstantPad2d module. /// - public sealed class ConstantPad2d : torch.nn.Module + public sealed class ConstantPad2d : PadBase { - internal ConstantPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ConstantPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ConstantPad2d(double value, params long[] padding) : base(nameof(ConstantPad2d), PaddingModes.Constant, value, padding) { } } } @@ -48,9 +30,7 @@ public static partial class nn /// public static ConstantPad2d ConstantPad2d(long padding, double value) { - var handle = THSNN_ConstantPad2d_ctor(value, padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad2d(handle, boxedHandle); + return new ConstantPad2d(value, padding, padding, padding, padding); } /// @@ -61,9 +41,7 @@ public static ConstantPad2d ConstantPad2d(long padding, double value) /// public static ConstantPad2d ConstantPad2d((long, long, long, long) padding, double value) { - var handle = THSNN_ConstantPad2d_ctor_tuple(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad2d(handle, boxedHandle); + return new ConstantPad2d(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/Padding/ConstantPad3d.cs b/src/TorchSharp/NN/Padding/ConstantPad3d.cs index 4ab2c55fb..4da9344e0 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad3d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad3d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ConstantPad3d module. /// - public sealed class ConstantPad3d : torch.nn.Module + public sealed class ConstantPad3d : PadBase { - internal ConstantPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ConstantPad3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ConstantPad3d(double value, params long[] padding) : base(nameof(ConstantPad3d), PaddingModes.Constant, value, padding) { } } } @@ -48,9 +30,7 @@ public static partial class nn /// public static ConstantPad3d ConstantPad3d(long padding, double value) { - var handle = THSNN_ConstantPad3d_ctor(value, padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad3d(handle, boxedHandle); + return new ConstantPad3d(value, padding, padding, padding, padding, padding, padding); } /// @@ -61,9 +41,7 @@ public static ConstantPad3d ConstantPad3d(long padding, double value) /// public static ConstantPad3d ConstantPad3d((long, long, long, long, long, long) padding, double value) { - var handle = THSNN_ConstantPad3d_ctor_tuple(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad3d(handle, boxedHandle); + return new ConstantPad3d(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6); } } } diff --git a/src/TorchSharp/NN/Padding/PadBase.cs b/src/TorchSharp/NN/Padding/PadBase.cs new file mode 100644 index 000000000..3a10de24c --- /dev/null +++ b/src/TorchSharp/NN/Padding/PadBase.cs @@ -0,0 +1,39 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using static TorchSharp.torch; +using static TorchSharp.PInvoke.NativeMethods; + +namespace TorchSharp +{ + using Modules; + + namespace Modules + { + /// + /// This class is used to represent the base of all padding-related modules. + /// + public abstract class PadBase : ParamLessModule + { + protected PadBase(string name, PaddingModes mode, double value, params long[] padding) : base(name) + { + this.value = value; + this.padding = padding; + padding_mode = mode; + } + + /// + /// Forward pass. + /// + /// Input tensor + /// + public override Tensor forward(Tensor input) + { + return nn.functional.pad(input, padding, padding_mode, value); + } + + private PaddingModes padding_mode { get; set; } + public long[] padding { get; set; } + public double value { get; set; } + } + } +} diff --git a/src/TorchSharp/NN/Padding/ReflectionPad1d.cs b/src/TorchSharp/NN/Padding/ReflectionPad1d.cs index 1a975dd7d..780f77550 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad1d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad1d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReflectionPad1d module. /// - public sealed class ReflectionPad1d : torch.nn.Module + public sealed class ReflectionPad1d : PadBase { - internal ReflectionPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReflectionPad1d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReflectionPad1d(params long[] padding) : base(nameof(ReflectionPad1d), PaddingModes.Reflect, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReflectionPad1d ReflectionPad1d(long padding) { - var handle = THSNN_ReflectionPad1d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad1d(handle, boxedHandle); + return new ReflectionPad1d(padding, padding); } /// @@ -59,9 +39,7 @@ public static ReflectionPad1d ReflectionPad1d(long padding) /// public static ReflectionPad1d ReflectionPad1d((long, long) padding) { - var handle = THSNN_ReflectionPad1d_ctor_tuple(padding.Item1, padding.Item2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad1d(handle, boxedHandle); + return new ReflectionPad1d(padding.Item1, padding.Item2); } } } diff --git a/src/TorchSharp/NN/Padding/ReflectionPad2d.cs b/src/TorchSharp/NN/Padding/ReflectionPad2d.cs index 418e971c3..f2a505528 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad2d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReflectionPad2d module. /// - public sealed class ReflectionPad2d : torch.nn.Module + public sealed class ReflectionPad2d : PadBase { - internal ReflectionPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReflectionPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReflectionPad2d(params long[] padding) : base(nameof(ReflectionPad2d), PaddingModes.Reflect, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReflectionPad2d ReflectionPad2d(long padding) { - var handle = THSNN_ReflectionPad2d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad2d(handle, boxedHandle); + return new ReflectionPad2d(padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReflectionPad2d ReflectionPad2d(long padding) /// public static ReflectionPad2d ReflectionPad2d((long, long, long, long) padding) { - var handle = THSNN_ReflectionPad2d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad2d(handle, boxedHandle); + return new ReflectionPad2d(padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/Padding/ReflectionPad3d.cs b/src/TorchSharp/NN/Padding/ReflectionPad3d.cs index 18db464be..d1dbd584b 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad3d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad3d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReflectionPad3d module. /// - public sealed class ReflectionPad3d : torch.nn.Module + public sealed class ReflectionPad3d : PadBase { - internal ReflectionPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReflectionPad3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReflectionPad3d(params long[] padding) : base(nameof(ReflectionPad3d), PaddingModes.Reflect, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReflectionPad3d ReflectionPad3d(long padding) { - var handle = THSNN_ReflectionPad3d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad3d(handle, boxedHandle); + return new ReflectionPad3d(padding, padding, padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReflectionPad3d ReflectionPad3d(long padding) /// public static ReflectionPad3d ReflectionPad3d((long, long, long, long, long, long) padding) { - var handle = THSNN_ReflectionPad3d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad3d(handle, boxedHandle); + return new ReflectionPad3d(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6); } } } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad1d.cs b/src/TorchSharp/NN/Padding/ReplicationPad1d.cs index 55f572ee8..fb3744f5b 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad1d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad1d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReplicationPad1d module. /// - public sealed class ReplicationPad1d : torch.nn.Module + public sealed class ReplicationPad1d : PadBase { - internal ReplicationPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReplicationPad1d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReplicationPad1d(params long[] padding) : base(nameof(ReplicationPad1d), PaddingModes.Replicate, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReplicationPad1d ReplicationPad1d(long padding) { - var handle = THSNN_ReplicationPad1d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad1d(handle, boxedHandle); + return new ReplicationPad1d(padding, padding); } /// @@ -59,9 +39,7 @@ public static ReplicationPad1d ReplicationPad1d(long padding) /// public static ReplicationPad1d ReplicationPad1d((long, long) padding) { - var handle = THSNN_ReplicationPad1d_ctor_tuple(padding.Item1, padding.Item2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad1d(handle, boxedHandle); + return new ReplicationPad1d(padding.Item1, padding.Item2); } } } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad2d.cs b/src/TorchSharp/NN/Padding/ReplicationPad2d.cs index 205ac9e59..81b25ee27 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad2d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReplicationPad2d module. /// - public sealed class ReplicationPad2d : torch.nn.Module + public sealed class ReplicationPad2d : PadBase { - internal ReplicationPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReplicationPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReplicationPad2d(params long[] padding) : base(nameof(ReplicationPad2d), PaddingModes.Replicate, 0, padding) { } } } @@ -41,15 +23,13 @@ public static partial class torch public static partial class nn { /// - /// Pads the input tensor using replication of the input boundary. + /// Pads the input tensor using the replication of the input boundary. /// /// The size of the padding. /// public static ReplicationPad2d ReplicationPad2d(long padding) { - var handle = THSNN_ReplicationPad2d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad2d(handle, boxedHandle); + return new ReplicationPad2d(padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReplicationPad2d ReplicationPad2d(long padding) /// public static ReplicationPad2d ReplicationPad2d((long, long, long, long) padding) { - var handle = THSNN_ReplicationPad2d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad2d(handle, boxedHandle); + return new ReplicationPad2d(padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad3d.cs b/src/TorchSharp/NN/Padding/ReplicationPad3d.cs index 6b92f2972..7eddd4c8c 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad3d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad3d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReplicationPad3d module. /// - public sealed class ReplicationPad3d : torch.nn.Module + public sealed class ReplicationPad3d : PadBase { - internal ReplicationPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReplicationPad3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReplicationPad3d(params long[] padding) : base(nameof(ReplicationPad3d), PaddingModes.Replicate, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReplicationPad3d ReplicationPad3d(long padding) { - var handle = THSNN_ReplicationPad3d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad3d(handle, boxedHandle); + return new ReplicationPad3d(padding, padding, padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReplicationPad3d ReplicationPad3d(long padding) /// public static ReplicationPad3d ReplicationPad3d((long, long, long, long, long, long) padding) { - var handle = THSNN_ReplicationPad3d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad3d(handle, boxedHandle); + return new ReplicationPad3d(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6); } } } diff --git a/src/TorchSharp/NN/Padding/ZeroPad2d.cs b/src/TorchSharp/NN/Padding/ZeroPad2d.cs index 82a075d86..679e96e4d 100644 --- a/src/TorchSharp/NN/Padding/ZeroPad2d.cs +++ b/src/TorchSharp/NN/Padding/ZeroPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ZeroPad2d module. /// - public sealed class ZeroPad2d : torch.nn.Module + public sealed class ZeroPad2d : PadBase { - internal ZeroPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ZeroPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ZeroPad2d(params long[] padding) : base(nameof(ZeroPad2d), PaddingModes.Zeros, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ZeroPad2d ZeroPad2d(long padding) { - var handle = THSNN_ZeroPad2d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ZeroPad2d(handle, boxedHandle); + return new ZeroPad2d(padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ZeroPad2d ZeroPad2d(long padding) /// public static ZeroPad2d ZeroPad2d((long, long, long, long) padding) { - var handle = THSNN_ZeroPad2d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ZeroPad2d(handle, boxedHandle); + return new ZeroPad2d(padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/ParamLessModule.cs b/src/TorchSharp/NN/ParamLessModule.cs new file mode 100644 index 000000000..c3824a2ad --- /dev/null +++ b/src/TorchSharp/NN/ParamLessModule.cs @@ -0,0 +1,123 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using static TorchSharp.torch; +using static TorchSharp.PInvoke.NativeMethods; + +namespace TorchSharp +{ + using Modules; + + namespace Modules + { + public interface IParameterLessModule { + + } + /// + /// Base class for all modules that do not have any tensor parameters or buffers, and + /// for which the `_to()` implementation can therefore be simplified. + /// + public abstract class ParamLessModule : nn.Module, IParameterLessModule + { + protected ParamLessModule(string name) : base(name) { } + + protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) {} + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; + + public override void register_buffer(string name, Tensor tensor, bool persistent = true) + { + throw new InvalidOperationException($"Cannot register a buffer on a module that is declared 'parameter-less.'"); + } + + public override void register_parameter(string name, Parameter param) + { + throw new InvalidOperationException($"Cannot register a parameter on a module that is declared 'parameter-less.'"); + } + + public override void register_module(string name, nn.Module submodule) + { + if (submodule is not IParameterLessModule) + throw new InvalidOperationException($"Submodules of a parameter-less module must also be parameter-less."); + base.register_module(name, submodule); + } + } + + /// + /// Base class for all modules that do not have any tensor parameters or buffers, and + /// for which the `_to()` implementation can therefore be simplified. + /// + public abstract class ParamLessModule : nn.Module, IParameterLessModule + { + protected ParamLessModule(string name) : base(name) { } + + protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) {} + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; + + public override void register_buffer(string name, Tensor tensor, bool persistent = true) + { + throw new InvalidOperationException($"Cannot register a buffer on a module that is declared 'parameter-less.'"); + } + + public override void register_parameter(string name, Parameter param) + { + throw new InvalidOperationException($"Cannot register a parameter on a module that is declared 'parameter-less.'"); + } + + public override void register_module(string name, nn.Module submodule) + { + if (submodule is not IParameterLessModule) + throw new InvalidOperationException($"Submodules of a parameter-less module must also be parameter-less."); + base.register_module(name, submodule); + } + } + + /// + /// Base class for all modules that do not have any tensor parameters or buffers, and + /// for which the `_to()` implementation can therefore be simplified. + /// + public abstract class ParamLessModule : nn.Module, IParameterLessModule + { + protected ParamLessModule(string name) : base(name) { } + + protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) {} + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; + + public override void register_buffer(string name, Tensor tensor, bool persistent = true) + { + throw new InvalidOperationException($"Cannot register a buffer on a module that is declared 'parameter-less.'"); + } + + public override void register_parameter(string name, Parameter param) + { + throw new InvalidOperationException($"Cannot register a parameter on a module that is declared 'parameter-less.'"); + } + + public override void register_module(string name, nn.Module submodule) + { + if (submodule is not IParameterLessModule) + throw new InvalidOperationException($"Submodules of a parameter-less module must also be parameter-less."); + base.register_module(name, submodule); + } + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/NN/Parameter.cs b/src/TorchSharp/NN/Parameter.cs index 81e9051d8..4c1faa01e 100644 --- a/src/TorchSharp/NN/Parameter.cs +++ b/src/TorchSharp/NN/Parameter.cs @@ -26,6 +26,12 @@ public class Parameter : Tensor public Parameter(Tensor data, bool requires_grad = true) : base(data.with_requires_grad(requires_grad).MoveHandle()) { + var scope = data.OwningDisposeScope; + if (scope is not null) { + this.OwningDisposeScope = scope; + scope.Include(this); + scope.Detach(data); + } } /// @@ -35,7 +41,6 @@ public Parameter(Tensor data, bool requires_grad = true) : internal Parameter(System.IntPtr handle) : base(handle) { } - }; } diff --git a/src/TorchSharp/NN/PixelShuffle.cs b/src/TorchSharp/NN/PixelShuffle.cs index fe1d94bd5..ddb459a57 100644 --- a/src/TorchSharp/NN/PixelShuffle.cs +++ b/src/TorchSharp/NN/PixelShuffle.cs @@ -12,21 +12,24 @@ namespace Modules /// /// This class is used to represent a dropout module. /// - public sealed class PixelShuffle : torch.nn.Module + public sealed class PixelShuffle : ParamLessModule { - internal PixelShuffle(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal PixelShuffle(long upscale_factor) : base(nameof(PixelShuffle)) + { + this.upscale_factor = upscale_factor; + } /// /// Forward pass. /// - /// Input tensor + /// Input tensor /// - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_PixelShuffle_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.pixel_shuffle(input, this.upscale_factor); } + + public long upscale_factor { get; set; } } } @@ -38,13 +41,11 @@ public static partial class nn /// Rearranges elements in a tensor of shape (*, C * r^2, H, W) to a tensor of shape(*, C, H * r, W * r), where r is an upscale factor. /// This is useful for implementing efficient sub-pixel convolution with a stride of 1/r. /// - /// Factor to increase spatial resolution by + /// Factor to increase spatial resolution by /// - public static PixelShuffle PixelShuffle(long upscaleFactor) + public static PixelShuffle PixelShuffle(long upscale_factor) { - var handle = THSNN_PixelShuffle_ctor(upscaleFactor, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new PixelShuffle(handle, boxedHandle); + return new PixelShuffle(upscale_factor); } public static partial class functional @@ -53,15 +54,15 @@ public static partial class functional /// Rearranges elements in a tensor of shape (*, C * r^2, H, W) to a tensor of shape(*, C, H * r, W * r), where r is an upscale factor. /// This is useful for implementing efficient sub-pixel convolution with a stride of 1/r. /// - /// Input tensor - /// Factor to increase spatial resolution by + /// Input tensor + /// Factor to increase spatial resolution by /// /// - public static Tensor pixel_shuffle(Tensor x, long upscaleFactor) + public static Tensor pixel_shuffle(Tensor input, long upscale_factor) { - using (var d = nn.PixelShuffle(upscaleFactor)) { - return d.call(x); - } + var res = THSNN_pixel_shuffle(input.Handle, upscale_factor); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); } } } diff --git a/src/TorchSharp/NN/PixelUnshuffle.cs b/src/TorchSharp/NN/PixelUnshuffle.cs index e6d3f120a..6b4ab8b23 100644 --- a/src/TorchSharp/NN/PixelUnshuffle.cs +++ b/src/TorchSharp/NN/PixelUnshuffle.cs @@ -12,21 +12,24 @@ namespace Modules /// /// This class is used to represent a dropout module. /// - public sealed class PixelUnshuffle : torch.nn.Module + public sealed class PixelUnshuffle : ParamLessModule { - internal PixelUnshuffle(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal PixelUnshuffle(long downscale_factor) : base(nameof(PixelUnshuffle)) + { + this.downscale_factor = downscale_factor; + } /// /// Forward pass. /// - /// Input tensor + /// Input tensor /// - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_PixelUnshuffle_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.pixel_unshuffle(input, downscale_factor); } + + public long downscale_factor { get; set; } } } @@ -38,13 +41,11 @@ public static partial class nn /// /// Reverses the PixelShuffle operation by rearranging elements in a tensor of shape (*, C, H * r, W * r) to a tensor of shape (*, C * r^2, H, W), where r is an downscale factor. /// - /// Factor to increase spatial resolution by + /// Factor to increase spatial resolution by /// - public static PixelUnshuffle PixelUnshuffle(long downscaleFactor) + public static PixelUnshuffle PixelUnshuffle(long downscale_factor) { - var handle = THSNN_PixelUnshuffle_ctor(downscaleFactor, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new PixelUnshuffle(handle, boxedHandle); + return new PixelUnshuffle(downscale_factor); } public static partial class functional @@ -53,15 +54,15 @@ public static partial class functional /// Reverses the PixelShuffle operation by rearranging elements in a tensor of shape (*, C * r^2, H, W) to a tensor of shape(*, C, H * r, W * r), where r is an downscale factor. /// This is useful for implementing efficient sub-pixel convolution with a stride of 1/r. /// - /// Input tensor - /// Factor to increase spatial resolution by + /// Input tensor + /// Factor to increase spatial resolution by /// /// - public static Tensor pixel_unshuffle(Tensor x, long downscaleFactor) + public static Tensor pixel_unshuffle(Tensor input, long downscale_factor) { - using (var d = nn.PixelUnshuffle(downscaleFactor)) { - return d.call(x); - } + var res = THSNN_pixel_unshuffle(input.Handle, downscale_factor); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); } } } diff --git a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs index ec03a74f6..f1136eaa1 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs @@ -12,24 +12,19 @@ namespace Modules /// /// This class is used to represent a AdaptiveAvgPool1D module. /// - public sealed class AdaptiveAvgPool1d : torch.nn.Module + public sealed class AdaptiveAvgPool1d : ParamLessModule { - internal AdaptiveAvgPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AdaptiveAvgPool1d(long output_size) : base(nameof(AdaptiveAvgPool1d)) { + this.output_size = output_size; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AdaptiveAvgPool1d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.adaptive_avg_pool1d(input, this.output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long output_size { get; set; } } } @@ -41,14 +36,11 @@ public static partial class nn /// Applies a 1D adaptive average pooling over an input signal composed of several input planes. /// The output size is H, for any input size.The number of output features is equal to the number of input planes. /// - /// the target output size H + /// the target output size H /// - public static unsafe AdaptiveAvgPool1d AdaptiveAvgPool1d(long outputSize) + public static unsafe AdaptiveAvgPool1d AdaptiveAvgPool1d(long output_size) { - long* pkernelSize = stackalloc long[1] { outputSize }; - var handle = THSNN_AdaptiveAvgPool1d_ctor((IntPtr)pkernelSize, 1, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveAvgPool1d(handle, boxedHandle); + return new AdaptiveAvgPool1d(output_size); } public static partial class functional diff --git a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs index 3f481d9d2..04871729f 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs @@ -12,24 +12,19 @@ namespace Modules /// /// This class is used to represent a AdaptiveAvgPool2D module. /// - public sealed class AdaptiveAvgPool2d : torch.nn.Module + public sealed class AdaptiveAvgPool2d : ParamLessModule { - internal AdaptiveAvgPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AdaptiveAvgPool2d(long[] output_size) : base(nameof(AdaptiveAvgPool2d)) { + this.output_size = output_size; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AdaptiveAvgPool2d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.adaptive_avg_pool2d(input, this.output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] output_size { get; set; } } } @@ -41,43 +36,33 @@ public static partial class nn /// Applies a 2D adaptive average pooling over an input signal composed of several input planes. /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size (H,W) of the image of the form H x W. + /// The target output size (H,W) of the image of the form H x W. /// - public static unsafe AdaptiveAvgPool2d AdaptiveAvgPool2d(long[] outputSize) + public static unsafe AdaptiveAvgPool2d AdaptiveAvgPool2d(long[] output_size) { - fixed (long* poutputSize = outputSize) { - var handle = THSNN_AdaptiveAvgPool2d_ctor((IntPtr)poutputSize, outputSize.Length, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveAvgPool2d(handle, boxedHandle); - } + return new AdaptiveAvgPool2d(output_size); } /// /// Applies a 2D adaptive average pooling over an input signal composed of several input planes. /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size (H,W) of the image of the form H x W. + /// The target output size (H,W) of the image of the form H x W. /// - public static unsafe AdaptiveAvgPool2d AdaptiveAvgPool2d((long,long) outputSize) + public static unsafe AdaptiveAvgPool2d AdaptiveAvgPool2d((long,long) output_size) { - long* poutputSize = stackalloc long[2] { outputSize.Item1, outputSize.Item2 }; - var handle = THSNN_AdaptiveAvgPool2d_ctor((IntPtr)poutputSize, 2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveAvgPool2d(handle, boxedHandle); + return new AdaptiveAvgPool2d(new[] { output_size.Item1, output_size.Item2 }); } /// /// Applies a 2D adaptive average pooling over an input signal composed of several input planes. /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size (H,W) of the image of the form H x W. + /// The target output size (H,W) of the image of the form H x W. /// - public static unsafe AdaptiveAvgPool2d AdaptiveAvgPool2d(long outputSize) + public static unsafe AdaptiveAvgPool2d AdaptiveAvgPool2d(long output_size) { - long* poutputSize = stackalloc long[2] { outputSize, outputSize }; - var handle = THSNN_AdaptiveAvgPool2d_ctor((IntPtr)poutputSize, 2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveAvgPool2d(handle, boxedHandle); + return new AdaptiveAvgPool2d(new[] { output_size, output_size }); } public static partial class functional diff --git a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs index 862bf0b02..ce37c4f67 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs @@ -12,24 +12,19 @@ namespace Modules /// /// This class is used to represent a AdaptiveAvgPool3D module. /// - public sealed class AdaptiveAvgPool3d : torch.nn.Module + public sealed class AdaptiveAvgPool3d : ParamLessModule { - internal AdaptiveAvgPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AdaptiveAvgPool3d(long[] output_size) : base(nameof(AdaptiveAvgPool3d)) { + this.output_size = output_size; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AdaptiveAvgPool3d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.adaptive_avg_pool3d(input, this.output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] output_size { get; set; } } } @@ -41,44 +36,33 @@ public static partial class nn /// Applies a 3D adaptive average pooling over an input signal composed of several input planes. /// The output is of size D x H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size of the image of the form D x H x W. + /// The target output size of the image of the form D x H x W. /// - public static unsafe AdaptiveAvgPool3d AdaptiveAvgPool3d(long[] outputSize) + public static unsafe AdaptiveAvgPool3d AdaptiveAvgPool3d(long[] output_size) { - fixed (long* pkernelSize = outputSize) { - var handle = THSNN_AdaptiveAvgPool3d_ctor((IntPtr)pkernelSize, outputSize.Length, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveAvgPool3d(handle, boxedHandle); - } + return new AdaptiveAvgPool3d(output_size); } /// /// Applies a 3D adaptive average pooling over an input signal composed of several input planes. /// The output is of size D x H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size (D,H,W) of the image of the form D x H x W. + /// The target output size (D,H,W) of the image of the form D x H x W. /// - public static unsafe AdaptiveAvgPool3d AdaptiveAvgPool3d((long, long, long) outputSize) + public static unsafe AdaptiveAvgPool3d AdaptiveAvgPool3d((long, long, long) output_size) { - long* pkernelSize = stackalloc long[3] { outputSize.Item1, outputSize.Item2, outputSize.Item3 }; - - var handle = THSNN_AdaptiveAvgPool3d_ctor((IntPtr)pkernelSize, 3, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveAvgPool3d(handle, boxedHandle); + return new AdaptiveAvgPool3d(new[] { output_size.Item1, output_size.Item2, output_size.Item3 }); } /// /// Applies a 3D adaptive average pooling over an input signal composed of several input planes. /// The output is of size D x H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size (D,H,W) of the image of the form H x W. + /// The target output size (D,H,W) of the image of the form H x W. /// - public static unsafe AdaptiveAvgPool3d AdaptiveAvgPool3d(long outputSize) + public static unsafe AdaptiveAvgPool3d AdaptiveAvgPool3d(long output_size) { - long* pkernelSize = stackalloc long[3] { outputSize, outputSize, outputSize }; - var handle = THSNN_AdaptiveAvgPool3d_ctor((IntPtr)pkernelSize, 3, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveAvgPool3d(handle, boxedHandle); + return new AdaptiveAvgPool3d(new [] { output_size, output_size, output_size }); } public static partial class functional diff --git a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs index 269d5aeec..199aefaa5 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs @@ -12,24 +12,24 @@ namespace Modules /// /// This class is used to represent a AdaptiveMaxPool1D module. /// - public sealed class AdaptiveMaxPool1d : torch.nn.Module + public sealed class AdaptiveMaxPool1d : ParamLessModule { - internal AdaptiveMaxPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AdaptiveMaxPool1d(long output_size) : base(nameof(AdaptiveMaxPool1d)) { + this.output_size = output_size; } - public override Tensor forward(Tensor tensor) + public (Tensor Values, Tensor Indices) forward_with_indices(Tensor input) { - var res = THSNN_AdaptiveMaxPool1d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.adaptive_max_pool1d_with_indices(input, this.output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public override Tensor forward(Tensor input) + { + return torch.nn.functional.adaptive_max_pool1d(input, this.output_size); + } + + public long output_size { get; set; } } } @@ -41,17 +41,11 @@ public static partial class nn /// Applies a 1D adaptive max pooling over an input signal composed of several input planes. /// The output size is H, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size H. + /// The target output size H. /// - public static AdaptiveMaxPool1d AdaptiveMaxPool1d(long outputSize) + public static AdaptiveMaxPool1d AdaptiveMaxPool1d(long output_size) { - unsafe { - fixed (long* pkernelSize = new long[] { outputSize }) { - var handle = THSNN_AdaptiveMaxPool1d_ctor((IntPtr)pkernelSize, 1, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveMaxPool1d(handle, boxedHandle); - } - } + return new AdaptiveMaxPool1d(output_size); } public static partial class functional @@ -60,13 +54,32 @@ public static partial class functional /// Applies a 1D adaptive max pooling over an input signal composed of several input planes. /// The output size is H, for any input size.The number of output features is equal to the number of input planes. /// - /// - /// The target output size H. + /// + /// The target output size H. + /// + public static Tensor adaptive_max_pool1d(Tensor input, long output_size) + { + var ret = adaptive_max_pool1d_with_indices(input, output_size); + ret.Indices.Dispose(); + return ret.Values; + } + + /// + /// Applies a 1D adaptive max pooling over an input signal composed of several input planes. + /// The output size is H, for any input size.The number of output features is equal to the number of input planes. + /// + /// + /// The target output size H. /// - public static Tensor adaptive_max_pool1d(Tensor x, long outputSize) + public static (Tensor Values, Tensor Indices) adaptive_max_pool1d_with_indices(Tensor input, long output_size) { - using (var d = nn.AdaptiveMaxPool1d(outputSize)) { - return d.call(x); + var outputSizes = new long[] { output_size }; + unsafe { + fixed (long* poutputSize = outputSizes) { + var resOutput = THSTensor_adaptive_max_pool1d(input.Handle, (IntPtr)poutputSize, outputSizes.Length, out var resIndices); + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); + } } } } diff --git a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs index 15dae4187..8cba5c401 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs @@ -12,24 +12,24 @@ namespace Modules /// /// This class is used to represent a AdaptiveMaxPool2D module. /// - public sealed class AdaptiveMaxPool2d : torch.nn.Module + public sealed class AdaptiveMaxPool2d : ParamLessModule { - internal AdaptiveMaxPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AdaptiveMaxPool2d(long[] output_size) : base(nameof(AdaptiveMaxPool2d)) { + this.output_size = output_size; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AdaptiveMaxPool2d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.adaptive_max_pool2d(input, this.output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public (Tensor output, Tensor indices) forward_with_indices(Tensor input) + { + return torch.nn.functional.adaptive_max_pool2d_with_indices(input, this.output_size); + } + + public long[] output_size { get; set; } } } @@ -41,18 +41,12 @@ public static partial class nn /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. + /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. /// - public static AdaptiveMaxPool2d AdaptiveMaxPool2d(long[] outputSize) + public static AdaptiveMaxPool2d AdaptiveMaxPool2d(long[] output_size) { - unsafe { - fixed (long* pkernelSize = outputSize) { - var handle = THSNN_AdaptiveMaxPool2d_ctor((IntPtr)pkernelSize, outputSize.Length, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveMaxPool2d(handle, boxedHandle); - } - } + return new AdaptiveMaxPool2d(output_size); } public static partial class functional @@ -61,14 +55,33 @@ public static partial class functional /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// - /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. + /// + /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. + /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. + /// + public static Tensor adaptive_max_pool2d(Tensor input, long[] output_size) + { + var ret = adaptive_max_pool2d_with_indices(input, output_size); + ret.Indices.Dispose(); + return ret.Values; + } + + /// + /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. + /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. + /// + /// + /// Applies a 2D adaptive max pooling over an input signal composed of several input planes. /// The output is of size H x W, for any input size.The number of output features is equal to the number of input planes. /// - public static Tensor adaptive_max_pool2d(Tensor x, long[] outputSize) + public static (Tensor Values, Tensor Indices) adaptive_max_pool2d_with_indices(Tensor input, long[] output_size) { - using (var d = nn.AdaptiveMaxPool2d(outputSize)) { - return d.call(x); + unsafe { + fixed (long* poutputSize = output_size) { + var resOutput = THSTensor_adaptive_max_pool2d(input.Handle, (IntPtr)poutputSize, output_size.Length, out var resIndices); + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); + } } } } diff --git a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs index 3a07b1aa8..e59ce6565 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs @@ -12,24 +12,24 @@ namespace Modules /// /// This class is used to represent a AdaptiveMaxPool3D module. /// - public sealed class AdaptiveMaxPool3d : torch.nn.Module + public sealed class AdaptiveMaxPool3d : ParamLessModule { - internal AdaptiveMaxPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AdaptiveMaxPool3d(long[] output_size) : base(nameof(AdaptiveMaxPool3d)) { } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AdaptiveMaxPool3d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.adaptive_max_pool3d(input, output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public (Tensor output, Tensor indices) forward_with_indices(Tensor input) + { + return torch.nn.functional.adaptive_max_pool3d_with_indices(input, output_size); + } + + + public long[] output_size { get; set; } } } @@ -41,18 +41,12 @@ public static partial class nn /// Applies a 3D adaptive max pooling over an input signal composed of several input planes. /// The output is of size D x H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The target output size of the image of the form D x H x W. + /// The target output size of the image of the form D x H x W. /// Can be a tuple (D, H, W) or a single D for a cube D x D x D. D, H and W can be either a int, or null which means the size will be the same as that of the input. /// - public static AdaptiveMaxPool3d AdaptiveMaxPool3d(long[] outputSize) + public static AdaptiveMaxPool3d AdaptiveMaxPool3d(long[] output_size) { - unsafe { - fixed (long* pkernelSize = outputSize) { - var handle = THSNN_AdaptiveMaxPool3d_ctor((IntPtr)pkernelSize, outputSize.Length, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AdaptiveMaxPool3d(handle, boxedHandle); - } - } + return new AdaptiveMaxPool3d(output_size); } public static partial class functional @@ -61,14 +55,33 @@ public static partial class functional /// Applies a 3D adaptive max pooling over an input signal composed of several input planes. /// The output is of size D x H x W, for any input size.The number of output features is equal to the number of input planes. /// - /// The input tensor - /// The target output size of the image of the form D x H x W. + /// The input tensor + /// The target output size of the image of the form D x H x W. + /// Can be a tuple (D, H, W) or a single D for a cube D x D x D. D, H and W can be either a int, or null which means the size will be the same as that of the input. + /// + public static Tensor adaptive_max_pool3d(Tensor input, long[] output_size) + { + var ret = adaptive_max_pool3d_with_indices(input, output_size); + ret.Indices.Dispose(); + return ret.Values; + } + + /// + /// Applies a 3D adaptive max pooling over an input signal composed of several input planes. + /// The output is of size D x H x W, for any input size.The number of output features is equal to the number of input planes. + /// + /// The input tensor + /// The target output size of the image of the form D x H x W. /// Can be a tuple (D, H, W) or a single D for a cube D x D x D. D, H and W can be either a int, or null which means the size will be the same as that of the input. /// - public static Tensor adaptive_max_pool3d(Tensor x, long[] outputSize) + public static (Tensor Values, Tensor Indices) adaptive_max_pool3d_with_indices(Tensor input, long[] output_size) { - using (var d = nn.AdaptiveMaxPool3d(outputSize)) { - return d.call(x); + unsafe { + fixed (long* poutputSize = output_size) { + var resOutput = THSTensor_adaptive_max_pool1d(input.Handle, (IntPtr)poutputSize, output_size.Length, out var resIndices); + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); + } } } } diff --git a/src/TorchSharp/NN/Pooling/AvgPool1D.cs b/src/TorchSharp/NN/Pooling/AvgPool1D.cs index 3430fcdaa..7dcd374b0 100644 --- a/src/TorchSharp/NN/Pooling/AvgPool1D.cs +++ b/src/TorchSharp/NN/Pooling/AvgPool1D.cs @@ -12,24 +12,27 @@ namespace Modules /// /// This class is used to represent a AvgPool1D module. /// - public sealed class AvgPool1d : torch.nn.Module + public sealed class AvgPool1d : ParamLessModule { - internal AvgPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AvgPool1d(long kernel_size, long? stride = null, long? padding = null, bool ceil_mode = false, bool count_include_pad = true) : base(nameof(AvgPool1d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; + this.ceil_mode = ceil_mode; + this.count_include_pad = count_include_pad; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AvgPool1d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long kernel_size { get; set; } + public long? stride { get; set; } + public long? padding { get; set; } + public bool ceil_mode { get; set; } + public bool count_include_pad { get; set; } } } @@ -45,32 +48,9 @@ public static partial class nn /// implicit zero padding to be added on both sides /// Whether to use ceil instead of floor to compute the output shape /// Whether to include the zero-padding in the averaging calculation - /// If specified, it will be used as divisor, otherwise size of the pooling region will be used - public static AvgPool1d AvgPool1d(long kernel_size, long? stride = null, long padding = 0, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) + public static AvgPool1d AvgPool1d(long kernel_size, long? stride = null, long padding = 0, bool ceil_mode = false, bool count_include_pad = true) { - return stride.HasValue ? - AvgPool1d(new long[] { kernel_size }, new long[] { stride.Value }, new long[] { padding }, ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0) : - AvgPool1d(new long[] { kernel_size }, null, new long[] { padding }, ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0); - } - - /// - /// Applies a 1D average pooling over an input signal composed of several input planes. - /// - /// The size of the window - /// The stride of the window. Default value is kernel_size - /// implicit zero padding to be added on both sides - /// Whether to use ceil instead of floor to compute the output shape - /// Whether to include the zero-padding in the averaging calculation - /// If specified, it will be used as divisor, otherwise size of the pooling region will be used - private static AvgPool1d AvgPool1d(long[] kernel_size, long[] strides = null, long[] padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) - { - unsafe { - fixed (long* pkernelSize = kernel_size, pstrides = strides, ppadding = padding) { - var handle = THSNN_AvgPool1d_ctor((IntPtr)pkernelSize, (IntPtr)pstrides, (IntPtr)ppadding, ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AvgPool1d(handle, boxedHandle); - } - } + return new AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad); } public static partial class functional diff --git a/src/TorchSharp/NN/Pooling/AvgPool2D.cs b/src/TorchSharp/NN/Pooling/AvgPool2D.cs index bdbd3d41e..f44dfe42c 100644 --- a/src/TorchSharp/NN/Pooling/AvgPool2D.cs +++ b/src/TorchSharp/NN/Pooling/AvgPool2D.cs @@ -5,6 +5,7 @@ namespace TorchSharp { + using System.Data; using Modules; namespace Modules @@ -12,24 +13,29 @@ namespace Modules /// /// This class is used to represent a AvgPool2D module. /// - public sealed class AvgPool2d : torch.nn.Module + public sealed class AvgPool2d : ParamLessModule { - internal AvgPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AvgPool2d(long[] kernel_size, long[] stride = null, long[] padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) : base(nameof(AvgPool2d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; + this.ceil_mode = ceil_mode; + this.count_include_pad = count_include_pad; + this.divisor_override = divisor_override; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AvgPool2d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] kernel_size { get; set; } + public long[] stride { get; set; } + public long[] padding { get; set; } + public bool ceil_mode { get; set; } + public bool count_include_pad { get; set; } + public long? divisor_override { get; set; } } } @@ -41,18 +47,14 @@ public static partial class nn /// Applies a 2D average pooling over an input signal composed of several input planes. /// /// The size of the window - /// The stride of the window. Default value is kernel_size + /// The stride of the window. Default value is kernel_size /// implicit zero padding to be added on both sides /// Whether to use ceil instead of floor to compute the output shape /// Whether to include the zero-padding in the averaging calculation /// If specified, it will be used as divisor, otherwise size of the pooling region will be used - public static unsafe AvgPool2d AvgPool2d(long[] kernel_size, long[] strides = null, long[] padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) + public static AvgPool2d AvgPool2d(long[] kernel_size, long[] stride = null, long[] padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) { - fixed (long* pkernelSize = kernel_size, pstrides = strides, ppadding = padding) { - var handle = THSNN_AvgPool2d_ctor((IntPtr)pkernelSize, kernel_size.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), (IntPtr)ppadding, (padding == null ? 0 : padding.Length), ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AvgPool2d(handle, boxedHandle); - } + return new AvgPool2d(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); } /// @@ -66,19 +68,10 @@ public static unsafe AvgPool2d AvgPool2d(long[] kernel_size, long[] strides = nu /// If specified, it will be used as divisor, otherwise size of the pooling region will be used public static unsafe AvgPool2d AvgPool2d((long,long) kernel_size, (long,long)? stride = null, (long,long)? padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) { - long svalue1 = (stride == null) ? kernel_size.Item1 : stride.Value.Item1; - long svalue2 = (stride == null) ? kernel_size.Item2 : stride.Value.Item2; - - long pvalue1 = (padding == null) ? 0 : padding.Value.Item1; - long pvalue2 = (padding == null) ? 0 : padding.Value.Item2; - - long* pkernelSize = stackalloc long[2] { kernel_size.Item1, kernel_size.Item2 }; - long* pstrides = stackalloc long[2] { svalue1, svalue2 }; - long* ppadding = stackalloc long[2] { pvalue1, pvalue2 }; - - var handle = THSNN_AvgPool2d_ctor((IntPtr)pkernelSize, 2, (IntPtr)pstrides, 2, (IntPtr)ppadding, 2, ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AvgPool2d(handle, boxedHandle); + long[] kernelValue = new[] { kernel_size.Item1, kernel_size.Item2 }; + long[] strideValue = stride == null ? null : new[] { stride.Value.Item1, stride.Value.Item2 }; + long[] paddingValue = padding == null ? null : new[] { padding.Value.Item1, padding.Value.Item2 }; + return new AvgPool2d(kernelValue, strideValue, paddingValue, ceil_mode, count_include_pad, divisor_override); } /// @@ -90,18 +83,12 @@ public static unsafe AvgPool2d AvgPool2d((long,long) kernel_size, (long,long)? s /// Whether to use ceil instead of floor to compute the output shape /// Whether to include the zero-padding in the averaging calculation /// If specified, it will be used as divisor, otherwise size of the pooling region will be used - public static unsafe AvgPool2d AvgPool2d(long kernel_size, long? stride = null, long? padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) + public static AvgPool2d AvgPool2d(long kernel_size, long? stride = null, long? padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) { - long svalue = (stride == null) ? kernel_size : stride.Value; - long pvalue = (padding == null) ? 0 : padding.Value; - - long* pkernelSize = stackalloc long[2] { kernel_size, kernel_size }; - long* pstrides = stackalloc long[2] { svalue, svalue }; - long* ppadding = stackalloc long[2] { pvalue, pvalue }; - - var handle = THSNN_AvgPool2d_ctor((IntPtr)pkernelSize, 2, (IntPtr)pstrides, 2, (IntPtr)ppadding, 2, ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AvgPool2d(handle, boxedHandle); + long[] kernelValue = new[] { kernel_size, kernel_size }; + long[] strideValue = stride == null ? null : new[] { stride.Value, stride.Value }; + long[] paddingValue = padding == null ? null : new[] { padding.Value, padding.Value }; + return new AvgPool2d(kernelValue, strideValue, paddingValue, ceil_mode, count_include_pad, divisor_override); } public static partial class functional @@ -110,29 +97,32 @@ public static partial class functional /// Applies 2D average-pooling operation in kH × kW regions by step size sH * sW steps. The number of output features is equal to the number of input planes. /// /// The input tensor. - /// - /// - /// + /// + /// + /// /// /// + /// /// - public static Tensor avg_pool2d(Tensor input, long[] kernelSizes, - long[] strides = null, - long[] paddings = null, + public static Tensor avg_pool2d(Tensor input, long[] kernel_size, + long[] stride = null, + long[] padding = null, bool ceil_mode = false, - bool count_include_pad = true) + bool count_include_pad = true, + long? divisor_override = null) { - strides = (strides == null) ? new long[] { 1 } : strides; - paddings = (paddings == null) ? new long[] { 0 } : paddings; + stride = (stride == null) ? kernel_size : stride; + padding = (padding == null) ? new long[] { 0 } : padding; unsafe { - fixed (long* pkernelSize = kernelSizes, pstrides = strides, ppadding = paddings) { + fixed (long* pkernelSize = kernel_size, pstrides = stride, ppadding = padding) { var res = THSTensor_avg_pool2d(input.Handle, - (IntPtr)pkernelSize, kernelSizes.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, paddings.Length, + (IntPtr)pkernelSize, kernel_size.Length, + (IntPtr)pstrides, stride.Length, + (IntPtr)ppadding, padding.Length, ceil_mode, - count_include_pad); + count_include_pad, + divisor_override ?? 0); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } @@ -143,21 +133,23 @@ public static Tensor avg_pool2d(Tensor input, long[] kernelSizes, /// Applies 2D average-pooling operation in kH × kW regions by step size sH * sW steps. The number of output features is equal to the number of input planes. /// /// The input tensor. - /// + /// /// /// /// /// + /// /// - public static unsafe Tensor avg_pool2d(Tensor input, long kernelSize, + public static unsafe Tensor avg_pool2d(Tensor input, long kernel_size, long? stride = null, long padding = 0, bool ceil_mode = false, - bool count_include_pad = true) + bool count_include_pad = true, + long? divisor_override = null) { - long svalue = (stride == null) ? kernelSize : stride.Value; + long svalue = (stride == null) ? kernel_size : stride.Value; - long* pkernelSize = stackalloc long[2] { kernelSize, kernelSize }; + long* pkernelSize = stackalloc long[2] { kernel_size, kernel_size }; long* pstrides = stackalloc long[2] { svalue, svalue }; long* ppadding = stackalloc long[2] { padding, padding }; @@ -167,7 +159,8 @@ public static unsafe Tensor avg_pool2d(Tensor input, long kernelSize, (IntPtr)pstrides, 2, (IntPtr)ppadding, 2, ceil_mode, - count_include_pad); + count_include_pad, + divisor_override ?? 0); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } @@ -176,27 +169,29 @@ public static unsafe Tensor avg_pool2d(Tensor input, long kernelSize, /// Applies 2D average-pooling operation in kH × kW regions by step size sH * sW steps. The number of output features is equal to the number of input planes. /// /// The input tensor. - /// + /// /// /// /// /// + /// /// - public static unsafe Tensor avg_pool2d(Tensor input, (long, long) kernelSize, + public static unsafe Tensor avg_pool2d(Tensor input, (long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null, bool ceil_mode = false, - bool count_include_pad = true) + bool count_include_pad = true, + long? divisor_override = null) { - long svalue1 = (stride == null) ? kernelSize.Item1 : stride.Value.Item1; - long svalue2 = (stride == null) ? kernelSize.Item2 : stride.Value.Item2; + long svalue1 = (stride == null) ? kernel_size.Item1 : stride.Value.Item1; + long svalue2 = (stride == null) ? kernel_size.Item2 : stride.Value.Item2; long pvalue1 = padding != null ? padding.Value.Item1 : 0; long pvalue2 = padding != null ? padding.Value.Item2 : 0; long* pstrides = stackalloc long[2] { svalue1, svalue2 }; long* ppadding = stackalloc long[2] { pvalue1, pvalue2 }; - long* pkernelSize = stackalloc long[2] { kernelSize.Item1, kernelSize.Item2 }; + long* pkernelSize = stackalloc long[2] { kernel_size.Item1, kernel_size.Item2 }; var res = THSTensor_avg_pool2d(input.Handle, @@ -204,7 +199,8 @@ public static unsafe Tensor avg_pool2d(Tensor input, (long, long) kernelSize, (IntPtr)pstrides, 2, (IntPtr)ppadding, 2, ceil_mode, - count_include_pad); + count_include_pad, + divisor_override ?? 0); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } @@ -215,7 +211,7 @@ public static Tensor avg_pool2d_backward(Tensor input, Tensor originalInput, long[] paddings = null, bool ceil_mode = false, bool count_include_pad = true, - long divisorOverride = 0) + long? divisor_override = null) { strides = (strides == null) ? new long[] { 1 } : strides; paddings = (paddings == null) ? new long[] { 0 } : paddings; @@ -228,7 +224,7 @@ public static Tensor avg_pool2d_backward(Tensor input, Tensor originalInput, (IntPtr)ppadding, paddings.Length, ceil_mode, count_include_pad, - divisorOverride); + divisor_override ?? 0); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } diff --git a/src/TorchSharp/NN/Pooling/AvgPool3D.cs b/src/TorchSharp/NN/Pooling/AvgPool3D.cs index cf499c839..d08304bdc 100644 --- a/src/TorchSharp/NN/Pooling/AvgPool3D.cs +++ b/src/TorchSharp/NN/Pooling/AvgPool3D.cs @@ -12,24 +12,29 @@ namespace Modules /// /// This class is used to represent a AvgPool3D module. /// - public sealed class AvgPool3d : torch.nn.Module + public sealed class AvgPool3d : ParamLessModule { - internal AvgPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal AvgPool3d(long[] kernel_size, long[] stride = null, long[] padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) : base(nameof(AvgPool3d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; + this.ceil_mode = ceil_mode; + this.count_include_pad = count_include_pad; + this.divisor_override = divisor_override; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_AvgPool3d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.avg_pool3d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] kernel_size { get; set; } + public long[] stride { get; set; } + public long[] padding { get; set; } + public bool ceil_mode { get; set; } + public bool count_include_pad { get; set; } + public long? divisor_override { get; set; } } } @@ -41,20 +46,14 @@ public static partial class nn /// Applies a 3D average pooling over an input signal composed of several input planes. /// /// The size of the window - /// The stride of the window. Default value is kernel_size + /// The stride of the window. Default value is kernel_size /// implicit zero padding to be added on both sides /// Whether to use ceil instead of floor to compute the output shape /// Whether to include the zero-padding in the averaging calculation /// If specified, it will be used as divisor, otherwise size of the pooling region will be used - public static AvgPool3d AvgPool3d(long[] kernel_size, long[] strides = null, long[] padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) + public static AvgPool3d AvgPool3d(long[] kernel_size, long[] stride = null, long[] padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) { - unsafe { - fixed (long* pkernelSize = kernel_size, pstrides = strides, ppadding = padding) { - var handle = THSNN_AvgPool3d_ctor((IntPtr)pkernelSize, kernel_size.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), (IntPtr)ppadding, (padding == null ? 0 : padding.Length), ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AvgPool3d(handle, boxedHandle); - } - } + return new AvgPool3d(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); } /// @@ -68,21 +67,10 @@ public static AvgPool3d AvgPool3d(long[] kernel_size, long[] strides = null, lon /// If specified, it will be used as divisor, otherwise size of the pooling region will be used public static unsafe AvgPool3d AvgPool3d((long, long, long) kernel_size, (long, long, long)? stride = null, (long, long, long)? padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) { - long svalue1 = (stride == null) ? kernel_size.Item1 : stride.Value.Item1; - long svalue2 = (stride == null) ? kernel_size.Item2 : stride.Value.Item2; - long svalue3 = (stride == null) ? kernel_size.Item3 : stride.Value.Item3; - - long pvalue1 = (padding == null) ? 0 : padding.Value.Item1; - long pvalue2 = (padding == null) ? 0 : padding.Value.Item2; - long pvalue3 = (padding == null) ? 0 : padding.Value.Item3; - - long* pkernelSize = stackalloc long[3] { kernel_size.Item1, kernel_size.Item2, kernel_size.Item3 }; - long* pstrides = stackalloc long[3] { svalue1, svalue2, svalue3 }; - long* ppadding = stackalloc long[3] { pvalue1, pvalue2, pvalue3 }; - - var handle = THSNN_AvgPool3d_ctor((IntPtr)pkernelSize, 3, (IntPtr)pstrides, 3, (IntPtr)ppadding, 3, ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AvgPool3d(handle, boxedHandle); + long[] kernelValue = new[] { kernel_size.Item1, kernel_size.Item2, kernel_size.Item3 }; + long[] strideValue = stride == null ? null : new[] { stride.Value.Item1, stride.Value.Item2, stride.Value.Item3 }; + long[] paddingValue = padding == null ? null : new[] { padding.Value.Item1, padding.Value.Item2, padding.Value.Item3 }; + return new AvgPool3d(kernelValue, strideValue, paddingValue, ceil_mode, count_include_pad, divisor_override); } /// @@ -96,16 +84,10 @@ public static unsafe AvgPool3d AvgPool3d((long, long, long) kernel_size, (long, /// If specified, it will be used as divisor, otherwise size of the pooling region will be used public static unsafe AvgPool3d AvgPool3d(long kernel_size, long? stride = null, long? padding = null, bool ceil_mode = false, bool count_include_pad = true, long? divisor_override = null) { - long svalue = (stride == null) ? kernel_size : stride.Value; - long pvalue = (padding == null) ? 0 : padding.Value; - - long* pkernelSize = stackalloc long[3] { kernel_size, kernel_size, kernel_size }; - long* pstrides = stackalloc long[3] { svalue, svalue, svalue }; - long* ppadding = stackalloc long[3] { pvalue, pvalue, pvalue }; - - var handle = THSNN_AvgPool3d_ctor((IntPtr)pkernelSize, 3, (IntPtr)pstrides, 3, (IntPtr)ppadding, 3, ceil_mode, count_include_pad, divisor_override.HasValue ? divisor_override.Value : 0, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new AvgPool3d(handle, boxedHandle); + long[] kernelValue = new[] { kernel_size, kernel_size, kernel_size }; + long[] strideValue = stride == null ? null : new[] { stride.Value, stride.Value, stride.Value }; + long[] paddingValue = padding == null ? null : new[] { padding.Value, padding.Value, padding.Value }; + return new AvgPool3d(kernelValue, strideValue, paddingValue, ceil_mode, count_include_pad, divisor_override); } public static partial class functional @@ -114,29 +96,31 @@ public static partial class functional /// Applies 3D average-pooling operation in kT x kH x kW regions by step size sT x sH x sW steps. /// /// The input tensor. - /// - /// - /// + /// + /// + /// /// /// + /// /// - public static Tensor avg_pool3d(Tensor input, long[] kernelSizes, - long[] strides = null, - long[] paddings = null, + public static Tensor avg_pool3d(Tensor input, long[] kernel_size, + long[] stride = null, + long[] padding = null, bool ceil_mode = false, - bool count_include_pad = true) + bool count_include_pad = true, + long? divisor_override = null) { - strides = (strides == null) ? new long[] { 1 } : strides; - paddings = (paddings == null) ? new long[] { 0 } : paddings; + stride = (stride == null) ? kernel_size : stride; + padding = (padding == null) ? new long[] { 0 } : padding; unsafe { - fixed (long* pkernelSize = kernelSizes, pstrides = strides, ppadding = paddings) { + fixed (long* pkernelSize = kernel_size, pstrides = stride, ppadding = padding) { var res = THSTensor_avg_pool3d(input.Handle, - (IntPtr)pkernelSize, kernelSizes.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, paddings.Length, + (IntPtr)pkernelSize, kernel_size.Length, + (IntPtr)pstrides, stride.Length, + (IntPtr)ppadding, padding.Length, ceil_mode, - count_include_pad); + count_include_pad, divisor_override ?? 0); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } @@ -149,9 +133,9 @@ public static Tensor avg_pool3d_backward(Tensor input, Tensor originalInput, long[] paddings = null, bool ceil_mode = false, bool count_include_pad = true, - long divisorOverride = 0) + long? divisor_override = null) { - strides = (strides == null) ? new long[] { 1 } : strides; + strides = (strides == null) ? kernelSizes : strides; paddings = (paddings == null) ? new long[] { 0 } : paddings; unsafe { fixed (long* pkernelSize = kernelSizes, pstrides = strides, ppadding = paddings) { @@ -162,7 +146,7 @@ public static Tensor avg_pool3d_backward(Tensor input, Tensor originalInput, (IntPtr)ppadding, paddings.Length, ceil_mode, count_include_pad, - divisorOverride); + divisor_override ?? 0); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } diff --git a/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs b/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs index 7fbfa371a..e37d17083 100644 --- a/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs +++ b/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs @@ -5,6 +5,7 @@ namespace TorchSharp { + using System.Data; using Modules; namespace Modules @@ -12,31 +13,28 @@ namespace Modules /// /// This class is used to represent a FractionalMaxPool2D module. /// - public sealed class FractionalMaxPool2d : torch.nn.Module + public sealed class FractionalMaxPool2d : ParamLessModule { - internal FractionalMaxPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal FractionalMaxPool2d(long[] kernel_size, long[] output_size = null, double[] output_ratio = null) : base(nameof(FractionalMaxPool2d)) { + this.kernel_size = kernel_size; + this.output_size = output_size; + this.output_ratio = output_ratio; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_FractionalMaxPool2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.fractional_max_pool2d(input, kernel_size, output_size, output_ratio); } - public (Tensor Values, Tensor Indices) forward_with_indices(Tensor tensor) + public (Tensor Values, Tensor Indices) forward_with_indices(Tensor input) { - var res = THSNN_FractionalMaxPool2d_forward_with_indices(handle, tensor.Handle, out var indices); - if (res == IntPtr.Zero || indices == IntPtr.Zero) { torch.CheckForErrors(); } - return (new Tensor(res), new Tensor(indices)); + return torch.nn.functional.fractional_max_pool2d_with_indices(input, kernel_size, output_size, output_ratio); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] kernel_size { get; set; } + public long[] output_size { get; set; } + public double[] output_ratio { get; set; } } } @@ -101,16 +99,135 @@ public static FractionalMaxPool2d FractionalMaxPool2d(long[] kernel_size, long[] if (output_size != null && output_ratio != null) throw new ArgumentNullException("FractionalMaxPool2d requires specifying either an output size, or a pooling ratio."); - unsafe { - fixed (long* pkernelSize = kernel_size, pSize = output_size) { - fixed (double* pRatio = output_ratio) { - var handle = THSNN_FractionalMaxPool2d_ctor( - (IntPtr)pkernelSize, kernel_size.Length, - (IntPtr)pSize, (output_size == null ? 0 : output_size.Length), - (IntPtr)pRatio, (output_ratio == null ? 0 : output_ratio.Length), - out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new FractionalMaxPool2d(handle, boxedHandle); + return new FractionalMaxPool2d(kernel_size, output_size, output_ratio); + } + + public static partial class functional + { + /// + /// Applies a 2D fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static Tensor fractional_max_pool2d(Tensor input, long kernel_size, long? output_size = null, double? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value, output_size.Value } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value, output_ratio.Value } : null; + return fractional_max_pool2d(input, new long[] { kernel_size, kernel_size }, pSize, pRatio); + } + + /// + /// Applies a 2D fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static Tensor fractional_max_pool2d(Tensor input, (long, long) kernel_size, (long, long)? output_size = null, (double, double)? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value.Item1, output_size.Value.Item2 } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value.Item1, output_ratio.Value.Item2 } : null; + return fractional_max_pool2d(input, new long[] { kernel_size.Item1, kernel_size.Item2 }, pSize, pRatio); + } + + /// + /// Applies a 2D fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static Tensor fractional_max_pool2d(Tensor input, long[] kernel_size, long[] output_size = null, double[] output_ratio = null) + { + var ret = fractional_max_pool2d_with_indices(input, kernel_size, output_size, output_ratio); + ret.Indices.Dispose(); + return ret.Values; + } + + /// + /// Applies a 2D fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static (Tensor Values, Tensor Indices) fractional_max_pool2d_with_indices(Tensor input, long kernel_size, long? output_size = null, double? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value, output_size.Value } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value, output_ratio.Value } : null; + return fractional_max_pool2d_with_indices(input, new long[] { kernel_size, kernel_size }, pSize, pRatio); + } + + /// + /// Applies a 2D fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static (Tensor Values, Tensor Indices) fractional_max_pool2d_with_indices(Tensor input, (long, long) kernel_size, (long, long)? output_size = null, (double, double)? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value.Item1, output_size.Value.Item2 } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value.Item1, output_ratio.Value.Item2 } : null; + return fractional_max_pool2d_with_indices(input, new long[] { kernel_size.Item1, kernel_size.Item2 }, pSize, pRatio); + } + + /// + /// Applies a 2D fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static (Tensor Values, Tensor Indices) fractional_max_pool2d_with_indices(Tensor input, long[] kernel_size, long[] output_size = null, double[] output_ratio = null) + { + if (kernel_size == null || kernel_size.Length != 2) + throw new ArgumentException("Kernel size must contain two elements."); + if (output_size != null && output_size.Length != 2) + throw new ArgumentException("output_size must contain two elements."); + if (output_ratio != null && output_ratio.Length != 2) + throw new ArgumentException("output_ratio must contain two elements."); + if (output_size == null && output_ratio == null) + throw new ArgumentNullException("Only one of output_size and output_ratio may be specified."); + if (output_size != null && output_ratio != null) + throw new ArgumentNullException("FractionalMaxPool2d requires specifying either an output size, or a pooling ratio."); + + output_size ??= Array.Empty(); + output_ratio ??= Array.Empty(); + + unsafe { + fixed (long* pkernelSize = kernel_size, poutputSize = output_size) { + fixed (double* poutputRatio = output_ratio) { + var resOutput = THSTensor_fractional_max_pool2d(input.Handle, (IntPtr)pkernelSize, kernel_size.Length, (IntPtr)poutputSize, output_size.Length, (IntPtr)poutputRatio, output_ratio.Length, out var resIndices); + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); + } } } } diff --git a/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs b/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs index 59be5e2b4..98ac3d0eb 100644 --- a/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs +++ b/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs @@ -12,42 +12,28 @@ namespace Modules /// /// This class is used to represent a FractionalMaxPool3d module. /// - public sealed class FractionalMaxPool3d : torch.nn.Module + public sealed class FractionalMaxPool3d : ParamLessModule { - internal FractionalMaxPool3d(IntPtr handle, IntPtr boxedHandle, bool ratio) : base(handle, boxedHandle) + internal FractionalMaxPool3d(long[] kernel_size, long[] output_size = null, double[] output_ratio = null) : base(nameof(FractionalMaxPool3d)) { - _used_ratio = ratio; + this.kernel_size = kernel_size; + this.output_size = output_size; + this.output_ratio = output_ratio; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - if (_used_ratio && tensor.ndim != 5) - // Not sure why this is the case, but there's an exception in the native runtime - // unless there's both a batch dimension and a channel dimension. - throw new ArgumentException("FractionalMaxPool3d: input tensor must have 5 dimensions: [N, C, D, H, W]"); - var res = THSNN_FractionalMaxPool3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.fractional_max_pool3d(input, kernel_size, output_size, output_ratio); } - public (Tensor Values, Tensor Indices) forward_with_indices(Tensor tensor) + public (Tensor Values, Tensor Indices) forward_with_indices(Tensor input) { - if (_used_ratio && tensor.ndim != 5) - // Not sure why this is the case, but there's an exception in the native runtime - // unless there's both a batch dimension and a channel dimension. - throw new ArgumentException("FractionalMaxPool3d: input tensor must have 5 dimensions: [N, C, D, H, W]"); - var res = THSNN_FractionalMaxPool3d_forward_with_indices(handle, tensor.Handle, out var indices); - if (res == IntPtr.Zero || indices == IntPtr.Zero) { torch.CheckForErrors(); } - return (new Tensor(res), new Tensor(indices)); + return torch.nn.functional.fractional_max_pool3d_with_indices(input, kernel_size, output_size, output_ratio); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - - private bool _used_ratio = false; + public long[] kernel_size { get; set; } + public long[] output_size { get; set; } + public double[] output_ratio { get; set; } } } @@ -112,16 +98,139 @@ public static FractionalMaxPool3d FractionalMaxPool3d(long[] kernel_size, long[] if (output_size != null && output_ratio != null) throw new ArgumentNullException("FractionalMaxPool3d requires specifying either an output size, or a pooling ratio."); - unsafe { - fixed (long* pkernelSize = kernel_size, pSize = output_size) { - fixed (double* pRatio = output_ratio) { - var handle = THSNN_FractionalMaxPool3d_ctor( - (IntPtr)pkernelSize, kernel_size.Length, - (IntPtr)pSize, (output_size == null ? 0 : output_size.Length), - (IntPtr)pRatio, (output_ratio == null ? 0 : output_ratio.Length), - out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new FractionalMaxPool3d(handle, boxedHandle, output_ratio != null); + return new FractionalMaxPool3d(kernel_size, output_size, output_ratio); + } + + public static partial class functional + { + /// + /// Applies a 3d fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static Tensor fractional_max_pool3d(Tensor input, long kernel_size, long? output_size = null, double? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value, output_size.Value, output_size.Value } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value, output_ratio.Value, output_ratio.Value } : null; + return fractional_max_pool3d(input, new long[] { kernel_size, kernel_size, kernel_size }, pSize, pRatio); + } + + /// + /// Applies a 3d fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static Tensor fractional_max_pool3d(Tensor input, (long, long, long) kernel_size, (long, long, long)? output_size = null, (double, double, double)? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value.Item1, output_size.Value.Item2, output_size.Value.Item3 } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value.Item1, output_ratio.Value.Item2, output_ratio.Value.Item3 } : null; + return fractional_max_pool3d(input, new long[] { kernel_size.Item1, kernel_size.Item2, kernel_size.Item3 }, pSize, pRatio); + } + + /// + /// Applies a 3d fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static Tensor fractional_max_pool3d(Tensor input, long[] kernel_size, long[] output_size = null, double[] output_ratio = null) + { + var ret = fractional_max_pool3d_with_indices(input, kernel_size, output_size, output_ratio); + ret.Indices.Dispose(); + return ret.Values; + } + + /// + /// Applies a 3d fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static (Tensor Values, Tensor Indices) fractional_max_pool3d_with_indices(Tensor input, long kernel_size, long? output_size = null, double? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value, output_size.Value, output_size.Value } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value, output_ratio.Value, output_ratio.Value } : null; + return fractional_max_pool3d_with_indices(input, new long[] { kernel_size, kernel_size, kernel_size }, pSize, pRatio); + } + + /// + /// Applies a 3d fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static (Tensor Values, Tensor Indices) fractional_max_pool3d_with_indices(Tensor input, (long, long, long) kernel_size, (long, long, long)? output_size = null, (double, double, double)? output_ratio = null) + { + var pSize = output_size.HasValue ? new long[] { output_size.Value.Item1, output_size.Value.Item2, output_size.Value.Item3 } : null; + var pRatio = output_ratio.HasValue ? new double[] { output_ratio.Value.Item1, output_ratio.Value.Item2, output_ratio.Value.Item3 } : null; + return fractional_max_pool3d_with_indices(input, new long[] { kernel_size.Item1, kernel_size.Item2, kernel_size.Item3 }, pSize, pRatio); + } + + /// + /// Applies a 3d fractional max pooling over an input signal composed of several input planes. + /// + /// Fractional MaxPooling is described in detail in the paper Fractional MaxPooling by Ben Graham, + /// see: https://arxiv.org/abs/1412.6071 + /// + /// The input tensor + /// The size of the sliding window, must be > 0. + /// The target output size of the image of the form oH x oW. Can be a tuple (oH, oW) or a single number oH for a square image oH x oH + /// If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1) + /// + public static (Tensor Values, Tensor Indices) fractional_max_pool3d_with_indices(Tensor input, long[] kernel_size, long[] output_size = null, double[] output_ratio = null) + { + if (kernel_size == null || kernel_size.Length != 3) + throw new ArgumentException("Kernel size must contain three elements."); + if (output_size != null && output_size.Length != 3) + throw new ArgumentException("output_size must contain three elements."); + if (output_ratio != null && output_ratio.Length != 3) + throw new ArgumentException("output_ratio must contain three elements."); + if (output_size == null && output_ratio == null) + throw new ArgumentNullException("Only one of output_size and output_ratio may be specified."); + if (output_size != null && output_ratio != null) + throw new ArgumentNullException("FractionalMaxPool3d requires specifying either an output size, or a pooling ratio."); + if (output_ratio != null && input.ndim != 5) + // Not sure why this is the case, but there's an exception in the native runtime + // unless there's both a batch dimension and a channel dimension. + throw new ArgumentException("FractionalMaxPool3d: input tensor must have 5 dimensions: [N, C, D, H, W]"); + + output_size ??= Array.Empty(); + output_ratio ??= Array.Empty(); + + unsafe { + fixed (long* pkernelSize = kernel_size, poutputSize = output_size) { + fixed (double* poutputRatio = output_ratio) { + var resOutput = THSTensor_fractional_max_pool3d(input.Handle, (IntPtr)pkernelSize, kernel_size.Length, (IntPtr)poutputSize, output_size.Length, (IntPtr)poutputRatio, output_ratio.Length, out var resIndices); + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); + } } } } diff --git a/src/TorchSharp/NN/Pooling/LPPool1d.cs b/src/TorchSharp/NN/Pooling/LPPool1d.cs index 424da18d5..9babde806 100644 --- a/src/TorchSharp/NN/Pooling/LPPool1d.cs +++ b/src/TorchSharp/NN/Pooling/LPPool1d.cs @@ -12,24 +12,25 @@ namespace Modules /// /// This class is used to represent a LPPool1D module. /// - public sealed class LPPool1d : torch.nn.Module + public sealed class LPPool1d : ParamLessModule { - internal LPPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal LPPool1d(double norm_type, long kernel_size, long? stride = null, bool ceil_mode = false) : base(nameof(LPPool1d)) { + this.norm_type = norm_type; + this.kernel_size = kernel_size; + this.stride = stride; + this.ceil_mode = ceil_mode; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_LPPool1d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.lp_pool1d(input, norm_type, kernel_size, stride, ceil_mode); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public double norm_type { get; set; } + public long kernel_size { get; set; } + public long? stride { get; set; } + public bool ceil_mode { get; set; } } } @@ -41,32 +42,37 @@ public static partial class nn /// Applies a 1D power-average pooling over an input signal composed of several input planes. /// /// The LP norm (exponent) - /// The size of the window + /// The size of the window /// The stride of the window. Default value is kernel_size /// Use ceil instead of floor to compute the output shape /// - public static LPPool1d LPPool1d(double norm_type, long kernelSize, long? stride = null, bool ceil_mode = false) + public static LPPool1d LPPool1d(double norm_type, long kernel_size, long? stride = null, bool ceil_mode = false) { - return stride.HasValue ? - LPPool1d(norm_type, new long[] { kernelSize }, new long[] { stride.Value }, ceil_mode) : - LPPool1d(norm_type, new long[] { kernelSize }, null); + return new LPPool1d(norm_type, kernel_size, stride, ceil_mode); } - /// - /// Applies a 1D power-average pooling over an input signal composed of several input planes. - /// - /// The LP norm (exponent) - /// The size of the window - /// The stride of the window. Default value is kernel_size - /// Use ceil instead of floor to compute the output shape - /// - private static LPPool1d LPPool1d(double norm_type, long[] kernelSize, long[] strides = null, bool ceil_mode = false) + public static partial class functional { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides) { - var handle = THSNN_LPPool1d_ctor(norm_type, (IntPtr)pkernelSize, (IntPtr)pstrides, ceil_mode, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new LPPool1d(handle, boxedHandle); + /// + /// Applies a 1D power-average pooling over an input signal composed of several input planes. + /// + /// The input tensor + /// The LP norm (exponent) + /// The size of the window + /// The stride of the window. Default value is kernel_size + /// Use ceil instead of floor to compute the output shape + /// + public static Tensor lp_pool1d(Tensor input, double norm_type, long kernel_size, long? stride = null, bool ceil_mode = false) + { + var kernels = new[] { kernel_size }; + var strides = stride.HasValue ? new[] { stride.Value } : Array.Empty(); + + unsafe { + fixed (long* pkernelSize = kernels, pstrides = strides) { + var res = THSTensor_lp_pool1d(input.Handle, norm_type, (IntPtr)pkernelSize, kernels.Length, (IntPtr)pstrides, strides.Length, ceil_mode); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } } } } diff --git a/src/TorchSharp/NN/Pooling/LPPool2d.cs b/src/TorchSharp/NN/Pooling/LPPool2d.cs index 67c06b58b..4fe66dafd 100644 --- a/src/TorchSharp/NN/Pooling/LPPool2d.cs +++ b/src/TorchSharp/NN/Pooling/LPPool2d.cs @@ -12,24 +12,25 @@ namespace Modules /// /// This class is used to represent a LPPool2D module. /// - public sealed class LPPool2d : torch.nn.Module + public sealed class LPPool2d : ParamLessModule { - internal LPPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal LPPool2d(double norm_type, long[] kernel_size, long[] stride = null, bool ceil_mode = false) : base(nameof(LPPool2d)) { + this.norm_type = norm_type; + this.kernel_size = kernel_size; + this.stride = stride; + this.ceil_mode = ceil_mode; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_LPPool2d_forward(handle.DangerousGetHandle(), tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.lp_pool2d(input, norm_type, kernel_size, stride, ceil_mode); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public double norm_type { get; set; } + public long[] kernel_size { get; set; } + public long[] stride { get; set; } + public bool ceil_mode { get; set; } } } @@ -42,18 +43,12 @@ public static partial class nn /// /// The LP norm (exponent) /// The size of the window - /// The stride of the window. Default value is kernel_size + /// The stride of the window. Default value is kernel_size /// Use ceil instead of floor to compute the output shape /// - public static LPPool2d LPPool2d(double norm_type, long[] kernel_size, long[] strides = null, bool ceil_mode = false) + public static LPPool2d LPPool2d(double norm_type, long[] kernel_size, long[] stride = null, bool ceil_mode = false) { - unsafe { - fixed (long* pkernelSize = kernel_size, pstrides = strides) { - var handle = THSNN_LPPool2d_ctor(norm_type, (IntPtr)pkernelSize, kernel_size.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), ceil_mode, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new LPPool2d(handle, boxedHandle); - } - } + return new LPPool2d(norm_type, kernel_size, stride, ceil_mode); } /// @@ -66,9 +61,46 @@ public static LPPool2d LPPool2d(double norm_type, long[] kernel_size, long[] str /// public static LPPool2d LPPool2d(double norm_type, long kernel_size, long? stride = null, bool ceil_mode = false) { - return stride.HasValue ? - LPPool2d(norm_type, new long[] { kernel_size, kernel_size }, new long[] { stride.Value, stride.Value }, ceil_mode) : - LPPool2d(norm_type, new long[] { kernel_size, kernel_size }, null, ceil_mode); + return new LPPool2d(norm_type, new[] { kernel_size, kernel_size }, stride.HasValue ? new[] { stride.Value, stride.Value } : null, ceil_mode); + } + + public static partial class functional + { + /// + /// Applies a 2D power-average pooling over an input signal composed of several input planes. + /// + /// The input tensor + /// The LP norm (exponent) + /// The size of the window + /// The stride of the window. Default value is kernel_size + /// Use ceil instead of floor to compute the output shape + /// + public static Tensor lp_pool2d(Tensor input, double norm_type, long[] kernel_size, long[] stride = null, bool ceil_mode = false) + { + stride ??= Array.Empty(); + + unsafe { + fixed (long* pkernelSize = kernel_size, pstrides = stride) { + var res = THSTensor_lp_pool2d(input.Handle, norm_type, (IntPtr)pkernelSize, kernel_size.Length, (IntPtr)pstrides, stride.Length, ceil_mode); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } + } + } + + /// + /// Applies a 2D power-average pooling over an input signal composed of several input planes. + /// + /// The input tensor + /// The LP norm (exponent) + /// The size of the window + /// The stride of the window. + /// Use ceil instead of floor to compute the output shape + /// + public static Tensor lp_pool2d(Tensor input, double norm_type, long kernel_size, long? stride = null, bool ceil_mode = false) + { + return lp_pool2d(input, norm_type, new[] { kernel_size, kernel_size }, stride.HasValue ? new[] { stride.Value, stride.Value } : null, ceil_mode); + } } } } diff --git a/src/TorchSharp/NN/Pooling/MaxPool1D.cs b/src/TorchSharp/NN/Pooling/MaxPool1D.cs index 79a521f59..558664c74 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool1D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool1D.cs @@ -13,31 +13,32 @@ namespace Modules /// /// This class is used to represent a MaxPool1D module. /// - public sealed class MaxPool1d : torch.nn.Module + public sealed class MaxPool1d : ParamLessModule { - internal MaxPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal MaxPool1d(long kernel_size, long? stride = null, long? padding = null, long? dilation = null, bool ceil_mode = false) : base(nameof(MaxPool1d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; + this.dilation = dilation; + this.ceil_mode = ceil_mode; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_MaxPool1d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode); } - public (Tensor Values, Tensor Indices) forward_with_indices(Tensor tensor) + public (Tensor Values, Tensor Indices) forward_with_indices(Tensor input) { - var res = THSNN_MaxPool1d_forward_with_indices(handle, tensor.Handle, out var indices); - if (res == IntPtr.Zero || indices == IntPtr.Zero) { torch.CheckForErrors(); } - return (new Tensor(res), new Tensor(indices)); + return torch.nn.functional.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long kernel_size { get; set; } + public long? stride { get; set; } + public long? padding { get; set; } + public long? dilation { get; set; } + public bool ceil_mode { get; set; } } } @@ -48,100 +49,56 @@ public static partial class nn /// /// Applies a 1D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// The stride between elements within a sliding window, must be > 0. - /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. + /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. /// - public static MaxPool1d MaxPool1d(long kernelSize, long? stride = null, long? padding = null, long? dilation = null, bool ceilMode = false) + public static MaxPool1d MaxPool1d(long kernel_size, long? stride = null, long? padding = null, long? dilation = null, bool ceil_mode = false) { - var pStride = stride.HasValue ? new long[] { stride.Value } : null; - var pPadding = padding.HasValue ? new long[] { padding.Value } : null; - var pDilation = dilation.HasValue ? new long[] { dilation.Value } : null; - return MaxPool1d(new long[] { kernelSize }, pStride, pPadding, pDilation, ceilMode); - } - - private static MaxPool1d MaxPool1d(long[] kernelSize, long[] strides = null, long[] padding = null, long[] dilation = null, bool ceilMode = false) - { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, pPadding = padding, pDilation = dilation) { - var handle = THSNN_MaxPool1d_ctor((IntPtr)pkernelSize, (IntPtr)pstrides, (IntPtr)pPadding, (IntPtr)pDilation, ceilMode, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxPool1d(handle, boxedHandle); - } - } + return new MaxPool1d(kernel_size, stride, padding, dilation, ceil_mode); } public static partial class functional { - /// - /// Applies a 1D max pooling over an input signal composed of several input planes. - /// - /// The input tensor. - /// - /// - /// - /// - /// - /// - public static Tensor max_pool1d(Tensor input, long kernelSize, long? stride = null, + public static Tensor max_pool1d(Tensor input, long kernel_size, long? stride = null, long? padding = null, long? dilation = null, bool ceil_mode = false) { - var kernelSizes = new long[] { kernelSize }; - var strides = new long[] { stride ?? kernelSize }; - var paddings = new long[] { padding ?? 0 }; - var dilations = new long[] { dilation ?? 1 }; - unsafe { - fixed (long* pkernelSize = kernelSizes, pstrides = strides, ppadding = paddings, pdilation = dilations) { - var res = - THSTensor_max_pool1d(input.Handle, - (IntPtr)pkernelSize, kernelSizes.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, paddings.Length, - (IntPtr)pdilation, dilations.Length, - ceil_mode); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - } + var ret = max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode); + ret.Indices.Dispose(); + return ret.Values; } /// /// Applies a 1D max pooling over an input signal composed of several input planes. /// /// The input tensor. - /// + /// /// /// /// /// /// - public static (Tensor output, Tensor indices) max_pool1d_with_indices(Tensor input, long kernelSize, long? stride = null, + public static (Tensor Values, Tensor Indices) max_pool1d_with_indices(Tensor input, long kernel_size, long? stride = null, long? padding = null, long? dilation = null, bool ceil_mode = false) { - var kernelSizes = new long[] { kernelSize }; - var strides = new long[] { stride ?? kernelSize }; + var kernelSizes = new long[] { kernel_size }; + var strides = new long[] { stride ?? kernel_size }; var paddings = new long[] { padding ?? 0 }; var dilations = new long[] { dilation ?? 1 }; - IntPtr[] ptrArray; - - using (var pa = new PinnedArray()) { - unsafe { - fixed (long* pkernelSize = kernelSizes, pstrides = strides, ppadding = paddings, pdilation = dilations) { - THSTensor_max_pool1d_with_indices(input.Handle, - pa.CreateArray, + unsafe { + fixed (long* pkernelSize = kernelSizes, pstrides = strides, ppadding = paddings, pdilation = dilations) { + var resOutput = THSTensor_max_pool1d_with_indices(input.Handle, (IntPtr)pkernelSize, kernelSizes.Length, (IntPtr)pstrides, strides.Length, (IntPtr)ppadding, paddings.Length, (IntPtr)pdilation, dilations.Length, - ceil_mode); - torch.CheckForErrors(); - } + ceil_mode, out var resIndices); + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); } - ptrArray = pa.Array; } - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); } } } diff --git a/src/TorchSharp/NN/Pooling/MaxPool2D.cs b/src/TorchSharp/NN/Pooling/MaxPool2D.cs index 55808c454..2cef87ce6 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool2D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool2D.cs @@ -13,30 +13,32 @@ namespace Modules /// /// This class is used to represent a MaxPool2D module. /// - public sealed class MaxPool2d : torch.nn.Module + public sealed class MaxPool2d : ParamLessModule { - internal MaxPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal MaxPool2d(long[] kernel_size, long[] stride = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) : base(nameof(MaxPool2d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; + this.dilation = dilation; + this.ceil_mode = ceil_mode; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_MaxPool2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode); } - public (Tensor Values, Tensor Indices) forward_with_indices(Tensor tensor) + + public (Tensor Values, Tensor Indices) forward_with_indices(Tensor input) { - var res = THSNN_MaxPool2d_forward_with_indices(handle, tensor.Handle, out var indices); - if (res == IntPtr.Zero || indices == IntPtr.Zero) { torch.CheckForErrors(); } - return (new Tensor(res), new Tensor(indices)); + return torch.nn.functional.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] kernel_size { get; set; } + public long[] stride { get; set; } + public long[] padding { get; set; } + public long[] dilation { get; set; } + public bool ceil_mode { get; set; } } } @@ -47,74 +49,53 @@ public static partial class nn /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// The stride between elements within a sliding window, must be > 0. - /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. + /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. /// - public static unsafe MaxPool2d MaxPool2d(long kernelSize, long? stride = null, long? padding = null, long? dilation = null, bool ceilMode = false) + public static MaxPool2d MaxPool2d(long kernel_size, long? stride = null, long? padding = null, long? dilation = null, bool ceil_mode = false) { - long svalue = stride.HasValue ? stride.Value : kernelSize; - long pvalue = padding.HasValue ? padding.Value : 0; - long dvalue = dilation.HasValue ? dilation.Value : 1; - - long* pStride = stackalloc long[2] { svalue, svalue }; - long* pPadding = stackalloc long[2] { pvalue, pvalue }; - long* pDilation = stackalloc long[2] { dvalue, dvalue }; + long[] kernelValue = new[] { kernel_size, kernel_size }; + long[] strideValue = stride.HasValue ? new[] { stride.Value, stride.Value } : kernelValue.ToArray(); + long[] paddingValue = padding.HasValue ? new[] { padding.Value, padding.Value } : new[] { 0L, 0L }; + long[] dilationValue = dilation.HasValue ? new[] { dilation.Value, dilation.Value } : new[] { 1L, 1L }; - long* pkernelSize = stackalloc long[2] { kernelSize, kernelSize }; - - var handle = THSNN_MaxPool2d_ctor((IntPtr)pkernelSize, 2, (IntPtr)pStride, 2, (IntPtr)pPadding, 2, (IntPtr)pDilation, 2, ceilMode, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxPool2d(handle, boxedHandle); + return new MaxPool2d(kernelValue, strideValue, paddingValue, dilationValue, ceil_mode); } /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// The stride between elements within a sliding window, must be > 0. - /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. + /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. /// - public static unsafe MaxPool2d MaxPool2d((long, long) kernelSize, (long, long)? stride = null, (long, long)? padding = null, (long, long)? dilation = null, bool ceilMode = false) + public static unsafe MaxPool2d MaxPool2d((long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null, (long, long)? dilation = null, bool ceil_mode = false) { - long svalue1 = stride != null ? stride.Value.Item1 : kernelSize.Item1; - long svalue2 = stride != null ? stride.Value.Item2 : kernelSize.Item2; - long pvalue1 = padding != null ? padding.Value.Item1 : 0; - long pvalue2 = padding != null ? padding.Value.Item2 : 0; - long dvalue1 = dilation != null ? dilation.Value.Item1 : 1; - long dvalue2 = dilation != null ? dilation.Value.Item2 : 1; - - long* pStride = stackalloc long[2] { svalue1, svalue2 }; - long* pPadding = stackalloc long[2] { pvalue1, pvalue2 }; - long* pDilation = stackalloc long[2] { dvalue1, dvalue2 }; + long[] kernelValue = new[] { kernel_size.Item1, kernel_size.Item2 }; + long[] strideValue = stride.HasValue ? new[] { stride.Value.Item1, stride.Value.Item2 } : kernelValue.ToArray(); + long[] paddingValue = padding.HasValue ? new[] { padding.Value.Item1, padding.Value.Item2 } : new[] { 0L, 0L }; + long[] dilationValue = dilation.HasValue ? new[] { dilation.Value.Item1, dilation.Value.Item2 } : new[] { 1L, 1L }; - long* pkernelSize = stackalloc long[2] { kernelSize.Item1, kernelSize.Item2 }; - - var handle = THSNN_MaxPool2d_ctor((IntPtr)pkernelSize, 2, (IntPtr)pStride, 2, (IntPtr)pPadding, 2, (IntPtr)pDilation, 2, ceilMode, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxPool2d(handle, boxedHandle); + return new MaxPool2d(kernelValue, strideValue, paddingValue, dilationValue, ceil_mode); } /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. - /// The stride of the sliding window, must be > 0. Default value is kernel_size. + /// The size of the sliding window, must be > 0. + /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// The stride between elements within a sliding window, must be > 0. - /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. + /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. /// - public static unsafe MaxPool2d MaxPool2d(long[] kernelSize, long[] strides = null, long[] padding = null, long[] dilation = null, bool ceilMode = false) + public static MaxPool2d MaxPool2d(long[] kernel_size, long[] stride = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) { - fixed (long* pkernelSize = kernelSize, pstrides = strides, pPadding = padding, pDilation = dilation) { - var handle = THSNN_MaxPool2d_ctor((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), (IntPtr)pPadding, (padding == null ? 0 : padding.Length), (IntPtr)pDilation, (dilation == null ? 0 : dilation.Length), ceilMode, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxPool2d(handle, boxedHandle); - } + return new MaxPool2d(kernel_size, stride, padding, dilation, ceil_mode); } public static partial class functional @@ -123,100 +104,64 @@ public static partial class functional /// Applies a 2D max pooling over an input signal composed of several input planes. /// /// The input tensor. - /// - /// + /// + /// /// /// /// /// - public static Tensor max_pool2d(Tensor input, long[] kernelSize, long[] strides = null, + public static Tensor max_pool2d(Tensor input, long[] kernel_size, long[] stride = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) { - strides = strides ?? kernelSize; - padding = padding ?? kernelSize.Select(x => 0L).ToArray(); - dilation = dilation ?? kernelSize.Select(x => 1L).ToArray(); - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, ppadding = padding, pdilation = dilation) { - var res = - THSTensor_max_pool2d(input.Handle, - (IntPtr)pkernelSize, kernelSize.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, padding.Length, - (IntPtr)pdilation, dilation.Length, - ceil_mode); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - } + var ret = max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode); + ret.Indices.Dispose(); + return ret.Values; } /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// /// The input tensor. - /// + /// /// /// /// /// /// - public static unsafe Tensor max_pool2d(Tensor input, long kernelSize, long? stride = null, + public static Tensor max_pool2d(Tensor input, long kernel_size, long? stride = null, long? padding = null, long? dilation = null, bool ceil_mode = false) { - long svalue = stride.HasValue ? stride.Value : kernelSize; - long pvalue = padding.HasValue ? padding.Value : 0; - long dvalue = dilation.HasValue ? dilation.Value : 1; - - long* pStride = stackalloc long[2] { svalue, svalue }; - long* pPadding = stackalloc long[2] { pvalue, pvalue }; - long* pDilation = stackalloc long[2] { dvalue, dvalue }; - - long* pkernelSize = stackalloc long[2] { kernelSize, kernelSize }; - - var res = THSTensor_max_pool2d(input.Handle, - (IntPtr)pkernelSize, 2, - (IntPtr)pStride, 2, - (IntPtr)pPadding, 2, - (IntPtr)pDilation, 2, - ceil_mode); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + long[] kernelValue = new[] { kernel_size, kernel_size }; + long[] strideValue = stride.HasValue ? new[] { stride.Value, stride.Value } : kernelValue.ToArray(); + long[] paddingValue = padding.HasValue ? new[] { padding.Value, padding.Value } : new[] { 0L, 0L }; + long[] dilationValue = dilation.HasValue ? new[] { dilation.Value, dilation.Value } : new[] { 1L, 1L }; + + var ret = max_pool2d_with_indices(input, kernelValue, strideValue, paddingValue, dilationValue, ceil_mode); + ret.Indices.Dispose(); + return ret.Values; } /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// /// The input tensor. - /// + /// /// /// /// /// /// - public static unsafe Tensor max_pool2d(Tensor input, (long, long) kernelSize, (long, long)? stride = null, + public static unsafe Tensor max_pool2d(Tensor input, (long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null, (long, long)? dilation = null, bool ceil_mode = false) { - long svalue1 = stride != null ? stride.Value.Item1 : kernelSize.Item1; - long svalue2 = stride != null ? stride.Value.Item2 : kernelSize.Item2; - long pvalue1 = padding != null ? padding.Value.Item1 : 0; - long pvalue2 = padding != null ? padding.Value.Item2 : 0; - long dvalue1 = dilation != null ? dilation.Value.Item1 : 1; - long dvalue2 = dilation != null ? dilation.Value.Item2 : 1; - - long* pStride = stackalloc long[2] { svalue1, svalue2 }; - long* pPadding = stackalloc long[2] { pvalue1, pvalue2 }; - long* pDilation = stackalloc long[2] { dvalue1, dvalue2 }; - - long* pkernelSize = stackalloc long[2] { kernelSize.Item1, kernelSize.Item2 }; - - var res = THSTensor_max_pool2d(input.Handle, - (IntPtr)pkernelSize, 2, - (IntPtr)pStride, 2, - (IntPtr)pPadding, 2, - (IntPtr)pDilation, 2, - ceil_mode); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + long[] kernelValue = new[] { kernel_size.Item1, kernel_size.Item2 }; + long[] strideValue = stride.HasValue ? new[] { stride.Value.Item1, stride.Value.Item2 } : kernelValue.ToArray(); + long[] paddingValue = padding.HasValue ? new[] { padding.Value.Item1, padding.Value.Item2 } : new[] { 0L, 0L }; + long[] dilationValue = dilation.HasValue ? new[] { dilation.Value.Item1, dilation.Value.Item2 } : new[] { 1L, 1L }; + + var ret = max_pool2d_with_indices(input, kernelValue, strideValue, paddingValue, dilationValue, ceil_mode); + ret.Indices.Dispose(); + return ret.Values; } /// @@ -229,30 +174,25 @@ public static unsafe Tensor max_pool2d(Tensor input, (long, long) kernelSize, (l /// /// /// - public static (Tensor output, Tensor indices) max_pool2d_with_indices(Tensor input, long[] kernelSize, long[] strides = null, + public static (Tensor Values, Tensor Indices) max_pool2d_with_indices(Tensor input, long[] kernelSize, long[] strides = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) { - strides = strides ?? kernelSize; - padding = padding ?? kernelSize.Select(x => 0L).ToArray(); - dilation = dilation ?? kernelSize.Select(x => 1L).ToArray(); - IntPtr[] ptrArray; - - using (var pa = new PinnedArray()) { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, ppadding = padding, pdilation = dilation) { - THSTensor_max_pool2d_with_indices(input.Handle, - pa.CreateArray, - (IntPtr)pkernelSize, kernelSize.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, padding.Length, - (IntPtr)pdilation, dilation.Length, - ceil_mode); - torch.CheckForErrors(); - } + strides ??= kernelSize; + padding ??= kernelSize.Select(x => 0L).ToArray(); + dilation ??= kernelSize.Select(x => 1L).ToArray(); + unsafe { + fixed (long* pkernelSize = kernelSize, pstrides = strides, ppadding = padding, pdilation = dilation) { + var resOutput = THSTensor_max_pool2d_with_indices(input.Handle, + (IntPtr)pkernelSize, kernelSize.Length, + (IntPtr)pstrides, strides.Length, + (IntPtr)ppadding, padding.Length, + (IntPtr)pdilation, dilation.Length, + ceil_mode, out var resIndices); + + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); } - ptrArray = pa.Array; } - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); } } } diff --git a/src/TorchSharp/NN/Pooling/MaxPool3D.cs b/src/TorchSharp/NN/Pooling/MaxPool3D.cs index 1ab30d15d..2b26b3dad 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool3D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool3D.cs @@ -6,6 +6,7 @@ namespace TorchSharp { + using Google.Protobuf.WellKnownTypes; using Modules; namespace Modules @@ -13,31 +14,32 @@ namespace Modules /// /// This class is used to represent a MaxPool3D module. /// - public sealed class MaxPool3d : torch.nn.Module + public sealed class MaxPool3d : ParamLessModule { - internal MaxPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal MaxPool3d(long[] kernel_size, long[] stride = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) : base(nameof(MaxPool3d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; + this.dilation = dilation; + this.ceil_mode = ceil_mode; } - public override Tensor forward(Tensor tensor) + public override Tensor forward(Tensor input) { - var res = THSNN_MaxPool3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode); } - public (Tensor Values, Tensor Indices) forward_with_indices(Tensor tensor) + public (Tensor Values, Tensor Indices) forward_with_indices(Tensor input) { - var res = THSNN_MaxPool3d_forward_with_indices(handle, tensor.Handle, out var indices); - if (res == IntPtr.Zero || indices == IntPtr.Zero) { torch.CheckForErrors(); } - return (new Tensor(res), new Tensor(indices)); + return torch.nn.functional.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] kernel_size { get; set; } + public long[] stride { get; set; } + public long[] padding { get; set; } + public long[] dilation { get; set; } + public bool ceil_mode { get; set; } } } @@ -48,55 +50,49 @@ public static partial class nn /// /// Applies a 3D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// The stride between elements within a sliding window, must be > 0. - /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. + /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. /// - public static MaxPool3d MaxPool3d(long kernelSize, long? stride = null, long? padding = null, long? dilation = null, bool ceilMode = false) + public static MaxPool3d MaxPool3d(long kernel_size, long? stride = null, long? padding = null, long? dilation = null, bool ceil_mode = false) { var pStride = stride.HasValue ? new long[] { stride.Value, stride.Value, stride.Value } : null; var pPadding = padding.HasValue ? new long[] { padding.Value, padding.Value, padding.Value } : null; var pDilation = dilation.HasValue ? new long[] { dilation.Value, dilation.Value, dilation.Value } : null; - return MaxPool3d(new long[] { kernelSize, kernelSize, kernelSize }, pStride, pPadding, pDilation, ceilMode); + return MaxPool3d(new long[] { kernel_size, kernel_size, kernel_size }, pStride, pPadding, pDilation, ceil_mode); } /// /// Applies a 3D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// The stride between elements within a sliding window, must be > 0. - /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. + /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. /// - public static MaxPool3d MaxPool3d((long, long, long) kernelSize, (long, long, long)? stride = null, (long, long, long)? padding = null, (long, long, long)? dilation = null, bool ceilMode = false) + public static MaxPool3d MaxPool3d((long, long, long) kernel_size, (long, long, long)? stride = null, (long, long, long)? padding = null, (long, long, long)? dilation = null, bool ceil_mode = false) { var pStride = stride.HasValue ? new long[] { stride.Value.Item1, stride.Value.Item2, stride.Value.Item3 } : null; var pPadding = padding.HasValue ? new long[] { padding.Value.Item1, padding.Value.Item2, padding.Value.Item3 } : null; var pDilation = dilation.HasValue ? new long[] { dilation.Value.Item1, dilation.Value.Item2, dilation.Value.Item3 } : null; - return MaxPool3d(new long[] { kernelSize.Item1, kernelSize.Item2, kernelSize.Item3 }, pStride, pPadding, pDilation, ceilMode); + return MaxPool3d(new long[] { kernel_size.Item1, kernel_size.Item2, kernel_size.Item3 }, pStride, pPadding, pDilation, ceil_mode); } /// /// Applies a 3D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. - /// The stride of the sliding window, must be > 0. Default value is kernel_size. + /// The size of the sliding window, must be > 0. + /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// The stride between elements within a sliding window, must be > 0. - /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. + /// If true, will use ceil instead of floor to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. /// - public static MaxPool3d MaxPool3d(long[] kernelSize, long[] strides = null, long[] padding = null, long[] dilation = null, bool ceilMode = false) + public static MaxPool3d MaxPool3d(long[] kernel_size, long[] stride = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, pPadding = padding, pDilation = dilation) { - var handle = THSNN_MaxPool3d_ctor((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), (IntPtr)pPadding, (padding == null ? 0 : padding.Length), (IntPtr)pDilation, (dilation == null ? 0 : dilation.Length), ceilMode, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxPool3d(handle, boxedHandle); - } - } + return new MaxPool3d(kernel_size, stride, padding, dilation, ceil_mode); } public static partial class functional @@ -105,31 +101,18 @@ public static partial class functional /// Applies a 3D max pooling over an input signal composed of several input planes. /// /// The input tensor. - /// - /// + /// + /// /// /// /// /// - public static Tensor max_pool3d(Tensor input, long[] kernelSize, long[] strides = null, + public static Tensor max_pool3d(Tensor input, long[] kernel_size, long[] stride = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) { - strides = strides ?? kernelSize; - padding = padding ?? kernelSize.Select(x => 0L).ToArray(); - dilation = dilation ?? kernelSize.Select(x => 1L).ToArray(); - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, ppadding = padding, pdilation = dilation) { - var res = - THSTensor_max_pool3d(input.Handle, - (IntPtr)pkernelSize, kernelSize.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, padding.Length, - (IntPtr)pdilation, dilation.Length, - ceil_mode); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - } + var ret = max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode); + ret.Indices.Dispose(); + return ret.Values; } /// @@ -142,30 +125,26 @@ public static Tensor max_pool3d(Tensor input, long[] kernelSize, long[] strides /// /// /// - public static (Tensor output, Tensor indices) max_pool3d_with_indices(Tensor input, long[] kernelSize, long[] strides = null, + public static (Tensor Values, Tensor Indices) max_pool3d_with_indices(Tensor input, long[] kernelSize, long[] strides = null, long[] padding = null, long[] dilation = null, bool ceil_mode = false) { - strides = strides ?? kernelSize; - padding = padding ?? kernelSize.Select(x => 0L).ToArray(); - dilation = dilation ?? kernelSize.Select(x => 1L).ToArray(); - IntPtr[] ptrArray; + strides ??= kernelSize; + padding ??= kernelSize.Select(x => 0L).ToArray(); + dilation ??= kernelSize.Select(x => 1L).ToArray(); + + unsafe { + fixed (long* pkernelSize = kernelSize, pstrides = strides, ppadding = padding, pdilation = dilation) { + var resOutput = THSTensor_max_pool3d_with_indices(input.Handle, + (IntPtr)pkernelSize, kernelSize.Length, + (IntPtr)pstrides, strides.Length, + (IntPtr)ppadding, padding.Length, + (IntPtr)pdilation, dilation.Length, + ceil_mode, out var resIndices); - using (var pa = new PinnedArray()) { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, ppadding = padding, pdilation = dilation) { - THSTensor_max_pool3d_with_indices(input.Handle, - pa.CreateArray, - (IntPtr)pkernelSize, kernelSize.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, padding.Length, - (IntPtr)pdilation, dilation.Length, - ceil_mode); - torch.CheckForErrors(); - } + if (resOutput == IntPtr.Zero || resIndices == IntPtr.Zero) { torch.CheckForErrors(); } + return (new Tensor(resOutput), new Tensor(resIndices)); } - ptrArray = pa.Array; } - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); } } } diff --git a/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs b/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs index 2d8d7e908..9c2afb04d 100644 --- a/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs +++ b/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs @@ -5,6 +5,7 @@ namespace TorchSharp { + using System.Runtime.CompilerServices; using Modules; namespace Modules @@ -12,21 +13,18 @@ namespace Modules /// /// This class is used to represent a MaxUnpool1D module. /// - public sealed class MaxUnpool1d : torch.nn.Module + public sealed class MaxUnpool1d : ParamLessModule { - internal MaxUnpool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal MaxUnpool1d(long kernel_size, long? stride = null, long? padding = null) : base(nameof(MaxUnpool1d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; } public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size = null) { - unsafe { - fixed (long* pOutSize = output_size) { - var res = THSNN_MaxUnpool1d_forward(handle, tensor.Handle, indices.Handle, (IntPtr)pOutSize); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - } + return torch.nn.functional.max_unpool1d(tensor, indices, kernel_size, stride, padding, output_size); } public new Tensor call(Tensor tensor, Tensor indices, long[] output_size = null) @@ -34,11 +32,9 @@ public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size return base.call(tensor, indices, output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long kernel_size { get; set; } + public long? stride { get; set; } + public long? padding { get; set; } } } @@ -47,26 +43,42 @@ public static partial class torch public static partial class nn { /// - /// Applies a 1D max pooling over an input signal composed of several input planes. + /// Computes a partial inverse of :class:`MaxPool1d`. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// - public static MaxUnpool1d MaxUnpool1d(long kernelSize, long? stride = null, long? padding = null) + public static MaxUnpool1d MaxUnpool1d(long kernel_size, long? stride = null, long? padding = null) { - var pStride = stride.HasValue ? new long[] { stride.Value } : null; - var pPadding = padding.HasValue ? new long[] { padding.Value } : null; - return MaxUnpool1d(new long[] { kernelSize }, pStride, pPadding); + return new MaxUnpool1d(kernel_size, stride, padding); } - private static MaxUnpool1d MaxUnpool1d(long[] kernelSize, long[] strides = null, long[] padding = null) + public static partial class functional { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, pPadding = padding) { - var handle = THSNN_MaxUnpool1d_ctor((IntPtr)pkernelSize, (IntPtr)pstrides, (IntPtr)pPadding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxUnpool1d(handle, boxedHandle); + /// + /// Applies a 1D max pooling over an input signal composed of several input planes. + /// + /// the input Tensor to invert + /// the indices given out by :class:`~torch.nn.MaxPool1d` + /// The size of the sliding window, must be > 0. + /// The stride of the sliding window, must be > 0. Default value is kernel_size. + /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 + /// (optional): The targeted output size + /// + public static Tensor max_unpool1d(Tensor input, Tensor indices, long kernel_size, long? stride = null, long? padding = null, long[] output_size = null) + { + long[] kernels = new[] { kernel_size }; + long[] strides = stride.HasValue ? new[] { stride.Value } : Array.Empty(); + long[] paddings = padding.HasValue ? new[] { padding.Value } : Array.Empty(); + output_size ??= Array.Empty(); + + unsafe { + fixed (long* pkernels = kernels, pstrides = strides, ppaddings = paddings, poutputSize = output_size) { + var res = THSTensor_max_unpool1d(input.Handle, indices.Handle, (IntPtr)pkernels, kernels.Length, (IntPtr)poutputSize, output_size.Length, (IntPtr)ppaddings, paddings.Length, (IntPtr)pstrides, strides.Length); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } } } } diff --git a/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs b/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs index 84e8c6cb3..f342049fb 100644 --- a/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs +++ b/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs @@ -5,6 +5,7 @@ namespace TorchSharp { + using Microsoft.VisualBasic; using Modules; namespace Modules @@ -12,21 +13,18 @@ namespace Modules /// /// This class is used to represent a MaxUnpool2D module. /// - public sealed class MaxUnpool2d : torch.nn.Module + public sealed class MaxUnpool2d : ParamLessModule { - internal MaxUnpool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal MaxUnpool2d(long[] kernel_size, long[] stride = null, long[] padding = null) : base(nameof(MaxUnpool2d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; } public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size = null) { - unsafe { - fixed (long* pOutSize = output_size) { - var res = THSNN_MaxUnpool2d_forward(handle, tensor.Handle, indices.Handle, (IntPtr)pOutSize, output_size == null ? 0 : output_size.Length); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - } + return torch.nn.functional.max_unpool2d(tensor, indices, kernel_size, stride, padding, output_size); } public new Tensor call(Tensor tensor, Tensor indices, long[] output_size = null) @@ -34,11 +32,9 @@ public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size return base.call(tensor, indices, output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] kernel_size { get; set; } + public long[] stride { get; set; } + public long[] padding { get; set; } } } @@ -49,47 +45,41 @@ public static partial class nn /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// - public static MaxUnpool2d MaxUnpool2d(long kernelSize, long? stride = null, long? padding = null) + public static MaxUnpool2d MaxUnpool2d(long kernel_size, long? stride = null, long? padding = null) { var pStride = stride.HasValue ? new long[] { stride.Value, stride.Value } : null; var pPadding = padding.HasValue ? new long[] { padding.Value, padding.Value } : null; - return MaxUnpool2d(new long[] { kernelSize, kernelSize }, pStride, pPadding); + return new MaxUnpool2d(new[] { kernel_size, kernel_size }, pStride, pPadding); } /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// - public static MaxUnpool2d MaxUnpool2d((long, long) kernelSize, (long, long)? stride = null, (long, long)? padding = null) + public static MaxUnpool2d MaxUnpool2d((long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null) { var pStride = stride.HasValue ? new long[] { stride.Value.Item1, stride.Value.Item2 } : null; var pPadding = padding.HasValue ? new long[] { padding.Value.Item1, padding.Value.Item2 } : null; - return MaxUnpool2d(new long[] { kernelSize.Item1, kernelSize.Item2 }, pStride, pPadding); + return new MaxUnpool2d(new[] { kernel_size.Item1, kernel_size.Item2 }, pStride, pPadding); } /// /// Applies a 2D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. - /// The stride of the sliding window, must be > 0. Default value is kernel_size. + /// The size of the sliding window, must be > 0. + /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// - public static MaxUnpool2d MaxUnpool2d(long[] kernelSize, long[] strides = null, long[] padding = null) + public static MaxUnpool2d MaxUnpool2d(long[] kernel_size, long[] stride = null, long[] padding = null) { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, pPadding = padding) { - var handle = THSNN_MaxUnpool2d_ctor((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), (IntPtr)pPadding, (padding == null ? 0 : padding.Length), out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxUnpool2d(handle, boxedHandle); - } - } + return new MaxUnpool2d(kernel_size, stride, padding); } public static partial class functional @@ -97,16 +87,22 @@ public static partial class functional /// /// Computes a partial inverse of MaxPool2d. /// - /// The input tensor. - /// - /// + /// the input Tensor to invert + /// the indices given out by :class:`~torch.nn.MaxPool2d` + /// The size of the sliding window, must be > 0. + /// The stride of the sliding window, must be > 0. Default value is kernel_size. + /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 + /// (optional): The targeted output size /// - public static Tensor max_unpool2d(Tensor input, Tensor indices, long[] outputSize) + public static Tensor max_unpool2d(Tensor input, Tensor indices, long[] kernel_size, long[] stride = null, long[] padding = null, long[] output_size = null) { + stride ??= Array.Empty(); + padding ??= Array.Empty(); + output_size ??= Array.Empty(); + unsafe { - fixed (long* poutputSize = outputSize) { - var res = THSTensor_maxunpool2d(input.Handle, indices.Handle, - (IntPtr)poutputSize, outputSize.Length); + fixed (long* pkernels = kernel_size, pstrides = stride, ppaddings = padding, poutputSize = output_size) { + var res = THSTensor_max_unpool2d(input.Handle, indices.Handle, (IntPtr)pkernels, kernel_size.Length, (IntPtr)poutputSize, output_size.Length, (IntPtr)ppaddings, padding.Length, (IntPtr)pstrides, stride.Length); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } diff --git a/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs b/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs index a5473d8d6..33abc0429 100644 --- a/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs +++ b/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs @@ -5,28 +5,26 @@ namespace TorchSharp { + using Microsoft.VisualBasic; using Modules; namespace Modules { /// - /// This class is used to represent a MaxUnpool3d module. + /// This class is used to represent a MaxUnpool3D module. /// - public sealed class MaxUnpool3d : torch.nn.Module + public sealed class MaxUnpool3d : ParamLessModule { - internal MaxUnpool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal MaxUnpool3d(long[] kernel_size, long[] stride = null, long[] padding = null) : base(nameof(MaxUnpool3d)) { + this.kernel_size = kernel_size; + this.stride = stride; + this.padding = padding; } public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size = null) { - unsafe { - fixed (long* pOutSize = output_size) { - var res = THSNN_MaxUnpool3d_forward(handle, tensor.Handle, indices.Handle, (IntPtr)pOutSize, output_size == null ? 0 : output_size.Length); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - } + return torch.nn.functional.max_unpool3d(tensor, indices, kernel_size, stride, padding, output_size); } public new Tensor call(Tensor tensor, Tensor indices, long[] output_size = null) @@ -34,11 +32,9 @@ public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size return base.call(tensor, indices, output_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long[] kernel_size { get; set; } + public long[] stride { get; set; } + public long[] padding { get; set; } } } @@ -47,69 +43,66 @@ public static partial class torch public static partial class nn { /// - /// Applies a 2D max pooling over an input signal composed of several input planes. + /// Applies a 3D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// - public static MaxUnpool3d MaxUnpool3d(long kernelSize, long? stride = null, long? padding = null) + public static MaxUnpool3d MaxUnpool3d(long kernel_size, long? stride = null, long? padding = null) { var pStride = stride.HasValue ? new long[] { stride.Value, stride.Value, stride.Value } : null; var pPadding = padding.HasValue ? new long[] { padding.Value, padding.Value, padding.Value } : null; - return MaxUnpool3d(new long[] { kernelSize, kernelSize, kernelSize }, pStride, pPadding); + return new MaxUnpool3d(new[] { kernel_size, kernel_size, kernel_size }, pStride, pPadding); } /// - /// Applies a 2D max pooling over an input signal composed of several input planes. + /// Applies a 3D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. + /// The size of the sliding window, must be > 0. /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// - public static MaxUnpool3d MaxUnpool3d((long, long, long) kernelSize, (long, long, long)? stride = null, (long, long, long)? padding = null) + public static MaxUnpool3d MaxUnpool3d((long, long, long) kernel_size, (long, long, long)? stride = null, (long, long, long)? padding = null) { var pStride = stride.HasValue ? new long[] { stride.Value.Item1, stride.Value.Item2, stride.Value.Item3 } : null; var pPadding = padding.HasValue ? new long[] { padding.Value.Item1, padding.Value.Item2, padding.Value.Item3 } : null; - return MaxUnpool3d(new long[] { kernelSize.Item1, kernelSize.Item2, kernelSize.Item3 }, pStride, pPadding); + return new MaxUnpool3d(new[] { kernel_size.Item1, kernel_size.Item2, kernel_size.Item3 }, pStride, pPadding); } /// - /// Applies a 2D max pooling over an input signal composed of several input planes. + /// Applies a 3D max pooling over an input signal composed of several input planes. /// - /// The size of the sliding window, must be > 0. - /// The stride of the sliding window, must be > 0. Default value is kernel_size. + /// The size of the sliding window, must be > 0. + /// The stride of the sliding window, must be > 0. Default value is kernel_size. /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 /// - public static MaxUnpool3d MaxUnpool3d(long[] kernelSize, long[] strides = null, long[] padding = null) + public static MaxUnpool3d MaxUnpool3d(long[] kernel_size, long[] stride = null, long[] padding = null) { - unsafe { - fixed (long* pkernelSize = kernelSize, pstrides = strides, pPadding = padding) { - var handle = THSNN_MaxUnpool3d_ctor((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), (IntPtr)pPadding, (padding == null ? 0 : padding.Length), out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new MaxUnpool3d(handle, boxedHandle); - } - } + return new MaxUnpool3d(kernel_size, stride, padding); } + public static partial class functional { /// /// Computes a partial inverse of MaxPool3d. /// - /// The input tensor. - /// - /// - /// - /// + /// the input Tensor to invert + /// the indices given out by :class:`~torch.nn.MaxPool3d` + /// The size of the sliding window, must be > 0. + /// The stride of the sliding window, must be > 0. Default value is kernel_size. + /// Implicit negative infinity padding to be added on both sides, must be >= 0 and less than or equal to kernel_size / 2 + /// (optional): The targeted output size /// - public static Tensor max_unpool3d(Tensor input, Tensor indices, long[] outputSize, long[] strides, long[] padding) + public static Tensor max_unpool3d(Tensor input, Tensor indices, long[] kernel_size, long[] stride = null, long[] padding = null, long[] output_size = null) { + stride ??= Array.Empty(); + padding ??= Array.Empty(); + output_size ??= Array.Empty(); + unsafe { - fixed (long* poutputSize = outputSize, pstrides = strides, ppadding = padding) { - var res = THSTensor_maxunpool3d(input.Handle, indices.Handle, - (IntPtr)poutputSize, outputSize.Length, - (IntPtr)pstrides, strides.Length, - (IntPtr)ppadding, padding.Length); + fixed (long* pkernels = kernel_size, pstrides = stride, ppaddings = padding, poutputSize = output_size) { + var res = THSTensor_max_unpool3d(input.Handle, indices.Handle, (IntPtr)pkernels, kernel_size.Length, (IntPtr)poutputSize, output_size.Length, (IntPtr)ppaddings, padding.Length, (IntPtr)pstrides, stride.Length); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } diff --git a/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs b/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs index f820cae57..6d39ecb1c 100644 --- a/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs +++ b/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs @@ -11,14 +11,13 @@ namespace Modules /// /// This class is used to represent a ChannelShuffle module. /// - public sealed class ChannelShuffle : torch.nn.Module + public sealed class ChannelShuffle : ParamLessModule { internal ChannelShuffle(long groups) : base(nameof(ChannelShuffle)) { this.groups = groups; } - private long groups; - + public override Tensor forward(Tensor tensor) { return tensor.channel_shuffle(groups); @@ -29,11 +28,7 @@ public override string GetName() return typeof(ChannelShuffle).Name; } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + public long groups { get; set; } } } diff --git a/src/TorchSharp/NN/Unflatten.cs b/src/TorchSharp/NN/Unflatten.cs index 71c7b6a23..eeb91e1a9 100644 --- a/src/TorchSharp/NN/Unflatten.cs +++ b/src/TorchSharp/NN/Unflatten.cs @@ -12,24 +12,22 @@ namespace Modules /// /// This class is used to represent an unflattening operation. /// - public sealed class Unflatten : torch.nn.Module + public sealed class Unflatten : ParamLessModule { - internal Unflatten(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal Unflatten(long dim, long[] unflattened_size) : base(nameof(Unflatten)) { + this.dim = dim; + this.unflattened_size = unflattened_size; } public override Tensor forward(Tensor tensor) { - var res = THSNN_Unflatten_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return tensor.unflatten(dim, unflattened_size); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + + public long dim { get; set; } + public long[] unflattened_size { get; set; } } } @@ -41,17 +39,11 @@ public static partial class nn /// Unflattens a tensor dim expanding it to a desired shape. For use with Sequential. /// /// Dimension to be unflattened - /// New shape of the unflattened dimension + /// New shape of the unflattened dimension /// - public static Unflatten Unflatten(long dim, long[] unflattenedSize) + public static Unflatten Unflatten(long dim, long[] unflattened_size) { - unsafe { - fixed (long* pUnflattenedSize = unflattenedSize) { - var handle = THSNN_Unflatten_ctor(dim, (IntPtr)pUnflattenedSize, unflattenedSize.Length, out var boxedHandle); - if (handle == IntPtr.Zero) { CheckForErrors(); } - return new Unflatten(handle, boxedHandle); - } - } + return new Unflatten(dim, unflattened_size); } } } diff --git a/src/TorchSharp/NN/Unfold.cs b/src/TorchSharp/NN/Unfold.cs index 7575c6169..050479623 100644 --- a/src/TorchSharp/NN/Unfold.cs +++ b/src/TorchSharp/NN/Unfold.cs @@ -6,16 +6,15 @@ #nullable enable namespace TorchSharp { - using System.Security.Cryptography; using Modules; namespace Modules { - public sealed class Unfold : torch.nn.Module + public sealed class Unfold : ParamLessModule { internal Unfold((long, long) kernel_size, (long, long) dilation, (long, long) padding, (long, long) stride) : base(nameof(Unfold)) { - this.kernelSize = kernel_size; + this.kernel_size = kernel_size; this.dilation = dilation; this.padding = padding; this.stride = stride; @@ -23,19 +22,13 @@ internal Unfold((long, long) kernel_size, (long, long) dilation, (long, long) pa public override Tensor forward(Tensor tensor) { - return torch.nn.functional.unfold(tensor, kernelSize, dilation, padding, stride); + return torch.nn.functional.unfold(tensor, kernel_size, dilation, padding, stride); } - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - - private (long, long) kernelSize; - private (long, long) dilation; - private (long, long) padding; - private (long, long) stride; + public (long, long) kernel_size { get; set; } + public (long, long) dilation { get; set; } + public (long, long) padding { get; set; } + public (long, long) stride { get; set; } } } @@ -51,7 +44,7 @@ public static partial class nn /// Implicit zero padding to be added on both sides of input. /// The stride of the sliding blocks in the input spatial dimensions. /// Currently, only 4-D input tensors (batched image-like tensors) are supported. - public unsafe static Unfold Unfold(long kernel_size, long dilation = 1, long padding = 0, long stride = 1) + public static Unfold Unfold(long kernel_size, long dilation = 1, long padding = 0, long stride = 1) { return new Unfold((kernel_size, kernel_size), (dilation, dilation), (padding, padding), (stride, stride)); } @@ -64,7 +57,7 @@ public unsafe static Unfold Unfold(long kernel_size, long dilation = 1, long pad /// Implicit zero padding to be added on both sides of input. /// The stride of the sliding blocks in the input spatial dimensions. /// Currently, only 4-D input tensors (batched image-like tensors) are supported. - public unsafe static Unfold Unfold((long, long) kernel_size, (long, long)? dilation = null, (long, long)? padding = null, (long, long)? stride = null) + public static Unfold Unfold((long, long) kernel_size, (long, long)? dilation = null, (long, long)? padding = null, (long, long)? stride = null) { dilation ??= (1, 1); stride ??= (1, 1); @@ -83,7 +76,7 @@ public static partial class functional /// A parameter that controls the stride of elements within the neighborhood. /// Implicit zero padding to be added on both sides of input. /// The stride of the sliding blocks in the input spatial dimensions. - public unsafe static Tensor unfold(Tensor input, long kernel_size, long dilation = 1, long padding = 0, long stride = 1) + public static Tensor unfold(Tensor input, long kernel_size, long dilation = 1, long padding = 0, long stride = 1) { var res = THSNN_unfold(input.Handle, kernel_size, kernel_size, stride, stride, padding, padding, dilation, dilation); if (res == IntPtr.Zero) { torch.CheckForErrors(); } @@ -98,7 +91,7 @@ public unsafe static Tensor unfold(Tensor input, long kernel_size, long dilation /// A parameter that controls the stride of elements within the neighborhood. /// Implicit zero padding to be added on both sides of input. /// The stride of the sliding blocks in the input spatial dimensions. - public unsafe static Tensor unfold(Tensor input, (long, long) kernel_size, (long, long)? dilation = null, (long, long)? padding = null, (long, long)? stride = null) + public static Tensor unfold(Tensor input, (long, long) kernel_size, (long, long)? dilation = null, (long, long)? padding = null, (long, long)? stride = null) { dilation ??= (1, 1); stride ??= (1, 1); diff --git a/src/TorchSharp/NN/Upsample.cs b/src/TorchSharp/NN/Upsample.cs index f313677f8..2f3b43707 100644 --- a/src/TorchSharp/NN/Upsample.cs +++ b/src/TorchSharp/NN/Upsample.cs @@ -8,6 +8,40 @@ namespace TorchSharp { using Modules; + namespace Modules + { + /// + /// This class is used to represent an Upsample module. + /// + public sealed class Upsample : ParamLessModule + { + internal Upsample(long[]? size, double[]? scale_factor, UpsampleMode mode, bool? align_corners, bool? recompute_scale_factor) : base(nameof(Upsample)) + { + this.size = size; + this.scale_factor = scale_factor; + this.mode = mode; + this.align_corners = align_corners; + this.recompute_scale_factor = recompute_scale_factor; + } + + /// + /// Forward pass. + /// + /// Input tensor + /// + public override Tensor forward(Tensor input) + { + return torch.nn.functional.interpolate(input, size, scale_factor, (InterpolationMode)mode, align_corners, recompute_scale_factor ?? false); + } + + public long[]? size { get; set; } + public double[]? scale_factor { get; set; } + public UpsampleMode mode { get; set; } + public bool? align_corners { get; set; } + public bool? recompute_scale_factor { get; set; } + } + } + public static partial class torch { public static partial class nn @@ -22,19 +56,11 @@ public static partial class nn /// The upsampling algorithm: one of 'nearest', 'linear', 'bilinear', 'bicubic' and 'trilinear'. Default: 'nearest' /// If true, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. /// This only has effect when mode is 'linear', 'bilinear', or 'trilinear'. Default: false + /// recompute the scale_factor for use in the interpolation calculation. If `recompute_scale_factor` is ``True``, then `scale_factor` must be passed in and `scale_factor` is used to compute the output `size`. The computed output `size` will be used to infer new scales for the interpolation. Note that when `scale_factor` is floating-point, it may differ from the recomputed `scale_factor` due to rounding and precision issues. If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will be used directly for interpolation. /// - public static Upsample Upsample(long[]? size = null, double[]? scale_factor = null, UpsampleMode mode = UpsampleMode.Nearest, bool? align_corners = null) + public static Upsample Upsample(long[]? size = null, double[]? scale_factor = null, UpsampleMode mode = UpsampleMode.Nearest, bool? align_corners = null, bool? recompute_scale_factor = null) { - unsafe { - fixed (long* psize = size) { - fixed (double* pSF = scale_factor) { - byte ac = (byte)((align_corners.HasValue) ? (align_corners.Value ? 1 : 2) : 0); - var res = THSNN_Upsample_ctor((IntPtr)psize, size is null ? 0 : size.Length, (IntPtr)pSF, scale_factor is null ? 0 : scale_factor.Length, (byte)mode, ac, out var boxedHandle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Upsample(res, boxedHandle, size, scale_factor, mode, align_corners); - } - } - } + return new Upsample(size, scale_factor, mode, align_corners, recompute_scale_factor); } public static partial class functional @@ -44,21 +70,18 @@ public static partial class functional /// The input data is assumed to be of the form minibatch x channels x[optional depth] x[optional height] x width. /// Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. /// - /// Input tensor + /// Input tensor /// Output spatial sizes /// Multiplier for spatial size. Has to match input size /// The upsampling algorithm: one of 'nearest', 'linear', 'bilinear', 'bicubic' and 'trilinear'. Default: 'nearest' /// If true, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. /// This only has effect when mode is 'linear', 'bilinear', or 'trilinear'. Default: false /// - public static Tensor upsample(Tensor x, long[]? size = null, double[]? scale_factor = null, UpsampleMode mode = UpsampleMode.Nearest, bool align_corners = false) + public static Tensor upsample(Tensor input, long[]? size = null, double[]? scale_factor = null, UpsampleMode mode = UpsampleMode.Nearest, bool align_corners = false) { - using (var d = nn.Upsample(size, scale_factor, mode, align_corners)) { - return d.call(x); - } + return interpolate(input, size, scale_factor, (InterpolationMode)mode, align_corners); } - /// /// Upsamples the input, using nearest neighbours’ pixel values. /// @@ -198,54 +221,4 @@ public static Tensor upsample_nearest3d(Tensor input, long[]? outputSizes = null } } } - - namespace Modules - { - /// - /// This class is used to represent an Upsample module. - /// - public sealed class Upsample : torch.nn.Module - { - internal Upsample(IntPtr handle, IntPtr boxedHandle, long[]? size, double[]? scale_factor, UpsampleMode mode, bool? align_corners) : base(handle, boxedHandle) - { - this._size = size; - this._scale_factor = scale_factor; - this.mode = mode; - this.align_corners = align_corners; - } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_Upsample_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - public UpsampleMode mode { get; private set; } - - public bool? align_corners { get; private set; } - - public ReadOnlySpan size { - get { return _size is null ? null : new ReadOnlySpan(_size!); } - } - - public ReadOnlySpan scale_factor { - get { return _scale_factor is null ? null : new ReadOnlySpan(_scale_factor!); } - } - - private long[]? _size; - private double[]? _scale_factor; - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; - } - } } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs index 45fca3604..57ac023fc 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs @@ -50,41 +50,14 @@ internal static extern IntPtr THSNN_custom_module( // align_corners -- 0=None, 1=true, 2=false internal static extern IntPtr THSNN_grid_sample(IntPtr input, IntPtr grid, byte mode, byte padding_mode, byte align_corners); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AlphaDropout_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AlphaDropout_ctor(double p, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_alpha_dropout(IntPtr input, double p, [MarshalAs(UnmanagedType.U1)] bool training, [MarshalAs(UnmanagedType.U1)] bool inplace); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Bilinear_forward(torch.nn.Module.HType module, IntPtr input1, IntPtr input2); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Bilinear_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_Bilinear_set_bias(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Bilinear_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_Bilinear_set_weight(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Bilinear_ctor(long in1_features, long in2_features, long output_size, [MarshalAs(UnmanagedType.U1)] bool bias, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_functional_bilinear(IntPtr input1, IntPtr input2, IntPtr weights, IntPtr bias); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_CosineSimilarity_ctor(long dim, double eps, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_CosineSimilarity_forward(torch.nn.Module.HType module, IntPtr input1, IntPtr input2); + internal static extern IntPtr THSNN_cosine_similarity(IntPtr input1, IntPtr input2, long dim, double eps); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_dropout(IntPtr input, double p, [MarshalAs(UnmanagedType.U1)] bool training, [MarshalAs(UnmanagedType.U1)] bool inplace); @@ -151,12 +124,6 @@ internal static extern IntPtr THSNN_custom_module( [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_EmbeddingBag_from_pretrained(IntPtr embeddings, [MarshalAs(UnmanagedType.U1)] bool freeze, double max_norm, [MarshalAs(UnmanagedType.U1)] bool hasMN, double norm_type, [MarshalAs(UnmanagedType.U1)] bool scale_grad_by_freq, long mode, [MarshalAs(UnmanagedType.U1)] bool sparse, [MarshalAs(UnmanagedType.U1)] bool include_last_offset, long padding_idx, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FeatureAlphaDropout_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FeatureAlphaDropout_ctor(double p, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_feature_alpha_dropout(IntPtr input, double p, [MarshalAs(UnmanagedType.U1)] bool training, [MarshalAs(UnmanagedType.U1)] bool inplace); @@ -441,16 +408,10 @@ internal static extern IntPtr THSNN_custom_module( internal static extern IntPtr THSNN_PairwiseDistance_ctor(double p, double eps, [MarshalAs(UnmanagedType.U1)] bool keep_dim, out IntPtr pBoxedModule); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_PixelUnshuffle_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_PixelUnshuffle_ctor(long downscaleFactor, out IntPtr pBoxedModule); + internal static extern IntPtr THSNN_pixel_unshuffle(IntPtr tensor, long downscale_factor); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_PixelShuffle_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_PixelShuffle_ctor(long upscaleFactor, out IntPtr pBoxedModule); + internal static extern IntPtr THSNN_pixel_shuffle(IntPtr tensor, long upscale_factor); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_GRUCell_forward(torch.nn.Module.HType module, IntPtr input, IntPtr h_0); @@ -548,150 +509,6 @@ internal static extern IntPtr THSNN_custom_module( [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_affine_grid(IntPtr theta, IntPtr size, int size_len, [MarshalAs(UnmanagedType.U1)] bool align_corners); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_CELU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_CELU_ctor(double alpha, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LeakyReLU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LeakyReLU_ctor(double negative_slope, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LogSoftmax_forward(torch.nn.Module.HType handle, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LogSoftmax_ctor(long dim, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm3d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm3d_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm3d_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm3d_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm3d_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm3d_reset_stats(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm3d_get_mean(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm3d_get_var(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm3d_get_batches(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm3d_set_mean(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm3d_set_var(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm3d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LayerNorm_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LayerNorm_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_LayerNorm_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LayerNorm_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_LayerNorm_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LayerNorm_ctor(IntPtr norm_shape, long norm_shape_len, double eps, [MarshalAs(UnmanagedType.U1)] bool elementwise_affine, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm2d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm2d_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm2d_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm2d_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm2d_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm2d_reset_stats(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm2d_get_mean(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm2d_get_var(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm2d_get_batches(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm2d_set_mean(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm2d_set_var(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm2d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm1d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm1d_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm1d_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm1d_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm1d_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm1d_reset_stats(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm1d_get_mean(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm1d_get_var(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm1d_get_batches(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm1d_set_mean(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_InstanceNorm1d_set_var(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_InstanceNorm1d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_Conv1d_forward(torch.nn.Module.HType module, IntPtr tensor); @@ -791,60 +608,6 @@ internal static extern IntPtr THSNN_custom_module( [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_ConvTranspose3d_ctor_1(long inputChannel, long outputChannel, long kernelSizeX, long kernelSizeY, long kernelSizeZ, long strideX, long strideY, long strideZ, long paddingX, long paddingY, long paddingZ, long outputPaddingX, long outputPaddingY, long outputPaddingZ, long dilationX, long dilationY, long dilationZ, long paddingMode, long groups, [MarshalAs(UnmanagedType.U1)] bool bias, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm1d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm1d_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm1d_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm1d_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm1d_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm1d_reset_stats(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm1d_get_mean(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm1d_get_var(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm1d_get_batches(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm1d_set_mean(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm1d_set_var(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm1d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GroupNorm_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GroupNorm_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_GroupNorm_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GroupNorm_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_GroupNorm_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GroupNorm_ctor(long num_groups, long num_channels, double eps, [MarshalAs(UnmanagedType.U1)] bool affine, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_Unflatten_forward(torch.nn.Module.HType module, IntPtr tensor); @@ -887,443 +650,20 @@ internal static extern IntPtr THSNN_custom_module( [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_ConvTranspose2d_ctor_1(long inputChannel, long outputChannel, long kernelSizeX, long kernelSizeY, long strideX, long strideY, long paddingX, long paddingY, long outputPaddingX, long outputPaddingY, long dilationX, long dilationY, long paddingMode, long groups, [MarshalAs(UnmanagedType.U1)] bool bias, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm2d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm2d_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm2d_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm2d_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm2d_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm2d_reset_stats(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm2d_get_mean(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm2d_get_var(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm2d_get_batches(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm2d_set_mean(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm2d_set_var(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm2d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm3d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm3d_bias(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm3d_set_bias(torch.nn.Module.HType module, IntPtr bias); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm3d_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm3d_set_weight(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm3d_reset_stats(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm3d_get_mean(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm3d_get_var(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm3d_get_batches(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm3d_set_mean(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_BatchNorm3d_set_var(torch.nn.Module.HType module, IntPtr weight); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_BatchNorm3d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool1d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool1d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool1d_ctor(IntPtr pkernelSize, IntPtr pStrides, IntPtr pPadding, IntPtr pDilation, [MarshalAs(UnmanagedType.U1)] bool ceilMode, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxUnpool3d_forward(torch.nn.Module.HType module, IntPtr tensor, IntPtr indices, IntPtr outSize, int outputSizeLength); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxUnpool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ELU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ELU_ctor(double alpha, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GELU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GELU_ctor(out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GLU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_GLU_ctor(long dim, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Hardshrink_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Hardshrink_ctor(double lambd, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Hardtanh_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Hardtanh_ctor(double min_val, double max_val, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Mish_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Mish_ctor(out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_PReLU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_PReLU_ctor(long nparams, double init, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_PReLU_weight(torch.nn.Module.HType module); - - [DllImport("LibTorchSharp")] - internal static extern void THSNN_PReLU_set_weight(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReLU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReLU_ctor([MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReLU6_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReLU6_ctor([MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_RReLU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_RReLU_ctor(double lower, double upper, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool casual); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_SELU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_SELU_ctor([MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Sigmoid_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Sigmoid_ctor(out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_SiLU_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_SiLU_ctor(out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softmax_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softmax_ctor(long dim, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softmax2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softmax2d_ctor(out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softmin_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softmin_ctor(long dim, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softplus_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softplus_ctor(double beta, double threshold, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_Softshrink_forward(torch.nn.Module.HType module, IntPtr tensor); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_Softshrink_ctor(double lambd, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softsign_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Softsign_ctor(out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Tanh_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Tanh_ctor(out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Tanhshrink_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Tanhshrink_ctor(out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_Threshold_forward(torch.nn.Module.HType module, IntPtr tensor); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_Threshold_ctor(double threshold, double value, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LocalResponseNorm_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LocalResponseNorm_ctor(long size, double alpha, double beta, double k, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad1d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad1d_ctor(double value, long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad1d_ctor_tuple(double value, long padding_left, long padding_right, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad2d_ctor(double value, long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad2d_ctor_tuple(double value, long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad3d_ctor(double value, long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad3d_ctor_tuple(double value, long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad1d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad1d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad1d_ctor_tuple(long padding_left, long padding_right, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad2d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad3d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad3d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad1d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad1d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad1d_ctor_tuple(long padding_left, long padding_right, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad2d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad3d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad3d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ZeroPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ZeroPad2d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ZeroPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveAvgPool1d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveAvgPool1d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveAvgPool2d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveAvgPool2d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveAvgPool3d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveAvgPool3d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveMaxPool1d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveMaxPool1d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveMaxPool2d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveMaxPool2d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveMaxPool3d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AdaptiveMaxPool3d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool1d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool1d_ctor(IntPtr pkernelSize, IntPtr pstrides, IntPtr ppadding, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool2d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool3d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FractionalMaxPool2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FractionalMaxPool2d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FractionalMaxPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pOutputSize, int sizeLength, IntPtr pOutputRatio, int ratioLength, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FractionalMaxPool3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FractionalMaxPool3d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_FractionalMaxPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pOutputSize, int sizeLength, IntPtr pOutputRatio, int ratioLength, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LPPool1d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LPPool1d_ctor(double norm_type, IntPtr pkernelSize, IntPtr pstrides, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LPPool2d_forward(IntPtr module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_LPPool2d_ctor(double norm_type, IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool2d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, IntPtr pDilation, int dilationLength, [MarshalAs(UnmanagedType.U1)] bool ceilMode, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool3d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, IntPtr pDilation, int dilationLength, [MarshalAs(UnmanagedType.U1)] bool ceilMode, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxUnpool1d_forward(torch.nn.Module.HType module, IntPtr tensor, IntPtr indices, IntPtr outSize); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxUnpool1d_ctor(IntPtr pkernelSize, IntPtr pStrides, IntPtr pPadding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxUnpool2d_forward(torch.nn.Module.HType module, IntPtr tensor, IntPtr indices, IntPtr outSize, int outputSizeLength); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_MaxUnpool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, out IntPtr pBoxedModule); + internal static extern IntPtr THSNN_Threshold_ctor(double threshold, double value, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule); } #pragma warning restore CA2101 } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs index c82b659a3..3055e8c83 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs @@ -64,56 +64,49 @@ internal static extern IntPtr THSTensor_conv_transpose3d(IntPtr input, IntPtr we long groups); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_max_pool1d(IntPtr input, + internal static extern IntPtr THSTensor_max_pool1d_with_indices(IntPtr input, IntPtr kernelSize, int kernelSizeLength, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - [MarshalAs(UnmanagedType.U1)] bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, out IntPtr indices); [DllImport("LibTorchSharp")] - internal static extern void THSTensor_max_pool1d_with_indices(IntPtr input, AllocatePinnedArray allocator, - IntPtr kernelSize, int kernelSizeLength, - IntPtr strides, int stridesLength, - IntPtr padding, int paddingLength, - IntPtr dilation, int dilationLength, - [MarshalAs(UnmanagedType.U1)] bool ceil_mode); + internal static extern IntPtr THSTensor_max_pool2d_with_indices(IntPtr input, + IntPtr kernelSize, int kernelSizeLength, + IntPtr strides, int stridesLength, + IntPtr padding, int paddingLength, + IntPtr dilation, int dilationLength, + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, out IntPtr indices); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_max_pool2d(IntPtr input, + internal static extern IntPtr THSTensor_max_pool3d_with_indices(IntPtr input, IntPtr kernelSize, int kernelSizeLength, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - [MarshalAs(UnmanagedType.U1)] bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, out IntPtr indices); [DllImport("LibTorchSharp")] - internal static extern void THSTensor_max_pool2d_with_indices(IntPtr input, AllocatePinnedArray allocator, - IntPtr kernelSize, int kernelSizeLength, - IntPtr strides, int stridesLength, - IntPtr padding, int paddingLength, - IntPtr dilation, int dilationLength, - [MarshalAs(UnmanagedType.U1)] bool ceil_mode); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_max_pool3d(IntPtr input, - IntPtr kernelSize, int kernelSizeLength, - IntPtr strides, int stridesLength, - IntPtr padding, int paddingLength, - IntPtr dilation, int dilationLength, - [MarshalAs(UnmanagedType.U1)] bool ceil_mode); + internal static extern IntPtr THSTensor_max_unpool1d(IntPtr tensor, IntPtr indices, + IntPtr kernelSize, int kernelSizeLength, + IntPtr outputSize, int outputSizeLength, + IntPtr padding, int paddingLength, + IntPtr strides, int stridesLength); [DllImport("LibTorchSharp")] - internal static extern void THSTensor_max_pool3d_with_indices(IntPtr input, AllocatePinnedArray allocator, + internal static extern IntPtr THSTensor_max_unpool2d(IntPtr tensor, IntPtr indices, IntPtr kernelSize, int kernelSizeLength, - IntPtr strides, int stridesLength, + IntPtr outputSize, int outputSizeLength, IntPtr padding, int paddingLength, - IntPtr dilation, int dilationLength, - [MarshalAs(UnmanagedType.U1)] bool ceil_mode); + IntPtr strides, int stridesLength); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_maxunpool3d(IntPtr input, IntPtr indices, IntPtr outputSize, int outputSizeLength, IntPtr strides, int stridesLength, - IntPtr padding, int paddingLength); + internal static extern IntPtr THSTensor_max_unpool3d(IntPtr tensor, IntPtr indices, + IntPtr kernelSize, int kernelSizeLength, + IntPtr outputSize, int outputSizeLength, + IntPtr padding, int paddingLength, + IntPtr strides, int stridesLength); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_avg_pool1d(IntPtr input, @@ -129,7 +122,7 @@ internal static extern IntPtr THSTensor_avg_pool2d(IntPtr input, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, - [MarshalAs(UnmanagedType.U1)] bool count_include_pad); + [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_avg_pool3d(IntPtr input, @@ -137,7 +130,7 @@ internal static extern IntPtr THSTensor_avg_pool3d(IntPtr input, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, - [MarshalAs(UnmanagedType.U1)] bool count_include_pad); + [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_avg_pool2d_backward(IntPtr gradOutput, IntPtr originalInput, @@ -157,14 +150,6 @@ internal static extern IntPtr THSTensor_avg_pool3d_backward(IntPtr gradOutput, I [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisorOverride); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_adaptive_avg_pool1d(IntPtr input, - IntPtr outputSize, int outputSizeLength); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_adaptive_avg_pool2d(IntPtr input, - IntPtr outputSize, int outputSizeLength); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_upsample_nearest1d(IntPtr input, IntPtr outputSize, int outputSizeLength, @@ -505,6 +490,15 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_transpose(IntPtr tensor, long dim1, long dim2); + [DllImport("LibTorchSharp")] + internal static extern void THSTensor_transpose_(IntPtr tensor, long dim1, long dim2); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_threshold(IntPtr tensor, IntPtr threshold, IntPtr value); + + [DllImport("LibTorchSharp")] + internal static extern void THSTensor_threshold_(IntPtr tensor, IntPtr threshold, IntPtr value); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_tril(IntPtr tensor, long diagonal, [MarshalAs(UnmanagedType.U1)] bool inplace); @@ -517,9 +511,6 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_triu_indices(long row, long col, long offset, sbyte scalar_type, int device_type, int device_index); - [DllImport("LibTorchSharp")] - internal static extern void THSTensor_transpose_(IntPtr tensor, long dim1, long dim2); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_view(IntPtr tensor, IntPtr shape, int length); @@ -620,7 +611,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern IntPtr THSTensor_positive(IntPtr tensor); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_softplus(IntPtr tensor); + internal static extern IntPtr THSTensor_softplus(IntPtr tensor, IntPtr beta, IntPtr threshold); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_ravel(IntPtr tensor); @@ -638,10 +629,22 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern void THSTensor_relu6_(IntPtr tensor); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_celu(IntPtr tensor); + internal static extern IntPtr THSTensor_rrelu(IntPtr tensor, double lower, double upper); + + [DllImport("LibTorchSharp")] + internal static extern void THSTensor_rrelu_(IntPtr tensor, double lower, double upper); [DllImport("LibTorchSharp")] - internal static extern void THSTensor_celu_(IntPtr tensor); + internal static extern IntPtr THSTensor_celu(IntPtr tensor, IntPtr alpha); + + [DllImport("LibTorchSharp")] + internal static extern void THSTensor_celu_(IntPtr tensor, IntPtr alpha); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_hardshrink(IntPtr tensor, IntPtr lambda); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_softshrink(IntPtr tensor, IntPtr lambda); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_elu(IntPtr tensor, IntPtr alpha, IntPtr scale, IntPtr input_scale); @@ -652,6 +655,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_gelu(IntPtr tensor); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_gelu_(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_glu(IntPtr tensor, long dim); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_hardsigmoid(IntPtr tensor); @@ -2086,7 +2095,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern IntPtr THSTensor_eye(long rows, long columns, sbyte scalarType, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_maxunpool2d(IntPtr input, IntPtr indices, IntPtr outputSize, int outputSizeLength); + internal static extern IntPtr THSTensor_adaptive_avg_pool1d(IntPtr input, IntPtr outputSize, int outputSizeLength); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_adaptive_avg_pool2d(IntPtr input, IntPtr outputSize, int outputSizeLength); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_adaptive_avg_pool3d(IntPtr input, IntPtr outputSize, int outputSizeLength); @@ -2094,6 +2106,27 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_adaptive_avg_pool3d_backward_out(IntPtr gradInput, IntPtr gradOutput, IntPtr originalInput); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_adaptive_max_pool1d(IntPtr input, IntPtr outputSize, int outputSizeLength, out IntPtr indices); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_adaptive_max_pool2d(IntPtr input, IntPtr outputSize, int outputSizeLength, out IntPtr indices); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_adaptive_max_pool3d(IntPtr input, IntPtr outputSize, int outputSizeLength, out IntPtr indices); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_fractional_max_pool2d(IntPtr input, IntPtr kernelSize, int kernelSizeLength, IntPtr outputSize, int outputSizeLength, IntPtr outputRatio, int outputRatioLength, out IntPtr indices); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_fractional_max_pool3d(IntPtr input, IntPtr kernelSize, int kernelSizeLength, IntPtr outputSize, int outputSizeLength, IntPtr outputRatio, int outputRatioLength, out IntPtr indices); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_lp_pool1d(IntPtr input, double norm_type, IntPtr kernelSize, int kernelSizeLength, IntPtr stride, int strideLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_lp_pool2d(IntPtr input, double norm_type, IntPtr kernelSize, int kernelSizeLength, IntPtr stride, int strideLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_searchsorted_t(IntPtr sorted_sequence, IntPtr values, bool out_int32, bool right, IntPtr sorter); [DllImport("LibTorchSharp")] diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 04fbb43da..f72a2e812 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -2166,6 +2166,21 @@ public Tensor transpose_(long dim0, long dim1) CheckForErrors(); return this; } + + public Tensor threshold(Scalar threshold, Scalar value) + { + var res = NativeMethods.THSTensor_threshold(Handle, threshold.Handle, value.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + public Tensor threshold_(Scalar threshold, Scalar value) + { + NativeMethods.THSTensor_threshold_(Handle, threshold.Handle, value.Handle); + CheckForErrors(); + return this; + } /// /// Returns a view of the tensor conjugated and with the last two dimensions transposed. @@ -2695,9 +2710,13 @@ public Tensor positive() public Tensor softmax(long dim, ScalarType? dtype = null) => torch.special.softmax(this, dim, dtype); - public Tensor softplus() + + public Tensor softplus(int beta = 1, int threshold = 20) => + softplus1(beta, threshold); + + private Tensor softplus1(Scalar beta, Scalar threshold) { - var res = NativeMethods.THSTensor_softplus(Handle); + var res = NativeMethods.THSTensor_softplus(Handle, beta.Handle, threshold.Handle); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); @@ -2741,22 +2760,50 @@ public Tensor relu6_() return this; } - public Tensor celu() + + + private const double one_eighth = 1.0 / 8.0; + private const double one_third = 1.0 / 3.0; + + public Tensor rrelu(double lower = one_eighth, double upper = one_third) { - var res = NativeMethods.THSTensor_celu(Handle); + var res = NativeMethods.THSTensor_rrelu(Handle, lower, upper); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); } - public Tensor celu_() + public Tensor rrelu_(double lower = one_eighth, double upper = one_third) { - NativeMethods.THSTensor_celu_(Handle); + NativeMethods.THSTensor_rrelu_(Handle, lower, upper); CheckForErrors(); return this; } - public Tensor elu(Scalar alpha, Scalar scale, Scalar input_scale) + public Tensor celu() => this.celu(1.0); + + public Tensor celu_() => this.celu_(1.0); + + public Tensor celu(Scalar alpha) + { + var res = NativeMethods.THSTensor_celu(Handle, alpha.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + public Tensor celu_(Scalar alpha) + { + NativeMethods.THSTensor_celu_(Handle, alpha.Handle); + CheckForErrors(); + return this; + } + + public Tensor elu(double alpha = 1) => elu1(alpha, 1.0, 1.0); + + public Tensor elu_(double alpha = 1) => elu2(alpha, 1.0, 1.0); + + private Tensor elu1(Scalar alpha, Scalar scale, Scalar input_scale) { var res = NativeMethods.THSTensor_elu(Handle, alpha.Handle, scale.Handle, input_scale.Handle); if (res == IntPtr.Zero) @@ -2764,7 +2811,7 @@ public Tensor elu(Scalar alpha, Scalar scale, Scalar input_scale) return new Tensor(res); } - public Tensor elu_(Scalar alpha, Scalar scale, Scalar input_scale) + private Tensor elu2(Scalar alpha, Scalar scale, Scalar input_scale) { NativeMethods.THSTensor_elu_(Handle, alpha.Handle, scale.Handle, input_scale.Handle); CheckForErrors(); @@ -2779,6 +2826,22 @@ public Tensor gelu() return new Tensor(res); } + public Tensor gelu_() + { + var res = NativeMethods.THSTensor_gelu_(Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + public Tensor glu(long dim = -1) + { + var res = NativeMethods.THSTensor_glu(Handle, dim); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + public Tensor hardsigmoid() { var res = NativeMethods.THSTensor_hardsigmoid(Handle); diff --git a/src/TorchVision/Ops/DropBlock.cs b/src/TorchVision/Ops/DropBlock.cs index 2aece98d0..dea7ed8a4 100644 --- a/src/TorchVision/Ops/DropBlock.cs +++ b/src/TorchVision/Ops/DropBlock.cs @@ -53,7 +53,7 @@ public static Tensor drop_block2d(Tensor input, double p, long block_size, bool var pad = block_size / 2; noise = torch.nn.functional.pad(noise, (pad, pad, pad, pad), value: 0); - noise = torch.nn.functional.max_pool2d(noise, stride: 1, kernelSize: block_size, padding: block_size / 2); + noise = torch.nn.functional.max_pool2d(noise, stride: 1, kernel_size: block_size, padding: block_size / 2); noise = 1 - noise; var normalize_scale = noise.numel() / (eps + noise.sum()); @@ -94,7 +94,7 @@ public static Tensor drop_block3d(Tensor input, double p, long block_size, bool var pad = block_size / 2; var padding = new[] { pad, pad, pad, pad, pad, pad }; noise = torch.nn.functional.pad(noise, padding, value: 0); - noise = torch.nn.functional.max_pool3d(noise, strides: new long[] { 1, 1, 1 }, kernelSize: new[] { block_size, block_size, block_size }, padding: new long[] { pad }); + noise = torch.nn.functional.max_pool3d(noise, stride: new long[] { 1, 1, 1 }, kernel_size: new[] { block_size, block_size, block_size }, padding: new long[] { pad }); noise = 1 - noise; var normalize_scale = noise.numel() / (eps + noise.sum()); diff --git a/src/TorchVision/models/AlexNet.cs b/src/TorchVision/models/AlexNet.cs index 33372af06..a26fd3718 100644 --- a/src/TorchVision/models/AlexNet.cs +++ b/src/TorchVision/models/AlexNet.cs @@ -76,17 +76,17 @@ public AlexNet(int numClasses, float dropout = 0.5f, string? weights_file = null features = Sequential( Conv2d(3, 64, kernelSize: 11, stride: 4, padding: 2), ReLU(inplace: true), - MaxPool2d(kernelSize: 3, stride: 2), + MaxPool2d(kernel_size: 3, stride: 2), Conv2d(64, 192, kernelSize: 5, padding: 2), ReLU(inplace: true), - MaxPool2d(kernelSize: 3, stride: 2), + MaxPool2d(kernel_size: 3, stride: 2), Conv2d(192, 384, kernelSize: 3, padding: 1), ReLU(inplace: true), Conv2d(384, 256, kernelSize: 3, padding: 1), ReLU(inplace: true), Conv2d(256, 256, kernelSize: 3, padding: 1), ReLU(inplace: true), - MaxPool2d(kernelSize: 3, stride: 2) + MaxPool2d(kernel_size: 3, stride: 2) ); avgpool = AdaptiveAvgPool2d(new long[] { 6, 6 }); diff --git a/src/TorchVision/models/GoogleNet.cs b/src/TorchVision/models/GoogleNet.cs index 861ccde6f..105c24d35 100644 --- a/src/TorchVision/models/GoogleNet.cs +++ b/src/TorchVision/models/GoogleNet.cs @@ -119,21 +119,21 @@ public GoogleNet(int numClasses = 1000, this.transform_input = transform_input; conv1 = conv_block(3, 64, kernel_size: 7, stride: 2, padding: 3); - maxpool1 = MaxPool2d(kernelSize: 3, stride: 2, ceilMode: true); + maxpool1 = MaxPool2d(kernel_size: 3, stride: 2, ceil_mode: true); conv2 = conv_block(64, 64, kernel_size: 1); conv3 = conv_block(64, 192, kernel_size: 3, padding: 1); - maxpool2 = MaxPool2d(kernelSize: 3, stride: 2, ceilMode: true); + maxpool2 = MaxPool2d(kernel_size: 3, stride: 2, ceil_mode: true); inception3a = inception_block(192, 64, 96, 128, 16, 32, 32); inception3b = inception_block(256, 128, 128, 192, 32, 96, 64); - maxpool3 = nn.MaxPool2d(3, stride: 2, ceilMode: true); + maxpool3 = nn.MaxPool2d(3, stride: 2, ceil_mode: true); inception4a = inception_block(480, 192, 96, 208, 16, 48, 64); inception4b = inception_block(512, 160, 112, 224, 24, 64, 64); inception4c = inception_block(512, 128, 128, 256, 24, 64, 64); inception4d = inception_block(512, 112, 144, 288, 32, 64, 64); inception4e = inception_block(528, 256, 160, 320, 32, 128, 128); - maxpool4 = nn.MaxPool2d(2, stride: 2, ceilMode: true); + maxpool4 = nn.MaxPool2d(2, stride: 2, ceil_mode: true); inception5a = inception_block(832, 256, 160, 320, 32, 128, 128); inception5b = inception_block(832, 384, 192, 384, 48, 128, 128); @@ -280,7 +280,7 @@ public Inception(int in_channels, int ch1x1, int ch3x3red, int ch3x3, int ch5x5r conv_block(ch5x5red, ch5x5, kernel_size: 3, padding: 1) ); branch4 = nn.Sequential( - nn.MaxPool2d(kernelSize: 3, stride: 1, padding: 1, ceilMode: true), + nn.MaxPool2d(kernel_size: 3, stride: 1, padding: 1, ceil_mode: true), conv_block(in_channels, pool_proj, kernel_size: 1) ); RegisterComponents(); diff --git a/src/TorchVision/models/InceptionV3.cs b/src/TorchVision/models/InceptionV3.cs index e7f7791c3..7b68b01f2 100644 --- a/src/TorchVision/models/InceptionV3.cs +++ b/src/TorchVision/models/InceptionV3.cs @@ -119,10 +119,10 @@ public InceptionV3(int numClasses = 1000, Conv2d_1a_3x3 = conv_block(3, 32, kernel_size: 3, stride: 2); Conv2d_2a_3x3 = conv_block(32, 32, kernel_size: 3); Conv2d_2b_3x3 = conv_block(32, 64, kernel_size: 3, padding: 1); - maxpool1 = MaxPool2d(kernelSize: 3, stride: 2); + maxpool1 = MaxPool2d(kernel_size: 3, stride: 2); Conv2d_3b_1x1 = conv_block(64, 80, kernel_size: 1); Conv2d_4a_3x3 = conv_block(80, 192, kernel_size: 3); - maxpool2 = MaxPool2d(kernelSize: 3, stride: 2); + maxpool2 = MaxPool2d(kernel_size: 3, stride: 2); Mixed_5b = inception_a(192, pool_features: 32); Mixed_5c = inception_a(256, pool_features: 64); @@ -292,7 +292,7 @@ public override Tensor forward(Tensor x) branch3x3dbl = branch3x3dbl_2.call(branch3x3dbl); branch3x3dbl = branch3x3dbl_3.call(branch3x3dbl); - var branch_pool_ = functional.avg_pool2d(x, kernelSize: 3, stride: 1, padding: 1); + var branch_pool_ = functional.avg_pool2d(x, kernel_size: 3, stride: 1, padding: 1); branch_pool_ = branch_pool.call(branch_pool_); var outputs = new [] { branch1x1_, branch5x5, branch3x3dbl, branch_pool_ }; @@ -341,7 +341,7 @@ public override Tensor forward(Tensor x) branch3x3dbl = branch3x3dbl_2.call(branch3x3dbl); branch3x3dbl = branch3x3dbl_3.call(branch3x3dbl); - var branch_pool = functional.max_pool2d(x, kernelSize: 3, stride: 2); + var branch_pool = functional.max_pool2d(x, kernel_size: 3, stride: 2); var outputs = new[] { branch3x3_, branch3x3dbl, branch_pool }; @@ -425,7 +425,7 @@ public override Tensor forward(Tensor x) branch7x7dbl = branch7x7dbl_4.call(branch7x7dbl); branch7x7dbl = branch7x7dbl_5.call(branch7x7dbl); - var branch_pool_ = functional.avg_pool2d(x, kernelSize: 3, stride: 1, padding: 1); + var branch_pool_ = functional.avg_pool2d(x, kernel_size: 3, stride: 1, padding: 1); branch_pool_ = branch_pool.call(branch_pool_); var outputs = new[] { branch1x1_, branch7x7, branch7x7dbl, branch_pool_ }; @@ -475,7 +475,7 @@ public override Tensor forward(Tensor x) branch7x7x3 = branch7x7x3_3.call(branch7x7x3); branch7x7x3 = branch7x7x3_4.call(branch7x7x3); - var branch_pool = functional.max_pool2d(x, kernelSize: 3, stride: 2); + var branch_pool = functional.max_pool2d(x, kernel_size: 3, stride: 2); var outputs = new[] { branch3x3, branch7x7x3, branch_pool }; @@ -537,7 +537,7 @@ public override Tensor forward(Tensor x) branch3x3dbl = branch3x3dbl_2.call(branch3x3dbl); branch3x3dbl = torch.cat(new[] { branch3x3dbl_3a.call(branch3x3dbl), branch3x3dbl_3b.call(branch3x3dbl) }, 1); - var branch_pool_ = functional.avg_pool2d(x, kernelSize: 3, stride: 1, padding: 1); + var branch_pool_ = functional.avg_pool2d(x, kernel_size: 3, stride: 1, padding: 1); branch_pool_ = branch_pool.call(branch_pool_); var outputs = new[] { branch1x1_, branch3x3, branch3x3dbl, branch_pool_ }; @@ -575,7 +575,7 @@ public InceptionAux(int in_channels, int num_classes) : base("InceptionAux") public override Tensor forward(Tensor x) { // N x 768 x 17 x 17 - x = functional.avg_pool2d(x, kernelSize: 5, stride: 3); + x = functional.avg_pool2d(x, kernel_size: 5, stride: 3); // N x 768 x 5 x 5 x = conv0.call(x); // N x 128 x 5 x 5 diff --git a/src/TorchVision/models/ResNet.cs b/src/TorchVision/models/ResNet.cs index 654d587c3..91c6fa7c4 100644 --- a/src/TorchVision/models/ResNet.cs +++ b/src/TorchVision/models/ResNet.cs @@ -773,7 +773,7 @@ public ResNet(string name, conv1 = Conv2d(3, in_planes, kernelSize: 7, stride: 2, padding: 3, bias: false); bn1 = norm_layer(in_planes); relu = ReLU(inplace: true); - maxpool = MaxPool2d(kernelSize: 3, stride: 2, padding: 1); + maxpool = MaxPool2d(kernel_size: 3, stride: 2, padding: 1); MakeLayer(layer1, block, expansion, 64, layers[0], 1); MakeLayer(layer2, block, expansion, 128, layers[1], 2, rswd.Item1); diff --git a/src/TorchVision/models/VGG.cs b/src/TorchVision/models/VGG.cs index e79f9ddec..08e6d8a3c 100644 --- a/src/TorchVision/models/VGG.cs +++ b/src/TorchVision/models/VGG.cs @@ -363,7 +363,7 @@ public VGG(string name, for (var i = 0; i < channels.Length; i++) { if (channels[i] == 0) { - layers.Add(MaxPool2d(kernelSize: 2, stride: 2)); + layers.Add(MaxPool2d(kernel_size: 2, stride: 2)); } else { layers.Add(Conv2d(in_channels, channels[i], kernelSize: 3, padding: 1)); if (batch_norm) { diff --git a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj index 055fb9ffc..e16e6025e 100644 --- a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj +++ b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj @@ -22,9 +22,11 @@ + Always + @@ -34,15 +36,16 @@ + + - diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index cd59ad5b9..b1380c9da 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -19,11 +19,12 @@ namespace TorchSharp { static internal class TestUtils { - public static IList AvailableDevices(bool cuda = true) + public static IList AvailableDevices(bool cuda = true, bool mps = false) { List result = new List(); result.Add(torch.CPU); if (cuda && torch.cuda_is_available()) result.Add(torch.CUDA); + if (mps && torch.mps_is_available()) result.Add(torch.MPS); return result; } } @@ -269,8 +270,10 @@ public void TestLinearNullBias() { var device = torch.CPU; - var lin = Linear(1000, 100, true, device: device); - Assert.Throws(() => lin.bias = null); + var lin = Linear(100, 100, true, device: device); + // This should not throw: + lin.bias = null; + lin.call(torch.rand(100)); } [Fact] @@ -333,6 +336,8 @@ public void TestIdentity() var input = torch.randn(new long[] { 1, 1000 }, device: device); var output = lin.call(input); + output[0, 511] = 10; // When we modify the copy, the original should be altered, too. + Assert.Equal(device.type, output.device_type); Assert.Equal(input.data(), output.data()); } @@ -593,9 +598,15 @@ public void EvaluateGELU() [Fact] public void EvaluatePReLU() { + var rel = PReLU(1, 0.35, torch.CPU); + + Assert.Equal(1, rel.num_parameters); + Assert.Equal(0.35f, rel.weight.item()); + Assert.True(rel.weight.requires_grad); + foreach (var device in TestUtils.AvailableDevices()) { - var rel = PReLU(1, 0.35, device); + rel = rel.to(device); var input = torch.randn(new long[] { 4, 3, 8, 8 }, device: device) * 5.0; var output = rel.call(input); @@ -713,11 +724,15 @@ public void EvaluateSiLU() public void EvaluateSoftmax2d() { var rel = Softmax2d(); + var rel_x = Softmax(-3); + foreach (var device in TestUtils.AvailableDevices()) { var input = torch.randn(new long[] { 64, 3, 8, 8 }, device: device) * 25.0; var output = rel.call(input); Assert.Equal(device.type, output.device_type); + Assert.True(torch.allclose(rel_x.call(input), output)); + var values = output.data().ToArray(); Assert.Equal(input.shape, output.shape); Assert.All(values, val => Assert.True(val >= 0.0 && val <= 1.0)); @@ -2361,12 +2376,12 @@ public void TestConv1d() var shape = new long[] { 16, 3, 28 }; foreach (var device in TestUtils.AvailableDevices(false)) { Tensor t = torch.rand(shape, device: device); - var conv = Conv1d(3, 64, 3, device: device); + var conv = Conv1d(3, 64, 5, device: device); var output = conv.call(t); Assert.Equal(device.type, output.device_type); Assert.Equal(16, output.shape[0]); Assert.Equal(64, output.shape[1]); - Assert.Equal(26, output.shape[2]); + Assert.Equal(24, output.shape[2]); } } @@ -3510,7 +3525,7 @@ public void AvgPool3DBackwardTensorExplicitDivisor() var ones = torch.ones(new long[] { 4, 2, 2, 2, 2 }, device: device); var kernelSize = new long[] { 2, 2, 2 }; var avg = torch.ones(new long[] { 4, 2, 1, 1, 1 }, device: device); - var res = torch.nn.functional.avg_pool3d_backward(avg, ones, kernelSize, divisorOverride: 6) * 6.0; + var res = torch.nn.functional.avg_pool3d_backward(avg, ones, kernelSize, divisor_override: 6) * 6.0; var ones0000 = ones.cpu()[0, 0, 0, 0, 0].ToSingle(); var res0000 = res.cpu()[0, 0, 0, 0, 0].ToSingle(); @@ -4636,6 +4651,35 @@ public void TestInstanceNorm1D() } } + [Fact] + public void TestInstanceNorm1dWeightAndBias() + { + foreach (var device in TestUtils.AvailableDevices()) { + var ones = torch.ones(new long[] { 16, 3, 28 }, device: device); + + using (var norm = InstanceNorm1d(3, affine:true, track_running_stats: false, device: device)) { + var w = norm.weight; + var b = norm.bias; + + Assert.NotNull(w); + Assert.NotNull(b); + + Assert.Equal(device.type, w.device_type); + Assert.Equal(device.type, b.device_type); + + var pooled = norm.call(ones); + Assert.Equal(device.type, pooled.device_type); + Assert.Equal(ones.shape, pooled.shape); + + Assert.Null(norm.running_mean); + Assert.Null(norm.running_var); + + Assert.Equal(new long[] { 3 }, w.shape); + Assert.Equal(new long[] { 3 }, b.shape); + } + } + } + [Fact] public void TestInstanceNorm2D() { @@ -4650,7 +4694,7 @@ public void TestInstanceNorm2D() Assert.Null(pool.running_mean); Assert.Null(pool.running_var); Assert.Equal(ones.shape, pooled.shape); - Assert.Throws(() => pool.call(torch.ones(new long[] { 16, 2, 2 }, device: device))); + Assert.Throws(() => pool.call(torch.ones(new long[] { 16, 2 }, device: device))); Assert.Throws(() => pool.call(torch.ones(new long[] { 2, 2, 2, 2, 2 }, device: device))); } } @@ -4696,6 +4740,35 @@ public void TestInstanceNorm2D() } } + [Fact] + public void TestInstanceNorm2dWeightAndBias() + { + foreach (var device in TestUtils.AvailableDevices()) { + var ones = torch.ones(new long[] { 16, 3, 28, 28 }, device: device); + + using (var norm = InstanceNorm2d(3, affine:true, track_running_stats: false, device: device)) { + var w = norm.weight; + var b = norm.bias; + + Assert.NotNull(w); + Assert.NotNull(b); + + Assert.Equal(device.type, w.device_type); + Assert.Equal(device.type, b.device_type); + + var pooled = norm.call(ones); + Assert.Equal(device.type, pooled.device_type); + Assert.Equal(ones.shape, pooled.shape); + + Assert.Null(norm.running_mean); + Assert.Null(norm.running_var); + + Assert.Equal(new long[] { 3 }, w.shape); + Assert.Equal(new long[] { 3 }, b.shape); + } + } + } + [Fact] public void TestInstanceNorm3D() { @@ -4710,7 +4783,7 @@ public void TestInstanceNorm3D() Assert.Null(pool.running_mean); Assert.Null(pool.running_var); Assert.Equal(ones.shape, pooled.shape); - Assert.Throws(() => pool.call(torch.ones(new long[] { 16, 2, 2, 2 }, device: device))); + Assert.Throws(() => pool.call(torch.ones(new long[] { 16, 2, 2 }, device: device))); Assert.Throws(() => pool.call(torch.ones(new long[] { 2, 2, 2, 2, 2, 2 }, device: device))); } } @@ -4756,6 +4829,35 @@ public void TestInstanceNorm3D() } } + [Fact] + public void TestInstanceNorm3dWeightAndBias() + { + foreach (var device in TestUtils.AvailableDevices()) { + var ones = torch.ones(new long[] { 16, 3, 28, 28, 28 }, device: device); + + using (var norm = InstanceNorm3d(3, affine:true, track_running_stats: false, device: device)) { + var w = norm.weight; + var b = norm.bias; + + Assert.NotNull(w); + Assert.NotNull(b); + + Assert.Equal(device.type, w.device_type); + Assert.Equal(device.type, b.device_type); + + var pooled = norm.call(ones); + Assert.Equal(device.type, pooled.device_type); + Assert.Equal(ones.shape, pooled.shape); + + Assert.Null(norm.running_mean); + Assert.Null(norm.running_var); + + Assert.Equal(new long[] { 3 }, w.shape); + Assert.Equal(new long[] { 3 }, b.shape); + } + } + } + [Fact] public void TestLayerNorm() { @@ -5236,6 +5338,7 @@ public void TestMultiheadAttention() using (var V = torch.tensor(v_data, src_seq_len, batch_size, vembed_dim)) using (var Attn = torch.tensor(attn_data, batch_size, src_seq_len, src_seq_len)) { + var children = mha.children().ToList(); mha.eval(); Assert.False(mha.training); @@ -5428,13 +5531,13 @@ public void TestFlatten() Assert.Equal(new long[] { 32, 360 }, output.shape); } - using (var flat = Flatten(startDim: 2)) { + using (var flat = Flatten(start_dim: 2)) { var output = flat.call(data); Assert.Equal(device.type, output.device_type); Assert.Equal(new long[] { 32, 3, 120 }, output.shape); } - using (var flat = Flatten(startDim: 0)) { + using (var flat = Flatten(start_dim: 0)) { var output = flat.call(data); Assert.Equal(device.type, output.device_type); Assert.Equal(new long[] { 32 * 360 }, output.shape); @@ -6583,5 +6686,37 @@ public void TestModulePostHooks() lin1.call(input); Assert.Equal(1, counter); } + + [Fact] + public void TestCustomParameterLessModule() + { + var cnp = new CustomNoParameters("test"); + + // Should not throw + cnp.register_module("sub", new CustomNoParameters("test")); + + Assert.True(cnp.named_modules().Count() > 0); + Assert.Equal("sub", cnp.named_modules().First().name); + + Assert.Throws(() => cnp.register_module("test", torch.nn.Linear(10,10, true))); + Assert.Throws(() => cnp.register_buffer("test", torch.rand(10))); + Assert.Throws(() => cnp.register_parameter("test", new Parameter(torch.rand(10)))); + } + + class CustomNoParameters : ParamLessModule + { + public CustomNoParameters(string name) : base(name) + { + } + + public CustomNoParameters(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + { + } + + public override Tensor forward(Tensor input) + { + throw new NotImplementedException(); + } + } } } \ No newline at end of file diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs index e2b65ac13..5653c0ea0 100644 --- a/test/TorchSharpTest/TestTorchTensorBugs.cs +++ b/test/TorchSharpTest/TestTorchTensorBugs.cs @@ -1172,7 +1172,7 @@ public void Validate1089_2d() () => Assert.Equal(expectedShape, functional.max_pool2d(t, new long[] { 2, 2 }).shape) ); - Assert.Equal(expectedShape, functional.max_pool2d_with_indices(t, new long[] { 2, 2 }).output.shape); + Assert.Equal(expectedShape, functional.max_pool2d_with_indices(t, new long[] { 2, 2 }).Values.shape); } [Fact] @@ -1182,7 +1182,7 @@ public void Validate1089_3d() var expectedShape = new long[] { 1, 6, 14, 14, 14 }; Assert.Equal(expectedShape, functional.max_pool3d(t, new long[] { 2, 2, 2 }).shape); - Assert.Equal(expectedShape, functional.max_pool3d_with_indices(t, new long[] { 2, 2, 2 }).output.shape); + Assert.Equal(expectedShape, functional.max_pool3d_with_indices(t, new long[] { 2, 2, 2 }).Values.shape); } [Fact] diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index 08aba9e03..c8f1bc341 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -842,7 +842,7 @@ public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveDifferen double[] stdevs = { 0.229, 0.224, 0.225, 0.222 }; // Different length // Act & Assert - Assert.Throws(() => torchvision.transforms.Normalize(means, stdevs)); + Assert.Throws(() => Normalize(means, stdevs)); } [Fact] @@ -853,7 +853,7 @@ public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveWrongLen double[] stdevs = { 0.229, 0.224 }; // Not 1 or 3 // Act & Assert - Assert.Throws(() => torchvision.transforms.Normalize(means, stdevs)); + Assert.Throws(() => Normalize(means, stdevs)); } [Fact] @@ -864,7 +864,7 @@ public void TestConstructor_CreatesNewNormalizeObject_WithValidArguments() double[] stdevs = { 0.229, 0.224, 0.225 }; // Act - var result = torchvision.transforms.Normalize(means, stdevs); + var result = Normalize(means, stdevs); // Assert Assert.NotNull(result); @@ -876,7 +876,7 @@ public void TestCall_ThrowsArgumentException_IfNumberOfChannelsIsNotEqual() // Arrange double[] means = { 0.485, 0.456, 0.406 }; double[] stdevs = { 0.229, 0.224, 0.225 }; - var sut = torchvision.transforms.Normalize(means, stdevs); + var sut = Normalize(means, stdevs); var wrongSizeInput = torch.rand(new long[] { 1, 4, 32, 32 }); // wrong number of input channels // Act & Assert @@ -889,7 +889,7 @@ public void TestCall_CallsOperatorsCorrectly() // Arrange double[] means = { 0.485, 0.456, 0.406 }; double[] stdevs = { 0.229, 0.224, 0.225 }; - var sut = torchvision.transforms.Normalize(means, stdevs); + var sut = Normalize(means, stdevs); var inputChannels = 3; var input = torch.rand(new long[] { 1, inputChannels, 32, 32 }, dtype: float64); @@ -905,12 +905,11 @@ public void TestCall_CallsOperatorsCorrectly() [Fact] public void Call_ThrowsException_WithWrongNumberOfChannels() { - // Act - Assert.Throws(() => torchvision.transforms.Grayscale(outputChannels: 2)); + Assert.Throws(() => Grayscale(outputChannels: 2)); Tensor input = torch.rand(new long[] { 1, 2, 128, 128 }); - var tfrm = torchvision.transforms.Grayscale(outputChannels: 1); + var tfrm = Grayscale(outputChannels: 1); Assert.Throws(() => tfrm.call(input)); } @@ -922,7 +921,7 @@ public void Resize_WithHeightAndWidth_ReturnsTensor() int height = 20; int width = 30; var input = torch.randn(1, 3, 256, 256); - var transform = torchvision.transforms.Resize(height, width); + var transform = Resize(height, width); //Act var result = transform.call(input); @@ -939,7 +938,7 @@ public void Resize_WithSizeAndMaxSize_ReturnsTensor() int size = 20; int? maxSize = 30; var input = torch.randn(1, 3, 256, 256); - var transform = torchvision.transforms.Resize(size, maxSize); + var transform = Resize(size, maxSize); //Act var result = transform.call(input); @@ -1214,5 +1213,381 @@ public void Adjust_Contrast_ReturnsTensorWithCorrectDtype() var img2 = torchvision.transforms.functional.adjust_contrast(img1, 2); Assert.Equal(img1.dtype, img2.dtype); } + + + [Fact] + public void RgbToGrayscale_ReturnsCorrectNumberOfChannels() + { + int numChannels = 3; + int numOutputChannels = 1; + var shape = new long[] { numChannels, 10, 10 }; + + var input = torch.rand(shape); + + var output = functional.rgb_to_grayscale(input, numOutputChannels); + + Assert.Equal(numOutputChannels, output.shape[0]); + } + + [Fact] + public void RgbToGrayscale_ThrowsArgumentException_ForInvalidOutputChannels() + { + int numChannels = 3; + int numOutputChannels = 2; + var shape = new long[] { numChannels, 10, 10 }; + + var input = torch.rand(shape); + + Assert.Throws(() => functional.rgb_to_grayscale(input, numOutputChannels)); + } + + [Fact] + public void RgbToGrayscale_AlreadyGrayscale_ReturnsInputTensorAsIs() + { + int numChannels = 1; + int numOutputChannels = 1; + var shape = new long[] { numChannels, 10, 10 }; + + var input = torch.rand(shape); + + var output = functional.rgb_to_grayscale(input, numOutputChannels); + + Assert.Equal(input, output); + } + + [Fact] + public void RgbToGrayscale_ConvertsInputToFloatTensor() + { + int numChannels = 3; + int numOutputChannels = 1; + var shape = new long[] { numChannels, 10, 10 }; + + var input = torch.randint(0, 255, shape, dtype:ScalarType.Byte); + + var output = functional.rgb_to_grayscale(input, numOutputChannels); + + Assert.True(output.is_floating_point()); + } + + [Fact] + public void RgbToGrayscale_ReturnsTensorWithCorrectShape() + { + int numChannels = 3; + int numOutputChannels = 1; + var shape = new long[] { numChannels, 10, 10 }; + + var input = torch.rand(shape); + + var output = functional.rgb_to_grayscale(input, numOutputChannels); + + Assert.Equal(new long[] { numOutputChannels, 10, 10 }, output.shape); + } + + [Fact] + public void Resize_WhenSizeNotChanged_ReturnsSameTensor() + { + // Arrange + var input = torch.rand( 3, 2, 2 ); + int height = 2; + int width = 2; + + // Act + var output = functional.resize(input, height, width); + + // Assert + Assert.Equal(input.Dimensions, output.Dimensions); + Assert.Equal(input.shape, output.shape); + Assert.Equal(input, output); + } + + [Fact] + public void Resize_WhenWidthChange_ReturnsTensorWithSameHeight() + { + // Arrange + var input = torch.rand( 3, 2, 4 ); + int height = 2; + int width = 3; + + // Act + var output = functional.resize(input, height, width); + + // Assert + Assert.Equal(input.Dimensions, output.Dimensions); + Assert.Equal(input.shape[0], output.shape[0]); + Assert.Equal(height, output.shape[1]); + Assert.Equal(width, output.shape[2]); + } + + [Fact] + public void Resize_WhenHeightChange_ReturnsTensorWithSameWidth() + { + // Arrange + var input = torch.rand( 3, 4, 2); + int height = 3; + int width = 2; + + // Act + var output = functional.resize(input, height, width); + + // Assert + Assert.Equal(input.Dimensions, output.Dimensions); + Assert.Equal(input.shape[0], output.shape[0]); + Assert.Equal(height, output.shape[1]); + Assert.Equal(width, output.shape[2]); + } + + [Fact] + public void Resize_WhenMaxSizeNotMet_ThrowsArgumentException() + { + // Arrange + var input = torch.rand( 3, 5, 4 ); + int height = 10; + int? maxSize = 8; + + // Act + Assert + Assert.Throws(() => functional.resize(input, height, -1, maxSize)); + } + + [Fact] + public void Resize_WhenMaxSizeMet_DoesNotThrowException() + { + // Arrange + var input = torch.rand( 3, 5, 4 ); + int height = 8; + int? maxSize = 10; + + // Act + Assert + functional.resize(input, height, -1, maxSize); + } + + + + [Fact] + public void CanApplyPerspective() + { + using var tensor = torch.rand(new long[] { 3, 256, 256 }); + + var startpoints = new List>() + { + new List(){ 10, 10 }, + new List(){ 10, 246 }, + new List(){ 246, 10 }, + new List(){ 246, 246 }, + }; + var endpoints = new List>() + { + new List(){ 0, 0 }, + new List(){ 0, 256 }, + new List(){ 256, 0 }, + new List(){ 256, 256 }, + }; + + using var output = functional.perspective(tensor, startpoints, endpoints); + + Assert.NotNull(output); + Assert.Equal(tensor.shape, output.shape); + } + + [Fact] + public void CanApplyPerspectiveWithInterpolation() + { + using var tensor = torch.rand(new long[] { 3, 256, 256 }); + + var startpoints = new List>() + { + new List(){ 10, 10 }, + new List(){ 10, 246 }, + new List(){ 246, 10 }, + new List(){ 246, 246 }, + }; + var endpoints = new List>() + { + new List(){ 0, 0 }, + new List(){ 0, 256 }, + new List(){ 256, 0 }, + new List(){ 256, 256 }, + }; + var interpolation = InterpolationMode.Nearest; + + using var output = functional.perspective(tensor, startpoints, endpoints, interpolation); + + Assert.NotNull(output); + Assert.Equal(tensor.shape, output.shape); + } + + [Fact] + public void CanApplyPerspectiveWithFill() + { + using var tensor = torch.rand(new long[] { 3, 256, 256 }); + + var startpoints = new List>() + { + new List(){ 10, 10 }, + new List(){ 10, 246 }, + new List(){ 246, 10 }, + new List(){ 246, 246 }, + }; + var endpoints = new List>() + { + new List(){ 0, 0 }, + new List(){ 0, 256 }, + new List(){ 256, 0 }, + new List(){ 256, 256 }, + }; + var fill = new List() { 0.5f }; + + using var output = functional.perspective(tensor, startpoints, endpoints, fill: fill); + + Assert.NotNull(output); + Assert.Equal(tensor.shape, output.shape); + } + + [Fact] + public void TestPadZeroes() + { + var input = torch.ones(3, 3, dtype: int64); + { + var padding = new long[] { 1, 2 }; + var padding_mode = PaddingModes.Zeros; + + var expectedOutput = torch.tensor(new long[,] { + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 1, 1, 1, 0}, + {0, 1, 1, 1, 0}, + {0, 1, 1, 1, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0} + }); + + var actualOutput = functional.pad(input, padding, padding_mode: padding_mode); + + Assert.Equal(expectedOutput, actualOutput); + } + { + var padding = new long[] { 1, 1, 2, 2 }; + var padding_mode = PaddingModes.Zeros; + + var expectedOutput = torch.tensor(new long[,] { + {0, 0, 0, 0, 0, 0}, + {0, 1, 1, 1, 0, 0}, + {0, 1, 1, 1, 0, 0}, + {0, 1, 1, 1, 0, 0}, + {0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0} + }); + + var actualOutput = functional.pad(input, padding, padding_mode: padding_mode); + + Assert.Equal(expectedOutput, actualOutput); + } + } + + [Fact] + public void TestPadConstant() + { + var input = torch.ones(3, 3, dtype: int64); + { + var padding = new long[] { 1, 2 }; + var fill = 0; + var padding_mode = PaddingModes.Constant; + + var expectedOutput = torch.tensor(new long[,] { + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 1, 1, 1, 0}, + {0, 1, 1, 1, 0}, + {0, 1, 1, 1, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0} + }); + + var actualOutput = functional.pad(input, padding, fill, padding_mode); + + Assert.Equal(expectedOutput, actualOutput); + } + { + var padding = new long[] { 1, 1, 2, 2 }; + var fill = 0; + var padding_mode = PaddingModes.Constant; + + var expectedOutput = torch.tensor(new long[,] { + {0, 0, 0, 0, 0, 0}, + {0, 1, 1, 1, 0, 0}, + {0, 1, 1, 1, 0, 0}, + {0, 1, 1, 1, 0, 0}, + {0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0} + }); + + var actualOutput = functional.pad(input, padding, fill, padding_mode); + + Assert.Equal(expectedOutput, actualOutput); + } + } + + [Fact] + public void TestPadReflect() + { + var input = torch.arange(1, 10, dtype:float32).reshape(1, 3, 3); + { + var padding = new long[] { 1, 2 }; + var padding_mode = PaddingModes.Reflect; + + var expectedOutput = torch.tensor(new float[,] { + {8, 7, 8, 9, 8}, + {5, 4, 5, 6, 5}, + {2, 1, 2, 3, 2}, + {5, 4, 5, 6, 5}, + {8, 7, 8, 9, 8}, + {5, 4, 5, 6, 5}, + {2, 1, 2, 3, 2} + }).reshape(1, 7, 5); + + var actualOutput = functional.pad(input, padding, padding_mode: padding_mode); + + Assert.Equal(expectedOutput, actualOutput); + } + { + var padding = new long[] { 1, 1, 2, 2 }; + var padding_mode = PaddingModes.Reflect; + + var expectedOutput = torch.tensor(new float[,] { + {5, 4, 5, 6, 5, 4}, + {2, 1, 2, 3, 2, 1}, + {5, 4, 5, 6, 5, 4}, + {8, 7, 8, 9, 8, 7}, + {5, 4, 5, 6, 5, 4}, + {2, 1, 2, 3, 2, 1} + }).reshape(1, 6, 6); + + var actualOutput = functional.pad(input, padding, padding_mode: padding_mode); + + Assert.Equal(expectedOutput, actualOutput); + } + } + + [Fact] + public void TestGaussianBlur() + { + var input = torch.arange(1 * 3 * 3 * 5).reshape(1, 3, 3, 5).to(float32) / 5.0f; + var kernelSize = new List { 3, 5 }; + var sigma = new List { 1.0f, 2.0f }; + + var actual = functional.gaussian_blur(input, kernelSize, sigma); + var expected = torch.tensor(new float[]{ + 2f, 2f, 2.2f, 2.4f, 2.4f, + 1.2f, 1.2f, 1.4f, 1.6f, 1.6f, + 0.4f, 0.4f, 0.6f, 0.8f, 0.8f, + 5f, 5f, 5.2f, 5.4f, 5.4f, + 4.2f, 4.2f, 4.4f, 4.6f, 4.6f, + 3.4f, 3.4f, 3.6f, 3.8f, 3.8f, + 8f, 8f, 8.2f, 8.4f, 8.4f, + 7.2f, 7.2f, 7.4f, 7.6f, 7.6f, + 6.4f, 6.4f, 6.6f, 6.8f, 6.8f + }).reshape(1, 3, 3, 5); + + Assert.True(expected.allclose(actual, rtol: 1e-4, atol: 1e-6)); + } } }