forked from comfyanonymous/ComfyUI_TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 1
/
engine_info.py
83 lines (72 loc) · 2.44 KB
/
engine_info.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from dataclasses import dataclass, asdict, is_dataclass
import json
@dataclass
class EngineInfo:
min_batch: int
opt_batch: int
max_batch: int
min_height: int
opt_height: int
max_height: int
min_width: int
opt_width: int
max_width: int
min_context_len: int
opt_context_len: int
max_context_len: int
in_channels: int
context_dim: int
y_dim: int
dtype: str
model_config_init: str # base64 encoded pickle
model_init: str # base64 encoded pickle
def verify(self, other: "EngineInfo"):
return (
other.dtype == self.dtype
and other.in_channels == self.in_channels
and other.y_dim == self.y_dim
and other.context_dim == self.context_dim
# NOTE: not checking batch size because it can be splitted
# and self.min_batch <= other.opt_batch <= self.max_batch
and self.min_height <= other.opt_height <= self.max_height
and self.min_width <= other.opt_width <= self.max_width
and self.min_context_len <= other.opt_context_len <= self.max_context_len
)
def _get_shapes(self, batch_size: int, height: int, width: int, context_len: int):
s = (
(
batch_size,
self.in_channels,
height // 8,
width // 8,
),
(batch_size,),
(batch_size, context_len * self.opt_context_len, self.context_dim),
)
if self.y_dim > 0:
return s + ((batch_size, self.y_dim),)
return s
def min_shapes(self):
return self._get_shapes(
self.min_batch, self.min_height, self.min_width, self.min_context_len
)
def opt_shapes(self):
return self._get_shapes(
self.opt_batch, self.opt_height, self.opt_width, self.opt_context_len
)
def max_shapes(self):
return self._get_shapes(
self.max_batch, self.max_height, self.max_width, self.max_context_len
)
@classmethod
def load(cls, info_path: str):
with open(info_path, "r") as f:
return cls(**json.load(f))
def dump(self, info_path: str):
with open(info_path, "w") as f:
json.dump(self, f, indent=2, cls=EngineInfoJsonEncoder)
class EngineInfoJsonEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o)
return super().default(o)