-
Notifications
You must be signed in to change notification settings - Fork 47
/
omnivore.py
77 lines (61 loc) · 2.22 KB
/
omnivore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
from dataclasses import dataclass
from typing import Tuple
import torch
from ego4d.features.config import BaseModelConfig, InferenceConfig
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale
from torch.nn import Identity, Module
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import CenterCropVideo, NormalizeVideo
@dataclass
class ModelConfig(BaseModelConfig):
model_name: str = "omnivore_swinB"
input_type: str = "video"
side_size: int = 256
crop_size: int = 224
mean: Tuple[float] = (0.485, 0.456, 0.406)
std: Tuple[float] = (0.229, 0.224, 0.225)
class WrapModel(Module):
def __init__(self, model: Module, input_type: str):
super().__init__()
self.model = model
self.input_type = input_type
def forward(self, x) -> torch.Tensor:
return self.model(x["video"], input_type=self.input_type)
def load_model(
inference_config: InferenceConfig,
config: ModelConfig,
patch_final_layer: bool = True,
) -> Module:
model = torch.hub.load("facebookresearch/omnivore", model=config.model_name)
if patch_final_layer:
model.heads.image = Identity()
model.heads.video = Identity()
model.heads.rgbd = Identity()
# Set to GPU or CPU
model = WrapModel(model, config.input_type)
model = model.eval()
model = model.to(inference_config.device)
return model
def norm_pixels(x):
return x / 255.0
def get_transform(inference_config: InferenceConfig, config: ModelConfig):
if config.input_type == "video":
transforms = [
Lambda(norm_pixels),
NormalizeVideo(config.mean, config.std),
ShortSideScale(size=config.side_size),
CenterCropVideo(config.crop_size),
]
else:
assert inference_config.frame_window == 1
transforms = [
Lambda(norm_pixels),
NormalizeVideo(config.mean, config.std),
ShortSideScale(size=config.side_size),
CenterCropVideo(config.crop_size),
]
return ApplyTransformToKey(
key="video",
transform=Compose(transforms),
)