Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] fix timm related bug for master #1975

Merged
merged 1 commit into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ jobs:
command: |
python -V
python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
- when:
condition:
equal: [ "1.9.0", << parameters.torch >> ]
steps:
- run: python -m pip install timm
- when:
condition:
equal: [ "1.6.0", << parameters.torch >> ]
steps:
- run: python -m pip install timm==0.6.7
- run:
name: Install mmaction dependencies
command: |
Expand Down Expand Up @@ -124,6 +134,16 @@ jobs:
docker exec mmaction pip install git+https://github.com/open-mmlab/mmdetection/
docker exec mmaction pip install git+https://github.com/open-mmlab/mmclassification/
docker exec mmaction pip install -r requirements.txt
- when:
condition:
equal: [ "1.8.1", << parameters.torch >> ]
steps:
- run: docker exec mmaction pip install timm
- when:
condition:
equal: [ "1.6.0", << parameters.torch >> ]
steps:
- run: docker exec mmaction pip install timm==0.6.7
- when:
condition:
equal: [ "10.2", << parameters.cuda >> ]
Expand Down
12 changes: 11 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ jobs:
run: sudo apt-get install -y libturbojpeg
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install timm
run: python -m pip install timm==0.6.7
if: ${{matrix.torch == '1.5.0'}}
- name: Install timm
run: python -m pip install timm
if: ${{matrix.torch != '1.5.0'}}
- name: Install MMCV
run: pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
- name: Install MMDet
Expand Down Expand Up @@ -138,7 +144,7 @@ jobs:
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmaction -m pytest tests/
coverage run --branch --source mmaction -m pytest tests/ -k 'not timm'
coverage xml
coverage report -m
# Only upload coverage report for python3.7 && pytorch1.5
Expand Down Expand Up @@ -188,6 +194,8 @@ jobs:
run: python -m pip install lmdb
- name: Install PyTorch
run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install timm
run: python -m pip install timm
- name: Install mmaction dependencies
run: |
python -V
Expand Down Expand Up @@ -229,6 +237,8 @@ jobs:
- name: Install PyTorch
# As a complement to Linux CI, we test on PyTorch LTS version
run: pip install torch==1.8.2+${{ matrix.platform }} torchvision==0.9.2+${{ matrix.platform }} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Install timm
run: python -m pip install timm
- name: Install MMCV
run: pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.8/index.html --only-binary mmcv-full
- name: Install mmaction dependencies
Expand Down
1 change: 0 additions & 1 deletion requirements/optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ onnxruntime
packaging
pims
PyTurboJPEG
timm
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ universal=1
[aliases]
test=pytest

[tool:pytest]
addopts=tests/

[yapf]
based_on_style = pep8
blank_line_before_nested_class_or_def = true
Expand Down
48 changes: 26 additions & 22 deletions tests/test_models/test_recognizers/test_recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,6 @@ def test_tsn():
for one_img in img_list:
recognizer(one_img, None, return_loss=False)

# test timm backbones
timm_backbone = dict(type='timm.efficientnet_b0', pretrained=False)
config.model['backbone'] = timm_backbone
config.model['cls_head']['in_channels'] = 1280

recognizer = build_recognizer(config.model)

input_shape = (1, 3, 3, 32, 32)
demo_inputs = generate_recognizer_demo_inputs(input_shape)

imgs = demo_inputs['imgs']
gt_labels = demo_inputs['gt_labels']

losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
img_list = [img[None, :] for img in imgs]
for one_img in img_list:
recognizer(one_img, None, return_loss=False)


def test_tsm():
config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')
Expand Down Expand Up @@ -280,3 +258,29 @@ def test_tanet():
recognizer(imgs, gradcam=True)
for one_img in img_list:
recognizer(one_img, gradcam=True)


def test_timm_backbone():
# test tsn from timm
config = get_recognizer_cfg('tsn/tsn_r50_1x1x3_100e_kinetics400_rgb.py')
config.model['backbone']['pretrained'] = None
timm_backbone = dict(type='timm.efficientnet_b0', pretrained=False)
config.model['backbone'] = timm_backbone
config.model['cls_head']['in_channels'] = 1280

recognizer = build_recognizer(config.model)

input_shape = (1, 3, 3, 32, 32)
demo_inputs = generate_recognizer_demo_inputs(input_shape)

imgs = demo_inputs['imgs']
gt_labels = demo_inputs['gt_labels']

losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
img_list = [img[None, :] for img in imgs]
for one_img in img_list:
recognizer(one_img, None, return_loss=False)