Skip to content

Commit

Permalink
[BugFix][UMA] Protect target registration (apache#13624)
Browse files Browse the repository at this point in the history
This PR address fixes for UMA target registration.
* Fix the doc issue apache#13304
* Continues stalled PR apache#12731

Changes:
* Incorporates all proposed fixes from mentioned [PR apache#12731](apache#12731)
* Address test case concerns and discussions from [PR apache#12731](apache#12731)
* **NEW:** Already exiting target cannot be created, explicit error on this.
* **NEW:** Attributes having special/reserved scope cannot be created explicitly.

It also address proper test cases for all the above.

Signed-off-by: tqchen <[email protected]>
  • Loading branch information
cbalint13 authored and tqchen committed Feb 20, 2023
1 parent 1d98634 commit 50b33ac
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 16 deletions.
2 changes: 1 addition & 1 deletion gallery/tutorial/uma.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
#

######################################################################
# .. image:: https://raw.githubusercontent.com/apache/tvm-site/main/images/tutorial/uma_vanilla_block_diagram.png
# .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/tutorial/uma_vanilla_block_diagram.png
# :width: 100%
# :alt: A block diagram of Vanilla
#
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/backend/contrib/uma/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,12 @@ def register(self) -> None:
"""
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")

for name, attr in self._target_attrs:
for name, attr in self._target_attrs.items():
if attr is None:
raise ValueError("Target attribute None is not supported.")

if registration_func(self.target_name, self._target_attrs):
# skip if target is already registered
if self.target_name not in tvm.target.Target.list_kinds():
registration_func(self.target_name, self._target_attrs)
self._relay_to_relay.register()
self._relay_to_tir.register()
self._tir_to_runtime.register()
Expand Down
24 changes: 15 additions & 9 deletions src/relay/backend/contrib/uma/targets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,23 @@ namespace tvm {
namespace relay {
namespace contrib {
namespace uma {
tvm::transform::Pass RelayToTIR(String target_name);
transform::Pass RelayToTIR(String target_name);
runtime::Module TIRToRuntime(IRModule mod, Target target);
} // namespace uma
} // namespace contrib
} // namespace relay

TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
.set_body_typed([](String target_name, Map<String, ObjectRef> attr_options) -> bool {
// @todo(cgerum): We probably should get rid of target.register rather sooner than later
// And use a proper registry for uma backends
for (const String registered_target_name : ::tvm::TargetKindRegEntry::ListTargetKinds()) {
// create only new target and init only once
for (const String registered_target_name : TargetKindRegEntry::ListTargetKinds()) {
if (registered_target_name == target_name) {
return false;
LOG(FATAL) << "TVM UMA Error: Target is already registered: " << target_name;
}
}

auto target_kind =
::tvm::TargetKindRegEntry::RegisterOrGet(target_name)
TargetKindRegEntry::RegisterOrGet(target_name)
.set_name()
.set_default_device_type(kDLCPU)
.add_attr_option<Array<String>>("keys")
Expand All @@ -58,20 +57,27 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
.add_attr_option<Array<String>>("libs")
.add_attr_option<Target>("host")
.add_attr_option<Integer>("from_device")
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
.set_attr<FTVMRelayToTIR>(attr::kRelayToTIR,
relay::contrib::uma::RelayToTIR(target_name))
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::uma::TIRToRuntime);

// target kind attrs inventory
auto kind = TargetKind::Get(target_name).value();
auto list_attrs = TargetKindRegEntry::ListTargetKindOptions(kind);

for (auto& attr_option : attr_options) {
auto option_name = attr_option.first;
auto default_value = attr_option.second;
if (list_attrs.find(option_name) != list_attrs.end()) {
LOG(FATAL) << "TVM UMA Error: Attribute is already registered: " << option_name;
}
if (default_value->IsInstance<StringObj>()) {
target_kind.add_attr_option<String>(option_name, Downcast<String>(default_value));
} else if (default_value->IsInstance<IntImmNode>()) {
target_kind.add_attr_option<Integer>(option_name, Downcast<Integer>(default_value));
} else {
LOG(FATAL) << "Only String, Integer, or Bool are supported. Given attribute option type: "
<< attr_option.second->GetTypeKey();
LOG(FATAL) << "TypeError: Only String, Integer, or Bool are supported. "
<< "Given attribute option type: " << attr_option.second->GetTypeKey();
}
}
return true;
Expand Down
25 changes: 22 additions & 3 deletions tests/python/contrib/test_uma/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,42 @@ def test_uma_target(target_name, target_attrs, target_args):
[
("float_attr", 3.14),
("none_attr", None),
("model", "my_model"),
],
)
def test_invalid_attr_option(attr_name: str, target_attr: Union[str, int, bool, float, None]):
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
if target_attr is None:
# None cannot be caught as TVMError, as it causes a SIGKILL, therefore it must be prevented to be
# entered into relay.backend.contrib.uma.RegisterTarget at Python level.
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Target attribute None is not supported."):
uma_backend = VanillaAcceleratorBackend()
uma_backend._target_attrs = {attr_name: target_attr}
uma_backend.register()
elif "model" in attr_name:
target_name = f"{attr_name}_{target_attr}"
target_attr = {attr_name: target_attr}
with pytest.raises(tvm.TVMError, match=r"Attribute is already registered: .*"):
registration_func(target_name, target_attr)
else:
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
target_name = f"{attr_name}_{target_attr}"
target_attr = {attr_name: target_attr}
with pytest.raises(tvm.TVMError, match=r"Only String, Integer, or Bool are supported. .*"):
with pytest.raises(TypeError, match=r"Only String, Integer, or Bool are supported. .*"):
registration_func(target_name, target_attr)


@pytest.mark.parametrize(
"target_name",
[
"llvm",
"c",
],
)
def test_target_duplication(target_name: str):
with pytest.raises(tvm.TVMError, match=r"TVM UMA Error: Target is already registered: .*"):
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
registration_func(target_name, {})


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 50b33ac

Please sign in to comment.