Skip to content

Commit

Permalink
[MINOR] Forward pass for ResNet18 and 34
Browse files Browse the repository at this point in the history
This commit contains the building blocks
for the ResNet primitive of ResNet18 and ResNet34.

Closes #1944
  • Loading branch information
MaximilianSchreff authored and Baunsgaard committed Nov 10, 2023
1 parent 9de62d1 commit 8dfe211
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 27 deletions.
64 changes: 37 additions & 27 deletions scripts/nn/networks/resnet.dml
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ basic_block_forward = function(matrix[double] X, list[unknown] weights,

ema_means_vars_upd = list(ema_mean_bn1_upd, ema_var_bn1_upd, ema_mean_bn2_upd, ema_var_bn2_upd)
if (downsample) {
ema_means_vars_upd = append(ema_means_vars, ema_mean_bn3_upd)
ema_means_vars_upd = append(ema_means_vars, ema_var_bn3_upd)
ema_means_vars_upd = append(ema_means_vars_upd, ema_mean_bn3_upd)
ema_means_vars_upd = append(ema_means_vars_upd, ema_var_bn3_upd)
}
}

Expand Down Expand Up @@ -224,21 +224,25 @@ basic_reslayer_forward = function(matrix[double] X, int Hin, int Win, int blocks
}
}

