Skip to content

Commit

Permalink
remove the use of the transfer module
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Aug 9, 2024
1 parent ebc57c5 commit ab035a0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 58 deletions.
3 changes: 1 addition & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ rapids_find_package(
INSTALL_EXPORT_SET kvikio-exports
)

# TODO: make optional before PR merge
rapids_find_package(
AWSSDK REQUIRED COMPONENTS s3 transfer
AWSSDK REQUIRED COMPONENTS s3
BUILD_EXPORT_SET kvikio-exports
INSTALL_EXPORT_SET kvikio-exports
)
Expand Down
61 changes: 5 additions & 56 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@
#include <aws/core/Aws.h>
#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/core/utils/threading/Executor.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
#include <aws/transfer/TransferHandle.h>
#include <aws/transfer/TransferManager.h>

#include <kvikio/parallel_operation.hpp>
#include <kvikio/posix_io.hpp>
Expand All @@ -52,44 +49,11 @@ class BufferAsStream : public Aws::IOStream {
~BufferAsStream() override = default;
};

/**
* An executor that does not spawn any thread, instead, tasks are executed in the current thread
*/
// TODO: remove
class SameThreadExecutor : public Aws::Utils::Threading::Executor {
public:
virtual ~SameThreadExecutor() { SameThreadExecutor::WaitUntilStopped(); }
void WaitUntilStopped() override
{
while (!m_tasks.empty()) {
auto task = std::move(m_tasks.front());
m_tasks.pop_front();
assert(task);
if (task) { task(); }
}
}

protected:
bool SubmitToThread(std::function<void()>&& task) override
{
m_tasks.push_back(std::move(task));
return true;
}

using TaskFunc = std::function<void()>;
Aws::List<TaskFunc> m_tasks;
};

class S3Context {
public:
S3Context()
: _client{S3Context::create_client()},
_transfer_manager{S3Context::create_transfer_manager(_client, &_executor)}
{
}
S3Context() : _client{S3Context::create_client()} {}

Aws::S3::S3Client& client() { return *_client; }
Aws::Transfer::TransferManager& transfer_manager() { return *_transfer_manager; }

static S3Context& default_context()
{
Expand Down Expand Up @@ -143,19 +107,7 @@ class S3Context {
return ret;
}

static std::shared_ptr<Aws::Transfer::TransferManager> create_transfer_manager(
std::shared_ptr<Aws::S3::S3Client> client, Aws::Utils::Threading::Executor* executor)
{
Aws::Transfer::TransferManagerConfiguration transfer_config(executor);
transfer_config.s3Client = client;
transfer_config.bufferSize = posix_bounce_buffer_size;
transfer_config.transferBufferMaxHeapSize = posix_bounce_buffer_size * 2;
return Aws::Transfer::TransferManager::Create(transfer_config);
}

SameThreadExecutor _executor; // TODO: remove
std::shared_ptr<Aws::S3::S3Client> _client;
std::shared_ptr<Aws::Transfer::TransferManager> _transfer_manager; // TODO: remove
};

inline std::size_t get_s3_file_size(const std::string& bucket_name, const std::string& object_name)
Expand Down Expand Up @@ -243,13 +195,10 @@ class RemoteHandle {
"bytes=" + std::to_string(file_offset) + "-" + std::to_string(file_offset + size - 1);
req.SetRange(byte_range.c_str());

// The local variable 'streamBuffer' is captured by reference in a lambda.
// It must persist until all downloading by the 'transfer_manager' is complete.
Aws::Utils::Stream::PreallocatedStreamBuf streamBuffer(static_cast<unsigned char*>(buf), size);
req.SetResponseStreamFactory([&]() { // Define a lambda expression for the callback method
// parameter to stream back the data.
return Aws::New<detail::BufferAsStream>("TestTag", &streamBuffer);
});
// To write directly to `buf`, we register a "factory" that wraps a buffer as a output stream.
Aws::Utils::Stream::PreallocatedStreamBuf buf_stream(static_cast<unsigned char*>(buf), size);
req.SetResponseStreamFactory(
[&]() { return Aws::New<detail::BufferAsStream>("BufferAsStream", &buf_stream); });

Aws::S3::Model::GetObjectOutcome outcome = default_context.client().GetObject(req);
if (!outcome.IsSuccess()) {
Expand Down

0 comments on commit ab035a0

Please sign in to comment.