Skip to content

Commit

Permalink
Extend the python_op_factory with a internal schema parameter
Browse files Browse the repository at this point in the history
Refactor PythonFunctionBase class into a base class generator.
Adjust TFRecord to use internal schema correctly.

Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki committed Feb 27, 2024
1 parent ef5d304 commit bb65224
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 68 deletions.
28 changes: 24 additions & 4 deletions dali/python/nvidia/dali/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,30 @@ def _check_arg_input(schema, op_name, name):
)


def python_op_factory(name, schema_name=None):
def python_op_factory(name, schema_name, internal_schema_name=None):
"""Generate the ops API class bindings for operator.
Parameters
----------
name : str
The name of the operator (without the module) - this will be the name of the class
schema_name : str
Name of the schema, used for documentation lookups and schema/spec retrieval unless
internal_schema_name is provided
internal_schema_name : str, optional
If provided, this will be the schema used to process the arguments, by default None
Returns
-------
_type_
_description_
"""
class Operator(metaclass=_DaliOperatorMeta):
def __init__(self, *, device="cpu", **kwargs):
schema_name = _schema_name(type(self))
if self._internal_schema_name is None:
schema_name = _schema_name(type(self))
else:
schema_name = self._internal_schema_name
self._spec = _b.OpSpec(schema_name)
self._schema = _b.GetSchema(schema_name)

Expand Down Expand Up @@ -581,7 +601,8 @@ def __call__(self, *inputs, **kwargs):
return result

Operator.__name__ = str(name)
Operator.schema_name = schema_name or Operator.__name__
Operator.schema_name = schema_name
Operator._internal_schema_name = internal_schema_name
Operator._generated = True # The class was generated using python_op_factory
Operator.__call__.__doc__ = _docs._docstring_generator_call(Operator.schema_name)
return Operator
Expand Down Expand Up @@ -733,7 +754,6 @@ def _promote_scalar_constant(value, input_device):