resnet18_forward = function(matrix[double] X, int Hin, int Win,
list[unknown] model, string mode,
list[unknown] ema_means_vars)
resnet_basic_forward = function(matrix[double] X, int Hin, int Win,
list[unknown] layer_sizes,
list[unknown] model, string mode,
list[unknown] ema_means_vars)
return (matrix[double] out, list[unknown] ema_means_vars_upd) {
/*
* Forward pass of the ResNet 18 model as introduced in
* "Deep Residual Learning for Image Recognition" by
* Kaiming He et. al. and inspired by the PyTorch
* implementation.
* Forward pass of the ResNet 18 and 34 model as introduced
* in "Deep Residual Learning for Image Recognition" by
* Kaiming He et. al. and inspired by the PyTorch.
*
* Inputs:
* - X: Inputs, of shape (N, C_in*Hin*Win).
* C_in = 3 is expected.
* - Hin: Input height.
* - Win: Input width.
* - layer_sizes: List of the sizes of each of
* the 4 residual layers.
* For ResNet18: [2, 2, 2, 2]
* For ResNet34: [3, 4, 6, 3]
* - model: Weights and bias matrices of the model
* with the following order/content:
* -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7)
Expand All @@ -254,10 +258,8 @@ resnet18_forward = function(matrix[double] X, int Hin, int Win,
* with 512 base channels.
* List of residual layers 1, 2, 3 & 4 have
* the content/order:
* -> 1: List of weights for first residual
* block.
* -> 2: List of weights for second residual
* block.
* -> i: List of weights for residual block i.
* with i in {1, ..., layer_sizes[layer]}
* Each list of weights for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
Expand All @@ -276,8 +278,8 @@ resnet18_forward = function(matrix[double] X, int Hin, int Win,
* -> 6: List of EMA means and vars for residual layer 4.
* Lists for EMAs of layer 1, 2, 3 & 4 must have the
* following order:
* -> 1: List of EMA means and vars for residual block 1.
* -> 2: List of EMA means and vars for residual block 2.
* -> i: List of EMA means and vars for residual block i.
* with i in {1, ..., layer_sizes[layer]}
* Each list of EMAs for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
Expand Down Expand Up @@ -330,28 +332,36 @@ resnet18_forward = function(matrix[double] X, int Hin, int Win,
Wf=3, strideh=2, stridew=2, padh=1, padw=1)

# residual layer 1
block_count = as.integer(as.scalar(layer_sizes[1]))
[out, Hout, Wout, emas1_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=1, stridew=1, C_in=C,
C_base=64, blocks_weights=weights_reslayer1,
mode=mode, ema_means_vars=emas_reslayer1)
Win=Wout, blocks=block_count, strideh=1,
stridew=1, C_in=C, C_base=64,
blocks_weights=weights_reslayer1, mode=mode,
ema_means_vars=emas_reslayer1)
C = 64
# residual layer 2
block_count = as.integer(as.scalar(layer_sizes[2]))
[out, Hout, Wout, emas2_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C,
C_base=128, blocks_weights=weights_reslayer2,
mode=mode, ema_means_vars=emas_reslayer2)
Win=Wout, blocks=block_count, strideh=2,
stridew=2, C_in=C, C_base=128,
blocks_weights=weights_reslayer2, mode=mode,
ema_means_vars=emas_reslayer2)
C = 128
# residual layer 3
block_count = as.integer(as.scalar(layer_sizes[3]))
[out, Hout, Wout, emas3_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C,
C_base=256, blocks_weights=weights_reslayer3,
mode=mode, ema_means_vars=emas_reslayer3)
Win=Wout, blocks=block_count, strideh=2,
stridew=2, C_in=C, C_base=256,
blocks_weights=weights_reslayer3, mode=mode,
ema_means_vars=emas_reslayer3)
C = 256
# residual layer 4
block_count = as.integer(as.scalar(layer_sizes[4]))
[out, Hout, Wout, emas4_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C,
C_base=512, blocks_weights=weights_reslayer4,
mode=mode, ema_means_vars=emas_reslayer4)
Win=Wout, blocks=block_count, strideh=2,
stridew=2, C_in=C, C_base=512,
blocks_weights=weights_reslayer4, mode=mode,
ema_means_vars=emas_reslayer4)
C = 512

# Global Average Pooling
Expand Down
94 changes: 94 additions & 0 deletions scripts/nn/networks/resnet18.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

source("scripts/nn/networks/resnet.dml") as resnet

forward = function(matrix[double] X, int Hin, int Win,
list[unknown] model, string mode,
list[unknown] ema_means_vars)
return (matrix[double] out, list[unknown] ema_means_vars_upd) {
/*
* Forward pass of the ResNet 18 model as introduced in
* "Deep Residual Learning for Image Recognition" by
* Kaiming He et. al. and inspired by PyTorch.
*
* Inputs:
* - X: Inputs, of shape (N, C_in*Hin*Win).
* C_in = 3 is expected.
* - Hin: Input height.
* - Win: Input width.
* - model: Weights and bias matrices of the model
* with the following order/content:
* -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7)
* -> 2: Weights of batch norm 1, of shape (64, 1).
* -> 3: Bias of batch norm 1, of shape (64, 1).
* -> 4: List of weights for first residual layer
* with 64 base channels.
* -> 5: List of weights for second residual layer
* with 128 base channels.
* -> 6: List of weights for third residual layer
* with 256 base channels.
* -> 7: List of weights for fourth residual layer
* with 512 base channels.
* List of residual layers 1, 2, 3 & 4 have
* the content/order:
* -> 1: List of weights for first residual
* block.
* -> 2: List of weights for second residual
* block.
* Each list of weights for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
* -> 8: Weights of fully connected layer, of shape (512, 1000)
* -> 9: Bias of fully connected layer, of shape (1, 1000)
* - mode: 'train' or 'test' to indicate if the model is currently
* being trained or tested for badge normalization layers.
* See badge_norm2d.dml docs for more info.
* - ema_means_vars: List of exponential moving averages for mean
* and variance for badge normalization layers.
* -> 1: EMA for mean of badge norm 1, of shape (64, 1).
* -> 2: EMA for variance of badge norm 1, of shape (64, 1).
* -> 3: List of EMA means and vars for residual layer 1.
* -> 4: List of EMA means and vars for residual layer 2.
* -> 5: List of EMA means and vars for residual layer 3.
* -> 6: List of EMA means and vars for residual layer 4.
* Lists for EMAs of layer 1, 2, 3 & 4 must have the
* following order:
* -> 1: List of EMA means and vars for residual block 1.
* -> 2: List of EMA means and vars for residual block 2.
* Each list of EMAs for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
* - NOTICE: The lists of the first blocks for layer 2, 3 and 4
* must include weights and EMAs for 1 extra conv layer
* and a batch norm layer for the downsampling on the
* identity path.
*
* Outputs:
* - out: Outputs, of shape (N, 1000)
* - ema_means_vars_upd: List of updated exponential moving averages
* for mean and variance of badge normalization layers. It follows
* the same exact structure as the input EMAs list.
*/
layer_sizes = list(2, 2, 2, 2)
[out, ema_means_vars_upd] = resnet::resnet_basic_forward(X, Hin, Win,
layer_sizes, model, mode, ema_means_vars)
}
92 changes: 92 additions & 0 deletions scripts/nn/networks/resnet34.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

