Skip to content

Commit

Permalink
Support timm zero shot learning (open-mmlab#1975)
Browse files Browse the repository at this point in the history
* support timm zero shot

* fix

* fix

* add cnn tests
  • Loading branch information
zhiqiangdon authored Jul 22, 2022
1 parent f54cea3 commit 54f7c2f
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 20 deletions.
30 changes: 15 additions & 15 deletions multimodal/src/autogluon/multimodal/models/timm_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from timm import create_model
from .utils import (
assign_layer_ids,
init_weights,
get_column_features,
get_model_head,
)
from ..constants import (
IMAGE,
Expand All @@ -33,7 +33,7 @@ def __init__(
self,
prefix: str,
checkpoint_name: str,
num_classes: Optional[int] = 0,
num_classes: Optional[int] = None,
mix_choice: Optional[str] = "all_logits",
pretrained: Optional[bool] = True,
):
Expand Down Expand Up @@ -61,11 +61,12 @@ def __init__(
# In TIMM, if num_classes==0, then create_model would automatically set self.model.head = nn.Identity()
logger.debug(f"initializing {checkpoint_name}")
self.checkpoint_name = checkpoint_name
self.num_classes = num_classes
self.model = create_model(checkpoint_name, pretrained=pretrained, num_classes=0)
self.pretrained = pretrained
self.model = create_model(checkpoint_name, pretrained=pretrained, num_classes=num_classes)
self.num_classes = self.model.num_classes
self.out_features = self.model.num_features
self.head = nn.Linear(self.out_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.apply(init_weights)
self.head = get_model_head(model=self.model)
self.model.reset_classifier(0) # remove the internal head

self.mix_choice = mix_choice
logger.debug(f"mix_choice: {mix_choice}")
Expand Down Expand Up @@ -114,7 +115,9 @@ def forward(
image_valid_num = batch[self.image_valid_num_key]
ret = {COLUMN_FEATURES: {FEATURES: {}, MASKS: {}}}
if self.mix_choice == "all_images": # mix inputs
mixed_images = images.sum(dim=1) / image_valid_num[:, None, None, None] # mixed shape: (b, 3, h, w)
mixed_images = (
images.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None, None, None]
) # mixed shape: (b, 3, h, w)
features = self.model(mixed_images)
logits = self.head(features)

Expand All @@ -137,18 +140,15 @@ def forward(
ret[COLUMN_FEATURES][FEATURES].update(column_features)
ret[COLUMN_FEATURES][MASKS].update(column_feature_masks)

features = features.sum(dim=1) # (b, num_features)
logits = logits.sum(dim=1) # (b, num_classes)
features = features.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_features)
logits = logits.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_classes)

else:
raise ValueError(f"unknown mix_choice: {self.mix_choice}")

ret.update(
{
LOGITS: logits,
FEATURES: features,
}
)
ret[FEATURES] = features
if self.num_classes > 0:
ret[LOGITS] = logits

return {self.prefix: ret}

Expand Down
25 changes: 25 additions & 0 deletions multimodal/src/autogluon/multimodal/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,28 @@ def inject_lora_to_linear_layer(
setattr(model, n, lora_layer)

return model # return model to enable method chaining


def get_model_head(model: nn.Module):
"""
Return the model's head. Different models may have different head names.
Parameters
----------
model
A Pytorch model.
Returns
-------
The model's head.
"""
if hasattr(model, "head"):
head = model.head # move the head outside
elif hasattr(model, "last_linear"):
head = model.last_linear
elif hasattr(model, "fc"):
head = model.fc
else:
raise ValueError(f"Model {type(model)} doesn't have head. Need to check its implementation.")

return head
5 changes: 4 additions & 1 deletion multimodal/src/autogluon/multimodal/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,10 @@ def predict(
y_pred=logits_or_prob,
)
else:
pred = logits_or_prob
if logits_or_prob.ndim == 2:
pred = logits_or_prob.argmax(axis=1)
else:
pred = logits_or_prob

if (as_pandas is None and isinstance(data, pd.DataFrame)) or as_pandas is True:
pred = self._as_pandas(data=data, to_be_converted=pred)
Expand Down
2 changes: 1 addition & 1 deletion multimodal/src/autogluon/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,7 +1652,7 @@ def init_zero_shot(
assert (
len(config.model.names) == 1
), f"Zero shot mode only supports using one model, but detects multiple models {config.model.names}"
model = create_model(config=config)
model = create_model(config=config, pretrained=True)

data_processors = init_data_processors(
config=config,
Expand Down
53 changes: 50 additions & 3 deletions multimodal/tests/unittests/test_zero_shot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from PIL import Image
import requests
import pytest
import numpy as np
from autogluon.multimodal import MultiModalPredictor

pytest.skip("Temporarily skip this test to pass the Jenkins build.", allow_module_level=True)


def test_clip_zero_shot():
def download_sample_images():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
cat_image_name = "cat.jpg"
Expand All @@ -17,6 +16,12 @@ def test_clip_zero_shot():
dog_image_name = "dog.jpg"
image.save(dog_image_name)

return cat_image_name, dog_image_name


def test_clip_zero_shot():
cat_image_name, dog_image_name = download_sample_images()

cat_text = "a photo of a cat"
dog_text = "a photo of a dog"
bird_text = "a photo of a bird"
Expand Down Expand Up @@ -98,3 +103,45 @@ def test_clip_zero_shot():
# invalid API usage 2: predicting probability with only one dictionary as input.
with pytest.raises(AssertionError):
prob = predictor.predict_proba({"image": [cat_image_name], "text": [cat_text]})


@pytest.mark.parametrize(
"checkpoint_name",
[
"swin_tiny_patch4_window7_224",
"vit_tiny_patch16_224",
"resnet18",
"legacy_seresnet18",
],
)
def test_timm_zero_shot(checkpoint_name):
cat_image_name, dog_image_name = download_sample_images()

predictor = MultiModalPredictor(
hyperparameters={
"model.names": ["timm_image"],
"model.timm_image.checkpoint_name": checkpoint_name,
},
problem_type="zero_shot",
)

pred = predictor.predict({"image": [cat_image_name, dog_image_name]})
assert pred.shape == (2,)

prob = predictor.predict_proba({"image": [cat_image_name, dog_image_name]})
assert prob.shape == (2, 1000)

features = predictor.extract_embedding({"abc": [cat_image_name, dog_image_name]})
assert features["abc"].ndim == 2 and features["abc"].shape[0] == 2

features, masks = predictor.extract_embedding({"abc": [cat_image_name, dog_image_name]}, return_masks=True)
assert features["abc"].ndim == 2 and features["abc"].shape[0] == 2
assert np.all(masks["abc"] == np.array([1, 1]))

features, masks = predictor.extract_embedding(
{"abc": [cat_image_name], "123": [dog_image_name]}, return_masks=True
)
assert features["abc"].ndim == 2 and features["abc"].shape[0] == 1
assert features["123"].ndim == 2 and features["123"].shape[0] == 1
assert np.all(masks["abc"] == np.array([1]))
assert np.all(masks["123"] == np.array([1]))

0 comments on commit 54f7c2f

Please sign in to comment.