diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index 0e3a4e11..74834053 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -29,7 +29,7 @@ import org.bytedeco.pytorch.{ } import scala.reflect.Typeable -import org.bytedeco.javacpp.LongPointer +import org.bytedeco.javacpp.{LongPointer, DoublePointer} import org.bytedeco.pytorch.GenericDict import org.bytedeco.pytorch.GenericDictIterator import spire.math.Complex @@ -76,6 +76,9 @@ private[torch] object NativeConverters: case (h, w) => LongPointer(Array(h.toLong, w.toLong)*) case (t, h, w) => LongPointer(Array(t.toLong, h.toLong, w.toLong)*) + given doubleToDoublePointer: Conversion[Double, DoublePointer] = (input: Double) => + DoublePointer(Array(input)*) + extension (x: ScalaType) def toScalar: pytorch.Scalar = x match case x: Boolean => pytorch.AbstractTensor.create(x).item() diff --git a/core/src/main/scala/torch/nn/modules/Module.scala b/core/src/main/scala/torch/nn/modules/Module.scala index 04ad69c6..d6b3915b 100644 --- a/core/src/main/scala/torch/nn/modules/Module.scala +++ b/core/src/main/scala/torch/nn/modules/Module.scala @@ -121,3 +121,7 @@ trait HasWeight[ParamType <: FloatNN | ComplexNN]: /** Transforms a single tensor into another one of the same type. */ trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]): override def toString(): String = "TensorModule" + +trait TensorModuleBase[D <: DType, D2 <: DType] extends Module with (Tensor[D] => Tensor[D2]) { + override def toString() = "TensorModuleBase" +} diff --git a/core/src/main/scala/torch/nn/modules/activation/LogSoftmax.scala b/core/src/main/scala/torch/nn/modules/activation/LogSoftmax.scala new file mode 100644 index 00000000..e2f56267 --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/activation/LogSoftmax.scala @@ -0,0 +1,44 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package activation + +import org.bytedeco.pytorch +import org.bytedeco.pytorch.LogSoftmaxImpl +import torch.nn.modules.Module +import torch.{DType, Tensor} + +/** Applies the log(Softmax(x)) function to an n-dimensional input Tensor. The LogSoftmax + * formulation can be simplified as: + * + * TODO LaTeX + * + * Example: + * + * ```scala sc + * import torch.* + * val m = nn.LogSoftmax(dim = 1) + * val input = torch.randn(Seq(2, 3)) + * val output = m(input) + * ``` + */ +final class LogSoftmax[D <: DType: Default](dim: Int) extends TensorModule[D]: + override val nativeModule: LogSoftmaxImpl = LogSoftmaxImpl(dim) + + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/activation/Tanh.scala b/core/src/main/scala/torch/nn/modules/activation/Tanh.scala new file mode 100644 index 00000000..0257b1dc --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/activation/Tanh.scala @@ -0,0 +1,46 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package activation + +import org.bytedeco.pytorch +import org.bytedeco.pytorch.TanhImpl +import torch.nn.modules.Module +import torch.{DType, Tensor} + +/** Applies the Hyperbolic Tangent (Tanh) function element-wise. Tanh is defined as:: + * + * TODO LaTeX + * + * Example: + * + * ```scala sc + * import torch.* + * val m = nn.Tanh() + * val input = torch.randn(Seq(2)) + * val output = m(input) + * ``` + */ +final class Tanh[D <: DType: Default] extends TensorModule[D]: + + override protected[torch] val nativeModule: TanhImpl = new TanhImpl() + + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) + + override def toString = getClass().getSimpleName() diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala new file mode 100644 index 00000000..28284b04 --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala @@ -0,0 +1,123 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package batchnorm + +import org.bytedeco.javacpp.LongPointer +import org.bytedeco.pytorch +import sourcecode.Name +import org.bytedeco.pytorch.BatchNorm1dImpl +import org.bytedeco.pytorch.BatchNormOptions +import torch.nn.modules.{HasParams, HasWeight, TensorModule} + +/** Applies Batch Normalization over a 2D or 3D input as described in the paper [Batch + * Normalization: Accelerating Deep Network Training by Reducing Internal Covariate + * Shift](https://arxiv.org/abs/1502.03167) . + * + * $$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$ + * + * The mean and standard-deviation are calculated per-dimension over the mini-batches and $\gamma$ + * and $\beta$ are learnable parameter vectors of size [C] (where [C] is the number of features or + * channels of the input). By default, the elements of $\gamma$ are set to 1 and the elements of + * $\beta$ are set to 0. The standard-deviation is calculated via the biased estimator, equivalent + * to *[torch.var(input, unbiased=False)]*. + * + * Also by default, during training this layer keeps running estimates of its computed mean and + * variance, which are then used for normalization during evaluation. The running estimates are + * kept with a default `momentum` of 0.1. + * + * If `trackRunningStats` is set to `false`, this layer then does not keep running estimates, and + * batch statistics are instead used during evaluation time as well. + * + * Example: + * + * ```scala sc + * import torch.nn + * // With Learnable Parameters + * var m = nn.BatchNorm1d(numFeatures = 100) + * // Without Learnable Parameters + * m = nn.BatchNorm1d(100, affine = false) + * val input = torch.randn(Seq(20, 100)) + * val output = m(input) + * ``` + * + * @note + * This `momentum` argument is different from one used in optimizer classes and the conventional + * notion of momentum. Mathematically, the update rule for running statistics here is + * $\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$, + * where $\hat{x}$ is the estimated statistic and $x_t$ is the new observed value. + * + * Because the Batch Normalization is done over the [C] dimension, computing statistics on [(N, L)] + * slices, it\'s common terminology to call this Temporal Batch Normalization. + * + * Args: + * + * @param numFeatures + * number of features or channels $C$ of the input + * @param eps: + * a value added to the denominator for numerical stability. Default: 1e-5 + * @param momentum + * the value used for the runningVean and runningVar computation. Can be set to `None` for + * cumulative moving average (i.e. simple average). Default: 0.1 + * @param affine: + * a boolean value that when set to `true`, this module has learnable affine parameters. Default: + * `True` + * @param trackRunningStats: + * a boolean value that when set to `true`, this module tracks the running mean and variance, and + * when set to `false`, this module does not track such statistics, and initializes statistics + * buffers `runningMean` and `runningVar` as `None`. When these buffers are `None`, this module + * always uses batch statistics. in both training and eval modes. Default: `true` + * + * Shape: + * + * - Input: $(N, C)$ or $(N, C, L)$, where $N$ is the batch size, $C$ is the number of features + * or channels, and $L$ is the sequence length + * - Output: $(N, C)$ or $(N, C, L)$ (same shape as input) + * + * @group nn_conv + * + * TODO use dtype + */ +final class BatchNorm1d[ParamType <: FloatNN | ComplexNN: Default]( + numFeatures: Int, + eps: Double = 1e-05, + momentum: Double = 0.1, + affine: Boolean = true, + trackRunningStats: Boolean = true +) extends HasParams[ParamType] + with HasWeight[ParamType] + with TensorModule[ParamType]: + + private val options = new BatchNormOptions(numFeatures) + options.eps().put(eps) + options.momentum().put(momentum) + options.affine().put(affine) + options.track_running_stats().put(trackRunningStats) + + override private[torch] val nativeModule: BatchNorm1dImpl = BatchNorm1dImpl(options) + nativeModule.to(paramType.toScalarType, false) + + // TODO weight, bias etc. are undefined if affine = false. We need to take that into account + val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) + val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias) + // TODO running_mean, running_var, num_batches_tracked + + def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) + + override def toString(): String = s"${getClass().getSimpleName()}(numFeatures=$numFeatures)" diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala index e106056f..e08bb872 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala @@ -26,82 +26,71 @@ import org.bytedeco.pytorch.BatchNorm2dImpl import org.bytedeco.pytorch.BatchNormOptions import torch.nn.modules.{HasParams, HasWeight} - -// format: off -/** Applies Batch Normalization over a 2D or 3D input as described in the paper -[Batch Normalization: Accelerating Deep Network Training by Reducing -Internal Covariate Shift](https://arxiv.org/abs/1502.03167) . - -$$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$ - -The mean and standard-deviation are calculated per-dimension over -the mini-batches and $\gamma$ and $\beta$ are learnable parameter vectors -of size [C]{.title-ref} (where [C]{.title-ref} is the number of features or channels of the input). By default, the -elements of $\gamma$ are set to 1 and the elements of $\beta$ are set to 0. The -standard-deviation is calculated via the biased estimator, equivalent to [torch.var(input, unbiased=False)]{.title-ref}. - -Also by default, during training this layer keeps running estimates of its -computed mean and variance, which are then used for normalization during -evaluation. The running estimates are kept with a default `momentum`{.interpreted-text role="attr"} -of 0.1. - -If `track_running_stats`{.interpreted-text role="attr"} is set to `False`, this layer then does not -keep running estimates, and batch statistics are instead used during -evaluation time as well. - -::: note -::: title -Note -::: - -This `momentum`{.interpreted-text role="attr"} argument is different from one used in optimizer -classes and the conventional notion of momentum. Mathematically, the -update rule for running statistics here is -$\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$, -where $\hat{x}$ is the estimated statistic and $x_t$ is the -new observed value. -::: - -Because the Batch Normalization is done over the [C]{.title-ref} dimension, computing statistics -on [(N, L)]{.title-ref} slices, it\'s common terminology to call this Temporal Batch Normalization. - -Args: - -: num_features: number of features or channels $C$ of the input - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Can be set to `None` for cumulative moving average - (i.e. simple average). Default: 0.1 - affine: a boolean value that when set to `True`, this module has - learnable affine parameters. Default: `True` - track_running_stats: a boolean value that when set to `True`, this - module tracks the running mean and variance, and when set to `False`, - this module does not track such statistics, and initializes statistics - buffers `running_mean`{.interpreted-text role="attr"} and `running_var`{.interpreted-text role="attr"} as `None`. - When these buffers are `None`, this module always uses batch statistics. - in both training and eval modes. Default: `True` - -Shape: - -: - Input: $(N, C)$ or $(N, C, L)$, where $N$ is the batch size, - $C$ is the number of features or channels, and $L$ is the sequence length - - Output: $(N, C)$ or $(N, C, L)$ (same shape as input) - -Examples: - - >>> # With Learnable Parameters - >>> m = nn.BatchNorm1d(100) - >>> # Without Learnable Parameters - >>> m = nn.BatchNorm1d(100, affine=False) - >>> input = torch.randn(20, 100) - >>> output = m(input) +/** Applies Batch Normalization over a 4D input as described in the paper [Batch Normalization: + * Accelerating Deep Network Training by Reducing Internal Covariate + * Shift](https://arxiv.org/abs/1502.03167) . + * + * $$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$ + * + * The mean and standard-deviation are calculated per-dimension over the mini-batches and $\gamma$ + * and $\beta$ are learnable parameter vectors of size [C] (where [C] is the number of features or + * channels of the input). By default, the elements of $\gamma$ are set to 1 and the elements of + * $\beta$ are set to 0. The standard-deviation is calculated via the biased estimator, equivalent + * to *[torch.var(input, unbiased=False)]*. + * + * Also by default, during training this layer keeps running estimates of its computed mean and + * variance, which are then used for normalization during evaluation. The running estimates are + * kept with a default `momentum` of 0.1. + * + * If `trackRunningStats` is set to `false`, this layer then does not keep running estimates, and + * batch statistics are instead used during evaluation time as well. + * + * Example: + * + * ```scala sc + * import torch.nn + * // With Learnable Parameters + * var m = nn.BatchNorm2d(numFeatures = 100) + * // Without Learnable Parameters + * m = nn.BatchNorm2d(100, affine = false) + * val input = torch.randn(Seq(20, 100, 35, 45)) + * val output = m(input) + * ``` + * + * @note + * This `momentum` argument is different from one used in optimizer classes and the conventional + * notion of momentum. Mathematically, the update rule for running statistics here is + * $\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$, + * where $\hat{x}$ is the estimated statistic and $x_t$ is the new observed value. + * + * Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) + * slices, it’s common terminology to call this Spatial Batch Normalization. + * + * @param numFeatures + * number of features or channels $C$ of the input + * @param eps: + * a value added to the denominator for numerical stability. Default: 1e-5 + * @param momentum + * the value used for the runningVean and runningVar computation. Can be set to `None` for + * cumulative moving average (i.e. simple average). Default: 0.1 + * @param affine: + * a boolean value that when set to `true`, this module has learnable affine parameters. Default: + * `True` + * @param trackRunningStats: + * a boolean value that when set to `true`, this module tracks the running mean and variance, and + * when set to `false`, this module does not track such statistics, and initializes statistics + * buffers `runningMean` and `runningVar` as `None`. When these buffers are `None`, this module + * always uses batch statistics. in both training and eval modes. Default: `true` + * + * Shape: + * + * - Input: $(N, C, H, W)$ + * - Output: $(N, C, H, W)$ (same shape as input) * * @group nn_conv - * + * * TODO use dtype */ -// format: on final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default]( numFeatures: Int, eps: Double = 1e-05, diff --git a/core/src/main/scala/torch/nn/modules/container/ModuleList.scala b/core/src/main/scala/torch/nn/modules/container/ModuleList.scala new file mode 100644 index 00000000..1ecc75ca --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/container/ModuleList.scala @@ -0,0 +1,38 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package container + +import sourcecode.Name +import scala.util.Random + +final class ModuleList[D <: DType](override val modules: TensorModule[D]*) + extends Module + with Seq[TensorModule[D]]: + modules.zipWithIndex.foreach((module, index) => + this.register(module)(using Name(index.toString())) + ) + + override def iterator: Iterator[TensorModule[D]] = modules.iterator + + def apply(i: Int): torch.nn.modules.TensorModule[D] = modules(i) + + override def length: Int = modules.length + + override def toString = getClass().getSimpleName() diff --git a/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala new file mode 100644 index 00000000..17659aa5 --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala @@ -0,0 +1,50 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package normalization + +import org.bytedeco.pytorch +import org.bytedeco.pytorch.{LayerNormImpl, LayerNormOptions, LongVector} +import torch.nn.modules.TensorModule +import torch.{DType, Tensor} + +/** Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer + * Normalization // TODO Add docs + */ +final class LayerNorm[ParamType <: DType: Default]( + normalizedShape: Seq[Int] | Int, + eps: Double = 1e-05, + elementwiseAffine: Boolean = true +) extends TensorModule[ParamType]: + private val options: LayerNormOptions = + val shape = normalizedShape match + case normalizedShape: Seq[Int] => normalizedShape.toArray.map(_.toLong) + case normalizedShape: Int => Array(normalizedShape.toLong) + LayerNormOptions(LongVector(shape*)) + + options.eps().put(eps) + options.elementwise_affine().put(elementwiseAffine) + + override private[torch] val nativeModule: LayerNormImpl = LayerNormImpl(options) + + val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) + val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias) + + def apply(t: Tensor[ParamType]): Tensor[ParamType] = + Tensor[ParamType](nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala b/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala new file mode 100644 index 00000000..ae75c65f --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala @@ -0,0 +1,85 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package sparse + +import org.bytedeco.javacpp.LongPointer +import org.bytedeco.pytorch +import sourcecode.Name +import org.bytedeco.pytorch.EmbeddingImpl +import org.bytedeco.pytorch.EmbeddingOptions +import torch.nn.modules.{HasParams, HasWeight, TensorModule} +import torch.internal.NativeConverters.{toNative, doubleToDoublePointer} + +/** A simple lookup table that stores embeddings of a fixed dictionary and size. + * + * This module is often used to store word embeddings and retrieve them using indices. The input to + * the module is a list of indices, and the output is the corresponding word embeddings. + * + * @group nn_sparse + * + * @param numEmbeddings + * Size of the dictionary of embeddings + * @param embeddingDim + * The size of each embedding vector + * @param paddingIdx + * If specified, the entries at `paddingIdx` do not contribute to the gradient; therefore, the + * embedding vector at `paddingIdx` is not updated during training, i.e. it remains as a fixed + * "pad". For a newly constructed Embedding, the embedding vector at `paddingIdx` will default to + * all zeros, but can be updated to another value to be used as the padding vector. + * @param maxNorm + * If given, each embedding vector with norm larger than `maxNorm` is renormalized to have norm + * `maxNorm`. + * @param normType + * The p of the p-norm to compute for the `maxNorm` option. Default `2`. + * @param scaleGradByFreq + * If given, this will scale gradients by the inverse of frequency of the words in the + * mini-batch. Default `false`. + * @param sparse + * If ``True``, gradient w.r.t. `weight` matrix will be a sparse tensor. See Notes for more + * details regarding sparse gradients. + */ +final class Embedding[ParamType <: FloatNN | ComplexNN: Default]( + numEmbeddings: Int, + embeddingDim: Int, + paddingIdx: Option[Int] = None, + maxNorm: Option[Double] = None, + normType: Option[Double] = Some(2.0), + scaleGradByFreq: Boolean = false, + sparse: Boolean = false +) extends HasParams[ParamType] + with HasWeight[ParamType] + with TensorModuleBase[Int64, ParamType]: + + private val options = new EmbeddingOptions(numEmbeddings.toLong, embeddingDim.toLong) + paddingIdx.foreach(p => options.padding_idx().put(toNative(p))) + maxNorm.foreach(m => options.max_norm().put(m)) + normType.foreach(n => options.norm_type().put(n)) + options.scale_grad_by_freq().put(scaleGradByFreq) + options.sparse().put(sparse) + + override val nativeModule: EmbeddingImpl = EmbeddingImpl(options) + nativeModule.to(paramType.toScalarType, false) + + def weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) + def weight_=(w: Tensor[ParamType]): Unit = nativeModule.weight(w.native) + + def apply(t: Tensor[Int64]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) + + override def toString(): String = s"${getClass().getSimpleName()}(numEmbeddings=$numEmbeddings)" diff --git a/core/src/main/scala/torch/nn/package.scala b/core/src/main/scala/torch/nn/package.scala index ded65ebb..48b7b367 100644 --- a/core/src/main/scala/torch/nn/package.scala +++ b/core/src/main/scala/torch/nn/package.scala @@ -27,14 +27,22 @@ package object nn { export modules.Module export modules.activation.Softmax + export modules.activation.LogSoftmax export modules.activation.ReLU + export modules.activation.Tanh + export modules.batchnorm.BatchNorm1d export modules.batchnorm.BatchNorm2d export modules.container.Sequential + export modules.container.ModuleList export modules.conv.Conv2d export modules.flatten.Flatten export modules.linear.Linear export modules.linear.Identity export modules.normalization.GroupNorm + export modules.normalization.LayerNorm export modules.pooling.AdaptiveAvgPool2d export modules.pooling.MaxPool2d + export modules.sparse.Embedding + + export loss.CrossEntropyLoss } diff --git a/core/src/test/scala/torch/nn/modules/ActivationSuite.scala b/core/src/test/scala/torch/nn/modules/ActivationSuite.scala new file mode 100644 index 00000000..3b48b511 --- /dev/null +++ b/core/src/test/scala/torch/nn/modules/ActivationSuite.scala @@ -0,0 +1,50 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules + +class ActivationSuite extends munit.FunSuite { + test("LogSoftmax") { + torch.manualSeed(0) + val m = nn.LogSoftmax(dim = 1) + val input = torch.randn(Seq(2, 3)) + val output = m(input) + assertEquals(output.shape, input.shape) + val expectedOutput = Tensor( + Seq( + Seq(-0.1689f, -2.0033f, -3.8886f), + Seq(-0.2862f, -1.9392f, -2.2532f) + ) + ) + assert(torch.allclose(output, expectedOutput, atol = 1e-4)) + } + + // TODO ReLU + // TODO Softmax + + test("Tanh") { + torch.manualSeed(0) + val m = nn.Tanh() + val input = torch.randn(Seq(2)) + val output = m(input) + assertEquals(output.shape, input.shape) + val expectedOutput = Tensor(Seq(0.9123f, -0.2853f)) + assert(torch.allclose(output, expectedOutput, atol = 1e-4)) + } + +} diff --git a/core/src/test/scala/torch/nn/modules/BatchNormSuite.scala b/core/src/test/scala/torch/nn/modules/BatchNormSuite.scala new file mode 100644 index 00000000..2f2f7408 --- /dev/null +++ b/core/src/test/scala/torch/nn/modules/BatchNormSuite.scala @@ -0,0 +1,54 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules + +class BatchNormSuite extends munit.FunSuite { + + test("BatchNorm1d") { + torch.manualSeed(0) + val m = nn.BatchNorm1d(numFeatures = 3) + val input = torch.randn(Seq(3, 3)) + val output = m(input) + assertEquals(output.shape, input.shape) + val expectedOutput = Tensor( + Seq( + Seq(1.4014f, -0.1438f, -1.2519f), + Seq(-0.5362f, -1.1465f, 0.0564f), + Seq(-0.8651f, 1.2903f, 1.1956f) + ) + ) + assert(torch.allclose(output, expectedOutput, atol = 1e-4)) + } + + test("BatchNorm2d") { + torch.manualSeed(0) + val m = nn.BatchNorm2d(numFeatures = 3) + val input = torch.randn(Seq(3, 3, 1, 1)) + val output = m(input) + assertEquals(output.shape, input.shape) + val expectedOutput = Tensor( + Seq( + Seq(1.4014f, -0.1438f, -1.2519f), + Seq(-0.5362f, -1.1465f, 0.0564f), + Seq(-0.8651f, 1.2903f, 1.1956f) + ) + ) + assert(torch.allclose(output.squeeze, expectedOutput, atol = 1e-4)) + } +} diff --git a/core/src/test/scala/torch/nn/modules/EmbeddingSuite.scala b/core/src/test/scala/torch/nn/modules/EmbeddingSuite.scala new file mode 100644 index 00000000..5ef332eb --- /dev/null +++ b/core/src/test/scala/torch/nn/modules/EmbeddingSuite.scala @@ -0,0 +1,92 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules + +class EmbeddingSuite extends munit.FunSuite { + + test("Embedding") { + { + torch.manualSeed(0) + val embedding = nn.Embedding(10, 3) + // a batch of 2 samples of 4 indices each + val input = torch.Tensor(Seq(Seq(1L, 2, 4, 5), Seq(4L, 3, 2, 9))) + val output = embedding(input) + val expectedOutput = Tensor( + Seq( + Seq( + Seq(-0.4339f, 0.8487f, 0.6920f), + Seq(-0.3160f, -2.1152f, 0.3223f), + Seq(0.1198f, 1.2377f, -0.1435f), + Seq(-0.1116f, -0.6136f, 0.0316f) + ), + Seq( + Seq(0.1198f, 1.2377f, -0.1435f), + Seq(-1.2633f, 0.3500f, 0.3081f), + Seq(-0.3160f, -2.1152f, 0.3223f), + Seq(0.0525f, 0.5229f, 2.3022f) + ) + ) + ) + assert(torch.allclose(output, expectedOutput, atol = 1e-4)) + } + { + torch.manualSeed(0) + // example with padding_idx + val embedding = nn.Embedding(5, 3, paddingIdx = Some(0)) + embedding.weight = Tensor[Float]( + Seq( + Seq(0f, 0f, 0f), + Seq(0.5684f, -1.0845f, -1.3986f), + Seq(0.4033f, 0.8380f, -0.7193f), + Seq(0.4033f, 0.8380f, -0.7193f), + Seq(-0.8567f, 1.1006f, -1.0712f) + ) + ) + val input = torch.Tensor(Seq(Seq(0L, 2, 0, 4))) + val output = embedding(input) + + val expectedOutput = Tensor( + Seq( + Seq(0f, 0f, 0f), + Seq(0.4033f, 0.8380f, -0.7193f), + Seq(0f, 0f, 0f), + Seq(-0.8567f, 1.1006f, -1.0712f) + ) + ).unsqueeze(0) + assert(torch.allclose(output, expectedOutput, atol = 1e-4)) + } + { + torch.manualSeed(0) + // example of changing `pad` vector + val paddingIdx = 0 + val embedding = nn.Embedding(3, 3, paddingIdx = Some(paddingIdx)) + noGrad { + embedding.weight(Seq(paddingIdx)) = torch.ones(3) + } + val expectedOutput = Tensor( + Seq( + Seq(1f, 1f, 1f), + Seq(0.5684f, -1.0845f, -1.3986f), + Seq(0.4033f, 0.8380f, -0.7193f) + ) + ) + assert(torch.allclose(embedding.weight, expectedOutput, atol = 1e-4)) + } + } +} diff --git a/core/src/test/scala/torch/nn/modules/flatten/FlattenSuite.scala b/core/src/test/scala/torch/nn/modules/FlattenSuite.scala similarity index 98% rename from core/src/test/scala/torch/nn/modules/flatten/FlattenSuite.scala rename to core/src/test/scala/torch/nn/modules/FlattenSuite.scala index 19a547e0..788f507c 100644 --- a/core/src/test/scala/torch/nn/modules/flatten/FlattenSuite.scala +++ b/core/src/test/scala/torch/nn/modules/FlattenSuite.scala @@ -17,7 +17,6 @@ package torch package nn package modules -package flatten class FlattenSuite extends munit.FunSuite { test("Flatten") { diff --git a/core/src/test/scala/torch/nn/modules/linear/LinearSuite.scala b/core/src/test/scala/torch/nn/modules/LinearSuite.scala similarity index 98% rename from core/src/test/scala/torch/nn/modules/linear/LinearSuite.scala rename to core/src/test/scala/torch/nn/modules/LinearSuite.scala index a611ef6f..4a4decd7 100644 --- a/core/src/test/scala/torch/nn/modules/linear/LinearSuite.scala +++ b/core/src/test/scala/torch/nn/modules/LinearSuite.scala @@ -17,7 +17,6 @@ package torch package nn package modules -package linear class LinearSuite extends munit.FunSuite { test("Linear shape") { diff --git a/core/src/test/scala/torch/nn/modules/NormalizationSuite.scala b/core/src/test/scala/torch/nn/modules/NormalizationSuite.scala new file mode 100644 index 00000000..2c5d90ce --- /dev/null +++ b/core/src/test/scala/torch/nn/modules/NormalizationSuite.scala @@ -0,0 +1,69 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules + +class NormalizationSuite extends munit.FunSuite { + + test("LayerNorm") { + { + torch.manualSeed(0) + val (batch, sentenceLength, embeddingDim) = (2, 2, 3) + val embedding = torch.randn(Seq(batch, sentenceLength, embeddingDim)) + val layerNorm = nn.LayerNorm(embeddingDim) + val output = layerNorm(embedding) + assertEquals(output.shape, embedding.shape) + val expectedOutput = Tensor( + Seq( + Seq( + Seq(1.2191f, 0.0112f, -1.2303f), + Seq(1.3985f, -0.5172f, -0.8813f) + ), + Seq( + Seq(0.3495f, 1.0120f, -1.3615f), + Seq(-0.3948f, -0.9786f, 1.3734f) + ) + ) + ) + assert(torch.allclose(output, expectedOutput, atol = 1e-4)) + } + { + torch.manualSeed(0) + val (n, c, h, w) = (1, 2, 2, 2) + val input = torch.randn(Seq(n, c, h, w)) + // Normalize over the last three dimensions (i.e. the channel and spatial dimensions) + val layerNorm = nn.LayerNorm(Seq(c, h, w)) + val output = layerNorm(input) + assertEquals(output.shape, (Seq(n, c, h, w))) + val expectedOutput = Tensor( + Seq( + Seq( + Seq(1.4715f, -0.0785f), + Seq(-1.6714f, 0.6497f) + ), + Seq( + Seq(-0.7469f, -1.0122f), + Seq(0.5103f, 0.8775f) + ) + ) + ).unsqueeze(0) + assert(torch.allclose(output, expectedOutput, atol = 1e-4)) + } + } + +} diff --git a/core/src/test/scala/torch/nn/modules/pooling/AdapativeAvgPoolSuite.scala b/core/src/test/scala/torch/nn/modules/PoolingSuite.scala similarity index 98% rename from core/src/test/scala/torch/nn/modules/pooling/AdapativeAvgPoolSuite.scala rename to core/src/test/scala/torch/nn/modules/PoolingSuite.scala index 2e18298f..4b75fd4b 100644 --- a/core/src/test/scala/torch/nn/modules/pooling/AdapativeAvgPoolSuite.scala +++ b/core/src/test/scala/torch/nn/modules/PoolingSuite.scala @@ -17,7 +17,6 @@ package torch package nn package modules -package pooling class AdapativeAvgPool2dSuite extends munit.FunSuite { test("AdapativeAvgPool2d output shapes") {