From 801021043a53a42279700933e1926348a4cd0e76 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Wed, 4 Oct 2023 21:22:23 -0500 Subject: [PATCH] fix(render): nn.Parameter codegen (#122) * fix(render): remove extra bracket for nn.Parameter codegen * fix(render): reference nn.Parameter as attr in codegen * test(render): test codegen to nn.Parameter * fix(grad-check): make param names consistent after pt2 compilation --- nnsmith/backends/pt2.py | 10 +-- nnsmith/materialize/torch/__init__.py | 12 +++- nnsmith/materialize/torch/symbolnet.py | 29 +++++--- tests/torch/test_render.py | 95 +++++++++++++++++++++++++- 4 files changed, 127 insertions(+), 19 deletions(-) diff --git a/nnsmith/backends/pt2.py b/nnsmith/backends/pt2.py index ca399a3..351f2e8 100644 --- a/nnsmith/backends/pt2.py +++ b/nnsmith/backends/pt2.py @@ -39,6 +39,9 @@ def import_libs(self) -> List[str]: @dispatch(TorchModel) def make_backend(self, model: TorchModel) -> BackendCallable: torch_net = model.torch_model.to(self.device) + # Names for parameters can be changed implicitly after compilation + # We keep the original names to align with names in eager mode + param_names = [k for k, _ in model.torch_model.named_parameters()] do_grad_check = model.needs_grad_check() @@ -59,21 +62,20 @@ def closure(inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: input_ts = [torch.from_numpy(v).to(self.device) for _, v in inputs.items()] if do_grad_check: outputs: List[torch.Tensor] = compiled(*input_ts) - params = {k: v for k, v in compiled.named_parameters()} ret = {} for name, output in zip(torch_net.output_like.keys(), outputs): ret[name] = numpify(output) if output.requires_grad: - # get Vector-Jacobian product + # Get Vector-Jacobian product out_grad = torch.autograd.grad( outputs=output, - inputs=params.values(), + inputs=compiled.parameters(), grad_outputs=torch.ones_like(output), retain_graph=True, allow_unused=True, ) - for k, v in zip(params.keys(), out_grad): + for k, v in zip(param_names, out_grad): ret[name + "_vjp_" + k] = numpify(v) else: with torch.no_grad(): diff --git a/nnsmith/materialize/torch/__init__.py b/nnsmith/materialize/torch/__init__.py index 50f3ea5..117f054 100644 --- a/nnsmith/materialize/torch/__init__.py +++ b/nnsmith/materialize/torch/__init__.py @@ -117,10 +117,10 @@ def make_oracle(self) -> Oracle: for name, output in zip(self.output_like.keys(), outputs): output_dict[name] = numpify(output) if output.requires_grad: - # get Vector-Jacobian Product (VJP) + # Get Vector-Jacobian Product (VJP) out_grad = torch.autograd.grad( outputs=output, - inputs=params.values(), + inputs=self.torch_model.parameters(), grad_outputs=torch.ones_like(output), retain_graph=True, allow_unused=True, @@ -222,8 +222,14 @@ def emit_def(self, mod_name: str, mod_cls: str) -> str: tab = " " * 4 mod_text = "" + # _parameters only shows the upper-level parameters + # while + # named_parameters() shows all parameters recursively for name, param in self.native_model._parameters.items(): - mod_text += f"{2*tab}self.{name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype})))" + mod_text += ( + f"{2*tab}self.{name} = torch.nn.Parameter" + + f"(torch.empty({list(param.shape)}, dtype={param.dtype}), requires_grad={param.requires_grad})\n" + ) for name, mod in self.native_model.named_children(): if name == "": diff --git a/nnsmith/materialize/torch/symbolnet.py b/nnsmith/materialize/torch/symbolnet.py index 660cb29..7429e1c 100644 --- a/nnsmith/materialize/torch/symbolnet.py +++ b/nnsmith/materialize/torch/symbolnet.py @@ -5,6 +5,7 @@ from typing import Dict, Optional import torch +import torch.fx as fx from torch import nn from nnsmith.abstract.dtype import DType @@ -22,7 +23,7 @@ __ENABLE_RT_CHECK__ = os.getenv("NNSMITH_RT_CHECK", "0") == "1" -# Probablistically, sampling at positive domain is beneficial. +# Probabilistically, sampling at positive domain is beneficial. def random_tensor(shape, dtype: torch.dtype, margin=4, base=5, use_cuda=False): # center: -margin ~ 0 ~ +margin dev = torch.device("cuda" if use_cuda else "cpu") @@ -131,7 +132,7 @@ def __init__( self.add_module(f"m{i}", target) if isinstance(target, nn.Parameter): - self.register_parameter(inst.retval(), target) + setattr(self, inst.retval(), target) else: self.instructions.append( (torch_fn, inst.iexpr.args, inst.retvals(), inst.iexpr.op) @@ -342,12 +343,24 @@ def grad_input_gen( def use_cuda(self): self.cuda() + def make_param_map(self) -> Dict[str, torch.Tensor]: + tensor_map: Dict[str, torch.Tensor] = {} + + for k, v in self._parameters.items(): + # Workaround: https://github.com/ise-uiuc/nnsmith/pull/122 + if hasattr(self, k): + attr = getattr(self, k) + if isinstance(attr, (nn.Parameter, fx.Proxy)): + tensor_map[k] = attr + continue + tensor_map[k] = v + + return tensor_map + def forward(self, *args): self.differentiable = True - tensor_map: Dict[str, torch.Tensor] = { - k: v for k, v in self._parameters.items() - } + tensor_map: Dict[str, torch.Tensor] = self.make_param_map() for i, key in enumerate(self.input_map.keys()): tensor_map[key] = args[i] @@ -360,7 +373,7 @@ def forward(self, *args): # REAL FORWARD. output_tensors = inst(*input_tensors) - if isinstance(output_tensors, torch.fx.proxy.Proxy): + if isinstance(output_tensors, fx.proxy.Proxy): # TODO(@ganler, @co1lin): can we do systematic check through the output type? if output_tensors.node.target not in [torch.split, torch.chunk]: output_tensors = [output_tensors] @@ -382,9 +395,7 @@ def forward(self, *args): def forward_grad(self, *args): self.differentiable = True - tensor_map: Dict[str, torch.Tensor] = { - k: v for k, v in self._parameters.items() - } + tensor_map: Dict[str, torch.Tensor] = self.make_param_map() for i, key in enumerate(self.input_map.keys()): tensor_map[key] = args[i] diff --git a/tests/torch/test_render.py b/tests/torch/test_render.py index b50bd38..ace0bf5 100644 --- a/tests/torch/test_render.py +++ b/tests/torch/test_render.py @@ -245,7 +245,6 @@ def test_render_model_only(): # pickle is not used (no `ModuleList` in the code) # so no need to import pickle - print(render.render()) assert ( render.render() == R""" @@ -293,7 +292,97 @@ def forward(self, x): ) -def test_render_e2e_pt2(): +def test_render_e2e_param_pt2(): + model = TorchModelCPU() + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.nn.Parameter( + torch.empty([1], dtype=torch.int16), requires_grad=False + ) + + def forward(self, x): + squeeze = self.const.squeeze(0) + mul = torch.mul(squeeze, x) + expand = mul.expand(1) + expand_1 = mul.expand(1, 1, 1, 1) + max_1 = torch.max(expand_1, x) + return (expand, max_1) + + model.torch_model = M() + + model.torch_model.input_like = {"x": AbsTensor([], DType.int16)} + + render = Render() + render.emit_model(model) + render.emit_input(model) + render.emit_backend(PT2(target="cpu", optmax=True)) + + rendered = render.render() + + # pickle is not used (no `ModuleList` in the code) + # so no need to import pickle + assert ( + rendered + == R""" +import numpy as np +import torch +import pickle + +# Model definition +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.nn.Parameter(torch.empty([1], dtype=torch.int16), requires_grad=False) + + def forward(self, x): + const = self.const + squeeze = const.squeeze(0); const = None + mul = torch.mul(squeeze, x); squeeze = None + expand = mul.expand(1) + expand_1 = mul.expand(1, 1, 1, 1); mul = None + max_1 = torch.max(expand_1, x); expand_1 = x = None + return (expand, max_1) + +m = M() + + +# Initialize weight +# None + +# Initialize input +inp = [np.zeros([], dtype='int16')] + +# Compile the model +opt = torch.compile(m, fullgraph=True, backend='inductor', mode=None) + +# Eager run +m_out = m(*[torch.from_numpy(v).to('cpu') for v in inp]) +m_out = [v.cpu().detach() for v in m_out] # torch2numpy +m_out = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in m_out] # torch2numpy + +# Compiled run +opt_out = opt(*[torch.from_numpy(v).to('cpu') for v in inp]) +opt_out = [v.cpu().detach() for v in opt_out] # torch2numpy +opt_out = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in opt_out] # torch2numpy + +# Differential testing +for i, (l, r) in enumerate(zip(m_out, opt_out)): + np.testing.assert_allclose(l, r, rtol=1e-2, atol=1e-3, err_msg=f"Result mismatch @ index {i}") +""" + ) + + # Run rendered code in a subprocess as a smoke test + subprocess.run( + ["python", "-c", rendered], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + +def test_render_e2e_cnn_pt2(): model = TorchModelCPU() model.torch_model = CNN() @@ -366,7 +455,7 @@ def forward(self, x): ) -def test_render_e2e_torchjit(): +def test_render_e2e_cnn_torchjit(): model = TorchModelCPU() model.torch_model = CNN()