Skip to content

Commit

Permalink
Make Workspace::Input return const reference (NVIDIA#3452)
Browse files Browse the repository at this point in the history
Input (previously InputRef) was returning
mutable reference to data by mistake.

Fix the constness of the returned reference,
adjust misuses of the mutable accessors.

Introduce UnsafeMutableInput for access
in the executor and related utilities,
where inputs needs to be adjusted additionally. 

Add docs.

Adjust ShareData to accept const ref,
which is needed after this change.

Use UnsafeMutableInput for DLPack function
implementation.

Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki authored and cyyever committed Jan 23, 2022
1 parent 417e84d commit c6ccf70
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 34 deletions.
2 changes: 1 addition & 1 deletion dali/operators/decoder/audio/audio_decoder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ AudioDecoderCpu::SetupImpl(std::vector<OutputDesc> &output_desc, const workspace

for (int i = 0; i < batch_size; i++) {
auto &meta = sample_meta_[i] =
decoders_[i]->Open({reinterpret_cast<const char *>(input[i].raw_mutable_data()),
decoders_[i]->Open({static_cast<const char *>(input[i].raw_data()),
input[i].shape().num_elements()});
TensorShape<> data_sample_shape = DecodedAudioShape(
meta, use_resampling_ ? target_sample_rates_[i] : -1.0f, downmix_);
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/generic/join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void TensorJoin<Backend, new_axis>::SetupTyped(

copy_idx_ = 0;
for (int i = 0; i < ninp; i++) {
auto tlv = view<T>(ws.template Input<Backend>(i));
auto tlv = view<const T>(ws.template Input<Backend>(i));
if (new_axis || tlv.num_elements() > 0) { // when concatenating, we can skip empty inputs
if (inputs.empty())
copy_idx_ = i;
Expand All @@ -109,7 +109,7 @@ void TensorJoin<Backend, new_axis>::SetupTyped(

// No non-empty inputs? Use the first one, even if it's empty.
if (inputs.empty()) {
inputs.push_back(view<T>(ws.template Input<Backend>(0)));
inputs.push_back(view<const T>(ws.template Input<Backend>(0)));
}

kernels::tensor_join::JoinedShape(output_shape, [&](int index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class CombineTransformsCPU : public Operator<CPUBackend> {
in_views.reserve(ws.NumInput());
for (int input_idx = 0; input_idx < ws.NumInput(); input_idx++) {
auto &in = ws.template Input<CPUBackend>(input_idx);
in_views.push_back(view<T, 2>(in));
in_views.push_back(view<const T, 2>(in));
}
auto out_view = view<T, 2>(out);
auto read_mat = [](affine_mat_t<T, mat_dim> &next_mat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class TransformBaseOp : public Operator<Backend> {
auto out_view = view<T>(out);
if (has_input_) {
auto &in = ws.template Input<Backend>(0);
auto in_view = view<T>(in);
auto in_view = view<const T>(in);
for (int i = 0; i < nsamples_; i++) {
int mat_idx = num_mats == 1 ? 0 : i;
ApplyTransform(out_view[i].data, in_view[i].data, matrices[mat_idx]);
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/numba_function/numba_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void NumbaFuncImpl<CPUBackend>::RunImpl(workspace_t<CPUBackend> &ws) {
for (size_t in_id = 0; in_id < in_types_.size(); in_id++) {
auto& in = ws.Input<CPUBackend>(in_id);
for (int i = 0; i < N; i++) {
in_ptrs[N * in_id + i] = reinterpret_cast<uint64_t>(in[i].raw_mutable_data());
in_ptrs[N * in_id + i] = reinterpret_cast<uint64_t>(in[i].raw_data());
}
}

Expand Down
12 changes: 6 additions & 6 deletions dali/operators/python_function/dltensor_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ py::list PrepareDLTensorInputs<CPUBackend>(HostWorkspace &ws) {
for (Index idx = 0; idx < ws.NumInput(); ++idx) {
py::list dl_tensor_list;
for (Index i = 0; i < ws.GetInputBatchSize(idx); ++i) {
auto &t = ws.Input<CPUBackend>(idx)[i];
auto dl_capsule = TensorToDLPackView(const_cast<Tensor<CPUBackend>&>(t));
auto &t = ws.UnsafeMutableInput<CPUBackend>(idx)[i];
auto dl_capsule = TensorToDLPackView(t);
dl_tensor_list.append(dl_capsule);
}
input_tuple.append(dl_tensor_list);
Expand All @@ -91,7 +91,7 @@ template <>
py::list PrepareDLTensorInputs<GPUBackend>(DeviceWorkspace &ws) {
py::list input_tuple;
for (Index idx = 0; idx < ws.NumInput(); ++idx) {
auto &tlist = ws.Input<GPUBackend>(idx);
auto &tlist = ws.UnsafeMutableInput<GPUBackend>(idx);
py::list dl_tensor_list = TensorListToDLPackView(tlist);
input_tuple.append(dl_tensor_list);
}
Expand All @@ -106,8 +106,8 @@ py::list PrepareDLTensorInputsPerSample<CPUBackend>(HostWorkspace &ws) {
for (Index s = 0; s < batch_size; ++s) {
py::list tuple;
for (Index idx = 0; idx < ws.NumInput(); ++idx) {
auto &t = ws.Input<CPUBackend>(idx)[s];
auto dl_capsule = TensorToDLPackView(const_cast<Tensor<CPUBackend>&>(t));
auto &t = ws.UnsafeMutableInput<CPUBackend>(idx)[s];
auto dl_capsule = TensorToDLPackView(t);
tuple.append(dl_capsule);
}
input_tuples.append(tuple);
Expand All @@ -122,7 +122,7 @@ py::list PrepareDLTensorInputsPerSample<GPUBackend>(DeviceWorkspace &ws) {
Index batch_size = ws.Input<GPUBackend>(0).num_samples();
input_tuples.resize(batch_size);
for (Index idx = 0; idx < ws.NumInput(); ++idx) {
py::list dl_tensor_list = TensorListToDLPackView(ws.Input<GPUBackend>(idx));
py::list dl_tensor_list = TensorListToDLPackView(ws.UnsafeMutableInput<GPUBackend>(idx));
for (Index s = 0; s < batch_size; ++s) {
input_tuples[s].append(dl_tensor_list[s]);
}
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/data/tensor_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class DLL_PUBLIC TensorList : private Buffer<Backend> {
* shared data or the call will fail.
* Size can be set to 0 and type to NoType as intermediate step.
*/
DLL_PUBLIC inline void ShareData(TensorList<Backend> &other) {
DLL_PUBLIC inline void ShareData(const TensorList<Backend> &other) {
DALI_ENFORCE(IsValidType(other.type_), "To share data, "
"the input TensorList must have a valid data type");

Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/data/tensor_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ void TensorVector<Backend>::Copy(const TensorVector<SrcBackend> &in_tv, cudaStre


template <typename Backend>
void TensorVector<Backend>::ShareData(TensorList<Backend> &in_tl) {
void TensorVector<Backend>::ShareData(const TensorList<Backend> &in_tl) {
SetContiguous(true);
type_ = in_tl.type_info();
pinned_ = in_tl.is_pinned();
Expand All @@ -331,7 +331,7 @@ void TensorVector<Backend>::ShareData(TensorList<Backend> &in_tl) {
}

template <typename Backend>
void TensorVector<Backend>::ShareData(TensorVector<Backend> &tv) {
void TensorVector<Backend>::ShareData(const TensorVector<Backend> &tv) {
type_ = tv.type_;
state_ = tv.state_;
pinned_ = tv.is_pinned();
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/data/tensor_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ class DLL_PUBLIC TensorVector {
template <typename SrcBackend>
void Copy(const TensorVector<SrcBackend> &in_tv, cudaStream_t stream);

void ShareData(TensorList<Backend> &in_tl);
void ShareData(const TensorList<Backend> &in_tl);

void ShareData(TensorVector<Backend> &tv);
void ShareData(const TensorVector<Backend> &tv);

TensorVector<Backend> &operator=(TensorVector<Backend> &&other) noexcept;

Expand Down
10 changes: 6 additions & 4 deletions dali/pipeline/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,11 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunHelper(OpNode &op_node, Workspac
for (int i = 0; i < spec.NumRegularInput(); i++) {
bool had_empty_layout = false;
if (ws.template InputIsType<CPUBackend>(i)) {
had_empty_layout = SetDefaultLayoutIfNeeded(ws.template Input<CPUBackend>(i), schema, i);
had_empty_layout =
SetDefaultLayoutIfNeeded(ws.template UnsafeMutableInput<CPUBackend>(i), schema, i);
} else {
had_empty_layout = SetDefaultLayoutIfNeeded(ws.template Input<GPUBackend>(i), schema, i);
had_empty_layout =
SetDefaultLayoutIfNeeded(ws.template UnsafeMutableInput<GPUBackend>(i), schema, i);
}
if (had_empty_layout) empty_layout_in_idxs.push_back(i);
}
Expand Down Expand Up @@ -334,10 +336,10 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunHelper(OpNode &op_node, Workspac

for (int i : empty_layout_in_idxs) {
if (ws.template InputIsType<CPUBackend>(i)) {
auto &in = ws.template Input<CPUBackend>(i);
auto &in = ws.template UnsafeMutableInput<CPUBackend>(i);
in.SetLayout({});
} else {
auto &in = ws.template Input<GPUBackend>(i);
auto &in = ws.template UnsafeMutableInput<GPUBackend>(i);
in.SetLayout({});
}
}
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/workspace/sample_workspace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ void MakeSampleView(SampleWorkspace& sample, HostWorkspace& batch, int data_idx,
int num_inputs = batch.NumInput();
for (int i = 0; i < num_inputs; i++) {
if (batch.InputIsType<CPUBackend>(i)) {
auto &input_ref = batch.Input<CPUBackend>(i);
auto &input_ref = batch.UnsafeMutableInput<CPUBackend>(i);
sample.AddInput(&input_ref[data_idx]);
} else {
auto &input_ref = batch.Input<GPUBackend>(i);
auto &input_ref = batch.UnsafeMutableInput<GPUBackend>(i);
sample.AddInput(&input_ref[data_idx]);
}
}
Expand Down
69 changes: 58 additions & 11 deletions dali/pipeline/workspace/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,32 @@ class WorkspaceBase : public ArgumentWorkspace {
gpu_outputs_index_.clear();
}

/** @defgroup InputOutput Input and output APIs
* Functions used to access inputs and outputs of the operator in its implementation.
* The inputs are read-only while outputs can be modified.
* @{
*/

/**
* @brief Returns the const reference to the input batch at the position `idx`.
*
* The operator implementation can use this function to access its inputs.
*/
template <typename Backend>
auto& Input(int idx) const {
const auto& Input(int idx) const {
return *InputHandle(idx, Backend{});
}

/**
* @brief Returns the mutable reference to the output batch at the position `idx`.
*
* The operator implementation can use this function to access its outputs.
*/
template <typename Backend>
auto& Output(int idx) const {
return *OutputHandle(idx, Backend{});
}

template <typename Backend>
const InputType<Backend>& InputPtr(int idx) const {
return InputHandle(idx, Backend{});
}

template <typename Backend>
const OutputType<Backend>& OutputPtr(int idx) const {
return OutputHandle(idx, Backend{});
}

/**
* @brief Returns the number of inputs.
*/
Expand All @@ -175,6 +181,47 @@ class WorkspaceBase : public ArgumentWorkspace {
return output_index_map_.size();
}


/** @} */ // end of InputOutput

/** @defgroup InputOutputInternal Internal API for input and output access
* Functions allowing mutable access to both inputs and outputs that should not be used in
* operator implementation.
* @{
*/

/**
* @brief Returns the mutable reference to the input batch at the position `idx`.
*
* Intended only for executor and other internal APIs.
*/
template <typename Backend>
auto& UnsafeMutableInput(int idx) const {
return *InputHandle(idx, Backend{});
}

/**
* @brief Returns the underlying handle to the input batch at the position `idx`.
*
* Intended only for executor and other internal APIs.
*/
template <typename Backend>
const InputType<Backend>& InputPtr(int idx) const {
return InputHandle(idx, Backend{});
}

/**
* @brief Returns the underlying handle to the output batch at the position `idx`.
*
* Intended only for executor and other internal APIs.
*/
template <typename Backend>
const OutputType<Backend>& OutputPtr(int idx) const {
return OutputHandle(idx, Backend{});
}

/** @} */ // end of InputOutputInternal

/**
* Returns shape of input at given index
* @return TensorShape<> for SampleWorkspace, TensorListShape<> for other Workspaces
Expand Down

0 comments on commit c6ccf70

Please sign in to comment.