Skip to content

Commit

Permalink
Fix and removing unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
Raymond Yang committed Feb 26, 2019
1 parent 9a69250 commit 18241bc
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 73 deletions.
119 changes: 53 additions & 66 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,59 @@ void check_attribute(
}
}

void check_node(
const NodeProto& node,
const CheckerContext& ctx,
const LexicalScopeContext& lex_ctx) {
enforce_non_empty_field(node, op_type);

if (node.input().empty() && node.output().empty()) {
fail_check(
"NodeProto (name: ",
node.name(),
", type: ",
node.op_type(),
") has zero input and zero output.");
}

// Put the removed experimental ops here
if (node.op_type() == "ConstantFill") {
std::cerr << "Warning: " << node.op_type() << " was a removed "
<< " experimental ops. In the future, we may directly "
<< "reject this operator. Please update your model as soon "
<< "as possible.";
return;
}

// Resolve domain for node
const auto& opset_imports = ctx.get_opset_imports();
auto dit = opset_imports.find(node.domain());
if (dit == opset_imports.end()) {
fail_check("No opset import for domain '" + node.domain() + "'");
}
auto domain_version = dit->second;

for (const auto& attr : node.attribute()) {
check_attribute(attr, ctx, lex_ctx);
}

const auto* schema = ctx.get_schema_registry()->GetSchema(
node.op_type(), domain_version, node.domain());
if (!schema) {
fail_check(
"No Op registered for " + node.op_type() +
" with domain_version of " +
ONNX_NAMESPACE::to_string(domain_version));
} else if (schema->Deprecated()) {
fail_check(
"Op registered for " + node.op_type() + " is depracted in domain_version of " +
ONNX_NAMESPACE::to_string(domain_version));
} else {
schema->Verify(node);
}
}


void check_function(
const FunctionProto& function,
const CheckerContext& ctx,
Expand Down Expand Up @@ -379,58 +432,6 @@ void check_function(
}
}

void check_node(
const NodeProto& node,
const CheckerContext& ctx,
const LexicalScopeContext& lex_ctx) {
enforce_non_empty_field(node, op_type);

if (node.input().empty() && node.output().empty()) {
fail_check(
"NodeProto (name: ",
node.name(),
", type: ",
node.op_type(),
") has zero input and zero output.");
}

// Put the removed experimental ops here
if (node.op_type() == "ConstantFill") {
std::cerr << "Warning: " << node.op_type() << " was a removed "
<< " experimental ops. In the future, we may directly "
<< "reject this operator. Please update your model as soon "
<< "as possible.";
return;
}

// Resolve domain for node
const auto& opset_imports = ctx.get_opset_imports();
auto dit = opset_imports.find(node.domain());
if (dit == opset_imports.end()) {
fail_check("No opset import for domain '" + node.domain() + "'");
}
auto domain_version = dit->second;

for (const auto& attr : node.attribute()) {
check_attribute(attr, ctx, lex_ctx);
}

const auto* schema = ctx.get_schema_registry()->GetSchema(
node.op_type(), domain_version, node.domain());
if (!schema) {
fail_check(
"No Op registered for " + node.op_type() +
" with domain_version of " +
ONNX_NAMESPACE::to_string(domain_version));
} else if (schema->Deprecated()) {
fail_check(
"Op registered for " + node.op_type() + " is depracted in domain_version of " +
ONNX_NAMESPACE::to_string(domain_version));
} else {
schema->Verify(node);
}
}

void check_graph(
const GraphProto& graph,
const CheckerContext& ctx,
Expand Down Expand Up @@ -587,20 +588,6 @@ void check_model(const ModelProto& model) {
check_model(model, ctx);
}

void VerifyFunctionNode(
const NodeProto& node,
const FunctionProto& func,
const CheckerContext& ctx,
const LexicalScopeContext& lex_ctx) {
// Create a temporary graphproto to hold the expanded subgraph
GraphProto g;
g.set_name("func_" + func.name() + "_expanded_subgraph");
// To Generate unique internal tensor names
// while preserving node's input/output names
FunctionExpandHelper(node, func, g);
check_graph(g, ctx, lex_ctx);
}

#undef fail_check
#undef enforce_has_field
#undef enforce_has_repeated_field
Expand Down
5 changes: 0 additions & 5 deletions onnx/checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ void check_function(

void check_model(const ModelProto& model);
void check_model(const std::string& model_path);
void VerifyFunctionNode(
const NodeProto& node,
const FunctionProto& func,
const CheckerContext& ctx,
const LexicalScopeContext& lex_ctx);

} // namespace checker
} // namespace ONNX_NAMESPACE
4 changes: 2 additions & 2 deletions onnx/test/cpp/function_verify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ void VerifyTypeConstraint(
const FunctionProto* function_proto) {
// TC for function nodes should satisfy the definition defined in the opschema
// This is designed to be a best-effort test
// TODO: Revisit to have a more consummate check on it
TENSOR_TYPES_MAP tc_map;
std::set<std::string> primitive_types(
OpSchema::all_numeric_types().begin(),
OpSchema::all_numeric_types().end());
OpSchema::all_tensor_types().begin(), OpSchema::all_tensor_types().end());
for (const auto& input : function_op.inputs()) {
std::string name = input.GetName();
for (const auto& t : input.GetTypes()) {
Expand Down

0 comments on commit 18241bc

Please sign in to comment.