-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use package objects to apply loading workaround more reliably
- Improve doc grouping and macros - Add more pooling ops
- Loading branch information
Showing
33 changed files
with
2,968 additions
and
2,547 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package nn | ||
package functional | ||
|
||
import org.bytedeco.pytorch | ||
import org.bytedeco.pytorch.global.torch as torchNative | ||
import org.bytedeco.javacpp.LongPointer | ||
import torch.internal.NativeConverters.toOptional | ||
import org.bytedeco.pytorch.{ScalarTypeOptional, TensorOptional} | ||
|
||
private[torch] trait Activations { | ||
|
||
/** Applies a softmax followed by a logarithm. | ||
* | ||
* While mathematically equivalent to log(softmax(x)), doing these two operations separately is | ||
* slower and numerically unstable. This function uses an alternative formulation to compute the | ||
* output and gradient correctly. | ||
* | ||
* See `torch.nn.LogSoftmax` for more details. | ||
* | ||
* @group nn_activation | ||
*/ | ||
def logSoftmax[In <: DType, Out <: DType](input: Tensor[In], dim: Long)( | ||
dtype: Out = input.dtype | ||
): Tensor[Out] = | ||
val nativeDType = | ||
if dtype == input.dtype then ScalarTypeOptional() else ScalarTypeOptional(dtype.toScalarType) | ||
Tensor(torchNative.log_softmax(input.native, dim, nativeDType)) | ||
|
||
/** Applies the rectified linear unit function element-wise. | ||
* | ||
* See [[torch.nn.ReLU]] for more details. | ||
* | ||
* @group nn_activation | ||
*/ | ||
def relu[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.relu(input.native)) | ||
|
||
/** Applies the element-wise function $\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}$ | ||
* | ||
* See `torch.nn.Sigmoid` for more details. | ||
* | ||
* @group nn_activation | ||
*/ | ||
def sigmoid[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.sigmoid(input.native)) | ||
|
||
/** Applies a softmax function. | ||
* | ||
* @group nn_activation | ||
*/ | ||
def softmax[In <: DType, Out <: DType](input: Tensor[In], dim: Long)( | ||
dtype: Out = input.dtype | ||
): Tensor[Out] = | ||
val nativeDType = | ||
if dtype == input.dtype then ScalarTypeOptional() else ScalarTypeOptional(dtype.toScalarType) | ||
Tensor(torchNative.softmax(input.native, dim, nativeDType)) | ||
} |
189 changes: 189 additions & 0 deletions
189
core/src/main/scala/torch/nn/functional/Convolution.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package nn | ||
package functional | ||
|
||
import org.bytedeco.pytorch | ||
import org.bytedeco.pytorch.TensorOptional | ||
import org.bytedeco.pytorch.global.torch as torchNative | ||
import torch.internal.NativeConverters.* | ||
|
||
private[torch] trait Convolution { | ||
|
||
/** Applies a 1D convolution over an input signal composed of several input planes. | ||
* | ||
* @group nn_conv | ||
*/ | ||
def conv1d[D <: FloatNN | ComplexNN]( | ||
input: Tensor[D], | ||
weight: Tensor[D], | ||
bias: Tensor[D] | Option[Tensor[D]] = None, | ||
stride: Int = 1, | ||
padding: Int = 0, | ||
dilation: Int = 1, | ||
groups: Int = 1 | ||
): Tensor[D] = | ||
Tensor( | ||
torchNative.conv1d( | ||
input.native, | ||
weight.native, | ||
toOptional(bias), | ||
Array(stride.toLong), | ||
Array(padding.toLong), | ||
Array(dilation.toLong), | ||
groups | ||
) | ||
) | ||
|
||
/** Applies a 2D convolution over an input signal composed of several input planes. | ||
* | ||
* @group nn_conv | ||
*/ | ||
def conv2d[D <: FloatNN | ComplexNN]( | ||
input: Tensor[D], | ||
weight: Tensor[D], | ||
bias: Tensor[D] | Option[Tensor[D]] = None, | ||
stride: Int | (Int, Int) = 1, | ||
padding: Int | (Int, Int) = 0, | ||
dilation: Int | (Int, Int) = 1, | ||
groups: Int = 1 | ||
): Tensor[D] = | ||
Tensor( | ||
torchNative.conv2d( | ||
input.native, | ||
weight.native, | ||
toOptional(bias), | ||
toArray(stride), | ||
toArray(padding), | ||
toArray(dilation), | ||
groups | ||
) | ||
) | ||
|
||
/** Applies a 3D convolution over an input image composed of several input planes. | ||
* | ||
* @group nn_conv | ||
*/ | ||
def conv3d[D <: FloatNN | ComplexNN]( | ||
input: Tensor[D], | ||
weight: Tensor[D], | ||
bias: Tensor[D] | Option[Tensor[D]] = None, | ||
stride: Int = 1, | ||
padding: Int = 0, | ||
dilation: Int = 1, | ||
groups: Int = 1 | ||
): Tensor[D] = | ||
Tensor( | ||
torchNative.conv3d( | ||
input.native, | ||
weight.native, | ||
toOptional(bias), | ||
Array(stride.toLong), | ||
Array(padding.toLong), | ||
Array(dilation.toLong), | ||
groups | ||
) | ||
) | ||
|
||
/** Applies a 1D transposed convolution operator over an input signal composed of several input | ||
* planes, sometimes also called “deconvolution”. | ||
* | ||
* @group nn_conv | ||
*/ | ||
def convTranspose1d[D <: FloatNN | ComplexNN]( | ||
input: Tensor[D], | ||
weight: Tensor[D], | ||
bias: Tensor[D] | Option[Tensor[D]] = None, | ||
stride: Int | (Int, Int) = 1, | ||
padding: Int | (Int, Int) = 0, | ||
outputPadding: Int | (Int, Int) = 0, | ||
groups: Int = 1, | ||
dilation: Int | (Int, Int) = 1 | ||
): Tensor[D] = | ||
Tensor( | ||
torchNative.conv_transpose1d( | ||
input.native, | ||
weight.native, | ||
toOptional(bias), | ||
toArray(stride), | ||
toArray(padding), | ||
toArray(outputPadding), | ||
groups, | ||
toArray(dilation): _* | ||
) | ||
) | ||
|
||
/** Applies a 2D transposed convolution operator over an input image composed of several input | ||
* planes, sometimes also called “deconvolution”. | ||
* | ||
* @group nn_conv | ||
*/ | ||
def convTranspose2d[D <: FloatNN | ComplexNN]( | ||
input: Tensor[D], | ||
weight: Tensor[D], | ||
bias: Tensor[D] | Option[Tensor[D]] = None, | ||
stride: Int | (Int, Int) = 1, | ||
padding: Int | (Int, Int) = 0, | ||
outputPadding: Int | (Int, Int) = 0, | ||
groups: Int = 1, | ||
dilation: Int | (Int, Int) = 1 | ||
): Tensor[D] = | ||
Tensor( | ||
torchNative.conv_transpose2d( | ||
input.native, | ||
weight.native, | ||
toOptional(bias), | ||
toArray(stride), | ||
toArray(padding), | ||
toArray(outputPadding), | ||
groups, | ||
toArray(dilation): _* | ||
) | ||
) | ||
|
||
/** Applies a 3D transposed convolution operator over an input image composed of several input | ||
* planes, sometimes also called “deconvolution”. | ||
* | ||
* @group nn_conv | ||
*/ | ||
def convTranspose3d[D <: FloatNN | ComplexNN]( | ||
input: Tensor[D], | ||
weight: Tensor[D], | ||
bias: Tensor[D] | Option[Tensor[D]] = None, | ||
stride: Int | (Int, Int, Int) = 1, | ||
padding: Int | (Int, Int, Int) = 0, | ||
outputPadding: Int | (Int, Int, Int) = 0, | ||
groups: Int = 1, | ||
dilation: Int | (Int, Int) = 1 | ||
): Tensor[D] = | ||
Tensor( | ||
torchNative.conv_transpose3d( | ||
input.native, | ||
weight.native, | ||
toOptional(bias), | ||
toArray(stride), | ||
toArray(padding), | ||
toArray(outputPadding), | ||
groups, | ||
toArray(dilation): _* | ||
) | ||
) | ||
|
||
// TODO unfold | ||
// TODO fold | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package nn | ||
package functional | ||
import org.bytedeco.pytorch.global.torch as torchNative | ||
|
||
private[torch] trait Dropout { | ||
|
||
/** During training, randomly zeroes some of the elements of the input tensor with probability `p` | ||
* using samples from a Bernoulli distribution. | ||
* | ||
* @see | ||
* [[torch.nn.Dropout]] for details. | ||
* | ||
* @group nn_dropout | ||
*/ | ||
def dropout[D <: DType](input: Tensor[D], p: Double = 0.5, training: Boolean = true): Tensor[D] = | ||
Tensor( | ||
torchNative.dropout(input.native, p, training) | ||
) | ||
|
||
// TODO alpha_dropout Applies alpha dropout to the input. | ||
// TODO feature_alpha_dropout Randomly masks out entire channels (a channel is a feature map, e.g. | ||
// TODO dropout1d Randomly zero out entire channels (a channel is a 1D feature map, e.g., the jj-th channel of the ii-th sample in the batched input is a 1D tensor input[i,j]input[i,j]) of the input tensor). | ||
// TODO dropout2d Randomly zero out entire channels (a channel is a 2D feature map, e.g., the jj-th channel of the ii-th sample in the batched input is a 2D tensor input[i,j]input[i,j]) of the input tensor). | ||
// TODO dropout3d Randomly zero out entire channels (a channel is a 3D feature map, e.g., the jj-th channel of the ii-th sample in the batched input is a 3D tensor input[i,j]input[i,j]) of the input tensor). | ||
} |
Oops, something went wrong.