Skip to content

Commit

Permalink
Fix that __main__ routine ignores device cpu setup
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex4386 committed Jun 15, 2024
1 parent 1337053 commit b41c754
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions server/model/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torchsummary import summary
from thop import profile
import time
from runner.utils import get_config


class PatchEmbedding(nn.Module):
Expand Down Expand Up @@ -102,6 +103,13 @@ def forward(self, x):


if __name__ == '__main__':
config = get_config('config.yaml')

if config['GPU']['cuda']:
device = 'cuda'
else:
device = 'cpu'

model = ViT(
in_channels=1,
patch_size=(2, 16),
Expand All @@ -111,7 +119,7 @@ def forward(self, x):
mlp_dim=32,
num_classes=24,
in_size=[2, 1024]
).to("cuda")
).to(device)

print(summary(model, (1, 2, 1024)))

Expand All @@ -128,7 +136,7 @@ def forward(self, x):

input = torch.randn(1, 1, 2, 1024)

macs, params = profile(model, inputs=(torch.Tensor(input).to(device="cuda"),))
macs, params = profile(model, inputs=(torch.Tensor(input).to(device=device),))
print(
"Param: %.2fM | FLOPs: %.3fG" % (params / (1000 ** 2), macs / (1000 ** 3))
)

0 comments on commit b41c754

Please sign in to comment.