Skip to content

Commit

Permalink
Dynamo fx completeness (#162)
Browse files Browse the repository at this point in the history
This PR takes care of
#139.

For the weird tensor meta data, we should only throw error if quantized.
If we start looking at other things like stride, memory_format, etc. the
pytest fails.

For the TODO where we wanted to throw a KeyError, we already check if
parameter.name in node.kwargs in the if statement above, so we won't run
into an invalid key.

Other issues that were part of this importer completeness task have been
addressed and documented.

Fixes #139 
Fixes #140 
Fixes #141 
Fixes #142 
Fixes #143 
Fixes #144
  • Loading branch information
saienduri authored and IanNod committed Nov 9, 2023
1 parent f6039f4 commit 48a855f
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions python/shark_turbine/dynamo/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,14 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
val = node.meta.get("val")
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# TODO: We should probably only be doing this if "vanilla".
# Specifically, there are strides/qparams/etc on there that
# should be annotated somewhere.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/139
return self.tensor_metadata_to_type(tensor_meta)
# Quantized tensor meta data is not preserved in our lowering,
# so throw error instead of silently doing wrong thing.
if (tensor_meta.is_quantized):
raise NotImplementedError(
f"Quantized tensor meta data is not supported."
)
else:
return self.tensor_metadata_to_type(tensor_meta)
elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta
Expand Down Expand Up @@ -630,23 +633,16 @@ def _import_torch_op_overload(
op_overload = getattr(op_overload, op_attrs[i])
schema = op_overload._schema

if not self._c.is_registered_operation(mlir_op_name):
# TODO: Implement a config setting to allow these to flow through.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/141
raise NotImplementedError(
f"Unimplemented torch op in the IREE compiler: '{mlir_op_name}' "
f"(either implement this op/variant or configure the compiler to "
f"allow unknown operations and fallback to PyTorch)."
)

return_count = len(schema.returns)
if return_count == 1:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
# if "tensor_meta" is None, this will throw unsupported placeholder node error
result_types = [self._cc.node_val_to_type(node)]
elif return_count == 0:
# TODO: Implement (https://github.com/nod-ai/SHARK-Turbine/issues/142)
raise NotImplementedError("FIXME: Zero ATen results")
# Some torch ops do have 0 returns, and these are supported with ZeroResults
# op trait. Python bindings for IR creation allow us to pass empty result_types
# for such ops. Therefore, we pass an empty result types for these cases.
result_types = []
else:
# Multi-return will unpack the meta["val"] and trigger our getitem subscripting
# short-circuit above. Note that if we ever choose to also fully reify Python
Expand All @@ -663,8 +659,6 @@ def _import_torch_op_overload(
operands = []
for i, parameter in enumerate(schema.arguments):
if parameter.kwarg_only and parameter.name in node.kwargs:
# TODO: Nice error if KeyError.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/143
operands.append(
self._import_argument(
loc, node.kwargs[parameter.name], parameter.type
Expand All @@ -681,12 +675,23 @@ def _import_torch_op_overload(
)
)

operation = Operation.create(
mlir_op_name,
results=result_types,
operands=operands,
loc=loc,
)
# Support unregistered torch ops using torch.operator.
# torch.operator is used to represent ops from registry
# which haven't been generated by torch_ods_gen.py.
if not self._c.is_registered_operation(mlir_op_name):
operation = Operation.create(
"torch.operator",
results=result_types,
operands=operands,
loc=loc,
)
else:
operation = Operation.create(
mlir_op_name,
results=result_types,
operands=operands,
loc=loc,
)

# Record value mapping.
for i, value in enumerate(operation.results):
Expand Down Expand Up @@ -830,18 +835,15 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
if isinstance(arg, list):
return self._import_list_argument(loc, arg, expected_jit_type)

# The LITERAL_CONVERTER_MAP maps each arg to its respective constant
# of the expected jit IR type (types like torch.dtype will form a chain of
# maps to get to constant of expected_jit_type).
cvt = LITERAL_CONVERTER_MAP.lookup(type(arg))
if cvt is None:
raise RuntimeError(f"Unhandled default value ({arg.__class__}): {arg})")
with loc:
return cvt(arg, self, self._cc)

# TODO: Support torch specific types which show up in function schemas.
# These all require an expected_jit_type to convert.
# torch.dtype, torch.device, torch.memory_format, torch.layout
# list
# See: https://github.com/nod-ai/SHARK-Turbine/issues/144


class TypeSubclassMap:
"""Mapping of super-types to values.
Expand Down

0 comments on commit 48a855f

Please sign in to comment.