forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
python_ir.cpp
1151 lines (1116 loc) · 40.4 KB
/
python_ir.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <torch/csrc/jit/python/python_ir.h>
#include <ATen/core/jit_type.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/api/include/torch/python.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/python/pybind.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/python_print.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_strings.h>
#include <iostream>
#include <sstream>
#include <utility>
namespace torch::jit {
// Controls whether graph source ranges are printed by default
bool global_print_source_ranges = true;
Symbol ConcretePythonOp::Kind = prim::PythonOp;
using c10::Type;
std::string getPythonName(const PyObject* obj_) {
pybind11::gil_scoped_acquire gil;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
PyObject* obj = const_cast<PyObject*>(obj_);
auto v = py::getattr(obj, "__name__", py::str("<python_value>"));
// if this was a autograd.Function recover the name of the class
return py::str(v);
}
std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
pybind11::gil_scoped_acquire gil;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
if (py::isinstance<py::tuple>(pyobj)) {
// This special-case for printing tuples handles a problem where
// str((2L, 3L)) outputs "(2L, 3L)" in Python 2 but "(2, 3)"
// in Python 3. In order to suppress the L-suffix, we must
// manually print the string ourselves, calling str() on the
// sub-elements.
//
// This is a fairly fragile fix (What if you have nested tuples
// in tuples? What if you have dictionaries?) but it seems to hit
// the cases that are triggered in practice in onnx-pytorch. Revisit
// this code if this is not the case.
//
// By the way, one non-solution for this problem is to monkeypatch
// tuple.__str__; this doesn't work because Python doesn't allow
// monkeypatching methods of built-in types.
auto pytuple = pyobj.cast<py::tuple>();
out << "(";
size_t i = 0;
for (const auto& o : pytuple) {
if (i > 0) {
out << ", ";
}
THPObjectPtr str(py::str(o).release().ptr());
out << THPUtils_unpackString(str.get());
i++;
}
if (i == 1) {
out << ",";
}
out << ")";
return out;
} else {
return out << THPUtils_unpackString(py::str(pyobj).ptr());
}
}
Node* findNode(
c10::ArrayRef<torch::jit::Block*> blocks,
Symbol kind,
bool recurse = true) {
for (Block* block : blocks) {
for (Node* n : block->nodes()) {
if (n->kind() == kind) {
return n;
}
if (recurse) {
auto node = findNode(n->blocks(), kind, recurse);
if (node != nullptr) {
return node;
}
}
}
}
return nullptr;
}
Node* findNode(Block* block, Symbol kind, bool recurse = true) {
std::vector<Block*> blocks = {block};
return findNode(blocks, kind, recurse);
}
std::string ConcretePythonOp::name() const {
pybind11::gil_scoped_acquire gil;
if (auto autograd = autogradFunction()) {
return getPythonName(autograd->get());
} else {
return getPythonName(pyobj.get());
}
}
void ConcretePythonOp::cloneFrom(Node* other_) {
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
Node::cloneFrom(other_);
auto other = other_->cast<ConcretePythonOp>();
this->cconv = other->cconv;
Py_INCREF(other->pyobj.get());
this->pyobj = THPObjectPtr(other->pyobj.get());
for (auto& sa : other->scalar_args) {
Py_INCREF(sa.get());
this->scalar_args.emplace_back(sa.get());
}
}
// recover the autograd.Function instance, if this PythonOp's function
// was originally SomeFunction.apply
// used in ONNX for discovering symbolics
c10::optional<THPObjectPtr> ConcretePythonOp::autogradFunction() const {
pybind11::gil_scoped_acquire gil;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
py::handle obj = const_cast<PyObject*>(pyobj.get());
auto r = py::getattr(obj, "__self__", py::none());
if (r.is_none())
return c10::nullopt;
auto apply = py::getattr(r, "apply", py::none());
if (apply.is_none())
return c10::nullopt;
auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
if (PyErr_Occurred())
throw py::error_already_set();
if (c)
return c10::nullopt;
return THPObjectPtr(r.release().ptr());
}
void ConcretePythonOp::writeScalars(std::ostream& out) const {
out << "(";
int i = 0;
for (auto& scalar : scalar_args) {
if (i++ > 0)
out << ", ";
printPyObject(out, scalar);
}
out << ")";
}
void ConcretePythonOp::lint_python() const {
size_t n_scalars = 0, n_tensors = 0;
for (auto c : cconv) {
if (c == 'c') {
n_scalars++;
} else if (c == 'd') {
n_tensors++;
} else {
AT_ASSERT(0);
}
AT_ASSERT(static_cast<bool>(pyobj));
}
AT_ASSERT(n_scalars == scalar_args.size());
AT_ASSERT(n_tensors == inputs().size());
}
Node* Graph::createPythonOp(
THPObjectPtr&& pyobj,
const std::string& cconv,
pyobj_list&& scalar_args) {
ConcretePythonOp* op = new ConcretePythonOp(this);
return op->init(std::move(pyobj), cconv, std::move(scalar_args));
}
void initPythonIRBindings(PyObject* module_) {
auto m = py::handle(module_).cast<py::module>();
py::class_<AliasDb, std::shared_ptr<AliasDb>>(m, "AliasDb")
.def("dump", &AliasDb::dump)
.def("to_graphviz_str", &AliasDb::toGraphviz)
.def(
"may_contain_alias",
[&](AliasDb& db, Value* v1, Value* v2) {
return db.mayContainAlias(v1, v2);
})
.def(
"has_writers",
[&](AliasDb& db, Value* v1) { return db.hasWriters(v1); })
.def("__str__", &AliasDb::toString)
.def(
"move_after_topologically_valid",
[](AliasDb& db, Node* n, Node* movePoint) {
return db.moveAfterTopologicallyValid(n, movePoint);
})
.def(
"move_before_topologically_valid",
[](AliasDb& db, Node* n, Node* movePoint) {
return db.moveBeforeTopologicallyValid(n, movePoint);
});
#define GS(name) def(#name, &Graph ::name)
py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
.def(py::init<>())
.def(
"__repr__",
[&](Graph& g) { return g.toString(global_print_source_ranges); })
.def("str", &Graph::toString, py::arg("print_source_ranges") = true)
.def_readonly_static(
"global_print_source_ranges", &global_print_source_ranges)
.def_static(
"set_global_print_source_ranges",
[&](const bool enabled) { global_print_source_ranges = enabled; },
py::arg("enabled") = true)
.def(
"alias_db",
[](std::shared_ptr<Graph> g,
bool isFrozen = false,
bool descend_function_calls = false) {
return std::make_shared<AliasDb>(
std::move(g), isFrozen, descend_function_calls);
},
py::arg("isFrozen") = false,
py::arg("descend_function_calls") = false)
.def(
"dump_alias_db",
[](std::shared_ptr<Graph> g) {
AliasDb db(std::move(g));
db.dump();
})
.def(
"_export_onnx",
[](const std::shared_ptr<Graph>& g,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
bool strip_doc_string,
bool keep_initializers_as_inputs,
const std::map<std::string, int>& custom_opsets,
bool add_node_names,
const std::string& onnx_file_path,
const NodeAttrNameMap& node_attr_to_name) {
std::string graph;
std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
RawDataExportMap export_map;
SymbolDimMap symbol_map;
bool val_use_external_data_format = false;
NodeNameMap onnx_node_names;
std::tie(
model_proto,
export_map,
symbol_map,
val_use_external_data_format,
onnx_node_names) =
export_onnx(
g,
initializers,
onnx_opset_version,
dynamic_axes,
defer_weight_export,
operator_export_type,
strip_doc_string,
keep_initializers_as_inputs,
custom_opsets,
add_node_names,
val_use_external_data_format,
onnx_file_path,
node_attr_to_name);
std::unordered_map<std::string, py::bytes>
python_serialized_export_map;
for (auto& kv : export_map) {
auto t = kv.second;
size_t copy_bytes = t.element_size() * t.numel();
// TODO: this is an unnecessary copy. In theory we can directly
// return the map from identifier to Tensor, but we need some API
// in Python to get raw `bytes` containing the raw tensor data.
python_serialized_export_map[kv.first] =
py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
}
graph = serialize_model_proto_to_string(model_proto);
return std::make_tuple(
py::bytes(graph),
python_serialized_export_map,
val_use_external_data_format,
onnx_node_names);
},
py::arg("initializers"),
py::arg("onnx_opset_version") = 0,
py::arg("dynamic_axes"),
py::arg("defer_weight_export") = false,
py::arg("operator_export_type") =
::torch::onnx::OperatorExportTypes::ONNX,
py::arg("strip_doc_string") = true,
py::arg("keep_initializers_as_inputs") = true,
py::arg("custom_opsets"),
py::arg("add_node_names") = true,
py::arg("onnx_file_path") = std::string(),
py::arg("node_attr_to_name") = NodeAttrNameMap())
.def(
"_pretty_print_onnx",
[](const std::shared_ptr<Graph>& g,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
bool google_printer,
bool keep_initializers_as_inputs,
const std::map<std::string, int>& custom_opsets,
bool add_node_names) {
return pretty_print_onnx(
g,
initializers,
onnx_opset_version,
defer_weight_export,
operator_export_type,
google_printer,
keep_initializers_as_inputs,
custom_opsets,
add_node_names);
},
py::arg("initializers"),
py::arg("onnx_opset_version") = 0,
py::arg("defer_weight_export") = false,
py::arg("operator_export_type") =
::torch::onnx::OperatorExportTypes::ONNX,
py::arg("google_printer") = false,
py::arg("keep_initializers_as_inputs") = true,
py::arg("custom_opsets"),
py::arg("add_node_names") = true)
.def(
"inputs",
[](Graph& g) {
return py::make_iterator(g.inputs().begin(), g.inputs().end());
},
py::keep_alive<0, 1>())
.def(
"outputs",
[](Graph& g) {
return py::make_iterator(g.outputs().begin(), g.outputs().end());
},
py::keep_alive<0, 1>())
// We keep the graph alive while the iterator lives. Destroying
// nodes might still be hazardous.
.def(
"nodes",
[](Graph& g) {
return py::make_iterator(g.nodes().begin(), g.nodes().end());
},
py::keep_alive<0, 1>())
.def(
"findNode",
[](Graph& g, const std::string& kind, bool recurse) {
return findNode(g.block(), Symbol::fromQualString(kind), recurse);
},
"Find Node",
py::arg("kind"),
py::arg("recurse") = true)
.def(
"findAllNodes",
[](Graph& g, const std::string& kind, bool recurse) {
return findAllNodes(g, Symbol::fromQualString(kind), recurse);
},
"Find all nodes",
py::arg("kind"),
py::arg("recurse") = true)
.def(
"addInput",
[](Graph& g, const std::string& name) { return g.addInput(name); },
"Add input to graph with optional name seed",
py::arg("name") = "")
.def("copy", [](Graph& g) { return g.copy(); })
.GS(eraseInput)
.GS(eraseOutput)
.GS(registerOutput)
.def(
"permuteInputs",
[](Graph& g, const std::vector<size_t>& new_inputs) {
g.block()->permuteInputs(new_inputs);
})
.def(
"create",
[](Graph& g, const char* str) {
return g.create(Symbol::fromQualString(str));
})
.def(
"create",
[](Graph& g, const char* str, size_t noutputs) {
return g.create(Symbol::fromQualString(str), noutputs);
})
.def(
"create",
[](Graph& g, const char* str, const std::vector<Value*>& inputs) {
TORCH_CHECK_VALUE(
std::all_of(
inputs.begin(),
inputs.end(),
[](Value* v) { return (v != nullptr); }),
"cannot pass None in inputs");
return g.create(Symbol::fromQualString(str), inputs);
})
.def(
"create",
[](Graph& g,
const char* str,
const std::vector<Value*>& inputs,
size_t noutputs) {
TORCH_CHECK_VALUE(
std::all_of(
inputs.begin(),
inputs.end(),
[](Value* v) { return (v != nullptr); }),
"cannot pass None in inputs");
return g.create(Symbol::fromQualString(str), inputs, noutputs);
})
.def("param_node", [](Graph& g) { return g.block()->param_node(); })
.def("return_node", [](Graph& g) { return g.block()->return_node(); })
.def(
"createFusionGroup",
[](Graph& g) { return g.createWithSubgraph(prim::FusionGroup); })
.def(
"createCudaFusionGroup",
[](Graph& g) { return g.createWithSubgraph(prim::CudaFusionGroup); })
.def(
"createClone",
[](Graph& g, Node* n, py::object fn) {
return g.createClone(
n, [&](Value* e) { return fn(e).cast<Value*>(); });
})
.GS(appendNode)
.GS(prependNode)
// NB: insert_point_guard defined over direct modification of insert point
.def(
"insert_point_guard",
[](Graph& g, Node* n) {
return py::module::import("torch.jit._ir_utils")
.attr("insert_point_guard")(g, n);
})
.def(
"insert_point_guard",
[](Graph& g, Block* b) {
return py::module::import("torch.jit._ir_utils")
.attr("insert_point_guard")(g, b);
})
.GS(insertPoint)
.def("setInsertPoint", [](Graph& g, Node* n) { g.setInsertPoint(n); })
.def("setInsertPoint", [](Graph& g, Block* n) { g.setInsertPoint(n); })
.def(
"insertGraph",
[](Graph& g, Graph& callee, std::vector<Value*> inputs) {
return insertGraph(g, callee, inputs);
})
.def(
"insertGraph",
[](Graph& g,
Graph& callee,
std::vector<Value*> inputs,
std::unordered_map<Value*, Value*> value_map) {
return insertGraph(g, callee, inputs, value_map);
})
.def(
"insert",
[](Graph& g, Symbol opname, std::vector<Value*> args) {
std::vector<NamedValue> args_named;
args_named.reserve(args.size());
for (Value* v : args) {
args_named.emplace_back(v);
}
return g.insert(opname, args_named);
})
.def(
"makeMultiOutputIntoTuple",
[](Graph& g) {
auto tup = g.createTuple(g.outputs());
tup->insertBefore(g.return_node());
for (int64_t i = g.outputs().size() - 1; i >= 0; i--) {
g.eraseOutput(0);
}
g.registerOutput(tup->output());
})
.def(
"insertConstant",
[](Graph& g, const IValue& ival) { return g.insertConstant(ival); })
.GS(lint)
.def("block", [](Graph& g) { return g.block(); })
.GS(insertNode);
#undef GS
#define VS(name) def(#name, &Value ::name)
py::class_<Value, unwrapping_shared_ptr<Value>>(m, "Value")
.def(
"__repr__",
[](Value& n) {
std::stringstream ss;
ss << n.debugName() << " defined in (" << *n.node() << ")";
return ss.str();
})
.VS(type)
.VS(setType)
.def(
"inferTypeFrom",
py::overload_cast<const at::Tensor&>(&Value::inferTypeFrom))
.def(
"inferTypeFrom",
py::overload_cast<const c10::intrusive_ptr<c10::ivalue::Object>&>(
&Value::inferTypeFrom))
// skip owningGraph because it returns a raw pointer to a otherwise
// std::shared_ptr stored graph object, and would cause a double free
.VS(unique)
.VS(debugName)
.VS(setDebugName)
.VS(offset)
.VS(uses)
.VS(replaceAllUsesWith)
.VS(replaceAllUsesAfterNodeWith)
.def("node", [](Value& v) { return v.node(); })
.def(
"setTypeAs",
[](Value* node, Value* other) {
node->setType(other->type());
return node;
})
.VS(copyMetadata)
.VS(isCompleteTensor)
.VS(requires_grad)
.def(
"requiresGrad",
[](Value& n) {
return n.type()->expectRef<TensorType>().requiresGrad();
})
.def("toIValue", [](Value& n) { return toIValue(&n); })
.def("type", [](Value& v) { return v.type(); });
#undef VS
py::class_<Block, unwrapping_shared_ptr<Block>>(m, "Block")
.def(
"nodes",
[](Block& b) {
return py::make_iterator(b.nodes().begin(), b.nodes().end());
})
.def(
"findNode",
[](Block& b, const std::string& kind, bool recurse) {
return findNode(&b, Symbol::fromQualString(kind), recurse);
},
"Find Node",
py::arg("kind"),
py::arg("recurse") = true)
.def(
"findAllNodes",
[](Block& b, const std::string& kind, bool recurse) {
return findAllNodes(b, Symbol::fromQualString(kind), recurse);
},
"Find all nodes",
py::arg("kind"),
py::arg("recurse") = true)
.def(
"inputs",
[](Block& b) {
return py::make_iterator(b.inputs().begin(), b.inputs().end());
})
.def(
"outputs",
[](Block& b) {
return py::make_iterator(b.outputs().begin(), b.outputs().end());
})
.def("returnNode", [](Block& b) { return b.return_node(); })
.def("paramNode", [](Block& b) { return b.param_node(); })
.def("owningNode", [](Block& b) { return b.owningNode(); })
.def(
"addNode",
[](Block& b, const char* str, const std::vector<Value*>& inputs) {
return addNodeToBlock(&b, Symbol::fromQualString(str), inputs);
})
.def("addInputToBlock", [](Block& b) { return addInputToBlock(&b); })
.def("registerOutput", [](Block& b, Value* value) {
return b.registerOutput(value);
});
#define NS(name) def(#name, &Node ::name)
py::class_<Node, unwrapping_shared_ptr<Node>>(m, "Node")
.def(
"__repr__",
[](Node& n) {
std::stringstream ss;
ss << n;
return ss.str();
})
.def("sourceRange", [](Node& n) { return n.sourceRange().str(); })
.def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
.def("inputsSize", [](Node& n) { return n.inputs().size(); })
.def("outputsSize", [](Node& n) { return n.outputs().size(); })
.NS(kind)
.def("prev", [](Node& n) { return n.prev(); })
.def("matches", [](Node& n, const char* s) { return n.matches(s); })
.def("owningBlock", [](Node& n) { return n.owningBlock(); })
.def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); })
.def(
"inputs",
[](Node& n) {
return py::make_iterator(n.inputs().begin(), n.inputs().end());
})
.def(
"schema",
[](Node& n) {
std::stringstream ss;
if (n.maybeSchema()) {
ss << n.schema();
} else {
ss << "(no schema)";
}
return ss.str();
})
.def(
"outputs",
[](Node& n) {
return py::make_iterator(n.outputs().begin(), n.outputs().end());
})
.def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
.def(
"findNode",
[](Node& n, const std::string& kind, bool recurse) {
return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
},
"Find Node",
py::arg("kind"),
py::arg("recurse") = true)
.def(
"findAllNodes",
[](Node& n, const std::string& kind, bool recurse) {
return findAllNodes(
n.blocks(), Symbol::fromQualString(kind), recurse);
},
"Find all nodes",
py::arg("kind"),
py::arg("recurse") = true)
.def("input", [](Node& n) { return n.input(); })
.def("output", [](Node& n) { return n.output(); })
.def(
"getModuleHierarchy",
[](Node& n) { return torch::jit::utils::getNodesModuleHierarchy(n); })
.def(
"namedInput",
[](Node& n, const std::string& unqualName) {
return n.namedInput(unqualName);
})
.NS(addInput)
.NS(copyMetadata)
.NS(replaceInput)
.NS(replaceInputWith)
.NS(replaceAllUsesWith)
.NS(insertBefore)
.NS(insertAfter)
.NS(isBefore)
.NS(isAfter)
.NS(moveAfter)
.NS(moveBefore)
.NS(removeInput)
.NS(removeAllInputs)
.NS(destroy)
.NS(hasUses)
.NS(eraseOutput)
.NS(addOutput)
.NS(scopeName)
.NS(isNondeterministic)
.def(
"blocks",
[](Node& n) {
return py::make_iterator(n.blocks().begin(), n.blocks().end());
})
.NS(addBlock)
.NS(mustBeNone)
#define AS(name) def(#name, &Node::name)
// methods from Attributes
.AS(copyAttributes)
.AS(hasAttributes)
#undef AS
#define AS(name) def(#name, &Node::name##S)
// The default method names take Symbol, but the string conversion for
// Symbol you to qualify with attr::. This is not very user friendly
// for attributes, so expose the string variants instead.
.AS(hasAttribute)
.AS(kindOf)
.AS(removeAttribute)
.AS(attributeNames)
#undef AS
#define CREATE_ACCESSOR(Kind, method) \
def(#method "_", [](Node& n, const char* name, Kind##Attr::ValueType v) { \
return n.method##_(Symbol::attr(name), std::move(v)); \
}).def(#method, [](Node& n, const char* name) { \
return n.method(Symbol::attr(name)); \
})
.CREATE_ACCESSOR(Float, f)
.CREATE_ACCESSOR(Floats, fs)
.CREATE_ACCESSOR(Complex, c)
.CREATE_ACCESSOR(String, s)
.CREATE_ACCESSOR(Strings, ss)
.CREATE_ACCESSOR(Int, i)
.CREATE_ACCESSOR(Ints, is)
.CREATE_ACCESSOR(Graph, g)
.CREATE_ACCESSOR(Graphs, gs)
.CREATE_ACCESSOR(IValue, ival)
#undef CREATE_ACCESSOR
// Tensor (t_) -- manually written to unwrap the variable into a tensor.
.def(
"t_",
[](Node& n, const char* name, const torch::autograd::Variable& v) {
AT_ASSERT(!v.requires_grad());
return n.t_(Symbol::attr(name), v);
})
.def(
"t",
[](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
// Tensors (ts_) -- manually written to unwrap variables into tensors.
.def(
"ts_",
[](Node& n,
const char* name,
const std::vector<torch::autograd::Variable>& vs) {
std::vector<at::Tensor> tensors;
tensors.reserve(vs.size());
for (auto& variable : vs) {
AT_ASSERT(!variable.requires_grad());
tensors.push_back(variable);
}
return n.ts_(Symbol::attr(name), std::move(tensors));
})
.def(
"ts",
[](Node& n, const char* name) {
auto tensors = n.ts(Symbol::attr(name));
std::vector<torch::autograd::Variable> variables;
variables.reserve(tensors.size());
for (auto& tensor : tensors) {
variables.emplace_back(std::move(tensor));
}
return variables;
})
.def(
"z_",
[](Node& n, const char* name, const at::Tensor& v) {
return n.t_(
Symbol::attr(name),
autograd::Variable(v.view(std::vector<int64_t>{}))
.set_requires_grad(false));
})
.def(
"z",
[](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
.def(
"ty_",
[](Node& n, const char* name, const TypePtr& type) {
return n.ty_(Symbol::attr(name), type);
})
.def(
"ty",
[](Node& n, const char* name) { return n.ty(Symbol::attr(name)); })
.def(
"tys_",
[](Node& n, const char* name, const std::vector<TypePtr>& types) {
return n.tys_(Symbol::attr(name), types);
})
.def(
"tys",
[](Node& n, const char* name) { return n.tys(Symbol::attr(name)); })
.def(
"zs_",
[](Node& n, const char* name, TensorsAttr::ValueType v) {
for (auto& i : v) {
i = autograd::Variable(i.view(std::vector<int64_t>{}))
.set_requires_grad(false);
}
return n.ts_(Symbol::attr(name), std::move(v));
})
.def(
"zs",
[](Node& n, const char* name) { return n.ts(Symbol::attr(name)); })
.def(
"pyobj",
[](Node& n) {
return py::handle(n.expect<ConcretePythonOp>()->pyobj.get())
.cast<py::object>();
})
.def("cconv", [](Node& n) { return n.expect<ConcretePythonOp>()->cconv; })
.def(
"pyname",
[](Node& n) { return n.expect<ConcretePythonOp>()->name(); })
.def("scalar_args", [](Node& n) {
auto op = n.expect<ConcretePythonOp>();
auto scalars = py::list();
auto append = scalars.attr("append");
for (auto& arg : op->scalar_args) {
append(py::handle(arg.get()));
}
return scalars;
});
using ::c10::Type;
py::class_<Type, TypePtr>(m, "Type")
.def("__repr__", [](Type& t) { return t.annotation_str(); })
.def(
"str",
[](Type& t) {
std::ostringstream s;
s << t;
return s.str();
})
.def(
"containedTypes",
[](Type& self) { return self.containedTypes().vec(); })
.def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
.def(
"dim",
[](Type& t) {
auto vshape = t.expectRef<TensorType>().sizes();
return vshape.size() ? py::cast(*vshape.size())
: py::cast<py::none>(Py_None);
})
.def(
"undefined",
[](Type& t) {
auto undef = t.expectRef<TensorType>().undefined();
return undef.has_value() ? py::cast(*undef)
: py::cast<py::none>(Py_None);
})
.def(
"sizes",
[](Type& t) -> py::object {
if (auto ptt = t.expect<TensorType>()) {
if (auto cs = ptt->sizes().concrete_sizes()) {
return py::cast(*cs);
}
}
return py::none();
})
.def(
"symbolic_sizes",
[](Type& t) -> py::object {
if (auto ptt = t.expect<TensorType>()) {
auto ss = ptt->symbolic_sizes();
if (!ss.rank().has_value()) {
return py::none();
}
std::vector<int64_t> ss_vals;
for (size_t i = 0; i < *ss.rank(); ++i) {
ss_vals.push_back(ss.at(i).value());
}
return py::cast(ss_vals);
}
return py::none();
})
.def(
"with_sizes",
[](Type& t, c10::optional<std::vector<c10::optional<int64_t>>> sizes)
-> py::object {
auto ptt = t.expect<TensorType>();
if (!ptt) {
return py::none();
}
if (!sizes) {
return py::cast(ptt->withSymbolicShapes(c10::SymbolicShape()));
}
return py::cast(ptt->withSymbolicShapes(*sizes));
})
.def(
"varyingSizes",
[](Type& t) -> py::object {
if (auto ptt = t.expect<TensorType>()) {
if (auto s = ptt->sizes().sizes()) {
return py::cast(s.value());
}
}
return py::none();
})
.def(
"strides",
[](Type& t) -> py::object {
if (auto ptt = t.expect<TensorType>()) {
if (auto cs = ptt->strides().concrete_sizes()) {
return py::cast(*cs);
}
}
return py::none();
})
.def(
"contiguous",
[](Type& t) {
return std::static_pointer_cast<Type>(
t.expectRef<TensorType>().contiguous());
})
.def(
"scalarType",
[](Type& t) {
auto scalar_type = t.expectRef<TensorType>().scalarType();
return (scalar_type) ? toString(*scalar_type) : nullptr;
})
.def(
"device",
[](Type& t) -> py::object {
auto device = t.expectRef<TensorType>().device();
if (!device) {
return py::none();
}
PyObject* thp_device = THPDevice_New(device.value());
return py::reinterpret_borrow<py::object>(thp_device);
// return toPyObject(device.value());
})
.def(
"with_device",
[](Type& t, py::object device) -> py::object {
at::Device c_device = python::detail::py_object_to_device(device);
if (auto ptt = t.expect<TensorType>()) {
return py::cast(ptt->withDevice(c_device));
}
return py::none();
})
.def(
"dtype",
[](Type& t) -> py::object {
auto scalar_type = t.expectRef<TensorType>().scalarType();
if (!scalar_type) {
return py::none();
}
THPDtype* thp_dtype = torch::getTHPDtype(*scalar_type);
py::object dtype =
py::reinterpret_borrow<py::object>((PyObject*)thp_dtype);
return dtype;
})
.def(
"with_dtype",
[](Type& t, py::object dtype) -> py::object {
at::ScalarType scalar_type =
python::detail::py_object_to_dtype(dtype);
if (auto ptt = t.expect<TensorType>()) {
// auto scalar_type = dtype->scalar_type;
return py::cast(ptt->withScalarType(scalar_type));
}
return py::none();
})
.def(
"__eq__",
[](const TypePtr& self, const TypePtr& other) {
if (!other) {
return false;
}
return *self == *other;
})
.def(
"isSubtypeOf",
[](const TypePtr& self, const TypePtr& other) {
if (!other) {
return false;
}
return self->isSubtypeOf(other);
})
.def(
"is_interface_type",
[](const TypePtr& self) {
return self->castRaw<InterfaceType>() != nullptr;
})
.def(
"requires_grad",
[](const TypePtr& self) -> bool { return self->requires_grad(); })
.def_property_readonly(
"annotation_str", [](const std::shared_ptr<Type>& self) {
return self->annotation_str();
});
py::class_<AnyType, Type, AnyTypePtr>(m, "AnyType")
.def_static("get", &AnyType::get);
py::class_<NumberType, Type, NumberTypePtr>(m, "NumberType")
.def_static("get", &NumberType::get);
py::class_<IntType, Type, IntTypePtr>(m, "IntType")
.def_static("get", &IntType::get);
py::class_<SymIntType, Type, SymIntTypePtr>(m, "SymIntType")
.def_static("get", &SymIntType::get);
py::class_<FloatType, Type, FloatTypePtr>(m, "FloatType")
.def_static("get", &FloatType::get);
py::class_<ComplexType, Type, ComplexTypePtr>(m, "ComplexType")
.def_static("get", &ComplexType::get);