Skip to content

Commit

Permalink
Remove pybind11_abseil from convert_wrapper
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631184467
  • Loading branch information
yijie-yang authored and copybara-github committed May 7, 2024
1 parent 751fb07 commit 601af5a
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/builtin-adapter/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ http_archive(
],
)

TENSORFLOW_COMMIT = "8b5370df5655da95b113362e18a9cd850ded7973"
TENSORFLOW_COMMIT = "7b26bb0c266f8b122932904b5f1216818429709d"

TENSORFLOW_SHA256 = "b0366cb1eef7bdc18f9e733c082115894f0147f87bf764d0474bd97f686d5d12"
TENSORFLOW_SHA256 = "418bd874023857039ffbadd7ecd95dd191a750aaa9da1170ae71e03946c4a4c0"

http_archive(
name = "org_tensorflow",
Expand Down
3 changes: 1 addition & 2 deletions src/builtin-adapter/python/convert_wrapper/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ pybind_extension(
"_pywrap_convert_wrapper.pyi",
],
deps = [
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:import_status_module",
"@pybind11_abseil//pybind11_abseil:status_casters",
"//:direct_flatbuffer_to_json_graph_convert",
"//:direct_saved_model_to_json_graph_convert",
"//:model_json_graph_convert",
Expand Down
77 changes: 53 additions & 24 deletions src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <stdexcept>
#include <string>

#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "pybind11/pybind11.h"
#include "pybind11_abseil/import_status_module.h"
#include "pybind11_abseil/status_casters.h" // IWYU pragma : keep
#include "direct_flatbuffer_to_json_graph_convert.h"
#include "direct_saved_model_to_json_graph_convert.h"
#include "model_json_graph_convert.h"
Expand All @@ -27,67 +29,94 @@ using tooling::visualization_client::VisualizeConfig;
namespace pybind11 {

PYBIND11_MODULE(_pywrap_convert_wrapper, m) {
pybind11::google::ImportStatusModule();

class_<VisualizeConfig>(m, "VisualizeConfig")
.def(init<>())
.def_readwrite("const_element_count_limit",
&VisualizeConfig::const_element_count_limit);

m.def(
"ConvertSavedModelToJson",
[](const VisualizeConfig& config, absl::string_view model_path) {
return ::tooling::visualization_client::ConvertSavedModelToJson(
config, model_path);
[](const VisualizeConfig& config,
absl::string_view model_path) -> std::string {
const absl::StatusOr<std::string> json_or_status =
::tooling::visualization_client::ConvertSavedModelToJson(
config, model_path);
if (!json_or_status.ok()) {
throw std::runtime_error(json_or_status.status().ToString());
}
return json_or_status.value();
},
R"pbdoc(
Converts a SavedModel to visualizer JSON string through tf dialect MLIR
module if succeeded, otherwise raises `StatusNotOk` exception.
module if succeeded, otherwise raises `RuntimeError` exception.
)pbdoc");

m.def(
"ConvertFlatbufferToJson",
[](const VisualizeConfig& config, absl::string_view model_path,
bool is_modelpath) {
return ::tooling::visualization_client::ConvertFlatbufferToJson(
config, model_path, is_modelpath);
bool is_modelpath) -> std::string {
const absl::StatusOr<std::string> json_or_status =
::tooling::visualization_client::ConvertFlatbufferToJson(
config, model_path, is_modelpath);
if (!json_or_status.ok()) {
throw std::runtime_error(json_or_status.status().ToString());
}
return json_or_status.value();
},
R"pbdoc(
Converts a Flatbuffer to visualizer JSON string through tfl dialect MLIR
module if succeeded, otherwise raises `StatusNotOk` exception.
module if succeeded, otherwise raises `RuntimeError` exception.
)pbdoc");

m.def(
"ConvertFlatbufferDirectlyToJson",
[](const VisualizeConfig& config, absl::string_view model_path) {
return ::tooling::visualization_client::ConvertFlatbufferDirectlyToJson(
config, model_path);
[](const VisualizeConfig& config,
absl::string_view model_path) -> std::string {
const absl::StatusOr<std::string> json_or_status =
::tooling::visualization_client::ConvertFlatbufferDirectlyToJson(
config, model_path);
if (!json_or_status.ok()) {
throw std::runtime_error(json_or_status.status().ToString());
}
return json_or_status.value();
},
R"pbdoc(
Converts a Flatbuffer directly to visualizer JSON string without MLIR or
execution. Raises `StatusNotOk` exception if failed.
execution. Raises `RuntimeError` exception if failed.
)pbdoc");

m.def(
"ConvertSavedModelDirectlyToJson",
[](const VisualizeConfig& config, absl::string_view model_path) {
return ::tooling::visualization_client::ConvertSavedModelDirectlyToJson(
config, model_path);
[](const VisualizeConfig& config,
absl::string_view model_path) -> std::string {
const absl::StatusOr<std::string> json_or_status =
::tooling::visualization_client::ConvertSavedModelDirectlyToJson(
config, model_path);
if (!json_or_status.ok()) {
throw std::runtime_error(json_or_status.status().ToString());
}
return json_or_status.value();
},
R"pbdoc(
Converts a SavedModel directly to visualizer JSON string without MLIR or
execution. Raises `StatusNotOk` exception if failed.
execution. Raises `RuntimeError` exception if failed.
)pbdoc");

m.def(
"ConvertGraphDefDirectlyToJson",
[](const VisualizeConfig& config, absl::string_view model_path) {
return ::tooling::visualization_client::ConvertGraphDefDirectlyToJson(
config, model_path);
[](const VisualizeConfig& config,
absl::string_view model_path) -> std::string {
const absl::StatusOr<std::string> json_or_status =
::tooling::visualization_client::ConvertGraphDefDirectlyToJson(
config, model_path);
if (!json_or_status.ok()) {
throw std::runtime_error(json_or_status.status().ToString());
}
return json_or_status.value();
},
R"pbdoc(
Converts a GraphDef directly to visualizer JSON string without MLIR or
execution. Raises `StatusNotOk` exception if failed.
execution. Raises `RuntimeError` exception if failed.
)pbdoc");
}

Expand Down
4 changes: 4 additions & 0 deletions src/builtin-adapter/tools/load_opdefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ namespace visualization_client {
struct OpMetadata {
std::vector<std::string> arguments;
std::vector<std::string> results;

OpMetadata(const std::vector<std::string>& arguments,
const std::vector<std::string>& results)
: arguments(arguments), results(results) {}
};

absl::flat_hash_map<std::string, OpMetadata> LoadTfliteOpdefs();
Expand Down

0 comments on commit 601af5a

Please sign in to comment.