Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How To Update Graphsurgeon Converter #15

Open
lowspec-1997 opened this issue Dec 9, 2019 · 2 comments
Open

How To Update Graphsurgeon Converter #15

lowspec-1997 opened this issue Dec 9, 2019 · 2 comments

Comments

@lowspec-1997
Copy link

lowspec-1997 commented Dec 9, 2019

how do I update graphsurgeon converter I am confused to put the code that has been provided

diff --git a/node_manipulation.py b/node_manipulation.py
index d2d012a..1ef30a0 100644
--- a/node_manipulation.py
+++ b/node_manipulation.py
@@ -30,6 +30,7 @@ def create_node(name, op=None, _do_suffix=False, **kwargs):
node = NodeDef()
node.name = name
node.op = op if op else name

  • node.attr["dtype"].type = 1
    for key, val in kwargs.items():
    if key == "dtype":
    node.attr["dtype"].type = val.as_datatype_enu
@lowspec-1997 lowspec-1997 changed the title HOW TO UPDATE GRAPHSURGEON CONVERTER How To Update Graphsurgeon Converter Dec 9, 2019
@ZouYunzhe
Copy link

ZouYunzhe commented Dec 26, 2019

Just add these lines directly to the create_node function.

def create_node(name, op=None, trt_plugin=False, **kwargs):
if not trt_plugin:
print("WARNING: To create TensorRT plugin nodes, please use the create_plugin_node function instead.")
node = tf.NodeDef()
node.name = name
node.op = op if op else name
node.attr["dtype"].type = 1
for key, val in kwargs.items():
if key == "dtype":
node.attr["dtype"].type = val.as_datatype_enum
return update_node(node, name, op, trt_plugin, **kwargs)

@omisha-bajoria
Copy link

Hello @ZouYunzhe

Just add these lines directly to the create_node function.

def create_node(name, op=None, trt_plugin=False, **kwargs): if not trt_plugin: print("WARNING: To create TensorRT plugin nodes, please use the create_plugin_node function instead.") node = tf.NodeDef() node.name = name node.op = op if op else name node.attr["dtype"].type = 1 for key, val in kwargs.items(): if key == "dtype": node.attr["dtype"].type = val.as_datatype_enum return update_node(node, name, op, trt_plugin, **kwargs)

When we update the node_manipulation.py file, where do we add these lines:
diff --git a/node_manipulation.py b/node_manipulation.py
index d2d012a..1ef30a0 100644
--- a/node_manipulation.py
+++ b/node_manipulation.py

We add the rest in the create_node function, however adding the above lines at the top of the document results in errors while importing graphsurgeoun to our notebook, and only adding the lines to the create_node function results in the TypeError: Cannot convert value 0 to a TensorFlow DType.

It would be a huge help if you could share what your node_manipulation.py file looked like after you're done with the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants