Skip to content

Commit

Permalink
Extra CR comments from #95621 (#96043)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezyang authored and cyyever committed Mar 12, 2023
1 parent 92ef868 commit d50ade0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
9 changes: 4 additions & 5 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,10 @@ def same(
log.error(f"Accuracy failed (float): {ref} != {res} (within tol={tol})")
return r
elif is_numpy_int_type(ref) or is_numpy_float_type(ref):
if relax_numpy_equality:
if is_numpy_int_type(ref):
ref = ref.item()
if is_numpy_int_type(res):
res = res.item()
if relax_numpy_equality and not (
is_numpy_int_type(res) or is_numpy_float_type(res)
):
ref = ref.item()
r = (type(ref) is type(res)) and (ref == res)
if not r:
log.error(f"Accuracy failed (numpy): {ref} != {res}")
Expand Down
8 changes: 4 additions & 4 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,10 @@ def wrap_unspecialized_primitive(self, value):
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]
else:
# NB: We do not do float. For motivation, see
# https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
if (
config.dynamic_shapes
and isinstance(value, int)
Expand All @@ -799,11 +803,7 @@ def wrap_unspecialized_primitive(self, value):
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source)
)
# TODO: Do float?
# Not entirely clear we want to do this, as float inputs don't
# work with inductor codegen at the moment
else:
# TODO: Eliminate this case entirely
wrapped_value = torch.tensor(value)
if not isinstance(self.get_source(), RandomValueSource):
guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH, True)}
Expand Down

0 comments on commit d50ade0

Please sign in to comment.