diff --git a/core/src/main/scala/torch/nn/modules/container/ModuleList.scala b/core/src/main/scala/torch/nn/modules/container/ModuleList.scala new file mode 100644 index 00000000..77524b72 --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/container/ModuleList.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package container + +import sourcecode.Name +import scala.util.Random + +final class ModuleList[D <: DType](override val modules: TensorModule[D]*) + extends Module + with Seq[TensorModule[D]]: + // with TensorModule[D]: + modules.zipWithIndex.foreach((module, index) => + this.register(module)(using Name(index.toString())) + ) + + override def iterator: Iterator[TensorModule[D]] = modules.iterator + + // def apply(v1: Int | torch.Tensor[D]): torch.nn.modules.TensorModule[D] & torch.Tensor[D] = ??? + def apply(i: Int): torch.nn.modules.TensorModule[D] = modules(i) + + override def length: Int = modules.length + + override def toString = getClass().getSimpleName() diff --git a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala index fd1bc268..f71fb2af 100644 --- a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala +++ b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala @@ -36,7 +36,7 @@ import torch.{DType, Tensor} * a boolean value that when set to `true`, this module has learnable per-channel affine * parameters initialized to ones (for weights) and zeros (for biases) */ -final class GroupNorm[ParamType <: DType]( +final class GroupNorm[ParamType <: DType: Default]( numGroups: Int, numChannels: Int, eps: Double = 1e-05, diff --git a/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala new file mode 100644 index 00000000..fc189e41 --- /dev/null +++ b/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch +package nn +package modules +package normalization + +import org.bytedeco.pytorch +import org.bytedeco.pytorch.{LayerNormImpl, LayerNormOptions, LongVector} +import torch.nn.modules.TensorModule +import torch.{DType, Tensor} + +/** Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer + * Normalization // TODO Add docs + */ +final class LayerNorm[ParamType <: DType: Default]( + normalizedShape: Seq[Int] | Int, + eps: Double = 1e-05, + elementwiseAffine: Boolean = true +) extends TensorModule[ParamType]: + private val options: LayerNormOptions = normalizedShape match { + case normalizedShape: Seq[Int] => + LayerNormOptions(LongVector(normalizedShape.map(_.toLong)*)) + case normalizedShape: Int => + LayerNormOptions(LongVector(normalizedShape.toLong)) + } + options.eps().put(eps) + options.elementwise_affine().put(elementwiseAffine) + + override private[torch] val nativeModule: LayerNormImpl = LayerNormImpl(options) + + override def registerWithParent[M <: pytorch.Module](parent: M)(using + name: sourcecode.Name + ): Unit = + parent.register_module(name.value, nativeModule) + + val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) + val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias) + + def apply(t: Tensor[ParamType]): Tensor[ParamType] = + Tensor[ParamType](nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala b/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala index 2b505cd1..30ebc7cd 100644 --- a/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala +++ b/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala @@ -27,6 +27,34 @@ import org.bytedeco.pytorch.EmbeddingOptions import torch.nn.modules.{HasParams, HasWeight, TensorModule} import torch.internal.NativeConverters.{toNative, doubleToDoublePointer} +/** A simple lookup table that stores embeddings of a fixed dictionary and size. + * + * This module is often used to store word embeddings and retrieve them using indices. The input to + * the module is a list of indices, and the output is the corresponding word embeddings. + * + * @group nn_sparse + * + * @param numEmbeddings + * Size of the dictionary of embeddings + * @param embeddingDim + * The size of each embedding vector + * @param paddingIdx + * If specified, the entries at `paddingIdx` do not contribute to the gradient; therefore, the + * embedding vector at `paddingIdx` is not updated during training, i.e. it remains as a fixed + * "pad". For a newly constructed Embedding, the embedding vector at `paddingIdx` will default to + * all zeros, but can be updated to another value to be used as the padding vector. + * @param maxNorm + * If given, each embedding vector with norm larger than `maxNorm` is renormalized to have norm + * `maxNorm`. + * @param normType + * The p of the p-norm to compute for the `maxNorm` option. Default `2`. + * @param scaleGradByFreq + * If given, this will scale gradients by the inverse of frequency of the words in the + * mini-batch. Default `false`. + * @param sparse + * If ``True``, gradient w.r.t. `weight` matrix will be a sparse tensor. See Notes for more + * details regarding sparse gradients. + */ final class Embedding[ParamType <: FloatNN | ComplexNN: Default]( numEmbeddings: Int, embeddingDim: Int, diff --git a/core/src/main/scala/torch/nn/package.scala b/core/src/main/scala/torch/nn/package.scala index 3a0f4a66..4628d8e2 100644 --- a/core/src/main/scala/torch/nn/package.scala +++ b/core/src/main/scala/torch/nn/package.scala @@ -34,11 +34,13 @@ package object nn { export modules.batchnorm.BatchNorm1d export modules.batchnorm.BatchNorm2d export modules.container.Sequential + export modules.container.ModuleList export modules.conv.Conv2d export modules.flatten.Flatten export modules.linear.Linear export modules.linear.Identity export modules.normalization.GroupNorm + export modules.normalization.LayerNorm export modules.pooling.AdaptiveAvgPool2d export modules.pooling.MaxPool2d export modules.sparse.Embedding