Skip to content

Commit

Permalink
[Anomaly Task] 🐞 Fix inference when model backbone changes (#1242)
Browse files Browse the repository at this point in the history
* Add check for change in model backbone

* Add load model to train

* Limit padim to only resnet18

* Fix comment
  • Loading branch information
ashwinvaidya17 authored Sep 8, 2022
1 parent 5f622b4 commit d9ad2d7
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 8 deletions.
4 changes: 4 additions & 0 deletions external/anomaly/configs/base/padim/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ class LearningParameters(BaseAnomalyConfig.LearningParameters):
header = string_attribute("Learning Parameters")
description = header

# Editable is set to false as WideResNet50 is very large for
# onnx's protobuf (2gb) limit. This ends up crashing the export.
backbone = selectable(
default_value=ModelBackbone.RESNET18,
header="Model Backbone",
description="Pre-trained backbone used for feature extraction",
editable=False,
visible_in_ui=False,
)

learning_parameters = add_parameter_group(LearningParameters)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ learning_parameters:
auto_hpo_value: null
default_value: resnet18
description: Pre-trained backbone used for feature extraction
editable: true
editable: false
enum_name: ModelBackbone
header: Model Backbone
options:
Expand All @@ -48,7 +48,7 @@ learning_parameters:
rules: []
type: UI_RULES
value: resnet18
visible_in_ui: true
visible_in_ui: false
warning: null
description: Learning Parameters
header: Learning Parameters
Expand Down
4 changes: 2 additions & 2 deletions external/anomaly/configs/detection/padim/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ learning_parameters:
auto_hpo_value: null
default_value: resnet18
description: Pre-trained backbone used for feature extraction
editable: true
editable: false
enum_name: ModelBackbone
header: Model Backbone
options:
Expand All @@ -48,7 +48,7 @@ learning_parameters:
rules: []
type: UI_RULES
value: resnet18
visible_in_ui: true
visible_in_ui: false
warning: null
description: Learning Parameters
header: Learning Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ learning_parameters:
auto_hpo_value: null
default_value: resnet18
description: Pre-trained backbone used for feature extraction
editable: true
editable: false
enum_name: ModelBackbone
header: Model Backbone
options:
Expand All @@ -48,7 +48,7 @@ learning_parameters:
rules: []
type: UI_RULES
value: resnet18
visible_in_ui: true
visible_in_ui: false
warning: null
description: Learning Parameters
header: Learning Parameters
Expand Down
10 changes: 8 additions & 2 deletions external/anomaly/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
AnomalyModule: Anomalib
classification or segmentation model with/without weights.
"""
model = get_model(config=self.config)
if ote_model is None:
model = get_model(config=self.config)
logger.info(
"No trained model in project yet. Created new model with '%s'",
self.model_name,
Expand All @@ -130,10 +130,16 @@ def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
buffer = io.BytesIO(ote_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]:
logger.warning(
"Backbone of the model in the Task Environment is different from the one in the template. "
f"creating model with backbone={model_data['config']['model']['backbone']}"
)
self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"]
try:
model = get_model(config=self.config)
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")

except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception

Expand Down
42 changes: 42 additions & 0 deletions external/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import io
from typing import Optional

import torch
from adapters.anomalib.callbacks import ProgressCallback
from adapters.anomalib.data import OTEAnomalyDataModule
from adapters.anomalib.logger import get_logger
from anomalib.models import AnomalyModule, get_model
from anomalib.utils.callbacks import (
MetricsConfigurationCallback,
MinMaxNormalizationCallback,
Expand Down Expand Up @@ -83,3 +86,42 @@ def train(
self.save_model(output_model)

logger.info("Training completed.")

def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
"""Create and Load Anomalib Module from OTE Model.
This method checks if the task environment has a saved OTE Model,
and creates one. If the OTE model already exists, it returns the
the model with the saved weights.
Args:
ote_model (Optional[ModelEntity]): OTE Model from the
task environment.
Returns:
AnomalyModule: Anomalib
classification or segmentation model with/without weights.
"""
model = get_model(config=self.config)
if ote_model is None:
logger.info(
"No trained model in project yet. Created new model with '%s'",
self.model_name,
)
else:
buffer = io.BytesIO(ote_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

try:
if model_data["config"]["model"]["backbone"] == self.config["model"]["backbone"]:
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")
else:
logger.info(
"Model backbone does not match. Created new model with '%s'",
self.model_name,
)
except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception

return model

0 comments on commit d9ad2d7

Please sign in to comment.