forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
python_dict.h
128 lines (108 loc) · 3.27 KB
/
python_dict.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once
#include <ATen/core/Dict.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace jit {
void initScriptDictBindings(PyObject* module);
/// An iterator over the keys of ScriptDict. This is used to support
/// .keys() and iteration.
class ScriptDictKeyIterator final {
public:
ScriptDictKeyIterator(
c10::impl::GenericDict::iterator iter,
c10::impl::GenericDict::iterator end)
: iter_(std::move(iter)), end_(std::move(end)) {}
IValue next();
private:
c10::impl::GenericDict::iterator iter_;
c10::impl::GenericDict::iterator end_;
};
/// An iterator over the key-value pairs of ScriptDict. This is used to support
/// .items().
class ScriptDictIterator final {
public:
ScriptDictIterator(
c10::impl::GenericDict::iterator iter,
c10::impl::GenericDict::iterator end)
: iter_(std::move(iter)), end_(std::move(end)) {}
IValue next();
private:
c10::impl::GenericDict::iterator iter_;
c10::impl::GenericDict::iterator end_;
};
/// A wrapper around c10::Dict that can be exposed in Python via pybind
/// with an API identical to the Python dictionary class. This allows
/// dictionaries to have reference semantics across the Python/TorchScript
/// boundary.
class ScriptDict final {
public:
// Constructor.
ScriptDict(IValue data) : dict_(AnyType::get(), AnyType::get()) {
TORCH_INTERNAL_ASSERT(data.isGenericDict());
dict_ = data.toGenericDict();
}
// Get the type of the dictionary.
DictTypePtr type() const {
return DictType::create(dict_.keyType(), dict_.valueType());
}
// Return a string representation that can be used
// to reconstruct the instance.
std::string repr() const {
std::ostringstream s;
s << '{';
bool f = false;
for (auto const& kv : dict_) {
if (f) {
s << ", ";
}
s << kv.key() << ": " << kv.value();
f = true;
}
s << '}';
return s.str();
}
// Return an iterator over the keys of the dictionary.
ScriptDictKeyIterator iter() const {
auto begin = dict_.begin();
auto end = dict_.end();
return ScriptDictKeyIterator(begin, end);
}
// Return an iterator over the key-value pairs of the dictionary.
ScriptDictIterator items() const {
auto begin = dict_.begin();
auto end = dict_.end();
return ScriptDictIterator(begin, end);
}
// Interpret the dictionary as a boolean; empty means false, non-empty means
// true.
bool toBool() const {
return !(dict_.empty());
}
// Get the value for the given key. Throws std::out_of_range if the key does
// not exist.
IValue getItem(const IValue& key) {
return dict_.at(key);
};
// Set the value for the given key.
void setItem(const IValue& key, const IValue& value) {
dict_.insert_or_assign(key, value);
};
// Check whether the dictionary contains the given key.
bool contains(const IValue& key) {
return dict_.contains(key);
}
// Delete the given key from the dictionary.
bool delItem(const IValue& key) {
return dict_.erase(key);
}
// Get the size of the dictionary.
int64_t len() const {
return dict_.size();
}
// A c10::Dict instance that holds the actual data.
c10::impl::GenericDict dict_;
};
} // namespace jit
} // namespace torch