Skip to content

Commit

Permalink
remove other mentions of cumbinop -> scanop
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Zhao Luo authored and Andrew Zhao Luo committed Mar 24, 2021
1 parent 23d4325 commit 4137c15
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,11 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
}; // struct MatrixSetDiagAttrs

/*! \brief Attributes used in cumsum and cumprod operator */
struct CumbinopAttrs : public tvm::AttrsNode<CumbinopAttrs> {
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
Integer axis;
DataType dtype;
Bool exclusive = Bool(false);
TVM_DECLARE_ATTRS(CumbinopAttrs, "relay.attrs.CumbinopAttrs") {
TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to operate over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target):
"""cumsum cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumbinop(topi.cuda.cumsum),
wrap_compute_scanop(topi.cuda.cumsum),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="cumsum.cuda",
)
Expand All @@ -1030,7 +1030,7 @@ def cumprod_strategy_cuda(attrs, inputs, out_type, target):
"""cumprod cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumbinop(topi.cuda.cumprod),
wrap_compute_scanop(topi.cuda.cumprod),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="cumprod.cuda",
)
Expand Down
16 changes: 8 additions & 8 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3772,20 +3772,20 @@ RELAY_REGISTER_OP("adv_index")
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);

TVM_REGISTER_NODE_TYPE(CumbinopAttrs);
TVM_REGISTER_NODE_TYPE(ScanopAttrs);

bool CumbinopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
bool ScanopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, output]
ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "cumbinop: expect input type to be TensorType but get " << types[0];
<< "Scanop: expect input type to be TensorType but get " << types[0];
return false;
}

const auto* param = attrs.as<CumbinopAttrs>();
const auto* param = attrs.as<ScanopAttrs>();

auto dtype = param->dtype;
if (dtype.is_void()) {
Expand All @@ -3806,7 +3806,7 @@ bool CumbinopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) {
auto attrs = make_object<CumbinopAttrs>();
auto attrs = make_object<ScanopAttrs>();
attrs->dtype = dtype;
attrs->axis = axis;
attrs->exclusive = exclusive;
Expand All @@ -3822,11 +3822,11 @@ RELAY_REGISTER_OP("cumsum")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Cumsum", CumbinopRel)
.add_type_rel("Cumsum", ScanopRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque);

Expr MakeCumprod(Expr data, Integer axis, DataType dtype, Bool exclusive) {
auto attrs = make_object<CumbinopAttrs>();
auto attrs = make_object<ScanopAttrs>();
attrs->dtype = dtype;
attrs->axis = axis;
attrs->exclusive = exclusive;
Expand All @@ -3842,7 +3842,7 @@ RELAY_REGISTER_OP("cumprod")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Cumprod", CumbinopRel)
.add_type_rel("Cumprod", ScanopRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque);

TVM_REGISTER_NODE_TYPE(UniqueAttrs);
Expand Down

0 comments on commit 4137c15

Please sign in to comment.