# Expose the PythonFunction family of classes and generate the fn bindings for them
from nvidia.dali.ops._operators.python_function import ( # noqa: E402, F401
PythonFunctionBase, # noqa: F401
PythonFunction,
DLTensorPythonFunction,
_dlpack_to_array, # noqa: F401
Expand Down
83 changes: 28 additions & 55 deletions dali/python/nvidia/dali/ops/_operators/python_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,63 +31,38 @@ def _setup_cupy():
import cupy as cupy


class PythonFunctionBase(metaclass=ops._DaliOperatorMeta):
def __init__(self, impl_name, function, num_outputs=1, device="cpu", **kwargs):
self._schema = _b.GetSchema(impl_name)
self._spec = _b.OpSpec(impl_name)
self._device = device
self._impl_name = impl_name
def _get_base_impl(name, impl_name):

self._init_args, self._call_args = ops._separate_kwargs(kwargs)
self._name = self._init_args.pop("name", None)
class PythonFunctionBase(ops.python_op_factory(impl_name, name, impl_name)):

for key, value in self._init_args.items():
self._spec.AddArg(key, value)
def __init__(self, function, num_outputs=1, **kwargs):

self.function = function
self.num_outputs = num_outputs
self._preserve = True
super().__init__(**kwargs)

@property
def spec(self):
return self._spec
self.function = function
self.num_outputs = num_outputs
self._preserve = True

@property
def schema(self):
return self._schema
def __call__(self, *inputs, **kwargs):
inputs = ops._preprocess_inputs(inputs, impl_name, self._device, None)
self.pipeline = _Pipeline.current()
if self.pipeline is None:
_Pipeline._raise_pipeline_required("PythonFunction operator")

@property
def device(self):
return self._device
for inp in inputs:
if not isinstance(inp, _DataNode):
raise TypeError(
f"Expected inputs of type `DataNode`. "
f"Received input of type '{type(inp).__name__}'. "
f"Python Operators do not support Multiple Input Sets."
)

@property
def preserve(self):
return self._preserve
kwargs.update({"function_id": id(self.function), "num_outputs": self.num_outputs})

def __call__(self, *inputs, **kwargs):
inputs = ops._preprocess_inputs(inputs, self._impl_name, self._device, None)
self.pipeline = _Pipeline.current()
if self.pipeline is None:
_Pipeline._raise_pipeline_required("PythonFunction operator")
return super().__call__(*inputs, **kwargs)

for inp in inputs:
if not isinstance(inp, _DataNode):
raise TypeError(
f"Expected inputs of type `DataNode`. "
f"Received input of type '{type(inp).__name__}'. "
f"Python Operators do not support Multiple Input Sets."
)

args, arg_inputs = ops._separate_kwargs(kwargs)
args.update({"function_id": id(self.function), "num_outputs": self.num_outputs})

args = ops._resolve_double_definitions(args, self._init_args, keep_old=False)
if self._name is not None:
args = ops._resolve_double_definitions(args, {"name": self._name}) # restore the name

op_instance = ops._OperatorInstance(inputs, arg_inputs, args, self._init_args, self)
op_instance.spec.AddArg("device", self.device)
return op_instance.unwrapped_outputs
PythonFunctionBase._generated = False
return PythonFunctionBase


def _dlpack_to_array(dlpack):
Expand All @@ -98,8 +73,7 @@ def _dlpack_from_array(array):
return nvidia.dali.python_function_plugin.ArrayToDLTensor(array)


class PythonFunction(PythonFunctionBase):
schema_name = "PythonFunction"
class PythonFunction(_get_base_impl("PythonFunction", "DLTensorPythonFunctionImpl")):
_registry.register_cpu_op("PythonFunction")
_registry.register_gpu_op("PythonFunction")

Expand Down Expand Up @@ -227,7 +201,6 @@ def func(*ts):
return self._function_wrapper_gpu(batch_processing, function, num_outputs, *ts)

super(PythonFunction, self).__init__(
impl_name="DLTensorPythonFunctionImpl",
function=func,
num_outputs=num_outputs,
device=device,
Expand All @@ -237,8 +210,9 @@ def func(*ts):
)


class DLTensorPythonFunction(PythonFunctionBase):
schema_name = "DLTensorPythonFunction"
class DLTensorPythonFunction(
_get_base_impl("DLTensorPythonFunction", "DLTensorPythonFunctionImpl")
):
_registry.register_cpu_op("DLTensorPythonFunction")
_registry.register_gpu_op("DLTensorPythonFunction")

Expand All @@ -265,7 +239,6 @@ def func(*ts):
return self._function_wrapper_dlpack(batch_processing, function, num_outputs, *ts)

super(DLTensorPythonFunction, self).__init__(
impl_name="DLTensorPythonFunctionImpl",
function=func,
num_outputs=num_outputs,
device=device,
Expand Down
8 changes: 4 additions & 4 deletions dali/python/nvidia/dali/ops/_operators/tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def tfrecord_enabled():
return False


def _get_impl(name, schema_name):
class _TFRecordReaderImpl(ops.python_op_factory(name, schema_name)):
def _get_impl(name, schema_name, internal_schema_name):
class _TFRecordReaderImpl(ops.python_op_factory(name, schema_name, internal_schema_name)):
"""custom wrappers around ops"""

def __init__(self, path, index_path, features, **kwargs):
Expand Down Expand Up @@ -80,9 +80,9 @@ def __call__(self, *inputs, **kwargs):
return _TFRecordReaderImpl


class TFRecordReader(_get_impl("_TFRecordReader", "_TFRecordReader")):
class TFRecordReader(_get_impl("_TFRecordReader", "TFRecordReader", "_TFRecordReader")):
pass


class TFRecord(_get_impl("_TFRecord", "readers___TFRecord")):
class TFRecord(_get_impl("_TFRecord", "readers__TFRecord", "readers___TFRecord")):
pass
9 changes: 5 additions & 4 deletions dali/python/nvidia/dali/plugin/pytorch/_torch_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,10 +17,12 @@

from nvidia.dali import ops
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.ops._operators import python_function


class TorchPythonFunction(ops.PythonFunctionBase):
schema_name = "TorchPythonFunction"
class TorchPythonFunction(
python_function._get_base_impl("TorchPythonFunction", "DLTensorPythonFunctionImpl")
):
ops.register_cpu_op("TorchPythonFunction")
ops.register_gpu_op("TorchPythonFunction")

Expand Down Expand Up @@ -64,7 +66,6 @@ def __call__(self, *inputs, **kwargs):
def __init__(self, function, num_outputs=1, device="cpu", batch_processing=False, **kwargs):
self.stream = None
super(TorchPythonFunction, self).__init__(
impl_name="DLTensorPythonFunctionImpl",
function=lambda *ins: self.torch_wrapper(batch_processing, function, device, *ins),
num_outputs=num_outputs,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/reader/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def get_dali_pipeline(tfrec_filenames, tfrec_idx_filenames, shard_id, num_gpus):

@raises(
TypeError,
glob="Expected `nvidia.dali.tfrecord.Feature` for the image/encoded, "
glob='Expected `nvidia.dali.tfrecord.Feature` for the "image/encoded", '
"but got <class 'int'>. Use `nvidia.dali.tfrecord.FixedLenFeature` "
"or `nvidia.dali.tfrecord.VarLenFeature` to define the features to extract.",
)
Expand Down

0 comments on commit bb65224

Please sign in to comment.