From 923269955f5582cc26d0454992afa5c888a9377f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Jan 2022 08:22:52 -0800 Subject: [PATCH] Do not report successful loads that are cancelled, unload immediately instead. PiperOrigin-RevId: 419600593 --- tensorflow_serving/core/loader_harness.cc | 52 ++++++++++++------- tensorflow_serving/core/loader_harness.h | 3 ++ .../core/loader_harness_test.cc | 23 ++++++++ 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/tensorflow_serving/core/loader_harness.cc b/tensorflow_serving/core/loader_harness.cc index 95a7bfb2e08..81ded30a9ac 100644 --- a/tensorflow_serving/core/loader_harness.cc +++ b/tensorflow_serving/core/loader_harness.cc @@ -80,14 +80,23 @@ Status LoaderHarness::Load() { [&]() { return loader_->LoadWithMetadata({id_}); }, [&]() { return cancel_load_retry(); }); - { - mutex_lock l(mu_); - if (status.ok()) { + if (status.ok()) { + if (cancel_load_retry()) { + // Servable is going to be unloaded very soon, + // we report a failure here so that we do not accidentally + // report that the servable is available. + TF_RETURN_IF_ERROR(UnloadDueToCancelledLoad()); + return errors::Cancelled( + strings::StrCat("Loading of servable cancelled")); + } + { + mutex_lock l(mu_); TF_RETURN_IF_ERROR(TransitionState(State::kLoading, State::kReady)); LOG(INFO) << "Successfully loaded servable version " << id_; - } else { - ErrorInternal(status); } + } else { + mutex_lock l(mu_); + ErrorInternal(status); } return status; @@ -103,21 +112,11 @@ Status LoaderHarness::UnloadRequested() { return Status::OK(); } -void LoaderHarness::set_cancel_load_retry(const bool value) { - mutex_lock l(mu_); - cancel_load_retry_ = value; -} - -bool LoaderHarness::cancel_load_retry() { - mutex_lock l(mu_); - return cancel_load_retry_; -} - -Status LoaderHarness::Unload() { +Status LoaderHarness::UnloadInternal(State from_state) { { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(TransitionState(State::kQuiesced, State::kUnloading)); - LOG(INFO) << "Unloading servable version " << id_; + TF_RETURN_IF_ERROR(TransitionState(from_state, State::kUnloading)); + LOG(INFO) << "Unloading just-loaded servable version " << id_; } loader_->Unload(); @@ -127,10 +126,25 @@ Status LoaderHarness::Unload() { TF_RETURN_IF_ERROR(TransitionState(State::kUnloading, State::kDisabled)); LOG(INFO) << "Done unloading servable version " << id_; } - return Status::OK(); } +Status LoaderHarness::UnloadDueToCancelledLoad() { + return UnloadInternal(State::kLoading); +} + +void LoaderHarness::set_cancel_load_retry(const bool value) { + mutex_lock l(mu_); + cancel_load_retry_ = value; +} + +bool LoaderHarness::cancel_load_retry() { + mutex_lock l(mu_); + return cancel_load_retry_; +} + +Status LoaderHarness::Unload() { return UnloadInternal(State::kQuiesced); } + Status LoaderHarness::StartQuiescing() { mutex_lock l(mu_); TF_RETURN_IF_ERROR( diff --git a/tensorflow_serving/core/loader_harness.h b/tensorflow_serving/core/loader_harness.h index eb5f8ab79b8..7fbb2fac4ee 100644 --- a/tensorflow_serving/core/loader_harness.h +++ b/tensorflow_serving/core/loader_harness.h @@ -229,6 +229,9 @@ class LoaderHarness final { // same error. Status TransitionState(State from, State to) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status UnloadInternal(State from_state) TF_LOCKS_EXCLUDED(mu_); + Status UnloadDueToCancelledLoad() TF_LOCKS_EXCLUDED(mu_); + const ServableId id_; const std::unique_ptr loader_; // Additional state that the manager uses. diff --git a/tensorflow_serving/core/loader_harness_test.cc b/tensorflow_serving/core/loader_harness_test.cc index 2740ecb8366..02f0a2c3fa1 100644 --- a/tensorflow_serving/core/loader_harness_test.cc +++ b/tensorflow_serving/core/loader_harness_test.cc @@ -325,6 +325,29 @@ TEST(LoaderHarnessTest, RetryOnLoadErrorCancelledLoad) { })); } +// Tests unload when ongoing load is cancelled. +TEST(LoaderHarnessTest, UnloadDueToCancelledLoad) { + test_util::MockLoader* loader = new NiceMock; + + const ServableId servable_id = {"test", 0}; + LoaderHarness harness(servable_id, std::unique_ptr(loader)); + + EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id})) + .WillOnce(InvokeWithoutArgs([]() { + Env::Default()->SleepForMicroseconds(1000000); + return Status::OK(); + })); + + std::unique_ptr test_thread( + Env::Default()->StartThread(ThreadOptions(), "test", [&harness]() { + TF_ASSERT_OK(harness.LoadRequested()); + TF_ASSERT_OK(harness.LoadApproved()); + harness.set_cancel_load_retry(true); + const Status status = harness.Load(); + EXPECT_THAT(status.error_message(), HasSubstr("cancelled")); + })); +} + } // namespace } // namespace serving } // namespace tensorflow