Skip to content

Commit

Permalink
Add ModuleList, LayerNorm. Also misc improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
davoclavo committed Jul 27, 2023
1 parent fcad603 commit 3f94605
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 1 deletion.
40 changes: 40 additions & 0 deletions core/src/main/scala/torch/nn/modules/container/ModuleList.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package container

import sourcecode.Name
import scala.util.Random

final class ModuleList[D <: DType](override val modules: TensorModule[D]*)
extends Module
with Seq[TensorModule[D]]:
// with TensorModule[D]:
modules.zipWithIndex.foreach((module, index) =>
this.register(module)(using Name(index.toString()))
)

override def iterator: Iterator[TensorModule[D]] = modules.iterator

// def apply(v1: Int | torch.Tensor[D]): torch.nn.modules.TensorModule[D] & torch.Tensor[D] = ???
def apply(i: Int): torch.nn.modules.TensorModule[D] = modules(i)

override def length: Int = modules.length

override def toString = getClass().getSimpleName()
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import torch.{DType, Tensor}
* a boolean value that when set to `true`, this module has learnable per-channel affine
* parameters initialized to ones (for weights) and zeros (for biases)
*/
final class GroupNorm[ParamType <: DType](
final class GroupNorm[ParamType <: DType: Default](
numGroups: Int,
numChannels: Int,
eps: Double = 1e-05,
Expand Down
55 changes: 55 additions & 0 deletions core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package normalization

import org.bytedeco.pytorch
import org.bytedeco.pytorch.{LayerNormImpl, LayerNormOptions, LongVector}
import torch.nn.modules.TensorModule
import torch.{DType, Tensor}

/** Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer
* Normalization // TODO Add docs
*/
final class LayerNorm[ParamType <: DType: Default](
normalizedShape: Seq[Int] | Int,
eps: Double = 1e-05,
elementwiseAffine: Boolean = true
) extends TensorModule[ParamType]:
private val options: LayerNormOptions = normalizedShape match {
case normalizedShape: Seq[Int] =>
LayerNormOptions(LongVector(normalizedShape.map(_.toLong)*))
case normalizedShape: Int =>
LayerNormOptions(LongVector(normalizedShape.toLong))
}
options.eps().put(eps)
options.elementwise_affine().put(elementwiseAffine)

override private[torch] val nativeModule: LayerNormImpl = LayerNormImpl(options)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)
val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias)

def apply(t: Tensor[ParamType]): Tensor[ParamType] =
Tensor[ParamType](nativeModule.forward(t.native))
28 changes: 28 additions & 0 deletions core/src/main/scala/torch/nn/modules/sparse/Embedding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@ import org.bytedeco.pytorch.EmbeddingOptions
import torch.nn.modules.{HasParams, HasWeight, TensorModule}
import torch.internal.NativeConverters.{toNative, doubleToDoublePointer}

/** A simple lookup table that stores embeddings of a fixed dictionary and size.
*
* This module is often used to store word embeddings and retrieve them using indices. The input to
* the module is a list of indices, and the output is the corresponding word embeddings.
*
* @group nn_sparse
*
* @param numEmbeddings
* Size of the dictionary of embeddings
* @param embeddingDim
* The size of each embedding vector
* @param paddingIdx
* If specified, the entries at `paddingIdx` do not contribute to the gradient; therefore, the
* embedding vector at `paddingIdx` is not updated during training, i.e. it remains as a fixed
* "pad". For a newly constructed Embedding, the embedding vector at `paddingIdx` will default to
* all zeros, but can be updated to another value to be used as the padding vector.
* @param maxNorm
* If given, each embedding vector with norm larger than `maxNorm` is renormalized to have norm
* `maxNorm`.
* @param normType
* The p of the p-norm to compute for the `maxNorm` option. Default `2`.
* @param scaleGradByFreq
* If given, this will scale gradients by the inverse of frequency of the words in the
* mini-batch. Default `false`.
* @param sparse
* If ``True``, gradient w.r.t. `weight` matrix will be a sparse tensor. See Notes for more
* details regarding sparse gradients.
*/
final class Embedding[ParamType <: FloatNN | ComplexNN: Default](
numEmbeddings: Int,
embeddingDim: Int,
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/torch/nn/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ package object nn {
export modules.batchnorm.BatchNorm1d
export modules.batchnorm.BatchNorm2d
export modules.container.Sequential
export modules.container.ModuleList
export modules.conv.Conv2d
export modules.flatten.Flatten
export modules.linear.Linear
export modules.linear.Identity
export modules.normalization.GroupNorm
export modules.normalization.LayerNorm
export modules.pooling.AdaptiveAvgPool2d
export modules.pooling.MaxPool2d
export modules.sparse.Embedding
Expand Down

0 comments on commit 3f94605

Please sign in to comment.