Skip to content

Commit

Permalink
properly handle variance op, which has two inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Sep 25, 2021
1 parent 3fc6ed1 commit 194c612
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,26 +175,38 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,

std::string new_layout_string;
Array<Integer> new_r_axes;
Array<Layout> new_input_layouts;

auto check_num_input_layouts = [](Array<Layout> in_layouts) {
// The second case is for variance op
ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
};

if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);
check_num_input_layouts(new_in_layouts);
check_num_input_layouts(old_in_layouts);

// Get inferred_in and inferred_out from new_in_layout.
std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
params->axis = new_r_axes;
} else if (old_in_layouts.defined()) {
ICHECK_EQ(old_in_layouts.size(), 1);
check_num_input_layouts(old_in_layouts);

// If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
if (old_in_layouts[0].defined()) {
std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
}
}

return InferCorrectLayoutOutput({inferred_in}, {inferred_out}, Attrs(params));
new_input_layouts.push_back(inferred_in);

if (old_in_layouts.size() == 2) {
new_input_layouts.push_back(inferred_in);
}

return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params));
}

template <typename F>
Expand Down Expand Up @@ -686,7 +698,7 @@ RELAY_REGISTER_OP("variance")
.add_argument("mean", "Tensor", "The mean tensor.")
.add_type_rel("Variance", VarianceRel)
.set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<VarianceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

} // namespace relay
Expand Down

0 comments on commit 194c612

Please sign in to comment.