Skip to content

Commit

Permalink
Merge branch 'branch-24.12' into bug/include_std_exchange_header
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham authored Oct 22, 2024
2 parents 43bbca6 + fcf4b15 commit c79e467
Show file tree
Hide file tree
Showing 15 changed files with 812 additions and 20 deletions.
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -18,6 +19,7 @@ dependencies:
- doxygen=1.9.1
- gcc_linux-aarch64=11.*
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -20,6 +21,7 @@ dependencies:
- libcufile-dev=1.4.0.31
- libcufile=1.4.0.31
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-aarch64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
13 changes: 13 additions & 0 deletions cpp/cmake/thirdparty/get_libcurl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
function(find_and_configure_libcurl)
include(${rapids-cmake-dir}/cpm/find.cmake)

# Work around https://github.com/curl/curl/issues/15351
if(DEFINED CACHE{BUILD_TESTING})
set(CACHE_HAS_BUILD_TESTING $CACHE{BUILD_TESTING})
endif()

rapids_cpm_find(
CURL 7.87.0
GLOBAL_TARGETS libcurl
Expand All @@ -27,6 +32,14 @@ function(find_and_configure_libcurl)
OPTIONS "BUILD_CURL_EXE OFF" "BUILD_SHARED_LIBS OFF" "BUILD_TESTING OFF" "CURL_USE_LIBPSL OFF"
"CURL_DISABLE_LDAP ON" "CMAKE_POSITION_INDEPENDENT_CODE ON"
)
if(DEFINED CACHE_HAS_BUILD_TESTING)
set(BUILD_TESTING
${CACHE_HAS_BUILD_TESTING}
CACHE BOOL "" FORCE
)
else()
unset(BUILD_TESTING CACHE)
endif()
endfunction()

find_and_configure_libcurl()
208 changes: 204 additions & 4 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <cstddef>
#include <cstring>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -89,7 +91,7 @@ inline std::size_t callback_device_memory(char* data,
void* context)
{
auto ctx = reinterpret_cast<CallbackContext*>(context);
const std::size_t nbytes = size * nmemb;
std::size_t const nbytes = size * nmemb;
if (ctx->size < ctx->offset + nbytes) {
ctx->overflow_error = true;
return CURL_WRITEFUNC_ERROR;
Expand Down Expand Up @@ -132,7 +134,7 @@ class RemoteEndpoint {
*
* @returns A string description.
*/
virtual std::string str() = 0;
virtual std::string str() const = 0;

virtual ~RemoteEndpoint() = default;
};
Expand All @@ -145,12 +147,203 @@ class HttpEndpoint : public RemoteEndpoint {
std::string _url;

public:
/**
* @brief Create an http endpoint from a url.
*
* @param url The full http url to the remote file.
*/
HttpEndpoint(std::string url) : _url{std::move(url)} {}
void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); }
std::string str() override { return _url; }
std::string str() const override { return _url; }
~HttpEndpoint() override = default;
};

/**
* @brief A remote endpoint using AWS's S3 protocol.
*/
class S3Endpoint : public RemoteEndpoint {
private:
std::string _url;
std::string _aws_sigv4;
std::string _aws_userpwd;

/**
* @brief Unwrap an optional parameter, obtaining a default from the environment.
*
* If not nullopt, the optional's value is returned. Otherwise, the environment
* variable `env_var` is used. If that also doesn't have a value:
* - if `err_msg` is empty, the empty string is returned.
* - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown.
*
* @param value The value to unwrap.
* @param env_var The name of the environment variable to check if `value` isn't set.
* @param err_msg The error message to throw on error or the empty string.
* @return The parsed AWS argument or the empty string.
*/
static std::string unwrap_or_default(std::optional<std::string> aws_arg,
std::string const& env_var,
std::string const& err_msg = "")
{
if (aws_arg.has_value()) { return std::move(*aws_arg); }

char const* env = std::getenv(env_var.c_str());
if (env == nullptr) {
if (err_msg.empty()) { return std::string(); }
throw std::invalid_argument(err_msg);
}
return std::string(env);
}

public:
/**
* @brief Get url from a AWS S3 bucket and object name.
*
* @throws std::invalid_argument if no region is specified and no default region is
* specified in the environment.
*
* @param bucket_name The name of the S3 bucket.
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
* the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
static std::string url_from_bucket_and_object(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> const& aws_region,
std::optional<std::string> aws_endpoint_url)
{
auto const endpoint_url = unwrap_or_default(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL");
std::stringstream ss;
if (endpoint_url.empty()) {
auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");
// We default to the official AWS url scheme.
ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name;
} else {
ss << endpoint_url << "/" << bucket_name << "/" << object_name;
}
return ss.str();
}

/**
* @brief Given an url like "s3://<bucket>/<object>", return the name of the bucket and object.
*
* @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name.
*
* @param s3_url S3 url.
* @return Pair of strings: [bucket-name, object-name].
*/
[[nodiscard]] static std::pair<std::string, std::string> parse_s3_url(std::string const& s3_url)
{
// Regular expression to match s3://<bucket>/<object>
std::regex const pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase};
std::smatch matches;
if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; }
throw std::invalid_argument("Input string does not match the expected S3 URL format.");
}

