Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more torch.nn modules #36

Merged
merged 8 commits into from
Sep 20, 2023
5 changes: 4 additions & 1 deletion core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.bytedeco.pytorch.{
}

import scala.reflect.Typeable
import org.bytedeco.javacpp.LongPointer
import org.bytedeco.javacpp.{LongPointer, DoublePointer}
import org.bytedeco.pytorch.GenericDict
import org.bytedeco.pytorch.GenericDictIterator
import spire.math.Complex
Expand Down Expand Up @@ -76,6 +76,9 @@ private[torch] object NativeConverters:
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case (t, h, w) => LongPointer(Array(t.toLong, h.toLong, w.toLong)*)

given doubleToDoublePointer: Conversion[Double, DoublePointer] = (input: Double) =>
DoublePointer(Array(input)*)

extension (x: ScalaType)
def toScalar: pytorch.Scalar = x match
case x: Boolean => pytorch.AbstractTensor.create(x).item()
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/torch/nn/modules/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ trait HasWeight[ParamType <: FloatNN | ComplexNN]:
/** Transforms a single tensor into another one of the same type. */
trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]):
override def toString(): String = "TensorModule"

trait TensorModuleBase[D <: DType, D2 <: DType] extends Module with (Tensor[D] => Tensor[D2]) {
override def toString() = "TensorModuleBase"
}
44 changes: 44 additions & 0 deletions core/src/main/scala/torch/nn/modules/activation/LogSoftmax.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package activation

import org.bytedeco.pytorch
import org.bytedeco.pytorch.LogSoftmaxImpl
import torch.nn.modules.Module
import torch.{DType, Tensor}

/** Applies the log(Softmax(x)) function to an n-dimensional input Tensor. The LogSoftmax
* formulation can be simplified as:
*
* TODO LaTeX
*
* Example:
*
* ```scala sc
* import torch.*
* val m = nn.LogSoftmax(dim = 1)
* val input = torch.randn(Seq(2, 3))
* val output = m(input)
* ```
*/
final class LogSoftmax[D <: DType: Default](dim: Int) extends TensorModule[D]:
override val nativeModule: LogSoftmaxImpl = LogSoftmaxImpl(dim)

def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
46 changes: 46 additions & 0 deletions core/src/main/scala/torch/nn/modules/activation/Tanh.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package activation

import org.bytedeco.pytorch
import org.bytedeco.pytorch.TanhImpl
import torch.nn.modules.Module
import torch.{DType, Tensor}

/** Applies the Hyperbolic Tangent (Tanh) function element-wise. Tanh is defined as::
*
* TODO LaTeX
*
* Example:
*
* ```scala sc
* import torch.*
* val m = nn.Tanh()
* val input = torch.randn(Seq(2))
* val output = m(input)
* ```
*/
final class Tanh[D <: DType: Default] extends TensorModule[D]:

override protected[torch] val nativeModule: TanhImpl = new TanhImpl()

def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

override def toString = getClass().getSimpleName()
123 changes: 123 additions & 0 deletions core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package batchnorm

import org.bytedeco.javacpp.LongPointer
import org.bytedeco.pytorch
import sourcecode.Name
import org.bytedeco.pytorch.BatchNorm1dImpl
import org.bytedeco.pytorch.BatchNormOptions
import torch.nn.modules.{HasParams, HasWeight, TensorModule}