source("scripts/nn/networks/resnet.dml") as resnet

forward = function(matrix[double] X, int Hin, int Win,
list[unknown] model, string mode,
list[unknown] ema_means_vars)
return (matrix[double] out, list[unknown] ema_means_vars_upd) {
/*
* Forward pass of the ResNet 34 model as introduced in
* "Deep Residual Learning for Image Recognition" by
* Kaiming He et. al. and inspired by PyTorch.
*
* Inputs:
* - X: Inputs, of shape (N, C_in*Hin*Win).
* C_in = 3 is expected.
* - Hin: Input height.
* - Win: Input width.
* - model: Weights and bias matrices of the model
* with the following order/content:
* -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7)
* -> 2: Weights of batch norm 1, of shape (64, 1).
* -> 3: Bias of batch norm 1, of shape (64, 1).
* -> 4: List of weights for first residual layer
* with 64 base channels.
* -> 5: List of weights for second residual layer
* with 128 base channels.
* -> 6: List of weights for third residual layer
* with 256 base channels.
* -> 7: List of weights for fourth residual layer
* with 512 base channels.
* List of residual layers 1, 2, 3 & 4 have
* n lists of weights for a residual block.
* Layer 1 has 3 lists, 2 has 4, 3 has 6 and
* layer 4 has 3 lists.
* Each list of weights for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
* -> 8: Weights of fully connected layer, of shape (512, 1000)
* -> 9: Bias of fully connected layer, of shape (1, 1000)
* - mode: 'train' or 'test' to indicate if the model is currently
* being trained or tested for badge normalization layers.
* See badge_norm2d.dml docs for more info.
* - ema_means_vars: List of exponential moving averages for mean
* and variance for badge normalization layers.
* -> 1: EMA for mean of badge norm 1, of shape (64, 1).
* -> 2: EMA for variance of badge norm 1, of shape (64, 1).
* -> 3: List of EMA means and vars for residual layer 1.
* -> 4: List of EMA means and vars for residual layer 2.
* -> 5: List of EMA means and vars for residual layer 3.
* -> 6: List of EMA means and vars for residual layer 4.
* List of residual layers 1, 2, 3 & 4 have
* n lists of EMAs for a residual block.
* Layer 1 has 3 lists, 2 has 4, 3 has 6 and
* layer 4 has 3 lists.
* Each list of EMAs for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
* - NOTICE: The lists of the first residual blocks for layer 2,
* 3 and 4 must include weights and EMAs for 1 extra
* conv layer and a batch norm layer for the downsampling
* on the identity path.
*
* Outputs:
* - out: Outputs, of shape (N, 1000)
* - ema_means_vars_upd: List of updated exponential moving averages
* for mean and variance of badge normalization layers. It follows
* the same exact structure as the input EMAs list.
*/
layer_sizes = list(3, 4, 6, 3)
[out, ema_means_vars_upd] = resnet::resnet_basic_forward(X, Hin, Win,
layer_sizes, model, mode, ema_means_vars)
}

0 comments on commit 8dfe211

Please sign in to comment.