Skip to content

Commit

Permalink
Operator registry key fix
Browse files Browse the repository at this point in the history
Summary: We use ascii encoding instead of hex to make it consistent with codegen.

Reviewed By: larryliu0820

Differential Revision: D53681267

fbshipit-source-id: d459ae8157429d8a4cba223db7d4b14fa7abf0d6
  • Loading branch information
kirklandsign authored and facebook-github-bot committed Feb 13, 2024
1 parent a3dec31 commit 6d54911
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 59 deletions.
3 changes: 1 addition & 2 deletions runtime/executor/test/kernel_integration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ struct KernelControl {
// TensorMeta(ScalarType::Float, contiguous), // other
// TensorMeta(ScalarType::Float, contiguous), // out
// TensorMeta(ScalarType::Float, contiguous)}; // out (repeated)
KernelKey key = torch::executor::KernelKey(
"v0/\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01\xff");
KernelKey key = torch::executor::KernelKey("v1/6;0,1|6;0,1|6;0,1|6;0,1");
Kernel kernel = torch::executor::Kernel(
"aten::add.out", key, KernelControl::kernel_hook);
Error err = torch::executor::register_kernels({kernel});
Expand Down
3 changes: 1 addition & 2 deletions runtime/executor/test/kernel_resolution_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ TEST_F(KernelResolutionTest, ResolveKernelKeySuccess) {
// TensorMeta(ScalarType::Float, contiguous),
// TensorMeta(ScalarType::Float, contiguous),
// TensorMeta(ScalarType::Float, contiguous)};
KernelKey key = KernelKey(
"v0/\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01\xff");
KernelKey key = KernelKey("v1/6;0,1|6;0,1|6;0,1|6;0,1");
Kernel kernel_1 = Kernel(
"aten::add.out", key, [](KernelRuntimeContext& context, EValue** stack) {
(void)context;
Expand Down
41 changes: 29 additions & 12 deletions runtime/kernel/operator_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,33 +91,50 @@ bool hasOpsFn(const char* name, ArrayRef<TensorMeta> kernel_key) {
return getOperatorRegistry().hasOpsFn(name, kernel_key);
}

static void make_kernel_key_string(ArrayRef<TensorMeta> key, char* buf) {
static int copy_char_as_number_to_buf(char num, char* buf) {
if ((char)num < 10) {
*buf = '0' + (char)num;
buf += 1;
return 1;
} else {
*buf = '0' + ((char)num) / 10;
buf += 1;
*buf = '0' + ((char)num) % 10;
buf += 1;
return 2;
}
}

void make_kernel_key_string(ArrayRef<TensorMeta> key, char* buf);

void make_kernel_key_string(ArrayRef<TensorMeta> key, char* buf) {
if (key.empty()) {
// If no tensor is present in an op, kernel key does not apply
*buf = 0xff;
return;
}
strncpy(buf, "v0/", 3);
strncpy(buf, "v1/", 3);
buf += 3;
for (size_t i = 0; i < key.size(); i++) {
auto& meta = key[i];
*buf = (char)meta.dtype_;
buf += 1;
buf += copy_char_as_number_to_buf((char)meta.dtype_, buf);
*buf = ';';
buf += 1;
memcpy(buf, (char*)meta.dim_order_.data(), meta.dim_order_.size());
buf += meta.dim_order_.size();
*buf = (i < (key.size() - 1)) ? '|' : 0xff;
for (int j = 0; j < meta.dim_order_.size(); j++) {
buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf);
if (j != meta.dim_order_.size() - 1) {
*buf = ',';
buf += 1;
}
}
*buf = (i < (key.size() - 1)) ? '|' : 0x00;
buf += 1;
}
}

constexpr int BUF_SIZE = 307;

bool OperatorRegistry::hasOpsFn(
const char* name,
ArrayRef<TensorMeta> meta_list) {
char buf[BUF_SIZE] = {0};
char buf[KernelKey::MAX_SIZE] = {0};
make_kernel_key_string(meta_list, buf);
KernelKey kernel_key = KernelKey(buf);

Expand All @@ -140,7 +157,7 @@ const OpFunction& getOpsFn(const char* name, ArrayRef<TensorMeta> kernel_key) {
const OpFunction& OperatorRegistry::getOpsFn(
const char* name,
ArrayRef<TensorMeta> meta_list) {
char buf[BUF_SIZE] = {0};
char buf[KernelKey::MAX_SIZE] = {0};
make_kernel_key_string(meta_list, buf);
KernelKey kernel_key = KernelKey(buf);

Expand Down
32 changes: 10 additions & 22 deletions runtime/kernel/operator_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,18 @@ struct TensorMeta {
* registered.
*
* The format of a kernel key data is a string:
* "v<version>/<tensor_meta>|<tensor_meta>...\xff"
* Size: Up to 307 1 1 1 (18 +1) * 16
* "v<version>/<tensor_meta>|<tensor_meta>..."
* Size: Up to 691 1 1 1 (42 +1) * 16
* Assuming max number of tensors is 16 ^
* Kernel key version is v0 for now. If the kernel key format changes,
* Kernel key version is v1 for now. If the kernel key format changes,
* update the version to avoid breaking pre-existing kernel keys.
* Example: v0/0x07;0x00 0x01 0x02 0x03 \xff
* Example: v1/7;0,1,2,3
* The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
*
* The string is a byte array and contains non-printable characters. It must
* be terminated with a '\xff' so 0xff cannot be a scalar type.
*
* Each tensor_meta has the following format: "<dtype>;<dim_order...>"
* Size: Up to 18 1 1 16
* Assuming that the max number of dims is 16 ^
* Example: 0x07;0x00 0x01 0x02 0x03 for [double; 0, 1, 2, 3]
* Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
* Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2
* for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
* 7;0,1,2,3 for [double; 0, 1, 2, 3]
*
* IMPORTANT:
* Users should not construct a kernel key manually. Instead, it should be
Expand All @@ -129,7 +126,7 @@ struct KernelKey {
/* implicit */ KernelKey(const char* kernel_key_data)
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}

constexpr static char TERMINATOR = 0xff;
constexpr static int MAX_SIZE = 691;

bool operator==(const KernelKey& other) const {
return this->equals(other);
Expand All @@ -146,16 +143,7 @@ struct KernelKey {
if (is_fallback_) {
return true;
}
size_t i;
for (i = 0; kernel_key_data_[i] != TERMINATOR &&
other.kernel_key_data_[i] != TERMINATOR;
i++) {
if (kernel_key_data_[i] != other.kernel_key_data_[i]) {
return false;
}
}
return kernel_key_data_[i] == TERMINATOR &&
other.kernel_key_data_[i] == TERMINATOR;
return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0;
}

bool is_fallback() const {
Expand Down
23 changes: 2 additions & 21 deletions runtime/kernel/test/operator_registry_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>
#include <executorch/runtime/kernel/operator_registry.h>
#include <executorch/runtime/kernel/test/test_util.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/test/utils/DeathTest.h>

Expand Down Expand Up @@ -43,27 +44,7 @@ TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) {
ET_EXPECT_DEATH({ auto res = register_kernels(kernels_array); }, "");
}

void make_kernel_key(
std::vector<std::pair<ScalarType, std::vector<exec_aten::DimOrderType>>>
tensors,
char* buf) {
char* start = buf;
strncpy(buf, "v0/", 3);
buf += 3;
for (size_t i = 0; i < tensors.size(); i++) {
auto& tensor = tensors[i];
*buf = (char)tensor.first;
buf += 1;
*buf = ';';
buf += 1;
memcpy(buf, (char*)tensor.second.data(), tensor.second.size());
buf += tensor.second.size();
*buf = (i < (tensors.size() - 1)) ? '|' : 0xff;
buf += 1;
}
}

constexpr int BUF_SIZE = 307;
constexpr int BUF_SIZE = KernelKey::MAX_SIZE;

TEST_F(OperatorRegistryTest, KernelKeyEquals) {
char buf_long_contiguous[BUF_SIZE];
Expand Down
1 change: 1 addition & 0 deletions runtime/kernel/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def define_common_targets():
srcs = [
"operator_registry_test.cpp",
],
headers = ["test_util.h"],
deps = [
"//executorch/runtime/kernel:operator_registry",
"//executorch/runtime/kernel:kernel_runtime_context",
Expand Down
34 changes: 34 additions & 0 deletions runtime/kernel/test/test_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <vector>

#include <executorch/runtime/core/exec_aten/exec_aten.h>

namespace torch {
namespace executor {
void make_kernel_key_string(ArrayRef<TensorMeta> key, char* buf);

inline void make_kernel_key(
std::vector<std::pair<ScalarType, std::vector<exec_aten::DimOrderType>>>
tensors,
char* buf) {
std::vector<TensorMeta> meta;
for (auto& t : tensors) {
ArrayRef<exec_aten::DimOrderType> dim_order(
t.second.data(), t.second.size());
meta.emplace_back(t.first, dim_order);
}
auto meatadata = ArrayRef<TensorMeta>(meta.data(), meta.size());
make_kernel_key_string(meatadata, buf);
}

} // namespace executor
} // namespace torch

0 comments on commit 6d54911

Please sign in to comment.