/** Applies Batch Normalization over a 2D or 3D input as described in the paper [Batch
* Normalization: Accelerating Deep Network Training by Reducing Internal Covariate
* Shift](https://arxiv.org/abs/1502.03167) .
*
* $$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$
*
* The mean and standard-deviation are calculated per-dimension over the mini-batches and $\gamma$
* and $\beta$ are learnable parameter vectors of size [C] (where [C] is the number of features or
* channels of the input). By default, the elements of $\gamma$ are set to 1 and the elements of
* $\beta$ are set to 0. The standard-deviation is calculated via the biased estimator, equivalent
* to *[torch.var(input, unbiased=False)]*.
*
* Also by default, during training this layer keeps running estimates of its computed mean and
* variance, which are then used for normalization during evaluation. The running estimates are
* kept with a default `momentum` of 0.1.
*
* If `trackRunningStats` is set to `false`, this layer then does not keep running estimates, and
* batch statistics are instead used during evaluation time as well.
*
* Example:
*
* ```scala sc
* import torch.nn
* // With Learnable Parameters
* var m = nn.BatchNorm1d(numFeatures = 100)
* // Without Learnable Parameters
* m = nn.BatchNorm1d(100, affine = false)
* val input = torch.randn(Seq(20, 100))
* val output = m(input)
* ```
*
* @note
* This `momentum` argument is different from one used in optimizer classes and the conventional
* notion of momentum. Mathematically, the update rule for running statistics here is
* $\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$,
* where $\hat{x}$ is the estimated statistic and $x_t$ is the new observed value.
*
* Because the Batch Normalization is done over the [C] dimension, computing statistics on [(N, L)]
* slices, it\'s common terminology to call this Temporal Batch Normalization.
*
* Args:
*
* @param numFeatures
* number of features or channels $C$ of the input
* @param eps:
* a value added to the denominator for numerical stability. Default: 1e-5
* @param momentum
* the value used for the runningVean and runningVar computation. Can be set to `None` for
* cumulative moving average (i.e. simple average). Default: 0.1
* @param affine:
* a boolean value that when set to `true`, this module has learnable affine parameters. Default:
* `True`
* @param trackRunningStats:
* a boolean value that when set to `true`, this module tracks the running mean and variance, and
* when set to `false`, this module does not track such statistics, and initializes statistics
* buffers `runningMean` and `runningVar` as `None`. When these buffers are `None`, this module
* always uses batch statistics. in both training and eval modes. Default: `true`
*
* Shape:
*
* - Input: $(N, C)$ or $(N, C, L)$, where $N$ is the batch size, $C$ is the number of features
* or channels, and $L$ is the sequence length
* - Output: $(N, C)$ or $(N, C, L)$ (same shape as input)
*
* @group nn_conv
*
* TODO use dtype
*/
final class BatchNorm1d[ParamType <: FloatNN | ComplexNN: Default](
numFeatures: Int,
eps: Double = 1e-05,
momentum: Double = 0.1,
affine: Boolean = true,
trackRunningStats: Boolean = true
) extends HasParams[ParamType]
with HasWeight[ParamType]
with TensorModule[ParamType]:

private val options = new BatchNormOptions(numFeatures)
options.eps().put(eps)
options.momentum().put(momentum)
options.affine().put(affine)
options.track_running_stats().put(trackRunningStats)

override private[torch] val nativeModule: BatchNorm1dImpl = BatchNorm1dImpl(options)
nativeModule.to(paramType.toScalarType, false)

// TODO weight, bias etc. are undefined if affine = false. We need to take that into account
val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)
val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias)
// TODO running_mean, running_var, num_batches_tracked

def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))

override def toString(): String = s"${getClass().getSimpleName()}(numFeatures=$numFeatures)"
133 changes: 61 additions & 72 deletions core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,82 +26,71 @@ import org.bytedeco.pytorch.BatchNorm2dImpl
import org.bytedeco.pytorch.BatchNormOptions
import torch.nn.modules.{HasParams, HasWeight}


