Skip to content

Commit

Permalink
status: Support a rough equivalent of absl::Status's payloads in prot…
Browse files Browse the repository at this point in the history
…o form.

absl::Status supports URL->bytes payloads, but tensorflow.StatusProto doesn't
so we can't currently roundtrip payloads. We can add that support in a simple
form via a proto map from string to bytes, as values are often serialized
protos.
PiperOrigin-RevId: 684501070
  • Loading branch information
pizzud authored and copybara-github committed Oct 10, 2024
1 parent e8a355f commit 06f1899
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
2 changes: 2 additions & 0 deletions tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ cc_library(
hdrs = ["status_to_from_proto.h"],
deps = [
":status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@xla//xla/tsl/protobuf:error_codes_proto_impl_cc",
"@xla//xla/tsl/protobuf:status_proto_cc",
] + tf_platform_deps("status"),
Expand Down
15 changes: 11 additions & 4 deletions tsl/platform/status_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ limitations under the License.
==============================================================================*/
#include "tsl/platform/status.h"

#include <string>
#include <unordered_map>
#include <vector>

Expand All @@ -29,7 +30,9 @@ limitations under the License.
namespace tsl {
namespace {

using ::testing::ElementsAre;
using ::testing::IsEmpty;
using ::testing::Pair;
using ::tsl::testing::IsOk;
using ::tsl::testing::StatusIs;

Expand Down Expand Up @@ -188,16 +191,18 @@ TEST(Status, SaveOKStatusToProto) {
}

TEST(Status, SaveErrorStatusToProto) {
tensorflow::StatusProto status_proto =
StatusToProto(errors::NotFound("Not found"));
tensorflow::StatusProto status_proto = StatusToProto(errors::Create(
absl::StatusCode::kNotFound, "Not found", {{"foo", "bar"}}));
EXPECT_EQ(status_proto.code(), error::NOT_FOUND);
EXPECT_EQ(status_proto.message(), "Not found");
EXPECT_THAT(status_proto.payload(), ElementsAre(Pair("foo", "bar")));
}

TEST(Status, SaveEmptyStatusToProto) {
tensorflow::StatusProto status_proto = StatusToProto(absl::Status());
EXPECT_EQ(status_proto.code(), error::OK);
EXPECT_THAT(status_proto.message(), IsEmpty());
EXPECT_THAT(status_proto.payload(), IsEmpty());
}

TEST(Status, MakeOKStatusFromProto) {
Expand All @@ -210,8 +215,10 @@ TEST(Status, MakeErrorStatusFromProto) {
tensorflow::StatusProto status_proto;
status_proto.set_code(error::INVALID_ARGUMENT);
status_proto.set_message("Invalid argument");
EXPECT_THAT(StatusFromProto(status_proto),
StatusIs(error::INVALID_ARGUMENT, "Invalid argument"));
status_proto.mutable_payload()->insert({"foo", "bar"});
absl::Status s = StatusFromProto(status_proto);
EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT, "Invalid argument"));
EXPECT_EQ(s.GetPayload("foo"), "bar");
}

TEST(Status, MakeStatusFromEmptyProto) {
Expand Down
22 changes: 19 additions & 3 deletions tsl/platform/status_to_from_proto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.

#include <string>

#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/protobuf/error_codes.pb.h"
#include "xla/tsl/protobuf/status.pb.h"
#include "tsl/platform/status.h"
Expand All @@ -32,6 +34,12 @@ tensorflow::StatusProto StatusToProto(const absl::Status& s) {
if (!s.message().empty()) {
status_proto.set_message(std::string(s.message()));
}

s.ForEachPayload(
[&status_proto](absl::string_view type_url, absl::Cord value) {
status_proto.mutable_payload()->insert(
{std::string(type_url), std::string(value)});
});
return status_proto;
}

Expand All @@ -41,15 +49,23 @@ absl::Status StatusFromProto(const tensorflow::StatusProto& proto,
if (proto.code() == tensorflow::error::OK) {
return absl::OkStatus();
}
return absl::Status(static_cast<absl::StatusCode>(proto.code()),
proto.message(), loc);
absl::Status s(static_cast<absl::StatusCode>(proto.code()), proto.message(),
loc);
for (const auto& [key, payload] : proto.payload()) {
s.SetPayload(key, absl::Cord(payload));
}
return s;
}
#else
Status StatusFromProto(const tensorflow::StatusProto& proto) {
if (proto.code() == tensorflow::error::OK) {
return OkStatus();
}
return Status(static_cast<absl::StatusCode>(proto.code()), proto.message());
Status s(static_cast<absl::StatusCode>(proto.code()), proto.message());
for (const auto& [key, payload] : proto.payload()) {
s.SetPayload(key, absl::Cord(payload));
}
return s;
}
#endif

Expand Down

0 comments on commit 06f1899

Please sign in to comment.