Skip to content

Commit

Permalink
fix(render): nn.Parameter codegen (#122)
Browse files Browse the repository at this point in the history
* fix(render): remove extra bracket for nn.Parameter codegen

* fix(render): reference nn.Parameter as attr in codegen

* test(render): test codegen to nn.Parameter

* fix(grad-check): make param names consistent after pt2 compilation
  • Loading branch information
ganler committed Oct 5, 2023
1 parent 71278ac commit 8010210
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 19 deletions.
10 changes: 6 additions & 4 deletions nnsmith/backends/pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def import_libs(self) -> List[str]:
@dispatch(TorchModel)
def make_backend(self, model: TorchModel) -> BackendCallable:
torch_net = model.torch_model.to(self.device)
# Names for parameters can be changed implicitly after compilation
# We keep the original names to align with names in eager mode
param_names = [k for k, _ in model.torch_model.named_parameters()]

do_grad_check = model.needs_grad_check()

Expand All @@ -59,21 +62,20 @@ def closure(inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
input_ts = [torch.from_numpy(v).to(self.device) for _, v in inputs.items()]
if do_grad_check:
outputs: List[torch.Tensor] = compiled(*input_ts)
params = {k: v for k, v in compiled.named_parameters()}
ret = {}

for name, output in zip(torch_net.output_like.keys(), outputs):
ret[name] = numpify(output)
if output.requires_grad:
# get Vector-Jacobian product
# Get Vector-Jacobian product
out_grad = torch.autograd.grad(
outputs=output,
inputs=params.values(),
inputs=compiled.parameters(),
grad_outputs=torch.ones_like(output),
retain_graph=True,
allow_unused=True,
)
for k, v in zip(params.keys(), out_grad):
for k, v in zip(param_names, out_grad):
ret[name + "_vjp_" + k] = numpify(v)
else:
with torch.no_grad():
Expand Down
12 changes: 9 additions & 3 deletions nnsmith/materialize/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def make_oracle(self) -> Oracle:
for name, output in zip(self.output_like.keys(), outputs):
output_dict[name] = numpify(output)
if output.requires_grad:
# get Vector-Jacobian Product (VJP)
# Get Vector-Jacobian Product (VJP)
out_grad = torch.autograd.grad(
outputs=output,
inputs=params.values(),
inputs=self.torch_model.parameters(),
grad_outputs=torch.ones_like(output),
retain_graph=True,
allow_unused=True,
Expand Down Expand Up @@ -222,8 +222,14 @@ def emit_def(self, mod_name: str, mod_cls: str) -> str:

tab = " " * 4
mod_text = ""
# _parameters only shows the upper-level parameters
# while
# named_parameters() shows all parameters recursively
for name, param in self.native_model._parameters.items():
mod_text += f"{2*tab}self.{name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype})))"
mod_text += (
f"{2*tab}self.{name} = torch.nn.Parameter"
+ f"(torch.empty({list(param.shape)}, dtype={param.dtype}), requires_grad={param.requires_grad})\n"
)

for name, mod in self.native_model.named_children():
if name == "":
Expand Down
29 changes: 20 additions & 9 deletions nnsmith/materialize/torch/symbolnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, Optional

import torch
import torch.fx as fx
from torch import nn

from nnsmith.abstract.dtype import DType
Expand All @@ -22,7 +23,7 @@
__ENABLE_RT_CHECK__ = os.getenv("NNSMITH_RT_CHECK", "0") == "1"


# Probablistically, sampling at positive domain is beneficial.
# Probabilistically, sampling at positive domain is beneficial.
def random_tensor(shape, dtype: torch.dtype, margin=4, base=5, use_cuda=False):
# center: -margin ~ 0 ~ +margin
dev = torch.device("cuda" if use_cuda else "cpu")
Expand Down Expand Up @@ -131,7 +132,7 @@ def __init__(
self.add_module(f"m{i}", target)

if isinstance(target, nn.Parameter):
self.register_parameter(inst.retval(), target)
setattr(self, inst.retval(), target)
else:
self.instructions.append(
(torch_fn, inst.iexpr.args, inst.retvals(), inst.iexpr.op)
Expand Down Expand Up @@ -342,12 +343,24 @@ def grad_input_gen(
def use_cuda(self):
self.cuda()

def make_param_map(self) -> Dict[str, torch.Tensor]:
tensor_map: Dict[str, torch.Tensor] = {}

for k, v in self._parameters.items():
# Workaround: https://github.com/ise-uiuc/nnsmith/pull/122
if hasattr(self, k):
attr = getattr(self, k)
if isinstance(attr, (nn.Parameter, fx.Proxy)):
tensor_map[k] = attr
continue
tensor_map[k] = v

return tensor_map

def forward(self, *args):
self.differentiable = True

tensor_map: Dict[str, torch.Tensor] = {
k: v for k, v in self._parameters.items()
}
tensor_map: Dict[str, torch.Tensor] = self.make_param_map()
for i, key in enumerate(self.input_map.keys()):
tensor_map[key] = args[i]

Expand All @@ -360,7 +373,7 @@ def forward(self, *args):

# REAL FORWARD.
output_tensors = inst(*input_tensors)
if isinstance(output_tensors, torch.fx.proxy.Proxy):
if isinstance(output_tensors, fx.proxy.Proxy):
# TODO(@ganler, @co1lin): can we do systematic check through the output type?
if output_tensors.node.target not in [torch.split, torch.chunk]:
output_tensors = [output_tensors]
Expand All @@ -382,9 +395,7 @@ def forward(self, *args):
def forward_grad(self, *args):
self.differentiable = True

tensor_map: Dict[str, torch.Tensor] = {
k: v for k, v in self._parameters.items()
}
tensor_map: Dict[str, torch.Tensor] = self.make_param_map()
for i, key in enumerate(self.input_map.keys()):
tensor_map[key] = args[i]

Expand Down
95 changes: 92 additions & 3 deletions tests/torch/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def test_render_model_only():

# pickle is not used (no `ModuleList` in the code)
# so no need to import pickle
print(render.render())
assert (
render.render()
== R"""
Expand Down Expand Up @@ -293,7 +292,97 @@ def forward(self, x):
)


def test_render_e2e_pt2():
def test_render_e2e_param_pt2():
model = TorchModelCPU()

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.nn.Parameter(
torch.empty([1], dtype=torch.int16), requires_grad=False
)

def forward(self, x):
squeeze = self.const.squeeze(0)
mul = torch.mul(squeeze, x)
expand = mul.expand(1)
expand_1 = mul.expand(1, 1, 1, 1)
max_1 = torch.max(expand_1, x)
return (expand, max_1)

model.torch_model = M()

model.torch_model.input_like = {"x": AbsTensor([], DType.int16)}

render = Render()
render.emit_model(model)
render.emit_input(model)
render.emit_backend(PT2(target="cpu", optmax=True))

rendered = render.render()

# pickle is not used (no `ModuleList` in the code)
# so no need to import pickle
assert (
rendered
== R"""
import numpy as np
import torch
import pickle
# Model definition
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.nn.Parameter(torch.empty([1], dtype=torch.int16), requires_grad=False)
def forward(self, x):
const = self.const
squeeze = const.squeeze(0); const = None
mul = torch.mul(squeeze, x); squeeze = None
expand = mul.expand(1)
expand_1 = mul.expand(1, 1, 1, 1); mul = None
max_1 = torch.max(expand_1, x); expand_1 = x = None
return (expand, max_1)
m = M()
# Initialize weight
# None
# Initialize input
inp = [np.zeros([], dtype='int16')]
# Compile the model
opt = torch.compile(m, fullgraph=True, backend='inductor', mode=None)
# Eager run
m_out = m(*[torch.from_numpy(v).to('cpu') for v in inp])
m_out = [v.cpu().detach() for v in m_out] # torch2numpy
m_out = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in m_out] # torch2numpy
# Compiled run
opt_out = opt(*[torch.from_numpy(v).to('cpu') for v in inp])
opt_out = [v.cpu().detach() for v in opt_out] # torch2numpy
opt_out = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in opt_out] # torch2numpy
# Differential testing
for i, (l, r) in enumerate(zip(m_out, opt_out)):
np.testing.assert_allclose(l, r, rtol=1e-2, atol=1e-3, err_msg=f"Result mismatch @ index {i}")
"""
)

# Run rendered code in a subprocess as a smoke test
subprocess.run(
["python", "-c", rendered],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)


def test_render_e2e_cnn_pt2():
model = TorchModelCPU()
model.torch_model = CNN()

Expand Down Expand Up @@ -366,7 +455,7 @@ def forward(self, x):
)


def test_render_e2e_torchjit():
def test_render_e2e_cnn_torchjit():
model = TorchModelCPU()
model.torch_model = CNN()

Expand Down

0 comments on commit 8010210

Please sign in to comment.