// format: off
/** Applies Batch Normalization over a 2D or 3D input as described in the paper
[Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167) .

$$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$

The mean and standard-deviation are calculated per-dimension over
the mini-batches and $\gamma$ and $\beta$ are learnable parameter vectors
of size [C]{.title-ref} (where [C]{.title-ref} is the number of features or channels of the input). By default, the
elements of $\gamma$ are set to 1 and the elements of $\beta$ are set to 0. The
standard-deviation is calculated via the biased estimator, equivalent to [torch.var(input, unbiased=False)]{.title-ref}.

Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default `momentum`{.interpreted-text role="attr"}
of 0.1.

If `track_running_stats`{.interpreted-text role="attr"} is set to `False`, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.

::: note
::: title
Note
:::

This `momentum`{.interpreted-text role="attr"} argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
$\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$,
where $\hat{x}$ is the estimated statistic and $x_t$ is the
new observed value.
:::

Because the Batch Normalization is done over the [C]{.title-ref} dimension, computing statistics
on [(N, L)]{.title-ref} slices, it\'s common terminology to call this Temporal Batch Normalization.

Args:

: num_features: number of features or channels $C$ of the input
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to `None` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to `True`, this module has
learnable affine parameters. Default: `True`
track_running_stats: a boolean value that when set to `True`, this
module tracks the running mean and variance, and when set to `False`,
this module does not track such statistics, and initializes statistics
buffers `running_mean`{.interpreted-text role="attr"} and `running_var`{.interpreted-text role="attr"} as `None`.
When these buffers are `None`, this module always uses batch statistics.
in both training and eval modes. Default: `True`

Shape:

: - Input: $(N, C)$ or $(N, C, L)$, where $N$ is the batch size,
$C$ is the number of features or channels, and $L$ is the sequence length
- Output: $(N, C)$ or $(N, C, L)$ (same shape as input)

Examples:

>>> # With Learnable Parameters
>>> m = nn.BatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm1d(100, affine=False)
>>> input = torch.randn(20, 100)
>>> output = m(input)
/** Applies Batch Normalization over a 4D input as described in the paper [Batch Normalization:
* Accelerating Deep Network Training by Reducing Internal Covariate
* Shift](https://arxiv.org/abs/1502.03167) .
*
* $$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$
*
* The mean and standard-deviation are calculated per-dimension over the mini-batches and $\gamma$
* and $\beta$ are learnable parameter vectors of size [C] (where [C] is the number of features or
* channels of the input). By default, the elements of $\gamma$ are set to 1 and the elements of
* $\beta$ are set to 0. The standard-deviation is calculated via the biased estimator, equivalent
* to *[torch.var(input, unbiased=False)]*.
*
* Also by default, during training this layer keeps running estimates of its computed mean and
* variance, which are then used for normalization during evaluation. The running estimates are
* kept with a default `momentum` of 0.1.
*
* If `trackRunningStats` is set to `false`, this layer then does not keep running estimates, and
* batch statistics are instead used during evaluation time as well.
*
* Example:
*
* ```scala sc
* import torch.nn
* // With Learnable Parameters
* var m = nn.BatchNorm2d(numFeatures = 100)
* // Without Learnable Parameters
* m = nn.BatchNorm2d(100, affine = false)
* val input = torch.randn(Seq(20, 100, 35, 45))
* val output = m(input)
* ```
*
* @note
* This `momentum` argument is different from one used in optimizer classes and the conventional
* notion of momentum. Mathematically, the update rule for running statistics here is
* $\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$,
* where $\hat{x}$ is the estimated statistic and $x_t$ is the new observed value.
*
* Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W)
* slices, it’s common terminology to call this Spatial Batch Normalization.
*
* @param numFeatures
* number of features or channels $C$ of the input
* @param eps:
* a value added to the denominator for numerical stability. Default: 1e-5
* @param momentum
* the value used for the runningVean and runningVar computation. Can be set to `None` for
* cumulative moving average (i.e. simple average). Default: 0.1
* @param affine:
* a boolean value that when set to `true`, this module has learnable affine parameters. Default:
* `True`
* @param trackRunningStats:
* a boolean value that when set to `true`, this module tracks the running mean and variance, and
* when set to `false`, this module does not track such statistics, and initializes statistics
* buffers `runningMean` and `runningVar` as `None`. When these buffers are `None`, this module
* always uses batch statistics. in both training and eval modes. Default: `true`
*
* Shape:
*
* - Input: $(N, C, H, W)$
* - Output: $(N, C, H, W)$ (same shape as input)
*
* @group nn_conv
*
*
* TODO use dtype
*/
// format: on
final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default](
numFeatures: Int,
eps: Double = 1e-05,
Expand Down
Loading