-
Notifications
You must be signed in to change notification settings - Fork 3
/
dynamic_simple_model_view.py
31 lines (25 loc) · 1.01 KB
/
dynamic_simple_model_view.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import torch
import random
from execution import runner
def optim_func(params) :
return torch.optim.SGD(params, lr=0.01)
def input_func(steps, dtype, device) :
max_seq_length = 128
min_seq_length = 2
seq_lengths = [random.randint(min_seq_length, max_seq_length) for _ in range(steps)]
return [[torch.randn(128, seql, 1024, dtype=dtype, device=device)] for seql in seq_lengths]
class TestModule(torch.nn.Module) :
def __init__(self) :
super(TestModule, self).__init__()
self.linear = torch.nn.Linear(1024, 1024)
self.act = torch.nn.ReLU()
def forward(self, inputs) :
out0 = inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2))
out1 = self.linear(out0)
out1_5 = out1.view(inputs.size(0), inputs.size(1), inputs.size(2))
out2 = self.act(out1_5)
out3 = out2 + inputs
return (out3.sum(),)
if __name__ == "__main__" :
runner.run(sys.argv, 'Dynamic-Simple-Model-View', TestModule(), optim_func, input_func, None)