/**
* @brief Create a S3 endpoint from a url.
*
* @param url The full http url to the S3 file. NB: this should be an url starting with
* "http://" or "https://". If you have an S3 url of the form "s3://<bucket>/<object>", please
* use `S3Endpoint::parse_s3_url()` and `S3Endpoint::url_from_bucket_and_object() to convert it.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
*/
S3Endpoint(std::string url,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt)
: _url{std::move(url)}
{
// Regular expression to match http[s]://
std::regex pattern{R"(^https?://.*)", std::regex_constants::icase};
if (!std::regex_search(_url, pattern)) {
throw std::invalid_argument("url must start with http:// or https://");
}

auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");

auto const access_key =
unwrap_or_default(std::move(aws_access_key),
"AWS_ACCESS_KEY_ID",
"S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set.");

auto const secret_access_key = unwrap_or_default(
std::move(aws_secret_access_key),
"AWS_SECRET_ACCESS_KEY",
"S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set.");

// Create the CURLOPT_AWS_SIGV4 option
{
std::stringstream ss;
ss << "aws:amz:" << region << ":s3";
_aws_sigv4 = ss.str();
}
// Create the CURLOPT_USERPWD option
// Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT included
// in the http header. See
// <https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html>
{
std::stringstream ss;
ss << access_key << ":" << secret_access_key;
_aws_userpwd = ss.str();
}
}

/**
* @brief Create a S3 endpoint from a bucket and object name.
*
* @param bucket_name The name of the S3 bucket.
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
* the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
S3Endpoint(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt,
std::optional<std::string> aws_endpoint_url = std::nullopt)
: S3Endpoint(url_from_bucket_and_object(
bucket_name, object_name, aws_region, std::move(aws_endpoint_url)),
std::move(aws_region),
std::move(aws_access_key),
std::move(aws_secret_access_key))
{
}

void setopt(CurlHandle& curl) override
{
curl.setopt(CURLOPT_URL, _url.c_str());
curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str());
curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str());
}
std::string str() const override { return _url; }
~S3Endpoint() override = default;
};

/**
* @brief Handle of remote file.
*/
Expand Down Expand Up @@ -211,6 +404,13 @@ class RemoteHandle {
*/
[[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; }

/**
* @brief Get a const reference to the underlying remote endpoint.
*
* @return The remote endpoint.
*/
[[nodiscard]] RemoteEndpoint const& endpoint() const noexcept { return *_endpoint; }

/**
* @brief Read from remote source into buffer (host or device memory).
*
Expand All @@ -229,7 +429,7 @@ class RemoteHandle {
<< " bytes file (" << _endpoint->str() << ")";
throw std::invalid_argument(ss.str());
}
const bool is_host_mem = is_host_memory(buf);
bool const is_host_mem = is_host_memory(buf);
auto curl = create_curl_handle();
_endpoint->setopt(curl);

Expand Down
7 changes: 7 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ dependencies:
- pytest
- pytest-cov
- rangehttpserver
- boto3>=1.21.21
- output_types: [requirements, pyproject]
packages:
- moto[server]>=4.0.8
- output_types: conda
packages:
- moto>=4.0.8
specific:
- output_types: [conda, requirements, pyproject]
matrices:
Expand Down
15 changes: 8 additions & 7 deletions python/kvikio/kvikio/_lib/arr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# cython: language_level=3


from cpython.array cimport array, newarrayobject
from cpython.buffer cimport PyBuffer_IsContiguous
from cpython.mem cimport PyMem_Free, PyMem_Malloc
from cpython.memoryview cimport PyMemoryView_FromObject, PyMemoryView_GET_BUFFER
from cpython.object cimport PyObject
from cpython.ref cimport Py_INCREF
from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
from cython cimport auto_pickle, boundscheck, initializedcheck, nonecheck, wraparound
from cython.view cimport array
from libc.stdint cimport uintptr_t
from libc.string cimport memcpy

Expand Down Expand Up @@ -53,13 +53,14 @@ cdef dict itemsize_mapping = {
}


cdef array array_Py_ssize_t = array("q")
cdef sizeof_Py_ssize_t = sizeof(Py_ssize_t)


cdef inline Py_ssize_t[::1] new_Py_ssize_t_array(Py_ssize_t n):
return newarrayobject(
(<PyObject*>array_Py_ssize_t).ob_type, n, array_Py_ssize_t.ob_descr
)
cdef Py_ssize_t[::1] new_Py_ssize_t_array(Py_ssize_t n):
cdef array a = array((n,), sizeof_Py_ssize_t, b"q", "c", False)
a.data = <char*>PyMem_Malloc(n * sizeof(Py_ssize_t))
a.callback_free_data = PyMem_Free
return a


@auto_pickle(False)
Expand Down
Loading

0 comments on commit c79e467

Please sign in to comment.