diff --git a/cpp/.clang-format b/.clang-format similarity index 100% rename from cpp/.clang-format rename to .clang-format diff --git a/.gitattributes b/.gitattributes index fbfe7434d50..ed8e5e1425a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,5 @@ python/cudf/cudf/_version.py export-subst +python/strings_udf/strings_udf/_version.py export-subst python/cudf_kafka/cudf_kafka/_version.py export-subst python/custreamz/custreamz/_version.py export-subst python/dask_cudf/dask_cudf/_version.py export-subst diff --git a/.gitignore b/.gitignore index 29df683e9ec..0d63c76bf9f 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,8 @@ python/cudf_kafka/*/_lib/**/*.cpp python/cudf_kafka/*/_lib/**/*.h python/custreamz/*/_lib/**/*.cpp python/custreamz/*/_lib/**/*.h +python/strings_udf/strings_udf/_lib/*.cpp +python/strings_udf/strings_udf/*.ptx .Python env/ develop-eggs/ diff --git a/build.sh b/build.sh index eee3ee512fa..ac283d01fc9 100755 --- a/build.sh +++ b/build.sh @@ -17,7 +17,7 @@ ARGS=$* # script, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean libcudf cudf cudfjar dask_cudf benchmarks tests libcudf_kafka cudf_kafka custreamz -v -g -n -l --allgpuarch --disable_nvtx --opensource_nvcomp --show_depr_warn --ptds -h --build_metrics --incl_cache_stats" +VALIDARGS="clean libcudf cudf cudfjar dask_cudf benchmarks tests libcudf_kafka cudf_kafka custreamz strings_udf -v -g -n -l --allgpuarch --disable_nvtx --opensource_nvcomp --show_depr_warn --ptds -h --build_metrics --incl_cache_stats" HELP="$0 [clean] [libcudf] [cudf] [cudfjar] [dask_cudf] [benchmarks] [tests] [libcudf_kafka] [cudf_kafka] [custreamz] [-v] [-g] [-n] [-h] [--cmake-args=\\\"\\\"] clean - remove all existing build artifacts and configuration (start over) @@ -335,6 +335,15 @@ if buildAll || hasArg cudf; then fi fi +if buildAll || hasArg strings_udf; then + + cd ${REPODIR}/python/strings_udf + python setup.py build_ext --inplace -- -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} -DCMAKE_LIBRARY_PATH=${LIBCUDF_BUILD_DIR} ${EXTRA_CMAKE_ARGS} -- -j${PARALLEL_LEVEL:-1} + if [[ ${INSTALL_TARGET} != "" ]]; then + python setup.py install --single-version-externally-managed --record=record.txt -- -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} -DCMAKE_LIBRARY_PATH=${LIBCUDF_BUILD_DIR} ${EXTRA_CMAKE_ARGS} -- -j${PARALLEL_LEVEL:-1} + fi +fi + # Build and install the dask_cudf Python package if buildAll || hasArg dask_cudf; then diff --git a/ci/cpu/build.sh b/ci/cpu/build.sh index f5ea2c902ef..a931546292e 100755 --- a/ci/cpu/build.sh +++ b/ci/cpu/build.sh @@ -80,6 +80,14 @@ fi if [ "$BUILD_LIBCUDF" == '1' ]; then gpuci_logger "Build conda pkg for libcudf" gpuci_conda_retry mambabuild --no-build-id --croot ${CONDA_BLD_DIR} conda/recipes/libcudf $CONDA_BUILD_ARGS + + # BUILD_LIBCUDF == 1 means this job is being run on the cpu_build jobs + # that is where we must also build the strings_udf package + gpuci_logger "Build conda pkg for strings_udf (python 3.8)" + gpuci_conda_retry mambabuild --no-build-id --croot ${CONDA_BLD_DIR} conda/recipes/strings_udf $CONDA_BUILD_ARGS --python=3.8 + gpuci_logger "Build conda pkg for strings_udf (python 3.9)" + gpuci_conda_retry mambabuild --no-build-id --croot ${CONDA_BLD_DIR} conda/recipes/strings_udf $CONDA_BUILD_ARGS --python=3.9 + mkdir -p ${CONDA_BLD_DIR}/libcudf/work cp -r ${CONDA_BLD_DIR}/work/* ${CONDA_BLD_DIR}/libcudf/work gpuci_logger "sccache stats" @@ -108,6 +116,10 @@ if [ "$BUILD_CUDF" == '1' ]; then gpuci_logger "Build conda pkg for custreamz" gpuci_conda_retry mambabuild --croot ${CONDA_BLD_DIR} conda/recipes/custreamz --python=$PYTHON $CONDA_BUILD_ARGS $CONDA_CHANNEL + + gpuci_logger "Build conda pkg for strings_udf" + gpuci_conda_retry mambabuild --croot ${CONDA_BLD_DIR} conda/recipes/strings_udf --python=$PYTHON $CONDA_BUILD_ARGS $CONDA_CHANNEL + fi ################################################################################ # UPLOAD - Conda packages diff --git a/ci/cpu/upload.sh b/ci/cpu/upload.sh index 29f6265ec63..771b7853ade 100755 --- a/ci/cpu/upload.sh +++ b/ci/cpu/upload.sh @@ -33,6 +33,12 @@ if [[ "$BUILD_LIBCUDF" == "1" && "$UPLOAD_LIBCUDF" == "1" ]]; then export LIBCUDF_FILES=$(conda build --no-build-id --croot "${CONDA_BLD_DIR}" conda/recipes/libcudf --output) LIBCUDF_FILES=$(echo "$LIBCUDF_FILES" | sed 's/.*libcudf-example.*//') # skip libcudf-example pkg upload gpuci_retry anaconda -t ${MY_UPLOAD_KEY} upload -u ${CONDA_USERNAME:-rapidsai} ${LABEL_OPTION} --skip-existing --no-progress $LIBCUDF_FILES + + # also build strings_udf on cpu machines + export STRINGS_UDF_FILE=$(conda build --croot "${CONDA_BLD_DIR}" conda/recipes/strings_udf --python=$PYTHON --output) + test -e ${STRINGS_UDF_FILE} + echo "Upload strings_udf: ${STRINGS_UDF_FILE}" + gpuci_retry anaconda -t ${MY_UPLOAD_KEY} upload -u ${CONDA_USERNAME:-rapidsai} ${LABEL_OPTION} --skip-existing ${STRINGS_UDF_FILE} --no-progress fi if [[ "$BUILD_CUDF" == "1" && "$UPLOAD_CUDF" == "1" ]]; then diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 316dbcbaa1d..ae19362ca11 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -121,11 +121,11 @@ if [[ -z "$PROJECT_FLASH" || "$PROJECT_FLASH" == "0" ]]; then install_dask ################################################################################ - # BUILD - Build libcudf, cuDF, libcudf_kafka, and dask_cudf from source + # BUILD - Build libcudf, cuDF, libcudf_kafka, dask_cudf, and strings_udf from source ################################################################################ gpuci_logger "Build from source" - "$WORKSPACE/build.sh" clean libcudf cudf dask_cudf libcudf_kafka cudf_kafka benchmarks tests --ptds + "$WORKSPACE/build.sh" clean libcudf cudf dask_cudf libcudf_kafka cudf_kafka strings_udf benchmarks tests --ptds ################################################################################ # TEST - Run GoogleTest @@ -183,7 +183,11 @@ else gpuci_conda_retry mambabuild --croot ${CONDA_BLD_DIR} conda/recipes/cudf_kafka --python=$PYTHON -c ${CONDA_ARTIFACT_PATH} gpuci_conda_retry mambabuild --croot ${CONDA_BLD_DIR} conda/recipes/custreamz --python=$PYTHON -c ${CONDA_ARTIFACT_PATH} - gpuci_logger "Installing cudf, dask-cudf, cudf_kafka and custreamz" + # the CUDA component of strings_udf must be built on cuda 11.5 just like libcudf + # but because there is no separate python package, we must also build the python on the 11.5 jobs + # this means that at this point (on the GPU test jobs) the whole package is already built and has been + # copied by CI from the upstream 11.5 jobs into $CONDA_ARTIFACT_PATH + gpuci_logger "Installing cudf, dask-cudf, cudf_kafka, and custreamz" gpuci_mamba_retry install cudf dask-cudf cudf_kafka custreamz -c "${CONDA_BLD_DIR}" -c "${CONDA_ARTIFACT_PATH}" gpuci_logger "GoogleTests" @@ -258,6 +262,31 @@ cd "$WORKSPACE/python/custreamz" gpuci_logger "Python py.test for cuStreamz" py.test -n 8 --cache-clear --basetemp="$WORKSPACE/custreamz-cuda-tmp" --junitxml="$WORKSPACE/junit-custreamz.xml" -v --cov-config=.coveragerc --cov=custreamz --cov-report=xml:"$WORKSPACE/python/custreamz/custreamz-coverage.xml" --cov-report term custreamz +gpuci_logger "Installing strings_udf" +gpuci_mamba_retry install strings_udf -c "${CONDA_BLD_DIR}" -c "${CONDA_ARTIFACT_PATH}" + +cd "$WORKSPACE/python/strings_udf/strings_udf" +gpuci_logger "Python py.test for strings_udf" + +# We do not want to exit with a nonzero exit code in the case where no +# strings_udf tests are run because that will always happen when the local CUDA +# version is not 11.5. We need to suppress the exit code because this script is +# run with set -e and we're already setting a trap that we don't want to +# override here. + +STRINGS_UDF_PYTEST_RETCODE=0 +py.test -n 8 --cache-clear --basetemp="$WORKSPACE/strings-udf-cuda-tmp" --junitxml="$WORKSPACE/junit-strings-udf.xml" -v --cov-config=.coveragerc --cov=strings_udf --cov-report=xml:"$WORKSPACE/python/strings_udf/strings-udf-coverage.xml" --cov-report term tests || STRINGS_UDF_PYTEST_RETCODE=$? + +if [ ${STRINGS_UDF_PYTEST_RETCODE} -eq 5 ]; then + echo "No strings UDF tests were run, but this script will continue to execute." +elif [ ${STRINGS_UDF_PYTEST_RETCODE} -ne 0 ]; then + exit ${STRINGS_UDF_PYTEST_RETCODE} +else + cd "$WORKSPACE/python/cudf/cudf" + gpuci_logger "Python py.test retest cuDF UDFs" + py.test tests/test_udf_masked_ops.py -n 8 --cache-clear +fi + # Run benchmarks with both cudf and pandas to ensure compatibility is maintained. # Benchmarks are run in DEBUG_ONLY mode, meaning that only small data sizes are used. # Therefore, these runs only verify that benchmarks are valid. diff --git a/conda/recipes/strings_udf/build.sh b/conda/recipes/strings_udf/build.sh new file mode 100644 index 00000000000..2de1325347b --- /dev/null +++ b/conda/recipes/strings_udf/build.sh @@ -0,0 +1,4 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +# This assumes the script is executed from the root of the repo directory +./build.sh strings_udf diff --git a/conda/recipes/strings_udf/conda_build_config.yaml b/conda/recipes/strings_udf/conda_build_config.yaml new file mode 100644 index 00000000000..d9c3f21448f --- /dev/null +++ b/conda/recipes/strings_udf/conda_build_config.yaml @@ -0,0 +1,14 @@ +c_compiler_version: + - 9 + +cxx_compiler_version: + - 9 + +sysroot_version: + - "2.17" + +cmake_version: + - ">=3.20.1,!=3.23.0" + +cuda_compiler: + - nvcc diff --git a/conda/recipes/strings_udf/meta.yaml b/conda/recipes/strings_udf/meta.yaml new file mode 100644 index 00000000000..e29fb55ce63 --- /dev/null +++ b/conda/recipes/strings_udf/meta.yaml @@ -0,0 +1,65 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +{% set version = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} +{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} +{% set py_version=environ.get('CONDA_PY', 38) %} +{% set cuda_version='.'.join(environ.get('CUDA', '11.5').split('.')[:2]) %} +{% set cuda_major=cuda_version.split('.')[0] %} + +package: + name: strings_udf + version: {{ version }} + +source: + git_url: ../../.. + +build: + number: {{ GIT_DESCRIBE_NUMBER }} + string: cuda_{{ cuda_major }}_py{{ py_version }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} + script_env: + - VERSION_SUFFIX + - PARALLEL_LEVEL + # libcudf's run_exports pinning is looser than we would like + ignore_run_exports: + - libcudf + ignore_run_exports_from: + - {{ compiler('cuda') }} + +requirements: + build: + - cmake {{ cmake_version }} + - {{ compiler('c') }} + - {{ compiler('cxx') }} + - {{ compiler('cuda') }} {{ cuda_version }} + - sysroot_{{ target_platform }} {{ sysroot_version }} + host: + - python + - cython >=0.29,<0.30 + - scikit-build>=0.13.1 + - setuptools + - numba >=0.54 + - libcudf ={{ version }} + - cudf ={{ version }} + - cudatoolkit ={{ cuda_version }} + run: + - python + - typing_extensions + - numba >=0.54 + - numpy + - libcudf ={{ version }} + - cudf ={{ version }} + - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} + - cachetools + - ptxcompiler # [linux64] # CUDA enhanced compatibility. See https://github.com/rapidsai/ptxcompiler +test: # [linux64] + requires: # [linux64] + - cudatoolkit {{ cuda_version }}.* # [linux64] + imports: # [linux64] + - strings_udf # [linux64] + +about: + home: https://rapids.ai/ + license: Apache-2.0 + license_family: APACHE + license_file: LICENSE + summary: strings_udf library diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 741aa62d1a0..30b1bc704c8 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -57,7 +57,12 @@ from cudf.core.missing import NA from cudf.core.multiindex import MultiIndex from cudf.core.resample import _Resampler -from cudf.core.udf.utils import _compile_or_get, _supported_cols_from_frame +from cudf.core.udf.utils import ( + _compile_or_get, + _get_input_args_from_frame, + _post_process_output_col, + _return_arr_from_dtype, +) from cudf.utils import docutils from cudf.utils.utils import _cudf_nvtx_annotate @@ -1819,30 +1824,19 @@ def _apply(self, func, kernel_getter, *args, **kwargs): ) from e # Mask and data column preallocated - ans_col = cp.empty(len(self), dtype=retty) + ans_col = _return_arr_from_dtype(retty, len(self)) ans_mask = cudf.core.column.column_empty(len(self), dtype="bool") - launch_args = [(ans_col, ans_mask), len(self)] - offsets = [] - - # if _compile_or_get succeeds, it is safe to create a kernel that only - # consumes the columns that are of supported dtype - for col in _supported_cols_from_frame(self).values(): - data = col.data - mask = col.mask - if mask is None: - launch_args.append(data) - else: - launch_args.append((data, mask)) - offsets.append(col.offset) - launch_args += offsets - launch_args += list(args) + output_args = [(ans_col, ans_mask), len(self)] + input_args = _get_input_args_from_frame(self) + launch_args = output_args + input_args + list(args) try: kernel.forall(len(self))(*launch_args) except Exception as e: raise RuntimeError("UDF kernel execution failed.") from e - col = cudf.core.column.as_column(ans_col) + col = _post_process_output_col(ans_col, retty) + col.set_base_mask(libcudf.transform.bools_to_mask(ans_mask)) result = cudf.Series._from_data({None: col}, self._index) diff --git a/python/cudf/cudf/core/udf/__init__.py b/python/cudf/cudf/core/udf/__init__.py index 97ca9df9ef4..c128bc2436c 100644 --- a/python/cudf/cudf/core/udf/__init__.py +++ b/python/cudf/cudf/core/udf/__init__.py @@ -1,3 +1,65 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2022, NVIDIA CORPORATION. +import numpy as np +from numba import cuda, types +from numba.cuda.cudaimpl import ( + lower as cuda_lower, + registry as cuda_lowering_registry, +) -from . import lowering, typing +from cudf.core.dtypes import dtype +from cudf.core.udf import api, row_function, utils +from cudf.utils.dtypes import STRING_TYPES + +from . import masked_lowering, masked_typing + +_units = ["ns", "ms", "us", "s"] +_datetime_cases = {types.NPDatetime(u) for u in _units} +_timedelta_cases = {types.NPTimedelta(u) for u in _units} + + +_supported_masked_types = ( + types.integer_domain + | types.real_domain + | _datetime_cases + | _timedelta_cases + | {types.boolean} +) + +_STRING_UDFS_ENABLED = False +try: + import strings_udf + + if strings_udf.ENABLED: + from . import strings_typing # isort: skip + from . import strings_lowering # isort: skip + from strings_udf import ptxpath + from strings_udf._lib.cudf_jit_udf import to_string_view_array + from strings_udf._typing import str_view_arg_handler, string_view + + # add an overload of MaskedType.__init__(string_view, bool) + cuda_lower(api.Masked, strings_typing.string_view, types.boolean)( + masked_lowering.masked_constructor + ) + + # add an overload of pack_return(string_view) + cuda_lower(api.pack_return, strings_typing.string_view)( + masked_lowering.pack_return_scalar_impl + ) + + _supported_masked_types |= {strings_typing.string_view} + utils.launch_arg_getters[dtype("O")] = to_string_view_array + utils.masked_array_types[dtype("O")] = string_view + utils.JIT_SUPPORTED_TYPES |= STRING_TYPES + utils.ptx_files.append(ptxpath) + utils.arg_handlers.append(str_view_arg_handler) + row_function.itemsizes[dtype("O")] = string_view.size_bytes + + _STRING_UDFS_ENABLED = True + else: + del strings_udf + +except ImportError as e: + # allow cuDF to work without strings_udf + pass + +masked_typing.register_masked_constructor(_supported_masked_types) diff --git a/python/cudf/cudf/core/udf/lowering.py b/python/cudf/cudf/core/udf/masked_lowering.py similarity index 99% rename from python/cudf/cudf/core/udf/lowering.py rename to python/cudf/cudf/core/udf/masked_lowering.py index 7dfe8427bfd..f825b6538bf 100644 --- a/python/cudf/cudf/core/udf/lowering.py +++ b/python/cudf/cudf/core/udf/masked_lowering.py @@ -18,7 +18,7 @@ comparison_ops, unary_ops, ) -from cudf.core.udf.typing import MaskedType, NAType +from cudf.core.udf.masked_typing import MaskedType, NAType @cuda_lowering_registry.lower_constant(NAType) @@ -62,7 +62,6 @@ def masked_scalar_op_impl(context, builder, sig, args): result = cgutils.create_struct_proxy(masked_return_type)( context, builder ) - # compute output validity valid = builder.and_(m1.valid, m2.valid) result.valid = valid diff --git a/python/cudf/cudf/core/udf/typing.py b/python/cudf/cudf/core/udf/masked_typing.py similarity index 85% rename from python/cudf/cudf/core/udf/typing.py rename to python/cudf/cudf/core/udf/masked_typing.py index 073900d115d..a815a9f6dae 100644 --- a/python/cudf/cudf/core/udf/typing.py +++ b/python/cudf/cudf/core/udf/masked_typing.py @@ -1,6 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. import operator +from typing import Any, Dict from numba import types from numba.core.extending import ( @@ -26,6 +27,12 @@ comparison_ops, unary_ops, ) +from cudf.utils.dtypes import ( + DATETIME_TYPES, + NUMERIC_TYPES, + STRING_TYPES, + TIMEDELTA_TYPES, +) SUPPORTED_NUMBA_TYPES = ( types.Number, @@ -34,29 +41,60 @@ types.NPTimedelta, ) +SUPPORTED_NUMPY_TYPES = ( + NUMERIC_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES | STRING_TYPES +) +supported_type_str = "\n".join(sorted(list(SUPPORTED_NUMPY_TYPES) + ["bool"])) +MASKED_INIT_MAP: Dict[Any, Any] = {} -class MaskedType(types.Type): + +def _format_error_string(err): """ - A Numba type consisting of a value of some primitive type - and a validity boolean, over which we can define math ops + Wrap an error message in newlines and color it red. """ + return "\033[91m" + "\n" + err + "\n" + "\033[0m" - def __init__(self, value): - # MaskedType in Numba shall be parameterized - # with a value type - if isinstance(value, SUPPORTED_NUMBA_TYPES): - self.value_type = value + +def _type_to_masked_type(t): + result = MASKED_INIT_MAP.get(t) + if result is None: + if isinstance(t, SUPPORTED_NUMBA_TYPES): + return t else: # Unsupported Dtype. Numba tends to print out the type info # for whatever operands and operation failed to type and then # output its own error message. Putting the message in the repr # then is one way of getting the true cause to the user - self.value_type = types.Poison( - "\n\n\n Unsupported MaskedType. This is usually caused by " + err = _format_error_string( + "Unsupported MaskedType. This is usually caused by " "attempting to use a column of unsupported dtype in a UDF. " - f"Supported dtypes are {SUPPORTED_NUMBA_TYPES}" + f"Supported dtypes are:\n{supported_type_str}" ) - super().__init__(name=f"Masked{self.value_type}") + return types.Poison(err) + else: + return result + + +MASKED_INIT_MAP[types.pyobject] = types.Poison( + _format_error_string( + "strings_udf library required for usage of string dtypes " + "inside user defined functions." + ) +) + + +# Masked scalars of all types +class MaskedType(types.Type): + """ + A Numba type consisting of a value of some primitive type + and a validity boolean, over which we can define math ops + """ + + def __init__(self, value): + # MaskedType in Numba shall be parameterized + # with a value type + self.value_type = _type_to_masked_type(value) + super().__init__(name=f"Masked({self.value_type})") def __hash__(self): """ @@ -131,44 +169,35 @@ def typeof_masked(val, c): # Implemented typing for Masked(value, valid) - the construction of a Masked # type in a kernel. -@cuda_decl_registry.register -class MaskedConstructor(ConcreteTemplate): - key = api.Masked - units = ["ns", "ms", "us", "s"] - datetime_cases = {types.NPDatetime(u) for u in units} - timedelta_cases = {types.NPTimedelta(u) for u in units} - cases = [ - nb_signature(MaskedType(t), t, types.boolean) - for t in ( - types.integer_domain - | types.real_domain - | datetime_cases - | timedelta_cases - | {types.boolean} - ) - ] +def register_masked_constructor(supported_masked_types): + class MaskedConstructor(ConcreteTemplate): + key = api.Masked + cases = [ + nb_signature(MaskedType(t), t, types.boolean) + for t in supported_masked_types + ] + cuda_decl_registry.register(MaskedConstructor) -# Provide access to `m.value` and `m.valid` in a kernel for a Masked `m`. -make_attribute_wrapper(MaskedType, "value", "value") -make_attribute_wrapper(MaskedType, "valid", "valid") - + # Typing for `api.Masked` + @cuda_decl_registry.register_attr + class ClassesTemplate(AttributeTemplate): + key = types.Module(api) -# Typing for `api.Masked` -@cuda_decl_registry.register_attr -class ClassesTemplate(AttributeTemplate): - key = types.Module(api) + def resolve_Masked(self, mod): + return types.Function(MaskedConstructor) - def resolve_Masked(self, mod): - return types.Function(MaskedConstructor) + # Registration of the global is also needed for Numba to type api.Masked + cuda_decl_registry.register_global(api, types.Module(api)) + # For typing bare Masked (as in `from .api import Masked` + cuda_decl_registry.register_global( + api.Masked, types.Function(MaskedConstructor) + ) -# Registration of the global is also needed for Numba to type api.Masked -cuda_decl_registry.register_global(api, types.Module(api)) -# For typing bare Masked (as in `from .api import Masked` -cuda_decl_registry.register_global( - api.Masked, types.Function(MaskedConstructor) -) +# Provide access to `m.value` and `m.valid` in a kernel for a Masked `m`. +make_attribute_wrapper(MaskedType, "value", "value") +make_attribute_wrapper(MaskedType, "valid", "valid") # Tell numba how `MaskedType` is constructed on the backend in terms diff --git a/python/cudf/cudf/core/udf/row_function.py b/python/cudf/cudf/core/udf/row_function.py index 1d0bd5ac99d..8d887a37706 100644 --- a/python/cudf/cudf/core/udf/row_function.py +++ b/python/cudf/cudf/core/udf/row_function.py @@ -1,5 +1,6 @@ # Copyright (c) 2021-2022, NVIDIA CORPORATION. import math +from typing import Any, Dict import numpy as np from numba import cuda @@ -7,13 +8,13 @@ from numba.types import Record from cudf.core.udf.api import Masked, pack_return +from cudf.core.udf.masked_typing import MaskedType from cudf.core.udf.templates import ( masked_input_initializer_template, row_initializer_template, row_kernel_template, unmasked_input_initializer_template, ) -from cudf.core.udf.typing import MaskedType from cudf.core.udf.utils import ( _all_dtypes_from_frame, _construct_signature, @@ -24,6 +25,8 @@ _supported_dtypes_from_frame, ) +itemsizes: Dict[Any, int] = {} + def _get_frame_row_type(dtype): """ @@ -31,12 +34,10 @@ def _get_frame_row_type(dtype): Models each column and its mask as a MaskedType and models the row as a dictionary like data structure containing these MaskedTypes. - Large parts of this function are copied with comments from the Numba internals and slightly modified to account for validity bools to be present in the final struct. - See numba.np.numpy_support.from_struct_dtype for details. """ @@ -45,7 +46,9 @@ def _get_frame_row_type(dtype): fields = [] offset = 0 - sizes = [val[0].itemsize for val in dtype.fields.values()] + sizes = [ + itemsizes.get(val[0], val[0].itemsize) for val in dtype.fields.values() + ] for i, (name, info) in enumerate(dtype.fields.items()): # *info* consists of the element dtype, its offset from the beginning # of the record, and an optional "title" containing metadata. @@ -62,7 +65,8 @@ def _get_frame_row_type(dtype): fields.append((name, infos)) # increment offset by itemsize plus one byte for validity - offset += elemdtype.itemsize + 1 + itemsize = itemsizes.get(elemdtype, elemdtype.itemsize) + offset += itemsize + 1 # Align the next member of the struct to be a multiple of the # memory access size, per PTX ISA 7.4/5.4.5 @@ -127,10 +131,8 @@ def _get_row_kernel(frame, func, args): np.dtype(list(_all_dtypes_from_frame(frame).items())) ) scalar_return_type = _get_udf_return_type(row_type, func, args) - # this is the signature for the final full kernel compilation sig = _construct_signature(frame, scalar_return_type, args) - # this row type is used within the kernel to pack up the column and # mask data into the dict like data structure the user udf expects np_field_types = np.dtype( diff --git a/python/cudf/cudf/core/udf/scalar_function.py b/python/cudf/cudf/core/udf/scalar_function.py index a7b887dd2d5..31599f4151e 100644 --- a/python/cudf/cudf/core/udf/scalar_function.py +++ b/python/cudf/cudf/core/udf/scalar_function.py @@ -4,12 +4,12 @@ from numba.np import numpy_support from cudf.core.udf.api import Masked, pack_return +from cudf.core.udf.masked_typing import MaskedType from cudf.core.udf.templates import ( masked_input_initializer_template, scalar_kernel_template, unmasked_input_initializer_template, ) -from cudf.core.udf.typing import MaskedType from cudf.core.udf.utils import ( _construct_signature, _get_kernel, diff --git a/python/cudf/cudf/core/udf/strings_lowering.py b/python/cudf/cudf/core/udf/strings_lowering.py new file mode 100644 index 00000000000..5b69d1a9da3 --- /dev/null +++ b/python/cudf/cudf/core/udf/strings_lowering.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import operator + +from numba import types +from numba.core import cgutils +from numba.core.typing import signature as nb_signature +from numba.cuda.cudaimpl import lower as cuda_lower + +from strings_udf._typing import size_type, string_view +from strings_udf.lowering import ( + contains_impl, + count_impl, + endswith_impl, + find_impl, + isalnum_impl, + isalpha_impl, + isdecimal_impl, + isdigit_impl, + islower_impl, + isspace_impl, + isupper_impl, + len_impl, + rfind_impl, + startswith_impl, +) + +from cudf.core.udf.masked_typing import MaskedType + + +@cuda_lower(len, MaskedType(string_view)) +def masked_len_impl(context, builder, sig, args): + ret = cgutils.create_struct_proxy(sig.return_type)(context, builder) + masked_sv_ty = sig.args[0] + masked_sv = cgutils.create_struct_proxy(masked_sv_ty)( + context, builder, value=args[0] + ) + result = len_impl( + context, builder, size_type(string_view), (masked_sv.value,) + ) + ret.value = result + ret.valid = masked_sv.valid + + return ret._getvalue() + + +def create_binary_string_func(op, cuda_func, retty): + """ + Provide a wrapper around numba's low-level extension API which + produces the boilerplate needed to implement a binary function + of two masked strings. + """ + + def masked_binary_func_impl(context, builder, sig, args): + ret = cgutils.create_struct_proxy(sig.return_type)(context, builder) + + lhs_masked = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ) + rhs_masked = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[1] + ) + + result = cuda_func( + context, + builder, + nb_signature(retty, string_view, string_view), + (lhs_masked.value, rhs_masked.value), + ) + + ret.value = result + ret.valid = builder.and_(lhs_masked.valid, rhs_masked.valid) + + return ret._getvalue() + + cuda_lower(op, MaskedType(string_view), MaskedType(string_view))( + masked_binary_func_impl + ) + + +create_binary_string_func( + "MaskedType.startswith", + startswith_impl, + types.boolean, +) +create_binary_string_func("MaskedType.endswith", endswith_impl, types.boolean) +create_binary_string_func("MaskedType.find", find_impl, size_type) +create_binary_string_func("MaskedType.rfind", rfind_impl, size_type) +create_binary_string_func("MaskedType.count", count_impl, size_type) +create_binary_string_func(operator.contains, contains_impl, types.boolean) + + +def create_masked_unary_identifier_func(op, cuda_func): + """ + Provide a wrapper around numba's low-level extension API which + produces the boilerplate needed to implement a unary function + of a masked string. + """ + + def masked_unary_func_impl(context, builder, sig, args): + ret = cgutils.create_struct_proxy(sig.return_type)(context, builder) + masked_str = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ) + + result = cuda_func( + context, + builder, + types.boolean(string_view, string_view), + (masked_str.value,), + ) + ret.value = result + ret.valid = masked_str.valid + return ret._getvalue() + + cuda_lower(op, MaskedType(string_view))(masked_unary_func_impl) + + +create_masked_unary_identifier_func("MaskedType.isalnum", isalnum_impl) +create_masked_unary_identifier_func("MaskedType.isalpha", isalpha_impl) +create_masked_unary_identifier_func("MaskedType.isdigit", isdigit_impl) +create_masked_unary_identifier_func("MaskedType.isupper", isupper_impl) +create_masked_unary_identifier_func("MaskedType.islower", islower_impl) +create_masked_unary_identifier_func("MaskedType.isspace", isspace_impl) +create_masked_unary_identifier_func("MaskedType.isdecimal", isdecimal_impl) diff --git a/python/cudf/cudf/core/udf/strings_typing.py b/python/cudf/cudf/core/udf/strings_typing.py new file mode 100644 index 00000000000..1179688651f --- /dev/null +++ b/python/cudf/cudf/core/udf/strings_typing.py @@ -0,0 +1,182 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import operator + +from numba import types +from numba.core.typing import signature as nb_signature +from numba.core.typing.templates import AbstractTemplate, AttributeTemplate +from numba.cuda.cudadecl import registry as cuda_decl_registry + +from strings_udf._typing import ( + StringView, + bool_binary_funcs, + id_unary_funcs, + int_binary_funcs, + size_type, + string_view, +) + +from cudf.core.udf import masked_typing +from cudf.core.udf._ops import comparison_ops +from cudf.core.udf.masked_typing import MaskedType + +masked_typing.MASKED_INIT_MAP[types.pyobject] = string_view +masked_typing.MASKED_INIT_MAP[string_view] = string_view + + +def _is_valid_string_arg(ty): + return ( + isinstance(ty, MaskedType) and isinstance(ty.value_type, StringView) + ) or isinstance(ty, types.StringLiteral) + + +def register_string_function(func): + """ + Helper function wrapping numba's low level extension API. Provides + the boilerplate needed to associate a signature with a function or + operator to be overloaded. + """ + + def deco(generic): + class MaskedStringFunction(AbstractTemplate): + pass + + MaskedStringFunction.generic = generic + cuda_decl_registry.register_global(func)(MaskedStringFunction) + + return deco + + +@register_string_function(len) +def len_typing(self, args, kws): + if isinstance(args[0], MaskedType) and isinstance( + args[0].value_type, StringView + ): + return nb_signature(MaskedType(size_type), args[0]) + elif isinstance(args[0], types.StringLiteral) and len(args) == 1: + return nb_signature(size_type, args[0]) + + +@register_string_function(operator.contains) +def contains_typing(self, args, kws): + if _is_valid_string_arg(args[0]) and _is_valid_string_arg(args[1]): + return nb_signature( + MaskedType(types.boolean), + MaskedType(string_view), + MaskedType(string_view), + ) + + +class MaskedStringViewCmpOp(AbstractTemplate): + """ + return the boolean result of `cmpop` between to strings + since the typing is the same for every comparison operator, + we can reuse this class for all of them. + """ + + def generic(self, args, kws): + if _is_valid_string_arg(args[0]) and _is_valid_string_arg(args[1]): + return nb_signature( + MaskedType(types.boolean), + MaskedType(string_view), + MaskedType(string_view), + ) + + +for op in comparison_ops: + cuda_decl_registry.register_global(op)(MaskedStringViewCmpOp) + + +def create_masked_binary_attr(attrname, retty): + """ + Helper function wrapping numba's low level extension API. Provides + the boilerplate needed to register a binary function of two masked + string objects as an attribute of one, e.g. `string.func(other)`. + """ + + class MaskedStringViewBinaryAttr(AbstractTemplate): + key = attrname + + def generic(self, args, kws): + return nb_signature( + MaskedType(retty), MaskedType(string_view), recvr=self.this + ) + + def attr(self, mod): + return types.BoundFunction( + MaskedStringViewBinaryAttr, + MaskedType(string_view), + ) + + return attr + + +def create_masked_identifier_attr(attrname): + """ + Helper function wrapping numba's low level extension API. Provides + the boilerplate needed to register a unary function of a masked + string object as an attribute, e.g. `string.func()`. + """ + + class MaskedStringViewIdentifierAttr(AbstractTemplate): + key = attrname + + def generic(self, args, kws): + return nb_signature(MaskedType(types.boolean), recvr=self.this) + + def attr(self, mod): + return types.BoundFunction( + MaskedStringViewIdentifierAttr, + MaskedType(string_view), + ) + + return attr + + +class MaskedStringViewCount(AbstractTemplate): + key = "MaskedType.count" + + def generic(self, args, kws): + return nb_signature( + MaskedType(size_type), MaskedType(string_view), recvr=self.this + ) + + +class MaskedStringViewAttrs(AttributeTemplate): + key = MaskedType(string_view) + + def resolve_count(self, mod): + return types.BoundFunction( + MaskedStringViewCount, MaskedType(string_view) + ) + + def resolve_value(self, mod): + return string_view + + def resolve_valid(self, mod): + return types.boolean + + +# Build attributes for `MaskedType(string_view)` +for func in bool_binary_funcs: + setattr( + MaskedStringViewAttrs, + f"resolve_{func}", + create_masked_binary_attr(f"MaskedType.{func}", types.boolean), + ) + +for func in int_binary_funcs: + setattr( + MaskedStringViewAttrs, + f"resolve_{func}", + create_masked_binary_attr(f"MaskedType.{func}", size_type), + ) + +for func in id_unary_funcs: + setattr( + MaskedStringViewAttrs, + f"resolve_{func}", + create_masked_identifier_attr(f"MaskedType.{func}"), + ) + +cuda_decl_registry.register_attr(MaskedStringViewAttrs) diff --git a/python/cudf/cudf/core/udf/utils.py b/python/cudf/cudf/core/udf/utils.py index 5e46c6d0d77..fa79088046c 100644 --- a/python/cudf/cudf/core/udf/utils.py +++ b/python/cudf/cudf/core/udf/utils.py @@ -1,16 +1,18 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. -from typing import Callable +from typing import Any, Callable, Dict, List import cachetools +import cupy as cp import numpy as np from numba import cuda, typeof from numba.core.errors import TypingError from numba.np import numpy_support -from numba.types import Poison, Tuple, boolean, int64, void +from numba.types import CPointer, Poison, Tuple, boolean, int64, void +from cudf.core.column.column import as_column from cudf.core.dtypes import CategoricalDtype -from cudf.core.udf.typing import MaskedType +from cudf.core.udf.masked_typing import MaskedType from cudf.utils import cudautils from cudf.utils.dtypes import ( BOOL_TYPES, @@ -23,11 +25,12 @@ JIT_SUPPORTED_TYPES = ( NUMERIC_TYPES | BOOL_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES ) - libcudf_bitmask_type = numpy_support.from_dtype(np.dtype("int32")) MASK_BITSIZE = np.dtype("int32").itemsize * 8 precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32) +arg_handlers: List[Any] = [] +ptx_files: List[Any] = [] @_cudf_nvtx_annotate @@ -109,6 +112,9 @@ def _supported_cols_from_frame(frame): } +masked_array_types: Dict[Any, Any] = {} + + def _masked_array_type_from_col(col): """ Return a type representing a tuple of arrays, @@ -116,11 +122,18 @@ def _masked_array_type_from_col(col): corresponding to `dtype`, and the second an array of bools representing a mask. """ - nb_scalar_ty = numpy_support.from_dtype(col.dtype) + + col_type = masked_array_types.get(col.dtype) + if col_type: + col_type = CPointer(col_type) + else: + nb_scalar_ty = numpy_support.from_dtype(col.dtype) + col_type = nb_scalar_ty[::1] + if col.mask is None: - return nb_scalar_ty[::1] + return col_type else: - return Tuple((nb_scalar_ty[::1], libcudf_bitmask_type[::1])) + return Tuple((col_type, libcudf_bitmask_type[::1])) def _construct_signature(frame, return_type, args): @@ -200,7 +213,6 @@ def _compile_or_get(frame, func, args, kernel_getter=None): # could be a MaskedType or a scalar type. kernel, scalar_return_type = kernel_getter(frame, func, args) - np_return_type = numpy_support.as_dtype(scalar_return_type) precompiled[cache_key] = (kernel, np_return_type) @@ -213,6 +225,37 @@ def _get_kernel(kernel_string, globals_, sig, func): globals_["f_"] = f_ exec(kernel_string, globals_) _kernel = globals_["_kernel"] - kernel = cuda.jit(sig)(_kernel) + kernel = cuda.jit(sig, link=ptx_files, extensions=arg_handlers)(_kernel) return kernel + + +launch_arg_getters: Dict[Any, Any] = {} + + +def _get_input_args_from_frame(fr): + args = [] + offsets = [] + for col in _supported_cols_from_frame(fr).values(): + getter = launch_arg_getters.get(col.dtype) + if getter: + data = getter(col) + else: + data = col.data + if col.mask is not None: + # argument is a tuple of data, mask + args.append((data, col.mask)) + else: + # argument is just the data pointer + args.append(data) + offsets.append(col.offset) + + return args + offsets + + +def _return_arr_from_dtype(dt, size): + return cp.empty(size, dtype=dt) + + +def _post_process_output_col(col, retty): + return as_column(col, retty) diff --git a/python/cudf/cudf/tests/test_extension_compilation.py b/python/cudf/cudf/tests/test_extension_compilation.py index 692f40873d7..f1ed17c5df5 100644 --- a/python/cudf/cudf/tests/test_extension_compilation.py +++ b/python/cudf/cudf/tests/test_extension_compilation.py @@ -10,7 +10,7 @@ from cudf import NA from cudf.core.udf.api import Masked -from cudf.core.udf.typing import MaskedType +from cudf.core.udf.masked_typing import MaskedType from cudf.testing._utils import parametrize_numeric_dtypes_pairwise arith_ops = ( diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index 4f385656405..2b96c920765 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -8,6 +8,7 @@ import cudf from cudf.core.missing import NA +from cudf.core.udf import _STRING_UDFS_ENABLED from cudf.core.udf._ops import ( arith_ops, bitwise_ops, @@ -22,6 +23,49 @@ ) +# only run string udf tests if library exists and is enabled +def string_udf_test(f): + if _STRING_UDFS_ENABLED: + return f + else: + return pytest.mark.skip(reason="String UDFs not enabled")(f) + + +@pytest.fixture(scope="module") +def str_udf_data(): + return cudf.DataFrame( + { + "str_col": [ + "abc", + "ABC", + "AbC", + "123", + "123aBc", + "123@.!", + "", + "rapids ai", + "gpu", + "True", + "False", + "1.234", + ".123a", + "0.013", + "1.0", + "01", + "20010101", + "cudf", + "cuda", + "gpu", + ] + } + ) + + +@pytest.fixture(params=["a", "cu", "2", "gpu", "", " "]) +def substr(request): + return request.param + + def run_masked_udf_test(func, data, args=(), **kwargs): gdf = data pdf = data.to_pandas(nullable=True) @@ -537,7 +581,6 @@ def func(row): @pytest.mark.parametrize( "unsupported_col", [ - ["a", "b", "c"], _decimal_series( ["1.0", "2.0", "3.0"], dtype=cudf.Decimal64Dtype(2, 1) ), @@ -682,6 +725,128 @@ def f(x): assert precompiled.currsize == 1 +@string_udf_test +def test_string_udf_len(str_udf_data): + def func(row): + return len(row["str_col"]) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_startswith(str_udf_data, substr): + def func(row): + return row["str_col"].startswith(substr) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_endswith(str_udf_data, substr): + def func(row): + return row["str_col"].endswith(substr) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_find(str_udf_data, substr): + def func(row): + return row["str_col"].find(substr) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_rfind(str_udf_data, substr): + def func(row): + return row["str_col"].rfind(substr) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_contains(str_udf_data, substr): + def func(row): + return substr in row["str_col"] + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +@pytest.mark.parametrize("other", ["cudf", "123", "", " "]) +@pytest.mark.parametrize("cmpop", comparison_ops) +def test_string_udf_cmpops(str_udf_data, other, cmpop): + def func(row): + return cmpop(row["str_col"], other) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_isalnum(str_udf_data): + def func(row): + return row["str_col"].isalnum() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_isalpha(str_udf_data): + def func(row): + return row["str_col"].isalpha() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_isdigit(str_udf_data): + def func(row): + return row["str_col"].isdigit() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_isdecimal(str_udf_data): + def func(row): + return row["str_col"].isdecimal() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_isupper(str_udf_data): + def func(row): + return row["str_col"].isupper() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_islower(str_udf_data): + def func(row): + return row["str_col"].islower() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_isspace(str_udf_data): + def func(row): + return row["str_col"].isspace() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_count(str_udf_data, substr): + def func(row): + return row["str_col"].count(substr) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + @pytest.mark.parametrize( "data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]] ) diff --git a/python/cudf/cudf/utils/cudautils.py b/python/cudf/cudf/utils/cudautils.py index 8b9a6be0ffe..e2bd4556ce8 100755 --- a/python/cudf/cudf/utils/cudautils.py +++ b/python/cudf/cudf/utils/cudautils.py @@ -197,13 +197,15 @@ def make_cache_key(udf, sig): """ codebytes = udf.__code__.co_code constants = udf.__code__.co_consts + names = udf.__code__.co_names + if udf.__closure__ is not None: cvars = tuple(x.cell_contents for x in udf.__closure__) cvarbytes = dumps(cvars) else: cvarbytes = b"" - return constants, codebytes, cvarbytes, sig + return names, constants, codebytes, cvarbytes, sig def compile_udf(udf, type_signature): @@ -248,7 +250,7 @@ def compile_udf(udf, type_signature): ptx_code, return_type = cuda.compile_ptx_for_current_device( udf, type_signature, device=True ) - if not isinstance(return_type, cudf.core.udf.typing.MaskedType): + if not isinstance(return_type, cudf.core.udf.masked_typing.MaskedType): output_type = numpy_support.as_dtype(return_type).type else: output_type = return_type diff --git a/python/cudf/setup.cfg b/python/cudf/setup.cfg index 1f7cfeb49ae..8a648097ac8 100644 --- a/python/cudf/setup.cfg +++ b/python/cudf/setup.cfg @@ -25,6 +25,7 @@ known_dask= dask_cuda known_rapids= rmm + strings_udf known_first_party= cudf default_section=THIRDPARTY @@ -41,4 +42,4 @@ skip= buck-out build dist - __init__.py \ No newline at end of file + __init__.py diff --git a/python/strings_udf/CMakeLists.txt b/python/strings_udf/CMakeLists.txt new file mode 100644 index 00000000000..59d8ae795f2 --- /dev/null +++ b/python/strings_udf/CMakeLists.txt @@ -0,0 +1,43 @@ +# ============================================================================= +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +cmake_minimum_required(VERSION 3.20.1 FATAL_ERROR) + +set(strings_udf_version 22.10.00) + +include(../../fetch_rapids.cmake) + +project( + strings-udf-python + VERSION ${strings_udf_version} + LANGUAGES CXX + # TODO: Building Python extension modules via the python_extension_module requires the C + # language to be enabled here. The test project that is built in scikit-build to verify + # various linking options for the python library is hardcoded to build with C, so until + # that is fixed we need to keep C. + C + # TODO: Enabling CUDA will not be necessary once we upgrade to CMake 3.22, which will + # pull in the required languages for the C++ project even if this project does not + # require those languges. + CUDA +) + +find_package(cudf ${strings_udf_version} REQUIRED) + +add_subdirectory(cpp) + +include(rapids-cython) +rapids_cython_init() + +add_subdirectory(strings_udf/_lib) diff --git a/python/strings_udf/cpp/CMakeLists.txt b/python/strings_udf/cpp/CMakeLists.txt new file mode 100644 index 00000000000..d157acfefde --- /dev/null +++ b/python/strings_udf/cpp/CMakeLists.txt @@ -0,0 +1,111 @@ +# ============================================================================= +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +cmake_minimum_required(VERSION 3.20.1) + +include(rapids-cmake) +include(rapids-cpm) +include(rapids-cuda) +include(rapids-find) + +rapids_cpm_init() + +rapids_cuda_init_architectures(STRINGS_UDF) + +# Create a project so that we can enable CUDA architectures in this file. +project( + strings-udf-cpp + VERSION ${strings_udf_version} + LANGUAGES CUDA +) + +rapids_find_package( + CUDAToolkit REQUIRED + BUILD_EXPORT_SET strings-udf-exports + INSTALL_EXPORT_SET strings-udf-exports +) + +include(${rapids-cmake-dir}/cpm/libcudacxx.cmake) +rapids_cpm_libcudacxx(BUILD_EXPORT_SET strings-udf-exports INSTALL_EXPORT_SET strings-udf-exports) + +add_library(cudf_strings_udf SHARED src/strings/udf/udf_apis.cu) +target_include_directories( + cudf_strings_udf PUBLIC "$" +) + +set_target_properties( + cudf_strings_udf + PROPERTIES BUILD_RPATH "\$ORIGIN" + INSTALL_RPATH "\$ORIGIN" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + INTERFACE_POSITION_INDEPENDENT_CODE ON +) + +set(UDF_CXX_FLAGS) +set(UDF_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr) +target_compile_options( + cudf_strings_udf PRIVATE "$<$:${UDF_CXX_FLAGS}>" + "$<$:${UDF_CUDA_FLAGS}>" +) +target_link_libraries(cudf_strings_udf PUBLIC cudf::cudf CUDA::nvrtc) +install(TARGETS cudf_strings_udf DESTINATION ./strings_udf/_lib/) + +# This function will copy the generated PTX file from its generator-specific location in the build +# tree into a specified location in the build tree from which we can install it. +function(copy_ptx_to_location target destination) + set(cmake_generated_file + "${CMAKE_CURRENT_BINARY_DIR}/cmake/cp_${target}_$>_ptx.cmake" + ) + file( + GENERATE + OUTPUT "${cmake_generated_file}" + CONTENT + " +set(ptx_paths \"$\") +file(COPY \${ptx_paths} DESTINATION \"${destination}\")" + ) + + add_custom_target( + ${target}_cp_ptx ALL + COMMAND ${CMAKE_COMMAND} -P "${cmake_generated_file}" + DEPENDS $ + COMMENT "Copying PTX files to '${destination}'" + ) +endfunction() + +# Create the shim library for each architecture. +set(SHIM_CUDA_FLAGS --expt-relaxed-constexpr -rdc=true) + +foreach(arch IN LISTS CMAKE_CUDA_ARCHITECTURES) + set(tgt shim_${arch}) + + add_library(${tgt} OBJECT src/strings/udf/shim.cu) + + set_target_properties(${tgt} PROPERTIES CUDA_ARCHITECTURES ${arch} CUDA_PTX_COMPILATION ON) + + target_include_directories(${tgt} PUBLIC include) + target_compile_options(${tgt} PRIVATE "$<$:${SHIM_CUDA_FLAGS}>") + target_link_libraries(${tgt} PUBLIC cudf::cudf) + + copy_ptx_to_location(${tgt} "${CMAKE_CURRENT_BINARY_DIR}/../strings_udf") + install( + FILES $ + DESTINATION ./strings_udf + RENAME ${tgt}.ptx + ) +endforeach() diff --git a/python/strings_udf/cpp/include/cudf/strings/udf/char_types.cuh b/python/strings_udf/cpp/include/cudf/strings/udf/char_types.cuh new file mode 100644 index 00000000000..e28111fd1f2 --- /dev/null +++ b/python/strings_udf/cpp/include/cudf/strings/udf/char_types.cuh @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace cudf { +namespace strings { +namespace udf { + +/** + * @brief Returns true if all characters in the string are of the type specified. + * + * The output will be false if the string is empty or has at least one character + * not of the specified type. If all characters fit the type then true is returned. + * + * To ignore all but specific types, set the `verify_types` to those types + * which should be checked. Otherwise, the default `ALL_TYPES` will verify all + * characters match `types`. + * + * @code{.pseudo} + * Examples: + * s = ['ab', 'a b', 'a7', 'a B'] + * all_characters_of_type('ab', LOWER) => true + * all_characters_of_type('a b', LOWER) => false + * all_characters_of_type('a7b', LOWER) => false + * all_characters_of_type('aB', LOWER) => false + * all_characters_of_type('ab', LOWER, LOWER|UPPER) => true + * all_characters_of_type('a b', LOWER, LOWER|UPPER) => true + * all_characters_of_type('a7', LOWER, LOWER|UPPER) => true + * all_characters_of_type('a B', LOWER, LOWER|UPPER) => false + * @endcode + * + * @param flags_table Table of character-type flags + * @param d_str String for this operation + * @param types The character types to check in the string + * @param verify_types Only verify against these character types. + * Default `ALL_TYPES` means return `true` + * iff all characters match `types`. + * @return True if all characters match the type conditions + */ +__device__ inline bool all_characters_of_type( + cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str, + string_character_types types, + string_character_types verify_types = string_character_types::ALL_TYPES) +{ + bool check = !d_str.empty(); // require at least one character + size_type check_count = 0; + for (auto itr = d_str.begin(); check && (itr != d_str.end()); ++itr) { + auto code_point = cudf::strings::detail::utf8_to_codepoint(*itr); + // lookup flags in table by code-point + auto flag = code_point <= 0x00FFFF ? flags_table[code_point] : 0; + if ((verify_types & flag) || // should flag be verified + (flag == 0 && verify_types == ALL_TYPES)) // special edge case + { + check = (types & flag) > 0; + ++check_count; + } + } + return check && (check_count > 0); +} + +/** + * @brief Returns true if all characters are alphabetic only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters alphabetic + */ +__device__ inline bool is_alpha(cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str) +{ + return all_characters_of_type(flags_table, d_str, string_character_types::ALPHA); +} + +/** + * @brief Returns true if all characters are alphanumeric only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters are alphanumeric + */ +__device__ inline bool is_alpha_numeric( + cudf::strings::detail::character_flags_table_type* flags_table, string_view d_str) +{ + return all_characters_of_type(flags_table, d_str, string_character_types::ALPHANUM); +} + +/** + * @brief Returns true if all characters are numeric only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters are numeric + */ +__device__ inline bool is_numeric(cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str) +{ + return all_characters_of_type(flags_table, d_str, string_character_types::NUMERIC); +} + +/** + * @brief Returns true if all characters are digits only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters are digits + */ +__device__ inline bool is_digit(cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str) +{ + return all_characters_of_type(flags_table, d_str, string_character_types::DIGIT); +} + +/** + * @brief Returns true if all characters are decimal only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters are decimal + */ +__device__ inline bool is_decimal(cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str) +{ + return all_characters_of_type(flags_table, d_str, string_character_types::DECIMAL); +} + +/** + * @brief Returns true if all characters are spaces only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters spaces + */ +__device__ inline bool is_space(cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str) +{ + return all_characters_of_type(flags_table, d_str, string_character_types::SPACE); +} + +/** + * @brief Returns true if all characters are upper case only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters are upper case + */ +__device__ inline bool is_upper(cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str) +{ + return all_characters_of_type( + flags_table, d_str, string_character_types::UPPER, string_character_types::CASE_TYPES); +} + +/** + * @brief Returns true if all characters are lower case only + * + * @param flags_table Table required for checking character types + * @param d_str Input string to check + * @return True if characters are lower case + */ +__device__ inline bool is_lower(cudf::strings::detail::character_flags_table_type* flags_table, + string_view d_str) +{ + return all_characters_of_type( + flags_table, d_str, string_character_types::LOWER, string_character_types::CASE_TYPES); +} + +} // namespace udf +} // namespace strings +} // namespace cudf diff --git a/python/strings_udf/cpp/include/cudf/strings/udf/search.cuh b/python/strings_udf/cpp/include/cudf/strings/udf/search.cuh new file mode 100644 index 00000000000..ef15886f1f5 --- /dev/null +++ b/python/strings_udf/cpp/include/cudf/strings/udf/search.cuh @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace cudf { +namespace strings { +namespace udf { + +/** + * @brief Returns the number of times that the target string appears + * in the source string. + * + * If `start <= 0` the search begins at the beginning of the `source` string. + * If `end <=0` or `end` is greater the length of the `source` string, + * the search stops at the end of the string. + * + * @param source Source string to search + * @param target String to match within source + * @param start First character position within source to start the search + * @param end Last character position (exclusive) within source to search + * @return Number of matches + */ +__device__ inline cudf::size_type count(string_view const source, + string_view const target, + cudf::size_type start = 0, + cudf::size_type end = -1) +{ + auto const tgt_length = target.length(); + auto const src_length = source.length(); + + start = start < 0 ? 0 : start; + end = (end < 0 || end > src_length) ? src_length : end; + + if (tgt_length == 0) { return (end - start) + 1; } + cudf::size_type count = 0; + cudf::size_type pos = start; + while (pos != cudf::string_view::npos) { + pos = source.find(target, pos, end - pos); + if (pos != cudf::string_view::npos) { + ++count; + pos += tgt_length; + } + } + return count; +} + +} // namespace udf +} // namespace strings +} // namespace cudf diff --git a/python/strings_udf/cpp/include/cudf/strings/udf/starts_with.cuh b/python/strings_udf/cpp/include/cudf/strings/udf/starts_with.cuh new file mode 100644 index 00000000000..38c609ae505 --- /dev/null +++ b/python/strings_udf/cpp/include/cudf/strings/udf/starts_with.cuh @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf { +namespace strings { +namespace udf { + +/** + * @brief Returns true if the beginning of the specified string + * matches the given character array. + * + * @param dstr String to check + * @param tgt Character array encoded in UTF-8 + * @param bytes Number of bytes to read from `tgt` + * @return true if `tgt` matches the beginning of `dstr` + */ +__device__ inline bool starts_with(cudf::string_view const dstr, + char const* tgt, + cudf::size_type bytes) +{ + if (bytes > dstr.size_bytes()) { return false; } + auto const start_str = cudf::string_view{dstr.data(), bytes}; + return start_str.compare(tgt, bytes) == 0; +} + +/** + * @brief Returns true if the beginning of the specified string + * matches the given target string. + * + * @param dstr String to check + * @param tgt String to match + * @return true if `tgt` matches the beginning of `dstr` + */ +__device__ inline bool starts_with(cudf::string_view const dstr, cudf::string_view const& tgt) +{ + return starts_with(dstr, tgt.data(), tgt.size_bytes()); +} + +/** + * @brief Returns true if the end of the specified string + * matches the given character array. + * + * @param dstr String to check + * @param tgt Character array encoded in UTF-8 + * @param bytes Number of bytes to read from `tgt` + * @return true if `tgt` matches the end of `dstr` + */ +__device__ inline bool ends_with(cudf::string_view const dstr, + char const* tgt, + cudf::size_type bytes) +{ + if (bytes > dstr.size_bytes()) { return false; } + auto const end_str = cudf::string_view{dstr.data() + dstr.size_bytes() - bytes, bytes}; + return end_str.compare(tgt, bytes) == 0; +} + +/** + * @brief Returns true if the end of the specified string + * matches the given target` string. + * + * @param dstr String to check + * @param tgt String to match + * @return true if `tgt` matches the end of `dstr` + */ +__device__ inline bool ends_with(cudf::string_view const dstr, cudf::string_view const& tgt) +{ + return ends_with(dstr, tgt.data(), tgt.size_bytes()); +} + +} // namespace udf +} // namespace strings +} // namespace cudf diff --git a/python/strings_udf/cpp/include/cudf/strings/udf/udf_apis.hpp b/python/strings_udf/cpp/include/cudf/strings/udf/udf_apis.hpp new file mode 100644 index 00000000000..6de9b91de08 --- /dev/null +++ b/python/strings_udf/cpp/include/cudf/strings/udf/udf_apis.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include + +namespace cudf { +namespace strings { +namespace udf { + +/** + * @brief Return a cudf::string_view array for the given strings column + * + * @param input Strings column to convert to a string_view array. + * @throw cudf::logic_error if input is not a strings column. + */ +std::unique_ptr to_string_view_array(cudf::column_view const input); + +} // namespace udf +} // namespace strings +} // namespace cudf diff --git a/python/strings_udf/cpp/src/strings/udf/shim.cu b/python/strings_udf/cpp/src/strings/udf/shim.cu new file mode 100644 index 00000000000..656861f9cd6 --- /dev/null +++ b/python/strings_udf/cpp/src/strings/udf/shim.cu @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +using namespace cudf::strings::udf; + +extern "C" __device__ int len(int* nb_retval, void const* str) +{ + auto sv = reinterpret_cast(str); + *nb_retval = sv->length(); + return 0; +} + +extern "C" __device__ int startswith(bool* nb_retval, void const* str, void const* substr) +{ + auto str_view = reinterpret_cast(str); + auto substr_view = reinterpret_cast(substr); + + *nb_retval = starts_with(*str_view, *substr_view); + return 0; +} + +extern "C" __device__ int endswith(bool* nb_retval, void const* str, void const* substr) +{ + auto str_view = reinterpret_cast(str); + auto substr_view = reinterpret_cast(substr); + + *nb_retval = ends_with(*str_view, *substr_view); + return 0; +} + +extern "C" __device__ int contains(bool* nb_retval, void const* str, void const* substr) +{ + auto str_view = reinterpret_cast(str); + auto substr_view = reinterpret_cast(substr); + + *nb_retval = (str_view->find(*substr_view) != cudf::string_view::npos); + return 0; +} + +extern "C" __device__ int find(int* nb_retval, void const* str, void const* substr) +{ + auto str_view = reinterpret_cast(str); + auto substr_view = reinterpret_cast(substr); + + *nb_retval = str_view->find(*substr_view); + return 0; +} + +extern "C" __device__ int rfind(int* nb_retval, void const* str, void const* substr) +{ + auto str_view = reinterpret_cast(str); + auto substr_view = reinterpret_cast(substr); + + *nb_retval = str_view->rfind(*substr_view); + return 0; +} + +extern "C" __device__ int eq(bool* nb_retval, void const* str, void const* rhs) +{ + auto str_view = reinterpret_cast(str); + auto rhs_view = reinterpret_cast(rhs); + + *nb_retval = (*str_view == *rhs_view); + return 0; +} + +extern "C" __device__ int ne(bool* nb_retval, void const* str, void const* rhs) +{ + auto str_view = reinterpret_cast(str); + auto rhs_view = reinterpret_cast(rhs); + + *nb_retval = (*str_view != *rhs_view); + return 0; +} + +extern "C" __device__ int ge(bool* nb_retval, void const* str, void const* rhs) +{ + auto str_view = reinterpret_cast(str); + auto rhs_view = reinterpret_cast(rhs); + + *nb_retval = (*str_view >= *rhs_view); + return 0; +} + +extern "C" __device__ int le(bool* nb_retval, void const* str, void const* rhs) +{ + auto str_view = reinterpret_cast(str); + auto rhs_view = reinterpret_cast(rhs); + + *nb_retval = (*str_view <= *rhs_view); + return 0; +} + +extern "C" __device__ int gt(bool* nb_retval, void const* str, void const* rhs) +{ + auto str_view = reinterpret_cast(str); + auto rhs_view = reinterpret_cast(rhs); + + *nb_retval = (*str_view > *rhs_view); + return 0; +} + +extern "C" __device__ int lt(bool* nb_retval, void const* str, void const* rhs) +{ + auto str_view = reinterpret_cast(str); + auto rhs_view = reinterpret_cast(rhs); + + *nb_retval = (*str_view < *rhs_view); + return 0; +} + +extern "C" __device__ int pyislower(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_lower( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pyisupper(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_upper( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pyisspace(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_space( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pyisdecimal(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_decimal( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pyisnumeric(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_numeric( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pyisdigit(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_digit( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pyisalnum(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_alpha_numeric( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::int64_t chars_table) +{ + auto str_view = reinterpret_cast(str); + + *nb_retval = is_alpha( + reinterpret_cast(chars_table), *str_view); + return 0; +} + +extern "C" __device__ int pycount(int* nb_retval, void const* str, void const* substr) +{ + auto str_view = reinterpret_cast(str); + auto substr_view = reinterpret_cast(substr); + + *nb_retval = count(*str_view, *substr_view); + return 0; +} diff --git a/python/strings_udf/cpp/src/strings/udf/udf_apis.cu b/python/strings_udf/cpp/src/strings/udf/udf_apis.cu new file mode 100644 index 00000000000..dfef1be39f5 --- /dev/null +++ b/python/strings_udf/cpp/src/strings/udf/udf_apis.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + +#include + +namespace cudf { +namespace strings { +namespace udf { +namespace detail { + +std::unique_ptr to_string_view_array(cudf::column_view const input, + rmm::cuda_stream_view stream) +{ + return std::make_unique( + std::move(cudf::strings::detail::create_string_vector_from_column( + cudf::strings_column_view(input), stream) + .release())); +} + +} // namespace detail + +std::unique_ptr to_string_view_array(cudf::column_view const input) +{ + return detail::to_string_view_array(input, rmm::cuda_stream_default); +} + +} // namespace udf +} // namespace strings +} // namespace cudf diff --git a/python/strings_udf/setup.cfg b/python/strings_udf/setup.cfg new file mode 100644 index 00000000000..9f29b26b5e0 --- /dev/null +++ b/python/strings_udf/setup.cfg @@ -0,0 +1,41 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +[versioneer] +VCS = git +style = pep440 +versionfile_source = strings_udf/_version.py +versionfile_build = strings_udf/_version.py +tag_prefix = v +parentdir_prefix = strings_udf- + +[isort] +line_length=79 +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +combine_as_imports=True +order_by_type=True +known_dask= + dask + distributed + dask_cuda +known_rapids= + rmm + cudf +known_first_party= + strings_udf +default_section=THIRDPARTY +sections=FUTURE,STDLIB,THIRDPARTY,DASK,RAPIDS,FIRSTPARTY,LOCALFOLDER +skip= + thirdparty + .eggs + .git + .hg + .mypy_cache + .tox + .venv + _build + buck-out + build + dist + __init__.py diff --git a/python/strings_udf/setup.py b/python/strings_udf/setup.py new file mode 100644 index 00000000000..c8cafe978f7 --- /dev/null +++ b/python/strings_udf/setup.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import os +import re +import shutil + +import versioneer +from setuptools import find_packages +from skbuild import setup + +install_requires = ["numba>=0.53.1", "numpy", "cudf"] + +extras_require = { + "test": [ + "pytest", + ] +} + + +def get_cuda_version_from_header(cuda_include_dir, delimeter=""): + + cuda_version = None + + with open(os.path.join(cuda_include_dir, "cuda.h"), encoding="utf-8") as f: + for line in f.readlines(): + if re.search(r"#define CUDA_VERSION ", line) is not None: + cuda_version = line + break + + if cuda_version is None: + raise TypeError("CUDA_VERSION not found in cuda.h") + cuda_version = int(cuda_version.split()[2]) + return "%d%s%d" % ( + cuda_version // 1000, + delimeter, + (cuda_version % 1000) // 10, + ) + + +CUDA_HOME = os.environ.get("CUDA_HOME", False) +if not CUDA_HOME: + path_to_cuda_gdb = shutil.which("cuda-gdb") + if path_to_cuda_gdb is None: + raise OSError( + "Could not locate CUDA. " + "Please set the environment variable " + "CUDA_HOME to the path to the CUDA installation " + "and try again." + ) + CUDA_HOME = os.path.dirname(os.path.dirname(path_to_cuda_gdb)) + +if not os.path.isdir(CUDA_HOME): + raise OSError(f"Invalid CUDA_HOME: directory does not exist: {CUDA_HOME}") + +cuda_include_dir = os.path.join(CUDA_HOME, "include") + +setup( + name="strings_udf", + version=versioneer.get_version(), + description="Strings UDF Library", + url="https://github.com/rapidsai/cudf", + author="NVIDIA Corporation", + license="Apache 2.0", + classifiers=[ + "Intended Audience :: Developers", + "Topic :: Database", + "Topic :: Scientific/Engineering", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + ], + packages=find_packages(include=["strings_udf", "strings_udf.*"]), + package_data={ + key: ["*.pxd"] for key in find_packages(include=["strings_udf._lib*"]) + }, + cmdclass=versioneer.get_cmdclass(), + install_requires=install_requires, + extras_require=extras_require, + zip_safe=False, +) diff --git a/python/strings_udf/strings_udf/__init__.py b/python/strings_udf/strings_udf/__init__.py new file mode 100644 index 00000000000..94bd2531779 --- /dev/null +++ b/python/strings_udf/strings_udf/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +import glob +import os +import re +import subprocess +import sys + +from numba import cuda +from ptxcompiler.patch import CMD + +from . import _version + +ENABLED = False + + +def compiler_from_ptx_file(path): + """Parse a PTX file header and extract the CUDA version used to compile it. + + Here is an example PTX header that this function should parse: + + // Generated by NVIDIA NVVM Compiler + // + // Compiler Build ID: CL-30672275 + // Cuda compilation tools, release 11.5, V11.5.119 + // Based on NVVM 7 + """ + file = open(path).read() + major, minor = ( + re.search(r"Cuda compilation tools, release ([0-9\.]+)", file) + .group(1) + .split(".") + ) + return int(major), int(minor) + + +# adapted from PTXCompiler +cp = subprocess.run([sys.executable, "-c", CMD], capture_output=True) +if cp.returncode == 0: + # must have a driver to proceed + versions = [int(s) for s in cp.stdout.strip().split()] + driver_version = tuple(versions[:2]) + runtime_version = tuple(versions[2:]) + + # CUDA enhanced compatibility not yet enabled + if driver_version >= runtime_version: + # Load the highest compute capability file available that is less than + # the current device's. + files = glob.glob( + os.path.join(os.path.dirname(__file__), "shim_*.ptx") + ) + dev = cuda.get_current_device() + cc = "".join(str(x) for x in dev.compute_capability) + files = glob.glob( + os.path.join(os.path.dirname(__file__), "shim_*.ptx") + ) + if len(files) == 0: + raise RuntimeError( + "This strings_udf installation is missing the necessary PTX " + "files. Please file an issue reporting this error and how you " + "installed cudf and strings_udf." + ) + sms = [ + os.path.basename(f).rstrip(".ptx").lstrip("shim_") for f in files + ] + selected_sm = max(sm for sm in sms if sm < cc) + ptxpath = os.path.join( + os.path.dirname(__file__), f"shim_{selected_sm}.ptx" + ) + + if driver_version >= compiler_from_ptx_file(ptxpath): + ENABLED = True + else: + del ptxpath + +__version__ = _version.get_versions()["version"] diff --git a/python/strings_udf/strings_udf/_lib/CMakeLists.txt b/python/strings_udf/strings_udf/_lib/CMakeLists.txt new file mode 100644 index 00000000000..91069a43891 --- /dev/null +++ b/python/strings_udf/strings_udf/_lib/CMakeLists.txt @@ -0,0 +1,25 @@ +# ============================================================================= +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +set(cython_sources cudf_jit_udf.pyx tables.pyx) +set(linked_libraries cudf::cudf cudf_strings_udf) +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" +) + +foreach(cython_module IN LISTS _RAPIDS_CYTHON_CREATED_TARGETS) + set_target_properties(${cython_module} PROPERTIES INSTALL_RPATH "\$ORIGIN;\$ORIGIN/cpp") +endforeach() diff --git a/python/strings_udf/strings_udf/_lib/__init__.py b/python/strings_udf/strings_udf/_lib/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/strings_udf/strings_udf/_lib/cpp/__init__.pxd b/python/strings_udf/strings_udf/_lib/cpp/__init__.pxd new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd b/python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd new file mode 100644 index 00000000000..fb8e3a949bf --- /dev/null +++ b/python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd @@ -0,0 +1,20 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +from libc.stdint cimport uint8_t +from libcpp.memory cimport unique_ptr +from libcpp.string cimport string +from libcpp.vector cimport vector + +from cudf._lib.cpp.column.column cimport column +from cudf._lib.cpp.column.column_view cimport column_view +from cudf._lib.cpp.types cimport size_type +from rmm._lib.device_buffer cimport DeviceBuffer, device_buffer + + +cdef extern from "cudf/strings/udf/udf_apis.hpp" namespace \ + "cudf::strings::udf" nogil: + cdef unique_ptr[device_buffer] to_string_view_array(column_view) except + + +cdef extern from "cudf/strings/detail/char_tables.hpp" namespace \ + "cudf::strings::detail" nogil: + cdef const uint8_t* get_character_flags_table() except + diff --git a/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx b/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx new file mode 100644 index 00000000000..bb1892a4d26 --- /dev/null +++ b/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx @@ -0,0 +1,24 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move + +from cudf.core.buffer import Buffer + +from cudf._lib.column cimport Column +from cudf._lib.cpp.column.column cimport column, column_view +from rmm._lib.device_buffer cimport DeviceBuffer, device_buffer + +from strings_udf._lib.cpp.strings_udf cimport ( + to_string_view_array as cpp_to_string_view_array, +) + + +def to_string_view_array(Column strings_col): + cdef unique_ptr[device_buffer] c_buffer + cdef column_view input_view = strings_col.view() + with nogil: + c_buffer = move(cpp_to_string_view_array(input_view)) + + device_buffer = DeviceBuffer.c_from_unique_ptr(move(c_buffer)) + return Buffer(device_buffer) diff --git a/python/strings_udf/strings_udf/_lib/tables.pyx b/python/strings_udf/strings_udf/_lib/tables.pyx new file mode 100644 index 00000000000..5443364a4a7 --- /dev/null +++ b/python/strings_udf/strings_udf/_lib/tables.pyx @@ -0,0 +1,14 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +from libc.stdint cimport uint8_t, uintptr_t + +from strings_udf._lib.cpp.strings_udf cimport ( + get_character_flags_table as cpp_get_character_flags_table, +) + +import numpy as np + + +def get_character_flags_table_ptr(): + cdef const uint8_t* tbl_ptr = cpp_get_character_flags_table() + return np.int64(tbl_ptr) diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py new file mode 100644 index 00000000000..2e4519a01fe --- /dev/null +++ b/python/strings_udf/strings_udf/_typing.py @@ -0,0 +1,229 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import operator + +import llvmlite.binding as ll +from numba import types +from numba.core.datamodel import default_manager +from numba.core.extending import models, register_model +from numba.core.typing import signature as nb_signature +from numba.core.typing.templates import AbstractTemplate, AttributeTemplate +from numba.cuda.cudadecl import registry as cuda_decl_registry +from numba.cuda.cudadrv import nvvm + +data_layout = nvvm.data_layout + +# libcudf size_type +size_type = types.int32 + +# workaround for numba < 0.56 +if isinstance(data_layout, dict): + data_layout = data_layout[64] +target_data = ll.create_target_data(data_layout) + + +# String object definitions +class DString(types.Type): + def __init__(self): + super().__init__(name="dstring") + llty = default_manager[self].get_value_type() + self.size_bytes = llty.get_abi_size(target_data) + + +class StringView(types.Type): + def __init__(self): + super().__init__(name="string_view") + llty = default_manager[self].get_value_type() + self.size_bytes = llty.get_abi_size(target_data) + + +@register_model(StringView) +class stringview_model(models.StructModel): + # from string_view.hpp: + _members = ( + # const char* _data{} + # Pointer to device memory contain char array for this string + ("data", types.CPointer(types.char)), + # size_type _bytes{}; + # Number of bytes in _data for this string + ("bytes", size_type), + # mutable size_type _length{}; + # Number of characters in this string (computed) + ("length", size_type), + ) + + def __init__(self, dmm, fe_type): + super().__init__(dmm, fe_type, self._members) + + +@register_model(DString) +class dstring_model(models.StructModel): + # from dstring.hpp: + # private: + # char* m_data{}; + # cudf::size_type m_bytes{}; + # cudf::size_type m_size{}; + + _members = ( + ("m_data", types.CPointer(types.char)), + ("m_bytes", size_type), + ("m_size", size_type), + ) + + def __init__(self, dmm, fe_type): + super().__init__(dmm, fe_type, self._members) + + +any_string_ty = (StringView, DString, types.StringLiteral) +string_view = StringView() + + +class StrViewArgHandler: + """ + As part of Numba's preprocessing step, incoming function arguments are + modified based on the associated type for that argument that was used + to JIT the kernel. However it only knows how to handle built in array + types natively. With string UDFs, the jitted type is string_view*, + which numba does not know how to handle. + + This class converts string_view* to raw pointer arguments, which Numba + knows how to use. + + See numba.cuda.compiler._prepare_args for details. + """ + + def prepare_args(self, ty, val, **kwargs): + if isinstance(ty, types.CPointer) and isinstance(ty.dtype, StringView): + return types.uint64, val.ptr + else: + return ty, val + + +str_view_arg_handler = StrViewArgHandler() + + +# String functions +@cuda_decl_registry.register_global(len) +class StringLength(AbstractTemplate): + """ + provide the length of a cudf::string_view like struct + """ + + def generic(self, args, kws): + if isinstance(args[0], any_string_ty) and len(args) == 1: + # length: + # string_view -> int32 + # dstring -> int32 + # literal -> int32 + return nb_signature(size_type, args[0]) + + +def register_stringview_binaryop(op, retty): + """ + Helper function wrapping numba's low level extension API. Provides + the boilerplate needed to associate a signature with a function or + operator expecting a string. + """ + + class StringViewBinaryOp(AbstractTemplate): + def generic(self, args, kws): + if isinstance(args[0], any_string_ty) and isinstance( + args[1], any_string_ty + ): + return nb_signature(retty, string_view, string_view) + + cuda_decl_registry.register_global(op)(StringViewBinaryOp) + + +register_stringview_binaryop(operator.eq, types.boolean) +register_stringview_binaryop(operator.ne, types.boolean) +register_stringview_binaryop(operator.lt, types.boolean) +register_stringview_binaryop(operator.gt, types.boolean) +register_stringview_binaryop(operator.le, types.boolean) +register_stringview_binaryop(operator.ge, types.boolean) +register_stringview_binaryop(operator.contains, types.boolean) + + +def create_binary_attr(attrname, retty): + """ + Helper function wrapping numba's low level extension API. Provides + the boilerplate needed to register a binary function of two string + objects as an attribute of one, e.g. `string.func(other)`. + """ + + class StringViewBinaryAttr(AbstractTemplate): + key = f"StringView.{attrname}" + + def generic(self, args, kws): + return nb_signature(retty, string_view, recvr=self.this) + + def attr(self, mod): + return types.BoundFunction(StringViewBinaryAttr, string_view) + + return attr + + +def create_identifier_attr(attrname): + """ + Helper function wrapping numba's low level extension API. Provides + the boilerplate needed to register a unary function of a string + object as an attribute, e.g. `string.func()`. + """ + + class StringViewIdentifierAttr(AbstractTemplate): + key = f"StringView.{attrname}" + + def generic(self, args, kws): + return nb_signature(types.boolean, recvr=self.this) + + def attr(self, mod): + return types.BoundFunction(StringViewIdentifierAttr, string_view) + + return attr + + +class StringViewCount(AbstractTemplate): + key = "StringView.count" + + def generic(self, args, kws): + return nb_signature(size_type, string_view, recvr=self.this) + + +@cuda_decl_registry.register_attr +class StringViewAttrs(AttributeTemplate): + key = string_view + + def resolve_count(self, mod): + return types.BoundFunction(StringViewCount, string_view) + + +# Build attributes for `MaskedType(string_view)` +bool_binary_funcs = ["startswith", "endswith"] +int_binary_funcs = ["find", "rfind"] +id_unary_funcs = [ + "isalpha", + "isalnum", + "isdecimal", + "isdigit", + "isupper", + "islower", + "isspace", + "isnumeric", +] + +for func in bool_binary_funcs: + setattr( + StringViewAttrs, + f"resolve_{func}", + create_binary_attr(func, types.boolean), + ) + +for func in int_binary_funcs: + setattr( + StringViewAttrs, f"resolve_{func}", create_binary_attr(func, size_type) + ) + +for func in id_unary_funcs: + setattr(StringViewAttrs, f"resolve_{func}", create_identifier_attr(func)) + +cuda_decl_registry.register_attr(StringViewAttrs) diff --git a/python/strings_udf/strings_udf/_version.py b/python/strings_udf/strings_udf/_version.py new file mode 100644 index 00000000000..14ff9ec314d --- /dev/null +++ b/python/strings_udf/strings_udf/_version.py @@ -0,0 +1,711 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import functools +import os +import re +import subprocess +import sys +from typing import Callable, Dict + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "v" + cfg.parentdir_prefix = "strings_udf-" + cfg.versionfile_source = "strings_udf/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + + return decorate + + +def run_command( + commands, args, cwd=None, verbose=False, hide_stderr=False, env=None +): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + + popen_kwargs = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) + break + except OSError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r"\d", r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix) :] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r"\d", r): + continue + if verbose: + print("picking %s" % r) + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner( + GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True + ) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner( + GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root + ) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[: git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + if not mo: + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ( + "unable to parse git-describe output: '%s'" % describe_out + ) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix) :] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver): + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces): + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + if pieces["distance"]: + # update the post release segment + tag_version, post_version = pep440_split_post( + pieces["closest-tag"] + ) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % ( + post_version + 1, + pieces["distance"], + ) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords( + get_keywords(), cfg.tag_prefix, verbose + ) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split("/"): + root = os.path.dirname(root) + except NameError: + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/python/strings_udf/strings_udf/lowering.py b/python/strings_udf/strings_udf/lowering.py new file mode 100644 index 00000000000..fd965a7a187 --- /dev/null +++ b/python/strings_udf/strings_udf/lowering.py @@ -0,0 +1,287 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import operator +from functools import partial + +from numba import cuda, types +from numba.core import cgutils +from numba.core.typing import signature as nb_signature +from numba.cuda.cudadrv import nvvm +from numba.cuda.cudaimpl import ( + lower as cuda_lower, + registry as cuda_lowering_registry, +) + +from strings_udf._lib.tables import get_character_flags_table_ptr +from strings_udf._typing import size_type, string_view + +character_flags_table_ptr = get_character_flags_table_ptr() + + +# read-only functions +# We will provide only one overload for this set of functions, which will +# expect a string_view. When a literal is encountered, numba will promote it to +# a string_view whereas when a dstring is encountered, numba will convert it to +# a view via its native view() method. + +_STR_VIEW_PTR = types.CPointer(string_view) + + +# CUDA function declarations +_string_view_len = cuda.declare_device("len", size_type(_STR_VIEW_PTR)) + + +def _declare_binary_func(lhs, rhs, out, name): + # Declare a binary function + return cuda.declare_device( + name, + out(lhs, rhs), + ) + + +# A binary function of the form f(string, string) -> bool +_declare_bool_str_str_func = partial( + _declare_binary_func, _STR_VIEW_PTR, _STR_VIEW_PTR, types.boolean +) + +_declare_size_type_str_str_func = partial( + _declare_binary_func, _STR_VIEW_PTR, _STR_VIEW_PTR, size_type +) + +_string_view_contains = _declare_bool_str_str_func("contains") +_string_view_eq = _declare_bool_str_str_func("eq") +_string_view_ne = _declare_bool_str_str_func("ne") +_string_view_ge = _declare_bool_str_str_func("ge") +_string_view_le = _declare_bool_str_str_func("le") +_string_view_gt = _declare_bool_str_str_func("gt") +_string_view_lt = _declare_bool_str_str_func("lt") +_string_view_startswith = _declare_bool_str_str_func("startswith") +_string_view_endswith = _declare_bool_str_str_func("endswith") +_string_view_find = _declare_size_type_str_str_func("find") +_string_view_rfind = _declare_size_type_str_str_func("rfind") +_string_view_contains = _declare_bool_str_str_func("contains") + + +# A binary function of the form f(string, int) -> bool +_declare_bool_str_int_func = partial( + _declare_binary_func, _STR_VIEW_PTR, types.int64, types.boolean +) + + +_string_view_isdigit = _declare_bool_str_int_func("pyisdigit") +_string_view_isalnum = _declare_bool_str_int_func("pyisalnum") +_string_view_isalpha = _declare_bool_str_int_func("pyisalpha") +_string_view_isdecimal = _declare_bool_str_int_func("pyisdecimal") +_string_view_isnumeric = _declare_bool_str_int_func("pyisnumeric") +_string_view_isspace = _declare_bool_str_int_func("pyisspace") +_string_view_isupper = _declare_bool_str_int_func("pyisupper") +_string_view_islower = _declare_bool_str_int_func("pyislower") + + +_string_view_count = cuda.declare_device( + "pycount", + size_type(_STR_VIEW_PTR, _STR_VIEW_PTR), +) + + +# casts +@cuda_lowering_registry.lower_cast(types.StringLiteral, string_view) +def cast_string_literal_to_string_view(context, builder, fromty, toty, val): + """ + Cast a literal to a string_view + """ + # create an empty string_view + sv = cgutils.create_struct_proxy(string_view)(context, builder) + + # set the empty strview data pointer to point to the literal value + s = context.insert_const_string(builder.module, fromty.literal_value) + sv.data = context.insert_addrspace_conv( + builder, s, nvvm.ADDRSPACE_CONSTANT + ) + sv.length = context.get_constant(size_type, len(fromty.literal_value)) + sv.bytes = context.get_constant( + size_type, len(fromty.literal_value.encode("UTF-8")) + ) + + return sv._getvalue() + + +# String function implementations +def call_len_string_view(st): + return _string_view_len(st) + + +@cuda_lower(len, string_view) +def len_impl(context, builder, sig, args): + sv_ptr = builder.alloca(args[0].type) + builder.store(args[0], sv_ptr) + result = context.compile_internal( + builder, + call_len_string_view, + nb_signature(size_type, _STR_VIEW_PTR), + (sv_ptr,), + ) + + return result + + +def create_binary_string_func(binary_func, retty): + """ + Provide a wrapper around numba's low-level extension API which + produces the boilerplate needed to implement a binary function + of two strings. + """ + + def deco(cuda_func): + @cuda_lower(binary_func, string_view, string_view) + def binary_func_impl(context, builder, sig, args): + lhs_ptr = builder.alloca(args[0].type) + rhs_ptr = builder.alloca(args[1].type) + + builder.store(args[0], lhs_ptr) + builder.store(args[1], rhs_ptr) + result = context.compile_internal( + builder, + cuda_func, + nb_signature(retty, _STR_VIEW_PTR, _STR_VIEW_PTR), + (lhs_ptr, rhs_ptr), + ) + + return result + + return binary_func_impl + + return deco + + +@create_binary_string_func(operator.contains, types.boolean) +def contains_impl(st, substr): + return _string_view_contains(st, substr) + + +@create_binary_string_func(operator.eq, types.boolean) +def eq_impl(st, rhs): + return _string_view_eq(st, rhs) + + +@create_binary_string_func(operator.ne, types.boolean) +def ne_impl(st, rhs): + return _string_view_ne(st, rhs) + + +@create_binary_string_func(operator.ge, types.boolean) +def ge_impl(st, rhs): + return _string_view_ge(st, rhs) + + +@create_binary_string_func(operator.le, types.boolean) +def le_impl(st, rhs): + return _string_view_le(st, rhs) + + +@create_binary_string_func(operator.gt, types.boolean) +def gt_impl(st, rhs): + return _string_view_gt(st, rhs) + + +@create_binary_string_func(operator.lt, types.boolean) +def lt_impl(st, rhs): + return _string_view_lt(st, rhs) + + +@create_binary_string_func("StringView.startswith", types.boolean) +def startswith_impl(sv, substr): + return _string_view_startswith(sv, substr) + + +@create_binary_string_func("StringView.endswith", types.boolean) +def endswith_impl(sv, substr): + return _string_view_endswith(sv, substr) + + +@create_binary_string_func("StringView.count", size_type) +def count_impl(st, substr): + return _string_view_count(st, substr) + + +@create_binary_string_func("StringView.find", size_type) +def find_impl(sv, substr): + return _string_view_find(sv, substr) + + +@create_binary_string_func("StringView.rfind", size_type) +def rfind_impl(sv, substr): + return _string_view_rfind(sv, substr) + + +def create_unary_identifier_func(id_func): + """ + Provide a wrapper around numba's low-level extension API which + produces the boilerplate needed to implement a unary function + of a string. + """ + + def deco(cuda_func): + @cuda_lower(id_func, string_view) + def id_func_impl(context, builder, sig, args): + str_ptr = builder.alloca(args[0].type) + builder.store(args[0], str_ptr) + + # Lookup table required for conversion functions + # must be resolved at runtime after context initialization, + # therefore cannot be a global variable + tbl_ptr = context.get_constant( + types.int64, character_flags_table_ptr + ) + result = context.compile_internal( + builder, + cuda_func, + nb_signature(types.boolean, _STR_VIEW_PTR, types.int64), + (str_ptr, tbl_ptr), + ) + + return result + + return id_func_impl + + return deco + + +@create_unary_identifier_func("StringView.isdigit") +def isdigit_impl(st, tbl): + return _string_view_isdigit(st, tbl) + + +@create_unary_identifier_func("StringView.isalnum") +def isalnum_impl(st, tbl): + return _string_view_isalnum(st, tbl) + + +@create_unary_identifier_func("StringView.isalpha") +def isalpha_impl(st, tbl): + return _string_view_isalpha(st, tbl) + + +@create_unary_identifier_func("StringView.isnumeric") +def isnumeric_impl(st, tbl): + return _string_view_isnumeric(st, tbl) + + +@create_unary_identifier_func("StringView.isdecimal") +def isdecimal_impl(st, tbl): + return _string_view_isdecimal(st, tbl) + + +@create_unary_identifier_func("StringView.isspace") +def isspace_impl(st, tbl): + return _string_view_isspace(st, tbl) + + +@create_unary_identifier_func("StringView.isupper") +def isupper_impl(st, tbl): + return _string_view_isupper(st, tbl) + + +@create_unary_identifier_func("StringView.islower") +def islower_impl(st, tbl): + return _string_view_islower(st, tbl) diff --git a/python/strings_udf/strings_udf/tests/test_string_udfs.py b/python/strings_udf/strings_udf/tests/test_string_udfs.py new file mode 100644 index 00000000000..9038f4cc79a --- /dev/null +++ b/python/strings_udf/strings_udf/tests/test_string_udfs.py @@ -0,0 +1,249 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import numba +import numpy as np +import pandas as pd +import pytest +from numba import cuda +from numba.core.typing import signature as nb_signature +from numba.types import CPointer, void + +import cudf +from cudf.testing._utils import assert_eq + +import strings_udf +from strings_udf._lib.cudf_jit_udf import to_string_view_array +from strings_udf._typing import str_view_arg_handler, string_view + +if not strings_udf.ENABLED: + pytest.skip("Strings UDF not enabled.", allow_module_level=True) + + +def get_kernel(func, dtype): + """ + Create a kernel for testing a single scalar string function + Allocates an output vector with a dtype specified by the caller + The returned kernel executes the input function on each data + element of the input and returns the output into the output vector + """ + + func = cuda.jit(device=True)(func) + outty = numba.np.numpy_support.from_dtype(dtype) + sig = nb_signature(void, CPointer(string_view), outty[::1]) + + @cuda.jit( + sig, link=[strings_udf.ptxpath], extensions=[str_view_arg_handler] + ) + def kernel(input_strings, output_col): + id = cuda.grid(1) + if id < len(output_col): + st = input_strings[id] + result = func(st) + output_col[id] = result + + return kernel + + +def run_udf_test(data, func, dtype): + """ + Run a test kernel on a set of input data + Converts the input data to a cuDF column and subsequently + to an array of cudf::string_view objects. It then creates + a CUDA kernel using get_kernel which calls the input function, + and then assembles the result back into a cuDF series before + comparing it with the equivalent pandas result + """ + dtype = np.dtype(dtype) + cudf_column = cudf.core.column.as_column(data) + str_view_ary = to_string_view_array(cudf_column) + + output_ary = cudf.core.column.column_empty(len(data), dtype=dtype) + + kernel = get_kernel(func, dtype) + kernel.forall(len(data))(str_view_ary, output_ary) + got = cudf.Series(output_ary, dtype=dtype) + expect = pd.Series(data).apply(func) + assert_eq(expect, got, check_dtype=False) + + +@pytest.fixture(scope="module") +def data(): + return [ + "abc", + "ABC", + "AbC", + "123", + "123aBc", + "123@.!", + "", + "rapids ai", + "gpu", + "True", + "False", + "1.234", + ".123a", + "0.013", + "1.0", + "01", + "20010101", + "cudf", + "cuda", + "gpu", + ] + + +@pytest.fixture(params=["cudf", "cuda", "gpucudf", "abc"]) +def rhs(request): + return request.param + + +@pytest.fixture(params=["c", "cu", "2", "abc", "", "gpu"]) +def substr(request): + return request.param + + +def test_string_udf_eq(data, rhs): + def func(st): + return st == rhs + + run_udf_test(data, func, "bool") + + +def test_string_udf_ne(data, rhs): + def func(st): + return st != rhs + + run_udf_test(data, func, "bool") + + +def test_string_udf_ge(data, rhs): + def func(st): + return st >= rhs + + run_udf_test(data, func, "bool") + + +def test_string_udf_le(data, rhs): + def func(st): + return st <= rhs + + run_udf_test(data, func, "bool") + + +def test_string_udf_gt(data, rhs): + def func(st): + return st > rhs + + run_udf_test(data, func, "bool") + + +def test_string_udf_lt(data, rhs): + def func(st): + return st < rhs + + run_udf_test(data, func, "bool") + + +def test_string_udf_contains(data, substr): + def func(st): + return substr in st + + run_udf_test(data, func, "bool") + + +def test_string_udf_count(data, substr): + def func(st): + return st.count(substr) + + run_udf_test(data, func, "int32") + + +def test_string_udf_find(data, substr): + def func(st): + return st.find(substr) + + run_udf_test(data, func, "int32") + + +def test_string_udf_endswith(data, substr): + def func(st): + return st.endswith(substr) + + run_udf_test(data, func, "bool") + + +def test_string_udf_isalnum(data): + def func(st): + return st.isalnum() + + run_udf_test(data, func, "bool") + + +def test_string_udf_isalpha(data): + def func(st): + return st.isalpha() + + run_udf_test(data, func, "bool") + + +def test_string_udf_isdecimal(data): + def func(st): + return st.isdecimal() + + run_udf_test(data, func, "bool") + + +def test_string_udf_isdigit(data): + def func(st): + return st.isdigit() + + run_udf_test(data, func, "bool") + + +def test_string_udf_islower(data): + def func(st): + return st.islower() + + run_udf_test(data, func, "bool") + + +def test_string_udf_isnumeric(data): + def func(st): + return st.isnumeric() + + run_udf_test(data, func, "bool") + + +def test_string_udf_isspace(data): + def func(st): + return st.isspace() + + run_udf_test(data, func, "bool") + + +def test_string_udf_isupper(data): + def func(st): + return st.isupper() + + run_udf_test(data, func, "bool") + + +def test_string_udf_len(data): + def func(st): + return len(st) + + run_udf_test(data, func, "int64") + + +def test_string_udf_rfind(data, substr): + def func(st): + return st.rfind(substr) + + run_udf_test(data, func, "int32") + + +def test_string_udf_startswith(data, substr): + def func(st): + return st.startswith(substr) + + run_udf_test(data, func, "bool") diff --git a/python/strings_udf/versioneer.py b/python/strings_udf/versioneer.py new file mode 100644 index 00000000000..6194b6a5698 --- /dev/null +++ b/python/strings_udf/versioneer.py @@ -0,0 +1,2245 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +# Version: 0.23 + +"""The Versioneer - like a rocketeer, but for versions. + +The Versioneer +============== + +* like a rocketeer, but for versions! +* https://github.com/python-versioneer/python-versioneer +* Brian Warner +* License: Public Domain (CC0-1.0) +* Compatible with: Python 3.7, 3.8, 3.9, 3.10 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in +distutils/setuptools-based python projects. The goal is to +remove the tedious and error-prone "update the embedded version string" +step from your release process. Making a new release should be as easy +as recording a new tag in your version-control +system, and maybe making new tarballs. + + +## Quick Install + +* `pip install versioneer` to somewhere in your $PATH +* add a `[versioneer]` section to your setup.cfg (see [Install](INSTALL.md)) +* run `versioneer install` in your source tree, commit the results +* Verify version information with `python setup.py version` + +## Version Identifiers + +Source trees come from a variety of places: + +* a version-control system checkout (mostly used by developers) +* a nightly tarball, produced by build automation +* a snapshot tarball, produced by a web-based VCS browser, like github's + "tarball from tag" feature +* a release tarball, produced by "setup.py sdist", distributed through PyPI + +Within each source tree, the version identifier (either a string or a number, +this tool is format-agnostic) can come from a variety of places: + +* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows + about recent "tags" and an absolute revision-id +* the name of the directory into which the tarball was unpacked +* an expanded VCS keyword ($Id$, etc) +* a `_version.py` created by some earlier build step + +For released software, the version identifier is closely related to a VCS +tag. Some projects use tag names that include more than just the version +string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool +needs to strip the tag prefix to extract the version identifier. For +unreleased software (between tags), the version identifier should provide +enough information to help developers recreate the same tree, while also +giving them an idea of roughly how old the tree is (after version 1.2, before +version 1.3). Many VCS systems can report a description that captures this, +for example `git describe --tags --dirty --always` reports things like +"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the +0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has +uncommitted changes). + +The version identifier is used for multiple purposes: + +* to allow the module to self-identify its version: `myproject.__version__` +* to choose a name and prefix for a 'setup.py sdist' tarball + +## Theory of Operation + +Versioneer works by adding a special `_version.py` file into your source +tree, where your `__init__.py` can import it. This `_version.py` knows how to +dynamically ask the VCS tool for version information at import time. + +`_version.py` also contains `$Revision$` markers, and the installation +process marks `_version.py` to have this marker rewritten with a tag name +during the `git archive` command. As a result, generated tarballs will +contain enough information to get the proper version. + +To allow `setup.py` to compute a version too, a `versioneer.py` is added to +the top level of your source tree, next to `setup.py` and the `setup.cfg` +that configures it. This overrides several distutils/setuptools commands to +compute the version when invoked, and changes `setup.py build` and `setup.py +sdist` to replace `_version.py` with a small static file that contains just +the generated version data. + +## Installation + +See [INSTALL.md](./INSTALL.md) for detailed installation instructions. + +## Version-String Flavors + +Code which uses Versioneer can learn about its version string at runtime by +importing `_version` from your main `__init__.py` file and running the +`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can +import the top-level `versioneer.py` and run `get_versions()`. + +Both functions return a dictionary with different flavors of version +information: + +* `['version']`: A condensed version string, rendered using the selected + style. This is the most commonly used value for the project's version + string. The default "pep440" style yields strings like `0.11`, + `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section + below for alternative styles. + +* `['full-revisionid']`: detailed revision identifier. For Git, this is the + full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". + +* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the + commit date in ISO 8601 format. This will be None if the date is not + available. + +* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that + this is only accurate if run in a VCS checkout, otherwise it is likely to + be False or None + +* `['error']`: if the version string could not be computed, this will be set + to a string describing the problem, otherwise it will be None. It may be + useful to throw an exception in setup.py if this is set, to avoid e.g. + creating tarballs with a version string of "unknown". + +Some variants are more useful than others. Including `full-revisionid` in a +bug report should allow developers to reconstruct the exact code being tested +(or indicate the presence of local changes that should be shared with the +developers). `version` is suitable for display in an "about" box or a CLI +`--version` output: it can be easily compared against release notes and lists +of bugs fixed in various releases. + +The installer adds the following text to your `__init__.py` to place a basic +version in `YOURPROJECT.__version__`: + + from ._version import get_versions + __version__ = get_versions()['version'] + del get_versions + +## Styles + +The setup.cfg `style=` configuration controls how the VCS information is +rendered into a version string. + +The default style, "pep440", produces a PEP440-compliant string, equal to the +un-prefixed tag name for actual releases, and containing an additional "local +version" section with more detail for in-between builds. For Git, this is +TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags +--dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the +tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and +that this commit is two revisions ("+2") beyond the "0.11" tag. For released +software (exactly equal to a known tag), the identifier will only contain the +stripped tag, e.g. "0.11". + +Other styles are available. See [details.md](details.md) in the Versioneer +source tree for descriptions. + +## Debugging + +Versioneer tries to avoid fatal errors: if something goes wrong, it will tend +to return a version of "0+unknown". To investigate the problem, run `setup.py +version`, which will run the version-lookup code in a verbose mode, and will +display the full contents of `get_versions()` (including the `error` string, +which may help identify what went wrong). + +## Known Limitations + +Some situations are known to cause problems for Versioneer. This details the +most significant ones. More can be found on Github +[issues page](https://github.com/python-versioneer/python-versioneer/issues). + +### Subprojects + +Versioneer has limited support for source trees in which `setup.py` is not in +the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are +two common reasons why `setup.py` might not be in the root: + +* Source trees which contain multiple subprojects, such as + [Buildbot](https://github.com/buildbot/buildbot), which contains both + "master" and "slave" subprojects, each with their own `setup.py`, + `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI + distributions (and upload multiple independently-installable tarballs). +* Source trees whose main purpose is to contain a C library, but which also + provide bindings to Python (and perhaps other languages) in subdirectories. + +Versioneer will look for `.git` in parent directories, and most operations +should get the right version string. However `pip` and `setuptools` have bugs +and implementation details which frequently cause `pip install .` from a +subproject directory to fail to find a correct version string (so it usually +defaults to `0+unknown`). + +`pip install --editable .` should work correctly. `setup.py install` might +work too. + +Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in +some later version. + +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) +is tracking this issue. The discussion in +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) +describes the issue from the Versioneer side in more detail. +[pip PR#3176](https://github.com/pypa/pip/pull/3176) and +[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve +pip to let Versioneer work correctly. + +Versioneer-0.16 and earlier only looked for a `.git` directory next to the +`setup.cfg`, so subprojects were completely unsupported with those releases. + +### Editable installs with setuptools <= 18.5 + +`setup.py develop` and `pip install --editable .` allow you to install a +project into a virtualenv once, then continue editing the source code (and +test) without re-installing after every change. + +"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a +convenient way to specify executable scripts that should be installed along +with the python package. + +These both work as expected when using modern setuptools. When using +setuptools-18.5 or earlier, however, certain operations will cause +`pkg_resources.DistributionNotFound` errors when running the entrypoint +script, which must be resolved by re-installing the package. This happens +when the install happens with one version, then the egg_info data is +regenerated while a different version is checked out. Many setup.py commands +cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into +a different virtualenv), so this can be surprising. + +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) +describes this one, but upgrading to a newer version of setuptools should +probably resolve it. + + +## Updating Versioneer + +To upgrade your project to a new release of Versioneer, do the following: + +* install the new Versioneer (`pip install -U versioneer` or equivalent) +* edit `setup.cfg`, if necessary, to include any new configuration settings + indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install` in your source tree, to replace + `SRC/_version.py` +* commit any changed files + +## Future Directions + +This tool is designed to make it easily extended to other version-control +systems: all VCS-specific components are in separate directories like +src/git/ . The top-level `versioneer.py` script is assembled from these +components by running make-versioneer.py . In the future, make-versioneer.py +will take a VCS name as an argument, and will construct a version of +`versioneer.py` that is specific to the given VCS. It might also take the +configuration arguments that are currently provided manually during +installation by editing setup.py . Alternatively, it might go the other +direction and include code from all supported VCS systems, reducing the +number of intermediate scripts. + +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored + build-time dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight + reimplementation of versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based + setuptools plugin + +## License + +To make Versioneer easier to embed, all its code is dedicated to the public +domain. The `_version.py` that it creates is also in the public domain. +Specifically, both are released under the Creative Commons "Public Domain +Dedication" license (CC0-1.0), as described in +https://creativecommons.org/publicdomain/zero/1.0/ . + +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer + +""" +# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring +# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements +# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error +# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with +# pylint:disable=attribute-defined-outside-init,too-many-arguments + +import configparser +import errno +import functools +import json +import os +import re +import subprocess +import sys +from typing import Callable, Dict + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_root(): + """Get the project root directory. + + We require that all commands are run from the project root, i.e. the + directory that contains setup.py, setup.cfg, and versioneer.py . + """ + root = os.path.realpath(os.path.abspath(os.getcwd())) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + # allow 'python path/to/setup.py COMMAND' + root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) + raise VersioneerBadRootError(err) + try: + # Certain runtime workflows (setup.py install/develop in a setuptools + # tree) execute all dependencies in a single python process, so + # "versioneer" may be imported multiple times, and python's shared + # module-import table will cache the first one. So we can't use + # os.path.dirname(__file__), as that will find whichever + # versioneer.py was first imported, even in later projects. + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) + vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) + if me_dir != vsr_dir: + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py) + ) + except NameError: + pass + return root + + +def get_config_from_root(root): + """Read the project setup.cfg file to determine Versioneer config.""" + # This might raise OSError (if setup.cfg is missing), or + # configparser.NoSectionError (if it lacks a [versioneer] section), or + # configparser.NoOptionError (if it lacks "VCS="). See the docstring at + # the top of versioneer.py for instructions on writing your setup.cfg . + setup_cfg = os.path.join(root, "setup.cfg") + parser = configparser.ConfigParser() + with open(setup_cfg, "r") as cfg_file: + parser.read_file(cfg_file) + VCS = parser.get("versioneer", "VCS") # mandatory + + # Dict-like interface for non-mandatory entries + section = parser["versioneer"] + + cfg = VersioneerConfig() + cfg.VCS = VCS + cfg.style = section.get("style", "") + cfg.versionfile_source = section.get("versionfile_source") + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = section.get("tag_prefix") + if cfg.tag_prefix in ("''", '""', None): + cfg.tag_prefix = "" + cfg.parentdir_prefix = section.get("parentdir_prefix") + cfg.verbose = section.get("verbose") + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +# these dictionaries contain VCS-specific tools +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + HANDLERS.setdefault(vcs, {})[method] = f + return f + + return decorate + + +def run_command( + commands, args, cwd=None, verbose=False, hide_stderr=False, env=None +): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + + popen_kwargs = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) + break + except OSError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +LONG_VERSION_PY[ + "git" +] = r''' +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys +from typing import Callable, Dict +import functools + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" + git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" + git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "%(STYLE)s" + cfg.tag_prefix = "%(TAG_PREFIX)s" + cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" + cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + + popen_kwargs = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) + break + except OSError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %%s" %% dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %%s" %% (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %%s (error)" %% dispcmd) + print("stdout was %%s" %% stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %%s but none started with prefix %%s" %% + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %%d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r'\d', r)} + if verbose: + print("discarding '%%s', no digits" %% ",".join(refs - tags)) + if verbose: + print("likely tags: %%s" %% ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue + if verbose: + print("picking %%s" %% r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %%s not under git control" %% root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%%s'" + %% describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%%s' doesn't start with prefix '%%s'" + print(fmt %% (full_tag, tag_prefix)) + pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" + %% (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], + cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver): + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces): + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + if pieces["distance"]: + # update the post release segment + tag_version, post_version = pep440_split_post( + pieces["closest-tag"] + ) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% ( + post_version + 1, pieces["distance"] + ) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] + else: + # exception #1 + rendered = "0.post0.dev%%d" %% pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%%s'" %% style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} +''' + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r"\d", r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix) :] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r"\d", r): + continue + if verbose: + print("picking %s" % r) + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner( + GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True + ) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner( + GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root + ) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[: git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + if not mo: + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ( + "unable to parse git-describe output: '%s'" % describe_out + ) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix) :] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def do_vcs_install(versionfile_source, ipy): + """Git-specific installation logic for Versioneer. + + For Git, this means creating/changing .gitattributes to mark _version.py + for export-subst keyword substitution. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + files = [versionfile_source] + if ipy: + files.append(ipy) + try: + my_path = __file__ + if my_path.endswith(".pyc") or my_path.endswith(".pyo"): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) + present = False + try: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: + pass + if not present: + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") + files.append(".gitattributes") + run_command(GITS, ["add", "--"] + files) + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +SHORT_VERSION_PY = """ +# This file was generated by 'versioneer.py' (0.23) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json + +version_json = ''' +%s +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) +""" + + +def versions_from_file(filename): + """Try to determine the version from _version.py if present.""" + try: + with open(filename) as f: + contents = f.read() + except OSError: + raise NotThisMethod("unable to read _version.py") + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, + re.M | re.S, + ) + if not mo: + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, + re.M | re.S, + ) + if not mo: + raise NotThisMethod("no version_json in _version.py") + return json.loads(mo.group(1)) + + +def write_to_version_file(filename, versions): + """Write the given version number to the given _version.py file.""" + os.unlink(filename) + contents = json.dumps( + versions, sort_keys=True, indent=1, separators=(",", ": ") + ) + with open(filename, "w") as f: + f.write(SHORT_VERSION_PY % contents) + + print("set %s to '%s'" % (filename, versions["version"])) + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver): + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces): + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + if pieces["distance"]: + # update the post release segment + tag_version, post_version = pep440_split_post( + pieces["closest-tag"] + ) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % ( + post_version + 1, + pieces["distance"], + ) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } + + +class VersioneerBadRootError(Exception): + """The project root directory is unknown or missing key files.""" + + +def get_versions(verbose=False): + """Get the project version from whatever source is available. + + Returns dict with two keys: 'version' and 'full'. + """ + if "versioneer" in sys.modules: + # see the discussion in cmdclass.py:get_cmdclass() + del sys.modules["versioneer"] + + root = get_root() + cfg = get_config_from_root(root) + + assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" + handlers = HANDLERS.get(cfg.VCS) + assert handlers, "unrecognized VCS '%s'" % cfg.VCS + verbose = verbose or cfg.verbose + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" + assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" + + versionfile_abs = os.path.join(root, cfg.versionfile_source) + + # extract version from first of: _version.py, VCS command (e.g. 'git + # describe'), parentdir. This is meant to work for developers using a + # source checkout, for users of a tarball created by 'setup.py sdist', + # and for users of a tarball/zipball created by 'git archive' or github's + # download-from-tag feature or the equivalent in other VCSes. + + get_keywords_f = handlers.get("get_keywords") + from_keywords_f = handlers.get("keywords") + if get_keywords_f and from_keywords_f: + try: + keywords = get_keywords_f(versionfile_abs) + ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) + if verbose: + print("got version from expanded keyword %s" % ver) + return ver + except NotThisMethod: + pass + + try: + ver = versions_from_file(versionfile_abs) + if verbose: + print("got version from file %s %s" % (versionfile_abs, ver)) + return ver + except NotThisMethod: + pass + + from_vcs_f = handlers.get("pieces_from_vcs") + if from_vcs_f: + try: + pieces = from_vcs_f(cfg.tag_prefix, root, verbose) + ver = render(pieces, cfg.style) + if verbose: + print("got version from VCS %s" % ver) + return ver + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + if verbose: + print("got version from parentdir %s" % ver) + return ver + except NotThisMethod: + pass + + if verbose: + print("unable to compute version") + + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } + + +def get_version(): + """Get the short version string for this project.""" + return get_versions()["version"] + + +def get_cmdclass(cmdclass=None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ + if "versioneer" in sys.modules: + del sys.modules["versioneer"] + # this fixes the "python setup.py develop" case (also 'install' and + # 'easy_install .'), in which subdependencies of the main project are + # built (using setup.py bdist_egg) in the same python process. Assume + # a main project A and a dependency B, which use different versions + # of Versioneer. A's setup.py imports A's Versioneer, leaving it in + # sys.modules by the time B's setup.py is executed, causing B to run + # with the wrong versioneer. Setuptools wraps the sub-dep builds in a + # sandbox that restores sys.modules to it's pre-build state, so the + # parent is protected against the child's "import versioneer". By + # removing ourselves from sys.modules here, before the child build + # happens, we protect the child from the parent's versioneer too. + # Also see + # https://github.com/python-versioneer/python-versioneer/issues/52 + + cmds = {} if cmdclass is None else cmdclass.copy() + + # we add "version" to setuptools + from setuptools import Command + + class cmd_version(Command): + description = "report generated version string" + user_options = [] + boolean_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + vers = get_versions(verbose=True) + print("Version: %s" % vers["version"]) + print(" full-revisionid: %s" % vers.get("full-revisionid")) + print(" dirty: %s" % vers.get("dirty")) + print(" date: %s" % vers.get("date")) + if vers["error"]: + print(" error: %s" % vers["error"]) + + cmds["version"] = cmd_version + + # we override "build_py" in setuptools + # + # most invocation pathways end up running build_py: + # distutils/build -> build_py + # distutils/install -> distutils/build ->.. + # setuptools/bdist_wheel -> distutils/install ->.. + # setuptools/bdist_egg -> distutils/install_lib -> build_py + # setuptools/install -> bdist_egg ->.. + # setuptools/develop -> ? + # pip install: + # copies source tree to a tempdir before running egg_info/etc + # if .git isn't copied too, 'git describe' will fail + # then does setup.py bdist_wheel, or sometimes setup.py install + # setup.py egg_info -> ? + + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + + # we override different "build_py" commands for both environments + if "build_py" in cmds: + _build_py = cmds["build_py"] + else: + from setuptools.command.build_py import build_py as _build_py + + class cmd_build_py(_build_py): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if cfg.versionfile_build: + target_versionfile = os.path.join( + self.build_lib, cfg.versionfile_build + ) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + cmds["build_py"] = cmd_build_py + + if "build_ext" in cmds: + _build_ext = cmds["build_ext"] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + target_versionfile = os.path.join( + self.build_lib, cfg.versionfile_build + ) + if not os.path.exists(target_versionfile): + print( + f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running " + "build_ext without first running build_py." + ) + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + cmds["build_ext"] = cmd_build_ext + + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe + + # nczeczulin reports that py2exe won't like the pep440-style string + # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. + # setup(console=[{ + # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION + # "product_version": versioneer.get_version(), + # ... + + class cmd_build_exe(_build_exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _build_exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + cmds["build_exe"] = cmd_build_exe + del cmds["build_py"] + + if "py2exe" in sys.modules: # py2exe enabled? + from py2exe.distutils_buildexe import py2exe as _py2exe + + class cmd_py2exe(_py2exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _py2exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + cmds["py2exe"] = cmd_py2exe + + # sdist farms its file list building out to egg_info + if "egg_info" in cmds: + _sdist = cmds["egg_info"] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self): + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append("versioneer.py") + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + + normalized = [ + unicode_utils.filesys_decode(f).replace(os.sep, "/") + for f in self.filelist.files + ] + + manifest_filename = os.path.join(self.egg_info, "SOURCES.txt") + with open(manifest_filename, "w") as fobj: + fobj.write("\n".join(normalized)) + + cmds["egg_info"] = cmd_egg_info + + # we override different "sdist" commands for both environments + if "sdist" in cmds: + _sdist = cmds["sdist"] + else: + from setuptools.command.sdist import sdist as _sdist + + class cmd_sdist(_sdist): + def run(self): + versions = get_versions() + self._versioneer_generated_versions = versions + # unless we update this, the command will keep using the old + # version + self.distribution.metadata.version = versions["version"] + return _sdist.run(self) + + def make_release_tree(self, base_dir, files): + root = get_root() + cfg = get_config_from_root(root) + _sdist.make_release_tree(self, base_dir, files) + # now locate _version.py in the new base_dir directory + # (remembering that it may be a hardlink) and replace it with an + # updated value + target_versionfile = os.path.join(base_dir, cfg.versionfile_source) + print("UPDATING %s" % target_versionfile) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + + cmds["sdist"] = cmd_sdist + + return cmds + + +CONFIG_ERROR = """ +setup.cfg is missing the necessary Versioneer configuration. You need +a section like: + + [versioneer] + VCS = git + style = pep440 + versionfile_source = src/myproject/_version.py + versionfile_build = myproject/_version.py + tag_prefix = + parentdir_prefix = myproject- + +You will also need to edit your setup.py to use the results: + + import versioneer + setup(version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), ...) + +Please read the docstring in ./versioneer.py for configuration instructions, +edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. +""" + +SAMPLE_CONFIG = """ +# See the docstring in versioneer.py for instructions. Note that you must +# re-run 'versioneer.py setup' after changing this section, and commit the +# resulting files. + +[versioneer] +#VCS = git +#style = pep440 +#versionfile_source = +#versionfile_build = +#tag_prefix = +#parentdir_prefix = + +""" + +OLD_SNIPPET = """ +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions +""" + +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" + + +def do_setup(): + """Do main VCS-independent setup function for installing Versioneer.""" + root = get_root() + try: + cfg = get_config_from_root(root) + except ( + OSError, + configparser.NoSectionError, + configparser.NoOptionError, + ) as e: + if isinstance(e, (OSError, configparser.NoSectionError)): + print( + "Adding sample versioneer config to setup.cfg", file=sys.stderr + ) + with open(os.path.join(root, "setup.cfg"), "a") as f: + f.write(SAMPLE_CONFIG) + print(CONFIG_ERROR, file=sys.stderr) + return 1 + + print(" creating %s" % cfg.versionfile_source) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + if os.path.exists(ipy): + try: + with open(ipy, "r") as f: + old = f.read() + except OSError: + old = "" + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: + print(" appending to %s" % ipy) + with open(ipy, "a") as f: + f.write(snippet) + else: + print(" %s unmodified" % ipy) + else: + print(" %s doesn't exist, ok" % ipy) + ipy = None + + # Make VCS-specific changes. For git, this means creating/changing + # .gitattributes to mark _version.py for export-subst keyword + # substitution. + do_vcs_install(cfg.versionfile_source, ipy) + return 0 + + +def scan_setup_py(): + """Validate the contents of setup.py against Versioneer's expectations.""" + found = set() + setters = False + errors = 0 + with open("setup.py", "r") as f: + for line in f.readlines(): + if "import versioneer" in line: + found.add("import") + if "versioneer.get_cmdclass()" in line: + found.add("cmdclass") + if "versioneer.get_version()" in line: + found.add("get_version") + if "versioneer.VCS" in line: + setters = True + if "versioneer.versionfile_source" in line: + setters = True + if len(found) != 3: + print("") + print("Your setup.py appears to be missing some important items") + print("(but I might be wrong). Please make sure it has something") + print("roughly like the following:") + print("") + print(" import versioneer") + print(" setup( version=versioneer.get_version(),") + print(" cmdclass=versioneer.get_cmdclass(), ...)") + print("") + errors += 1 + if setters: + print("You should remove lines like 'versioneer.VCS = ' and") + print("'versioneer.versionfile_source = ' . This configuration") + print("now lives in setup.cfg, and should be removed from setup.py") + print("") + errors += 1 + return errors + + +if __name__ == "__main__": + cmd = sys.argv[1] + if cmd == "setup": + errors = do_setup() + errors += scan_setup_py() + if errors: + sys.exit(1)