Skip to content

Commit

Permalink
Validate batching params when creating the wrapped (batching) session,
Browse files Browse the repository at this point in the history
instead of doing this when creating the shared batch scheduler.

This is a more logical place to validate, as the shared batch scheduler
can potentially be shared across multiple sessions, each with its own
batching params.

PiperOrigin-RevId: 454176092
  • Loading branch information
netfs authored and tensorflow-copybara committed Jun 10, 2022
1 parent 8484770 commit 48ff72d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
15 changes: 15 additions & 0 deletions tensorflow_serving/servables/tensorflow/bundle_factory_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ Status WrapSessionForBatching(const BatchingParameters& batching_config,
return errors::Internal("session not set");
}

if (!batching_config.allowed_batch_sizes().empty()) {
// Verify that the last allowed batch size matches the max batch size.
const int last_allowed_size = batching_config.allowed_batch_sizes(
batching_config.allowed_batch_sizes().size() - 1);
const int max_size = batching_config.has_max_batch_size()
? batching_config.max_batch_size().value()
: Batcher::QueueOptions().input_batch_size_limit;
if (last_allowed_size != max_size) {
return errors::InvalidArgument(
"Last entry in allowed_batch_sizes must match max_batch_size; last "
"entry was ",
last_allowed_size, "; expected ", max_size);
}
}

auto queue_options = GetQueueOptions<
tensorflow::serving::BatchingSessionTask>(
batching_config,
Expand Down
17 changes: 0 additions & 17 deletions tensorflow_serving/servables/tensorflow/bundle_factory_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,6 @@ template <typename TaskType>
Status CreateBatchScheduler(
const BatchingParameters& batching_config,
std::shared_ptr<SharedBatchScheduler<TaskType>>* batch_scheduler) {
if (!batching_config.allowed_batch_sizes().empty()) {
// Verify that the last allowed batch size matches the max batch size.
const int last_allowed_size = batching_config.allowed_batch_sizes(
batching_config.allowed_batch_sizes().size() - 1);
const int max_size =
batching_config.has_max_batch_size()
? batching_config.max_batch_size().value()
: typename SharedBatchScheduler<TaskType>::QueueOptions()
.input_batch_size_limit;
if (last_allowed_size != max_size) {
return errors::InvalidArgument(
"Last entry in allowed_batch_sizes must match max_batch_size; last "
"entry was ",
last_allowed_size, "; expected ", max_size);
}
}

typename SharedBatchScheduler<TaskType>::Options options;
if (batching_config.has_num_batch_threads()) {
options.num_batch_threads = batching_config.num_batch_threads().value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,24 @@ TEST_F(BundleFactoryUtilTest, WrapSessionForBatching) {
test_util::TestMultipleRequests(10, bundle.session.get());
}

TEST_F(BundleFactoryUtilTest, BatchingConfigError) {
TEST_F(BundleFactoryUtilTest, WrapSessionForBatchingConfigError) {
BatchingParameters batching_params;
batching_params.mutable_max_batch_size()->set_value(2);
// The last entry in 'allowed_batch_sizes' is supposed to equal
// 'max_batch_size'. Let's violate that constraint and ensure we get an error.
batching_params.add_allowed_batch_sizes(1);
batching_params.add_allowed_batch_sizes(3);

std::shared_ptr<Batcher> batch_scheduler;
EXPECT_FALSE(CreateBatchScheduler(batching_params, &batch_scheduler).ok());
TF_ASSERT_OK(CreateBatchScheduler(batching_params, &batch_scheduler));

SavedModelBundle bundle;
TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(), export_dir_,
{"serve"}, &bundle));
auto status = WrapSessionForBatching(batching_params, batch_scheduler,
{test_util::GetTestSessionSignature()},
&bundle.session);
ASSERT_TRUE(errors::IsInvalidArgument(status));
}

TEST_F(BundleFactoryUtilTest, EstimateResourceFromPathWithBadExport) {
Expand Down

0 comments on commit 48ff72d

Please sign in to comment.