-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
generic: deconvolution: move implementation from ocl folder
- Loading branch information
Showing
3 changed files
with
386 additions
and
293 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,381 @@ | ||
/******************************************************************************* | ||
* Copyright 2024 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#ifndef GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP | ||
#define GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP | ||
|
||
#include "common/c_types_map.hpp" | ||
#include "common/primitive.hpp" | ||
#include "common/type_helpers.hpp" | ||
#include "common/utils.hpp" | ||
#include "gpu/gpu_deconvolution_pd.hpp" | ||
#include "gpu/intel/compute/utils.hpp" | ||
#include "gpu/intel/gpu_primitive.hpp" | ||
#include "gpu/intel/primitive_conf.hpp" | ||
|
||
namespace dnnl { | ||
namespace impl { | ||
namespace gpu { | ||
namespace generic { | ||
|
||
static status_t weights_axes_permutation( | ||
memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { | ||
int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation | ||
for (int d = 0; d < DNNL_MAX_NDIMS; ++d) | ||
perm[d] = d; | ||
nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); | ||
|
||
return memory_desc_permute_axes(*o_md, *i_md, perm); | ||
} | ||
|
||
static status_t conv_descr_create( | ||
const deconvolution_desc_t *dd, convolution_desc_t *cd) { | ||
using namespace prop_kind; | ||
alg_kind_t alg_kind = alg_kind::convolution_direct; | ||
|
||
const memory_desc_t *src_md, *dst_md, *d_weights_d; | ||
prop_kind_t prop_kind; | ||
|
||
switch (dd->prop_kind) { | ||
case forward: | ||
case forward_inference: | ||
prop_kind = backward_data; | ||
src_md = &dd->dst_desc; | ||
dst_md = &dd->src_desc; | ||
d_weights_d = &dd->weights_desc; | ||
break; | ||
case backward_data: | ||
prop_kind = forward_training; | ||
src_md = &dd->diff_dst_desc; | ||
dst_md = &dd->diff_src_desc; | ||
d_weights_d = &dd->weights_desc; | ||
break; | ||
case backward_weights: | ||
prop_kind = dd->prop_kind; | ||
src_md = &dd->diff_dst_desc; | ||
dst_md = &dd->src_desc; | ||
d_weights_d = &dd->diff_weights_desc; | ||
break; | ||
default: assert(!"unknown prop kind"); return status::invalid_arguments; | ||
} | ||
|
||
// Create weights desc for convolution | ||
memory_desc_t c_weights_d; | ||
const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; | ||
CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups)); | ||
|
||
return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, | ||
prop_kind != backward_weights ? &dd->bias_desc : nullptr, dst_md, | ||
dd->strides, dd->dilates, dd->padding[0], dd->padding[1]); | ||
} | ||
|
||
struct convolution_deconvolution_fwd_t : public gpu::primitive_t { | ||
using gpu::primitive_t::primitive_t; | ||
struct pd_t : public gpu_deconvolution_fwd_pd_t { | ||
pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, | ||
const deconvolution_fwd_pd_t *hint_fwd_pd) | ||
: gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} | ||
|
||
pd_t(const pd_t &other) = default; | ||
|
||
~pd_t() = default; | ||
|
||
DECLARE_COMMON_PD_T(name_.c_str(), convolution_deconvolution_fwd_t); | ||
status_t init_convolution(impl::engine_t *engine) { | ||
convolution_desc_t cd; | ||
CHECK(conv_descr_create(desc(), &cd)); | ||
primitive_attr_t conv_attr(*attr()); | ||
if (!conv_attr.is_initialized()) return status::out_of_memory; | ||
primitive_desc_iterator_t it( | ||
engine, (op_desc_t *)&cd, &conv_attr, nullptr); | ||
if (!it.is_initialized()) return status::out_of_memory; | ||
conv_pd_ = *(++it); | ||
|
||
return (conv_pd_) ? status::success : status::unimplemented; | ||
} | ||
|
||
status_t init(impl::engine_t *engine) { | ||
using namespace format_tag; | ||
using sm = primitive_attr_t::skip_mask_t; | ||
|
||
const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime | ||
| sm::scales_runtime; | ||
|
||
VDISPATCH_DECONVOLUTION(is_fwd(), VERBOSE_BAD_PROPKIND); | ||
VDISPATCH_DECONVOLUTION( | ||
desc()->alg_kind == alg_kind::deconvolution_direct, | ||
VERBOSE_BAD_ALGORITHM); | ||
VDISPATCH_DECONVOLUTION(attr()->has_default_values(attr_skip_mask), | ||
VERBOSE_UNSUPPORTED_ATTR); | ||
VDISPATCH_DECONVOLUTION( | ||
intel::post_ops_with_binary_ok( | ||
attr(), desc()->dst_desc.data_type, ndims()), | ||
VERBOSE_UNSUPPORTED_POSTOP); | ||
VDISPATCH_DECONVOLUTION( | ||
(utils::everyone_is(data_type::f32, | ||
desc()->src_desc.data_type, | ||
desc()->weights_desc.data_type, | ||
desc()->dst_desc.data_type) | ||
|| (utils::everyone_is(data_type::f64, | ||
desc()->src_desc.data_type, | ||
desc()->weights_desc.data_type, | ||
desc()->dst_desc.data_type)) | ||
|| ((utils::everyone_is(data_type::f16, | ||
desc()->src_desc.data_type, | ||
desc()->weights_desc.data_type) | ||
|| utils::everyone_is(data_type::f32, | ||
desc()->src_desc.data_type, | ||
desc()->weights_desc.data_type) | ||
|| utils::everyone_is(data_type::bf16, | ||
desc()->src_desc.data_type, | ||
desc()->weights_desc.data_type)) | ||
&& utils::one_of(desc()->dst_desc.data_type, | ||
data_type::f16, data_type::u8, | ||
data_type::s8)) | ||
|| (utils::everyone_is(data_type::bf16, | ||
desc()->src_desc.data_type, | ||
desc()->weights_desc.data_type) | ||
&& utils::one_of(desc()->dst_desc.data_type, | ||
data_type::f32, data_type::bf16)) | ||
|| (utils::everyone_is(data_type::f16, | ||
desc()->src_desc.data_type, | ||
desc()->weights_desc.data_type) | ||
&& utils::one_of(desc()->dst_desc.data_type, | ||
data_type::f32, data_type::f16)) | ||
|| (desc()->weights_desc.data_type == data_type::s8 | ||
&& utils::one_of(desc()->src_desc.data_type, | ||
data_type::u8, data_type::s8) | ||
&& desc()->dst_desc.data_type | ||
!= data_type::f64)), | ||
VERBOSE_UNSUPPORTED_DT); | ||
|
||
VDISPATCH_DECONVOLUTION_SC( | ||
init_convolution(engine), "init_convolution()"); | ||
if (weights_md_.format_kind == format_kind::any) { | ||
VDISPATCH_DECONVOLUTION_SC( | ||
weights_axes_permutation(&weights_md_, | ||
conv_pd_->weights_md(), with_groups()), | ||
"weights_axes_permutation()"); | ||
} | ||
if (src_md_.format_kind == format_kind::any) | ||
src_md_ = *conv_pd_->diff_dst_md(); | ||
if (dst_md_.format_kind == format_kind::any) | ||
dst_md_ = *conv_pd_->diff_src_md(); | ||
if (bias_md_.format_kind == format_kind::any) { | ||
VDISPATCH_DECONVOLUTION_SC(memory_desc_init_by_tag(bias_md_, x), | ||
VERBOSE_UNSUPPORTED_TAG); | ||
} | ||
init_name(); | ||
init_scratchpad(); | ||
VDISPATCH_DECONVOLUTION_SC(attr_.set_default_formats(dst_md(0)), | ||
VERBOSE_UNSUPPORTED_ATTR); | ||
|
||
return status::success; | ||
} | ||
|
||
std::shared_ptr<primitive_desc_t> conv_pd_; | ||
|
||
private: | ||
std::string name_ = "conv:any"; | ||
|
||
void init_name() { | ||
name_.append("+"); | ||
name_.append(conv_pd_->name()); | ||
} | ||
|
||
void init_scratchpad() { | ||
auto scratchpad = scratchpad_registry().registrar(); | ||
scratchpad.book(memory_tracking::names::key_nested, | ||
conv_pd_->scratchpad_registry()); | ||
} | ||
}; | ||
|
||
status_t init(impl::engine_t *engine) override { | ||
return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); | ||
} | ||
|
||
status_t execute(const exec_ctx_t &ctx) const override { | ||
using namespace memory_tracking::names; | ||
const auto &args = ctx.args(); | ||
exec_args_t conv_args; | ||
conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); | ||
conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); | ||
conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST); | ||
if (pd()->with_bias()) | ||
conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); | ||
|
||
for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) { | ||
if (pd()->attr()->post_ops_.entry_[idx].is_binary()) { | ||
conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1] | ||
= args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | ||
| DNNL_ARG_SRC_1); | ||
} else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) { | ||
conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | ||
| DNNL_ARG_WEIGHTS] | ||
= args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | ||
| DNNL_ARG_WEIGHTS); | ||
} | ||
} | ||
const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC; | ||
const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST; | ||
if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src); | ||
if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst); | ||
|
||
for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { | ||
int key = DNNL_ARG_ATTR_SCALES | arg; | ||
if (args.find(key) != args.end()) conv_args[key] = args.at(key); | ||
} | ||
|
||
exec_ctx_t conv_ctx(ctx, std::move(conv_args)); | ||
|
||
nested_scratchpad_t ns(ctx, key_nested, conv_p_); | ||
conv_ctx.set_scratchpad_grantor(ns.grantor()); | ||
// Executing the convolution kernel | ||
return conv_p_->execute(conv_ctx); | ||
} | ||
|
||
private: | ||
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } | ||
std::shared_ptr<impl::primitive_t> conv_p_; | ||
}; | ||
|
||
struct convolution_deconvolution_bwd_data_t : public gpu::primitive_t { | ||
using gpu::primitive_t::primitive_t; | ||
struct pd_t : public gpu_deconvolution_bwd_data_pd_t { | ||
pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, | ||
const deconvolution_fwd_pd_t *hint_fwd_pd) | ||
: gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) | ||
, conv_pd_(nullptr) {} | ||
|
||
pd_t(const pd_t &other) = default; | ||
|
||
~pd_t() = default; | ||
|
||
DECLARE_COMMON_PD_T( | ||
name_.c_str(), convolution_deconvolution_bwd_data_t); | ||
|
||
status_t init_convolution(impl::engine_t *engine) { | ||
convolution_desc_t cd; | ||
CHECK(conv_descr_create(desc(), &cd)); | ||
primitive_attr_t conv_attr(*attr()); | ||
if (!conv_attr.is_initialized()) return status::out_of_memory; | ||
primitive_desc_iterator_t it( | ||
engine, (op_desc_t *)&cd, &conv_attr, nullptr); | ||
if (!it.is_initialized()) return status::out_of_memory; | ||
conv_pd_ = *(++it); | ||
return (conv_pd_) ? status::success : status::unimplemented; | ||
} | ||
|
||
status_t init(impl::engine_t *engine) { | ||
VDISPATCH_DECONVOLUTION( | ||
desc()->prop_kind == prop_kind::backward_data, | ||
VERBOSE_BAD_PROPKIND); | ||
|
||
VDISPATCH_DECONVOLUTION( | ||
(utils::everyone_is(data_type::f32, | ||
desc()->diff_src_desc.data_type, | ||
desc()->weights_desc.data_type, | ||
desc()->diff_dst_desc.data_type) | ||
|| (utils::everyone_is(data_type::f64, | ||
desc()->diff_src_desc.data_type, | ||
desc()->weights_desc.data_type, | ||
desc()->diff_dst_desc.data_type)) | ||
|| utils::everyone_is(data_type::f16, | ||
desc()->weights_desc.data_type, | ||
desc()->diff_dst_desc.data_type) | ||
|| utils::everyone_is(data_type::bf16, | ||
desc()->weights_desc.data_type, | ||
desc()->diff_dst_desc.data_type)), | ||
VERBOSE_UNSUPPORTED_DT); | ||
|
||
VDISPATCH_DECONVOLUTION( | ||
utils::one_of(desc()->diff_src_desc.data_type, | ||
data_type::bf16, data_type::f16, data_type::f32, | ||
data_type::f64), | ||
VERBOSE_UNSUPPORTED_DT); | ||
VDISPATCH_DECONVOLUTION( | ||
desc()->alg_kind == alg_kind::deconvolution_direct, | ||
VERBOSE_BAD_ALGORITHM); | ||
VDISPATCH_DECONVOLUTION( | ||
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); | ||
|
||
VDISPATCH_DECONVOLUTION_SC( | ||
init_convolution(engine), "init_convolution()"); | ||
if (weights_md_.format_kind == format_kind::any) | ||
VDISPATCH_DECONVOLUTION_SC( | ||
weights_axes_permutation(&weights_md_, | ||
conv_pd_->weights_md(), with_groups()), | ||
"weights_axes_permutation()"); | ||
if (diff_src_md_.format_kind == format_kind::any) | ||
diff_src_md_ = *conv_pd_->dst_md(); | ||
if (diff_dst_md_.format_kind == format_kind::any) | ||
diff_dst_md_ = *conv_pd_->src_md(); | ||
|
||
init_name(); | ||
init_scratchpad(); | ||
|
||
return status::success; | ||
} | ||
|
||
std::shared_ptr<primitive_desc_t> conv_pd_; | ||
|
||
private: | ||
std::string name_ = "conv:any"; | ||
|
||
void init_name() { | ||
name_.append("+"); | ||
name_.append(conv_pd_->name()); | ||
} | ||
|
||
void init_scratchpad() { | ||
auto scratchpad = scratchpad_registry().registrar(); | ||
scratchpad.book(memory_tracking::names::key_nested, | ||
conv_pd_->scratchpad_registry()); | ||
} | ||
}; | ||
|
||
status_t init(impl::engine_t *engine) override { | ||
return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); | ||
} | ||
|
||
status_t execute(const exec_ctx_t &ctx) const override { | ||
using namespace memory_tracking::names; | ||
const auto &args = ctx.args(); | ||
exec_args_t conv_args; | ||
conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); | ||
conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); | ||
conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); | ||
if (!types::is_zero_md(pd()->scratchpad_md())) | ||
conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD); | ||
exec_ctx_t conv_ctx(ctx, std::move(conv_args)); | ||
|
||
nested_scratchpad_t ns(ctx, key_nested, conv_p_); | ||
conv_ctx.set_scratchpad_grantor(ns.grantor()); | ||
// Executing the convolution kernel | ||
return conv_p_->execute(conv_ctx); | ||
} | ||
|
||
private: | ||
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } | ||
std::shared_ptr<impl::primitive_t> conv_p_; | ||
}; | ||
|
||
} // namespace generic | ||
} // namespace gpu | ||
} // namespace impl | ||
} // namespace dnnl | ||
|
||
#endif |
Oops, something went wrong.