-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph.hh
104 lines (97 loc) · 3.65 KB
/
graph.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Copyright 2013 Jacob Emmert-Aronson
// This file is part of Tensor Network.
//
// Tensor Network is free software: you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// Tensor Network is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Tensor Network. If not, see
// <http://www.gnu.org/licenses/>.
#pragma once
#include <iterator>
#include <unordered_set>
// Forward declare to avoid dependencies between headers.
class Tensor;
// Struct to represent the edges of a graph.
struct GraphEdge
{
Tensor *input_tensor;
size_t input_num;
Tensor *output_tensor;
size_t output_num;
// test equality of elements independent of order
bool operator== (const GraphEdge &v) const;
bool operator!= (const GraphEdge &v) const;
};
// Define hash function overload for GraphEdge.
template <>
struct std::hash<GraphEdge>
{
size_t operator() (const GraphEdge &v) const
{
// Just XOR the element hashes together.
std::hash<Tensor*> Thash;
std::hash<size_t> Ihash;
return Thash(v.input_tensor) ^ Ihash(v.input_num)
^ Thash(v.output_tensor) ^ Ihash(v.output_num);
}
};
// A graph maps all the tensors reachable from a specific tensor as
// well as the edges connecting them. If the links connecting a
// tensor are changed, graphs containing that tensor are no longer
// valid.
class Graph
{
public:
virtual ~Graph() {}
// Number of vertices and edges
virtual size_t vertices() = 0;
virtual size_t edges() = 0;
// The following four functions are iterators for use with functions
// which need to act on either all tensors or all links. No
// guarantees are made about ordering.
virtual std::unordered_set<Tensor*>::const_iterator vertex_begin() = 0;
virtual std::unordered_set<Tensor*>::const_iterator vertex_end() = 0;
virtual std::unordered_set<GraphEdge>::const_iterator edge_begin() = 0;
virtual std::unordered_set<GraphEdge>::const_iterator edge_end() = 0;
// Iterators for endpoints (unlinked tensors).
virtual std::unordered_set<GraphEdge>::const_iterator endpt_begin() = 0;
virtual std::unordered_set<GraphEdge>::const_iterator endpt_end() = 0;
};
class DFSGraph : public Graph
{
public:
// Create the graph of all tensors connected to t.
explicit DFSGraph(Tensor *t);
DFSGraph& operator=(const DFSGraph&) = default;
DFSGraph(const DFSGraph&) = default;
DFSGraph& operator=(DFSGraph&&) = default;
DFSGraph(DFSGraph&&) = default;
// From interface Graph.
size_t vertices() override;
size_t edges() override;
std::unordered_set<Tensor*>::const_iterator vertex_begin() override;
std::unordered_set<Tensor*>::const_iterator vertex_end() override;
std::unordered_set<GraphEdge>::const_iterator edge_begin() override;
std::unordered_set<GraphEdge>::const_iterator edge_end() override;
std::unordered_set<GraphEdge>::const_iterator endpt_begin() override;
std::unordered_set<GraphEdge>::const_iterator endpt_end() override;
protected:
// Function which is called recursively to find all connected
// tensors via depth-first search.
void _dfs(Tensor *t);
private:
// Tensors which belong to the graph.
std::unordered_set<Tensor*> _vertices;
// Edges connecting tensors.
std::unordered_set<GraphEdge> _edges;
// Tensors with detached inputs or outputs.
std::unordered_set<GraphEdge> _endpts;
};