Skip to content

Commit

Permalink
Do not report successful loads that are cancelled, unload immediately…
Browse files Browse the repository at this point in the history
… instead.

PiperOrigin-RevId: 419600593
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed Jan 4, 2022
1 parent 8548e2d commit 9232699
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
52 changes: 33 additions & 19 deletions tensorflow_serving/core/loader_harness.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,23 @@ Status LoaderHarness::Load() {
[&]() { return loader_->LoadWithMetadata({id_}); },
[&]() { return cancel_load_retry(); });

{
mutex_lock l(mu_);
if (status.ok()) {
if (status.ok()) {
if (cancel_load_retry()) {
// Servable is going to be unloaded very soon,
// we report a failure here so that we do not accidentally
// report that the servable is available.
TF_RETURN_IF_ERROR(UnloadDueToCancelledLoad());
return errors::Cancelled(
strings::StrCat("Loading of servable cancelled"));
}
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(TransitionState(State::kLoading, State::kReady));
LOG(INFO) << "Successfully loaded servable version " << id_;
} else {
ErrorInternal(status);
}
} else {
mutex_lock l(mu_);
ErrorInternal(status);
}

return status;
Expand All @@ -103,21 +112,11 @@ Status LoaderHarness::UnloadRequested() {
return Status::OK();
}

void LoaderHarness::set_cancel_load_retry(const bool value) {
mutex_lock l(mu_);
cancel_load_retry_ = value;
}

bool LoaderHarness::cancel_load_retry() {
mutex_lock l(mu_);
return cancel_load_retry_;
}

Status LoaderHarness::Unload() {
Status LoaderHarness::UnloadInternal(State from_state) {
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(TransitionState(State::kQuiesced, State::kUnloading));
LOG(INFO) << "Unloading servable version " << id_;
TF_RETURN_IF_ERROR(TransitionState(from_state, State::kUnloading));
LOG(INFO) << "Unloading just-loaded servable version " << id_;
}

loader_->Unload();
Expand All @@ -127,10 +126,25 @@ Status LoaderHarness::Unload() {
TF_RETURN_IF_ERROR(TransitionState(State::kUnloading, State::kDisabled));
LOG(INFO) << "Done unloading servable version " << id_;
}

return Status::OK();
}

Status LoaderHarness::UnloadDueToCancelledLoad() {
return UnloadInternal(State::kLoading);
}

void LoaderHarness::set_cancel_load_retry(const bool value) {
mutex_lock l(mu_);
cancel_load_retry_ = value;
}

bool LoaderHarness::cancel_load_retry() {
mutex_lock l(mu_);
return cancel_load_retry_;
}

Status LoaderHarness::Unload() { return UnloadInternal(State::kQuiesced); }

Status LoaderHarness::StartQuiescing() {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_serving/core/loader_harness.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class LoaderHarness final {
// same error.
Status TransitionState(State from, State to) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);

Status UnloadInternal(State from_state) TF_LOCKS_EXCLUDED(mu_);
Status UnloadDueToCancelledLoad() TF_LOCKS_EXCLUDED(mu_);

const ServableId id_;
const std::unique_ptr<Loader> loader_;
// Additional state that the manager uses.
Expand Down
23 changes: 23 additions & 0 deletions tensorflow_serving/core/loader_harness_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,29 @@ TEST(LoaderHarnessTest, RetryOnLoadErrorCancelledLoad) {
}));
}

// Tests unload when ongoing load is cancelled.
TEST(LoaderHarnessTest, UnloadDueToCancelledLoad) {
test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;

const ServableId servable_id = {"test", 0};
LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));

EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
.WillOnce(InvokeWithoutArgs([]() {
Env::Default()->SleepForMicroseconds(1000000);
return Status::OK();
}));

std::unique_ptr<Thread> test_thread(
Env::Default()->StartThread(ThreadOptions(), "test", [&harness]() {
TF_ASSERT_OK(harness.LoadRequested());
TF_ASSERT_OK(harness.LoadApproved());
harness.set_cancel_load_retry(true);
const Status status = harness.Load();
EXPECT_THAT(status.error_message(), HasSubstr("cancelled"));
}));
}

} // namespace
} // namespace serving
} // namespace tensorflow

0 comments on commit 9232699

Please sign in to comment.