Skip to content

Commit

Permalink
Requantize lambda statically in PACT_IntegerBatchNormNd
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescoConti committed May 8, 2020
1 parent 1174ba7 commit 02dd866
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions nemo/quant/pact.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,12 @@ def integerize_weights(self):
"""

# kappa_int = self.kappa.abs().max()
# lamda_int = self.lamda.abs().max()
self.kappa.data[:] = self.kappa / self.eps_kappa #torch.round(pact_quantize_signed_inference(self.kappa.data[:], self.eps_kappa, kappa_int) / self.eps_kappa)
self.lamda.data[:] = self.lamda / self.eps_lamda #torch.round(pact_quantize_signed_inference(self.lamda.data[:], self.eps_lamda, lamda_int) / self.eps_lamda)
self.kappa.data[:] = self.kappa / self.eps_kappa
self.lamda.data[:] = self.lamda / self.eps_lamda

# requantize lamda to eps_kappa*eps_in (which is the output precision)
self.lamda.data[:] = pact_integer_requantize(self.lamda, self.eps_lamda, self.eps_kappa*self.eps_in)


def get_output_eps(self, eps_in):
r"""Get the output quantum (:math:`\varepsilon`) given the input one.
Expand Down Expand Up @@ -913,9 +915,7 @@ def forward(self, x):
"""

# requantize lamda to eps_kappa*eps_in (which is the output precision)
lamda_rq = pact_integer_requantize(self.lamda, self.eps_lamda, self.eps_kappa*self.eps_in)
return self.kappa*x + lamda_rq
return self.kappa*x + self.lamda

class PACT_Identity(torch.nn.Module):
r"""Identity module.
Expand Down

0 comments on commit 02dd866

Please sign in to comment.