Skip to content

Commit

Permalink
Implemented gzip uncompression support in net_http.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 217447041
  • Loading branch information
wenbozhu authored and tensorflower-gardener committed Oct 17, 2018
1 parent 63982f7 commit b94f6c8
Show file tree
Hide file tree
Showing 11 changed files with 606 additions and 12 deletions.
7 changes: 2 additions & 5 deletions tensorflow_serving/util/net_http/compression/gzip_zlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ namespace net_http {
// XFL 2-4? DEFLATE flags
// OS ???? Operating system indicator (255 means unknown)

// gzip header, as a macro for sizeof()
#define GZIP_HEADER "\037\213\010\000\000\000\000\000\002\377"

constexpr char GZIP_HEADER[] = "\037\213\010\000\000\000\000\000\002\377";
constexpr uint8_t kMagicHeader[2] = {0x1f, 0x8b}; // gzip magic header

GZipHeader::Status GZipHeader::ReadMore(const char *inbuf, int inbuf_len,
Expand Down Expand Up @@ -808,8 +806,7 @@ int ZLib::UncompressGzipAndAllocate(Bytef **dest, uLongf *destLen,
*destLen = uncompress_length;

*dest = (Bytef *)malloc(*destLen);
if (*dest == nullptr) // probably a corrupted gzip buffer
return Z_MEM_ERROR;
if (*dest == nullptr) return Z_MEM_ERROR;

const int retval = Uncompress(*dest, destLen, source, sourceLen);
if (retval != Z_OK) { // just to make life easier for them
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_serving/util/net_http/compression/gzip_zlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class ZLib {
ZLib();
~ZLib();

// The max length of the buffer to store uncompressed data
static constexpr int64_t kMaxUncompressedBytes = 10 * 1024 * 1024; // 10MB

// Wipe a ZLib object to a virgin state. This differs from Reset()
// in that it also breaks any dictionary, gzip, etc, state.
void Reinit();
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_serving/util/net_http/server/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ cc_library(
"server_support.h",
],
deps = [
"//tensorflow_serving/util/net_http/compression:gzip_zlib",
"//tensorflow_serving/util/net_http/server/public:http_server_api",
"@com_github_libevent_libevent//:libevent",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@zlib_archive//:zlib",
],
)

Expand Down Expand Up @@ -55,6 +58,7 @@ cc_test(
":evhttp_server",
"//tensorflow_serving/core/test_util:test_main",
"//tensorflow_serving/util/net_http/client:evhttp_client",
"//tensorflow_serving/util/net_http/compression:gzip_zlib",
"//tensorflow_serving/util/net_http/internal:fixed_thread_pool",
"//tensorflow_serving/util/net_http/server/public:http_server",
"//tensorflow_serving/util/net_http/server/public:http_server_api",
Expand Down
60 changes: 57 additions & 3 deletions tensorflow_serving/util/net_http/server/internal/evhttp_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@ limitations under the License.

#include "tensorflow_serving/util/net_http/server/internal/evhttp_request.h"

#include <zlib.h>

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>

#include "absl/base/internal/raw_logging.h"
#include "absl/strings/string_view.h"

#include "libevent/include/event2/buffer.h"
#include "libevent/include/event2/event.h"
#include "libevent/include/event2/http.h"
#include "libevent/include/event2/keyvalq_struct.h"

#include "tensorflow_serving/util/net_http/compression/gzip_zlib.h"
#include "tensorflow_serving/util/net_http/server/public/header_names.h"

namespace tensorflow {
namespace serving {
namespace net_http {
Expand Down Expand Up @@ -145,10 +151,10 @@ std::unique_ptr<char, FreeDeleter> EvHTTPRequest::ReadRequestBytes(
evbuffer* input_buf =
evhttp_request_get_input_buffer(parsed_request_->request);
if (input_buf == nullptr) {
return nullptr; // nobody
return nullptr; // no body
}

size_t* buf_size = reinterpret_cast<size_t*>(size);
auto buf_size = reinterpret_cast<size_t*>(size);

*buf_size = evbuffer_get_contiguous_space(input_buf);

Expand All @@ -162,11 +168,59 @@ std::unique_ptr<char, FreeDeleter> EvHTTPRequest::ReadRequestBytes(
if (ret != *buf_size) {
ABSL_RAW_LOG(ERROR, "Unexpected: read less than specified num_bytes : %zu",
*buf_size);
free(block);
*buf_size = 0;
return nullptr; // don't return corrupted buffer
}

// Uncompress the entire body
if (NeedUncompressGzipContent()) {
void* new_block;
UncompressGzipContent(block, *buf_size, &new_block, buf_size);
free(block);
if (new_block != nullptr) {
block = new_block;
} else {
ABSL_RAW_LOG(ERROR, "Failed to uncompress the gzipped body");
*buf_size = 0;
return nullptr; // don't return corrupted buffer
}
}

return std::unique_ptr<char, FreeDeleter>(static_cast<char*>(block));
}

bool EvHTTPRequest::NeedUncompressGzipContent() {
if (handler_options_ != nullptr &&
handler_options_->auto_uncompress_input()) {
auto content_encoding = GetRequestHeader(HTTPHeaders::CONTENT_ENCODING);
if (content_encoding != nullptr) {
return content_encoding.find("gzip") != absl::string_view::npos;
}
}

return false;
}

void EvHTTPRequest::UncompressGzipContent(void* input, size_t input_size,
void** uncompressed_input,
size_t* uncompressed_input_size) {
int64_t max = handler_options_->auto_uncompress_max_size() > 0
? handler_options_->auto_uncompress_max_size()
: ZLib::kMaxUncompressedBytes;

// our APIs don't need expose the actual content-length
*uncompressed_input_size = static_cast<size_t>(max);

ZLib zlib;
int err = zlib.UncompressGzipAndAllocate(
reinterpret_cast<Bytef**>(uncompressed_input), uncompressed_input_size,
reinterpret_cast<Bytef*>(input), input_size);
if (err != Z_OK) {
ABSL_RAW_LOG(ERROR, "Got zlib error: %d", err);
}
}

// Note: passing string_view incurs a copy of underlying std::string data
// (stack)
absl::string_view EvHTTPRequest::GetRequestHeader(
Expand All @@ -182,7 +236,7 @@ std::vector<absl::string_view> EvHTTPRequest::request_headers() const {

for (evkeyval* header = ev_headers->tqh_first; header;
header = header->next.tqe_next) {
result.push_back(absl::string_view(header->key));
result.emplace_back(header->key);
}

return result;
Expand Down
16 changes: 16 additions & 0 deletions tensorflow_serving/util/net_http/server/internal/evhttp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <memory>

#include "tensorflow_serving/util/net_http/server/internal/server_support.h"
#include "tensorflow_serving/util/net_http/server/public/httpserver_interface.h"
#include "tensorflow_serving/util/net_http/server/public/server_request_interface.h"

struct evbuffer;
Expand Down Expand Up @@ -100,11 +101,26 @@ class EvHTTPRequest final : public ServerRequestInterface {
// Initializes the resource and returns false if any error.
bool Initialize();

// Keeps a reference to the registered RequestHandlerOptions
void SetHandlerOptions(const RequestHandlerOptions& handler_options) {
this->handler_options_ = &handler_options;
}

private:
void EvSendReply(HTTPStatusCode status);

// Returns true if the data needs be uncompressed
bool NeedUncompressGzipContent();

// Must set uncompressed_input to nullptr if uncompression is failed
void UncompressGzipContent(void* input, size_t input_size,
void** uncompressed_input,
size_t* uncompressed_input_size);

ServerSupport* server_;

const RequestHandlerOptions* handler_options_;

std::unique_ptr<ParsedEvRequest> parsed_request_;

evbuffer* output_buf; // owned by this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/memory/memory.h"

#include "tensorflow_serving/util/net_http/client/evhttp_connection.h"
#include "tensorflow_serving/util/net_http/compression/gzip_zlib.h"
#include "tensorflow_serving/util/net_http/internal/fixed_thread_pool.h"
#include "tensorflow_serving/util/net_http/server/public/httpserver.h"
#include "tensorflow_serving/util/net_http/server/public/httpserver_interface.h"
Expand Down Expand Up @@ -215,6 +216,146 @@ TEST_F(EvHTTPRequestTest, ResponseHeaders) {
server->WaitForTermination();
}

// === gzip support ====

// Test invalid gzip body
TEST_F(EvHTTPRequestTest, InvalidGzipPost) {
auto handler = [](ServerRequestInterface* request) {
int64_t num_bytes;
auto request_body = request->ReadRequestBytes(&num_bytes);
EXPECT_TRUE(request_body == nullptr);
EXPECT_EQ(0, num_bytes);

request->Reply();
};
server->RegisterRequestHandler("/ok", std::move(handler),
RequestHandlerOptions());
server->StartAcceptingRequests();

auto connection =
EvHTTPConnection::Connect("localhost", server->listen_port());
ASSERT_TRUE(connection != nullptr);

ClientRequest request = {"/ok", "POST", {}, "abcde"};
request.headers.emplace_back("Content-Encoding", "my_gzip");
ClientResponse response = {};

EXPECT_TRUE(connection->BlockingSendRequest(request, &response));
EXPECT_EQ(response.status, 200);

server->Terminate();
server->WaitForTermination();
}

// Test disabled gzip
TEST_F(EvHTTPRequestTest, DisableGzipPost) {
auto handler = [](ServerRequestInterface* request) {
int64_t num_bytes;
auto request_body = request->ReadRequestBytes(&num_bytes);
EXPECT_EQ(5, num_bytes);

request->Reply();
};
RequestHandlerOptions options;
options.set_auto_uncompress_input(false);
server->RegisterRequestHandler("/ok", std::move(handler), options);
server->StartAcceptingRequests();

auto connection =
EvHTTPConnection::Connect("localhost", server->listen_port());
ASSERT_TRUE(connection != nullptr);

ClientRequest request = {"/ok", "POST", {}, "abcde"};
request.headers.emplace_back("Content-Encoding", "my_gzip");
ClientResponse response = {};

EXPECT_TRUE(connection->BlockingSendRequest(request, &response));
EXPECT_EQ(response.status, 200);

server->Terminate();
server->WaitForTermination();
}

std::string CompressString(const char* data, size_t size) {
ZLib zlib;
std::string buf(1000, '\0');
size_t compressed_size = buf.size();
zlib.Compress((Bytef*)buf.data(), &compressed_size, (Bytef*)data, size);

return std::string(buf.data(), compressed_size);
}

// Test valid gzip body
TEST_F(EvHTTPRequestTest, ValidGzipPost) {
constexpr char kBody[] = "abcdefg12345";
std::string compressed = CompressString(kBody, sizeof(kBody) - 1);

auto handler = [&](ServerRequestInterface* request) {
int64_t num_bytes;
auto request_body = request->ReadRequestBytes(&num_bytes);

std::string body_str(request_body.get(), static_cast<size_t>(num_bytes));
EXPECT_EQ(body_str, std::string(kBody));
EXPECT_EQ(sizeof(kBody) - 1, num_bytes);

request->Reply();
};
server->RegisterRequestHandler("/ok", std::move(handler),
RequestHandlerOptions());
server->StartAcceptingRequests();

auto connection =
EvHTTPConnection::Connect("localhost", server->listen_port());
ASSERT_TRUE(connection != nullptr);

ClientRequest request = {"/ok", "POST", {}, compressed};
request.headers.emplace_back("Content-Encoding", "my_gzip");
ClientResponse response = {};

EXPECT_TRUE(connection->BlockingSendRequest(request, &response));
EXPECT_EQ(response.status, 200);

server->Terminate();
server->WaitForTermination();
}

// Test gzip exceeding the max uncompressed limit
TEST_F(EvHTTPRequestTest, GzipExceedingLimit) {
constexpr char kBody[] = "abcdefg12345";
constexpr int bodySize = sizeof(kBody) - 1;
std::string compressed = CompressString(kBody, static_cast<size_t>(bodySize));

auto handler = [&](ServerRequestInterface* request) {
int64_t num_bytes;
auto request_body = request->ReadRequestBytes(&num_bytes);

std::string body_str(request_body.get(), static_cast<size_t>(num_bytes));
EXPECT_TRUE(request_body == nullptr);
EXPECT_EQ(0, num_bytes);

request->Reply();
};

RequestHandlerOptions options;
options.set_auto_uncompress_max_size(bodySize - 1); // not enough buffer
server->RegisterRequestHandler("/ok", std::move(handler), options);
server->StartAcceptingRequests();

auto connection =
EvHTTPConnection::Connect("localhost", server->listen_port());
ASSERT_TRUE(connection != nullptr);

ClientRequest request = {"/ok", "POST", {}, compressed};
request.headers.emplace_back("Content-Encoding", "my_gzip");
ClientResponse response = {};

EXPECT_TRUE(connection->BlockingSendRequest(request, &response));
EXPECT_EQ(response.status, 200);

server->Terminate();
server->WaitForTermination();
}

} // namespace
} // namespace net_http
} // namespace serving
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ void EvHTTPServer::DispatchEvRequest(evhttp_request* req) {

auto handler_map_it = uri_handlers_.find(path);
if (handler_map_it != uri_handlers_.end()) {
ev_request->SetHandlerOptions(handler_map_it->second.options);
IncOps();
dispatched = true;
ScheduleHandlerReference(handler_map_it->second.handler,
Expand All @@ -151,6 +152,7 @@ void EvHTTPServer::DispatchEvRequest(evhttp_request* req) {
if (handler == nullptr) {
continue;
}
ev_request->SetHandlerOptions(dispatcher.options);
IncOps();
dispatched = true;
ScheduleHandler(std::move(handler), ev_request.release());
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_serving/util/net_http/server/public/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ licenses(["notice"]) # Apache 2.0

cc_library(
name = "http_server_api",
srcs = [
"header_names.cc",
],
hdrs = [
"header_names.h",
"httpserver_interface.h",
"response_code_enum.h",
"server_request_interface.h",
Expand Down
Loading

0 comments on commit b94f6c8

Please sign in to comment.