-
Notifications
You must be signed in to change notification settings - Fork 466
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MINOR] Forward pass for ResNet18 and 34
This commit contains the building blocks for the ResNet primitive of ResNet18 and ResNet34. Closes #1944
- Loading branch information
1 parent
9de62d1
commit 8dfe211
Showing
3 changed files
with
223 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
source("scripts/nn/networks/resnet.dml") as resnet | ||
|
||
forward = function(matrix[double] X, int Hin, int Win, | ||
list[unknown] model, string mode, | ||
list[unknown] ema_means_vars) | ||
return (matrix[double] out, list[unknown] ema_means_vars_upd) { | ||
/* | ||
* Forward pass of the ResNet 18 model as introduced in | ||
* "Deep Residual Learning for Image Recognition" by | ||
* Kaiming He et. al. and inspired by PyTorch. | ||
* | ||
* Inputs: | ||
* - X: Inputs, of shape (N, C_in*Hin*Win). | ||
* C_in = 3 is expected. | ||
* - Hin: Input height. | ||
* - Win: Input width. | ||
* - model: Weights and bias matrices of the model | ||
* with the following order/content: | ||
* -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7) | ||
* -> 2: Weights of batch norm 1, of shape (64, 1). | ||
* -> 3: Bias of batch norm 1, of shape (64, 1). | ||
* -> 4: List of weights for first residual layer | ||
* with 64 base channels. | ||
* -> 5: List of weights for second residual layer | ||
* with 128 base channels. | ||
* -> 6: List of weights for third residual layer | ||
* with 256 base channels. | ||
* -> 7: List of weights for fourth residual layer | ||
* with 512 base channels. | ||
* List of residual layers 1, 2, 3 & 4 have | ||
* the content/order: | ||
* -> 1: List of weights for first residual | ||
* block. | ||
* -> 2: List of weights for second residual | ||
* block. | ||
* Each list of weights for a residual block | ||
* must follow the same order as defined in | ||
* the documentation of basic_block_forward(). | ||
* -> 8: Weights of fully connected layer, of shape (512, 1000) | ||
* -> 9: Bias of fully connected layer, of shape (1, 1000) | ||
* - mode: 'train' or 'test' to indicate if the model is currently | ||
* being trained or tested for badge normalization layers. | ||
* See badge_norm2d.dml docs for more info. | ||
* - ema_means_vars: List of exponential moving averages for mean | ||
* and variance for badge normalization layers. | ||
* -> 1: EMA for mean of badge norm 1, of shape (64, 1). | ||
* -> 2: EMA for variance of badge norm 1, of shape (64, 1). | ||
* -> 3: List of EMA means and vars for residual layer 1. | ||
* -> 4: List of EMA means and vars for residual layer 2. | ||
* -> 5: List of EMA means and vars for residual layer 3. | ||
* -> 6: List of EMA means and vars for residual layer 4. | ||
* Lists for EMAs of layer 1, 2, 3 & 4 must have the | ||
* following order: | ||
* -> 1: List of EMA means and vars for residual block 1. | ||
* -> 2: List of EMA means and vars for residual block 2. | ||
* Each list of EMAs for a residual block | ||
* must follow the same order as defined in | ||
* the documentation of basic_block_forward(). | ||
* - NOTICE: The lists of the first blocks for layer 2, 3 and 4 | ||
* must include weights and EMAs for 1 extra conv layer | ||
* and a batch norm layer for the downsampling on the | ||
* identity path. | ||
* | ||
* Outputs: | ||
* - out: Outputs, of shape (N, 1000) | ||
* - ema_means_vars_upd: List of updated exponential moving averages | ||
* for mean and variance of badge normalization layers. It follows | ||
* the same exact structure as the input EMAs list. | ||
*/ | ||
layer_sizes = list(2, 2, 2, 2) | ||
[out, ema_means_vars_upd] = resnet::resnet_basic_forward(X, Hin, Win, | ||
layer_sizes, model, mode, ema_means_vars) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
source("scripts/nn/networks/resnet.dml") as resnet | ||
|
||
forward = function(matrix[double] X, int Hin, int Win, | ||
list[unknown] model, string mode, | ||
list[unknown] ema_means_vars) | ||
return (matrix[double] out, list[unknown] ema_means_vars_upd) { | ||
/* | ||
* Forward pass of the ResNet 34 model as introduced in | ||
* "Deep Residual Learning for Image Recognition" by | ||
* Kaiming He et. al. and inspired by PyTorch. | ||
* | ||
* Inputs: | ||
* - X: Inputs, of shape (N, C_in*Hin*Win). | ||
* C_in = 3 is expected. | ||
* - Hin: Input height. | ||
* - Win: Input width. | ||
* - model: Weights and bias matrices of the model | ||
* with the following order/content: | ||
* -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7) | ||
* -> 2: Weights of batch norm 1, of shape (64, 1). | ||
* -> 3: Bias of batch norm 1, of shape (64, 1). | ||
* -> 4: List of weights for first residual layer | ||
* with 64 base channels. | ||
* -> 5: List of weights for second residual layer | ||
* with 128 base channels. | ||
* -> 6: List of weights for third residual layer | ||
* with 256 base channels. | ||
* -> 7: List of weights for fourth residual layer | ||
* with 512 base channels. | ||
* List of residual layers 1, 2, 3 & 4 have | ||
* n lists of weights for a residual block. | ||
* Layer 1 has 3 lists, 2 has 4, 3 has 6 and | ||
* layer 4 has 3 lists. | ||
* Each list of weights for a residual block | ||
* must follow the same order as defined in | ||
* the documentation of basic_block_forward(). | ||
* -> 8: Weights of fully connected layer, of shape (512, 1000) | ||
* -> 9: Bias of fully connected layer, of shape (1, 1000) | ||
* - mode: 'train' or 'test' to indicate if the model is currently | ||
* being trained or tested for badge normalization layers. | ||
* See badge_norm2d.dml docs for more info. | ||
* - ema_means_vars: List of exponential moving averages for mean | ||
* and variance for badge normalization layers. | ||
* -> 1: EMA for mean of badge norm 1, of shape (64, 1). | ||
* -> 2: EMA for variance of badge norm 1, of shape (64, 1). | ||
* -> 3: List of EMA means and vars for residual layer 1. | ||
* -> 4: List of EMA means and vars for residual layer 2. | ||
* -> 5: List of EMA means and vars for residual layer 3. | ||
* -> 6: List of EMA means and vars for residual layer 4. | ||
* List of residual layers 1, 2, 3 & 4 have | ||
* n lists of EMAs for a residual block. | ||
* Layer 1 has 3 lists, 2 has 4, 3 has 6 and | ||
* layer 4 has 3 lists. | ||
* Each list of EMAs for a residual block | ||
* must follow the same order as defined in | ||
* the documentation of basic_block_forward(). | ||
* - NOTICE: The lists of the first residual blocks for layer 2, | ||
* 3 and 4 must include weights and EMAs for 1 extra | ||
* conv layer and a batch norm layer for the downsampling | ||
* on the identity path. | ||
* | ||
* Outputs: | ||
* - out: Outputs, of shape (N, 1000) | ||
* - ema_means_vars_upd: List of updated exponential moving averages | ||
* for mean and variance of badge normalization layers. It follows | ||
* the same exact structure as the input EMAs list. | ||
*/ | ||
layer_sizes = list(3, 4, 6, 3) | ||
[out, ema_means_vars_upd] = resnet::resnet_basic_forward(X, Hin, Win, | ||
layer_sizes, model, mode, ema_means_vars) | ||
} |