1#include <torch/csrc/lazy/core/ir_dump_util.h>
2
3#include <c10/util/Optional.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/lazy/backend/backend_interface.h>
6#include <torch/csrc/lazy/backend/lowering_context.h>
7#include <torch/csrc/lazy/core/ir_util.h>
8
9#include <regex>
10#include <sstream>
11#include <unordered_map>
12
13namespace torch {
14namespace lazy {
15namespace {
16
17using NodeIdMap = std::unordered_map<const Node*, size_t>;
18
19struct AttrTag {
20 std::string name;
21 std::string value;
22 std::string::size_type pos = 0;
23};
24
25std::string::size_type SkipTagSeparator(
26 const std::string& node_string,
27 std::string::size_type pos) {
28 return node_string.compare(pos, 2, ", ") == 0 ? pos + 2 : pos;
29}
30
31c10::optional<AttrTag> ParseAttrTag(
32 const std::string& node_string,
33 std::string::size_type pos) {
34 // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
35 const std::regex tag_regex("^([a-zA-Z0-9_]+)=");
36 std::smatch match;
37 // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
38 if (!std::regex_search(
39 node_string.begin() + pos, node_string.end(), match, tag_regex)) {
40 return c10::nullopt;
41 }
42
43 std::string::size_type vpos = match[1].second - node_string.begin() + 1;
44 char nested_open = -1;
45 char nested_close = -1;
46 size_t nest_count = 1;
47 AttrTag tag;
48 tag.name = match[1].str();
49 for (pos = vpos; pos < node_string.size(); ++pos) {
50 if (nested_open < 0) {
51 if (SkipTagSeparator(node_string, pos) != pos) {
52 break;
53 }
54 switch (node_string[pos]) {
55 case '(':
56 nested_open = node_string[pos];
57 nested_close = ')';
58 break;
59 case '[':
60 nested_open = node_string[pos];
61 nested_close = ']';
62 break;
63 case '{':
64 nested_open = node_string[pos];
65 nested_close = '}';
66 break;
67 }
68 } else if (node_string[pos] == nested_close) {
69 --nest_count;
70 if (nest_count == 0) {
71 nest_count = 1;
72 nested_open = nested_close = -1;
73 }
74 } else if (node_string[pos] == nested_open) {
75 ++nest_count;
76 }
77 }
78 tag.value = node_string.substr(vpos, pos - vpos);
79 tag.pos = pos;
80 return tag;
81}
82
83NodeIdMap GenerateIdMap(c10::ArrayRef<const Node*> post_order) {
84 NodeIdMap id_map;
85 for (auto node : post_order) {
86 TORCH_CHECK(id_map.emplace(node, id_map.size()).second, node->ToString());
87 }
88 return id_map;
89}
90
91std::unordered_map<const Node*, size_t> GetRootsIds(
92 c10::ArrayRef<const Node*> roots) {
93 std::unordered_map<const Node*, size_t> roots_ids;
94 for (const auto i : c10::irange(roots.size())) {
95 roots_ids[roots[i]] = i;
96 }
97 return roots_ids;
98}
99
100c10::optional<size_t> GetRootNodeId(
101 const Node* node,
102 const std::unordered_map<const Node*, size_t>& roots_ids) {
103 auto it = roots_ids.find(node);
104 if (it == roots_ids.end()) {
105 return c10::nullopt;
106 }
107 return it->second;
108}
109
110std::vector<AttrTag> GetNodeTags(const Node* node) {
111 std::string node_string = node->ToString();
112 std::string op_string = node->op().ToString();
113 std::string::size_type pos = node_string.find(op_string);
114 TORCH_CHECK(pos != std::string::npos, node_string, " : ", op_string);
115 pos += op_string.size();
116 std::vector<AttrTag> tags;
117 for (;;) {
118 pos = SkipTagSeparator(node_string, pos);
119 auto tag = ParseAttrTag(node_string, pos);
120 if (!tag) {
121 break;
122 }
123 pos = tag->pos;
124 tags.push_back(std::move(*tag));
125 }
126 return tags;
127}
128
129std::string GenerateDotNodeLabel(
130 const Node* node,
131 const std::unordered_map<const Node*, size_t>& roots_ids) {
132 static const size_t kMaxValueSize = 64;
133 std::stringstream ss;
134 ss << node->op() << "\\n" << node->shape();
135 for (auto& tag : GetNodeTags(node)) {
136 ss << "\\n" << tag.name << "=";
137 if (tag.value.size() < kMaxValueSize) {
138 ss << tag.value;
139 } else {
140 ss << tag.value.substr(0, kMaxValueSize) << "...";
141 }
142 }
143 auto opt_root_id = GetRootNodeId(node, roots_ids);
144 if (opt_root_id) {
145 ss << "\\nROOT=" << *opt_root_id;
146 }
147 return ss.str();
148}
149
150std::string GenerateDotNodeSpec(
151 const Node* node,
152 const std::unordered_map<const Node*, size_t>& roots_ids) {
153 std::stringstream ss;
154 ss << "label=\"" << GenerateDotNodeLabel(node, roots_ids) << "\"";
155 return ss.str();
156}
157
158std::string GenerateTextNodeSpec(const Node* node, const NodeIdMap& id_map) {
159 std::stringstream ss;
160 ss << node->shapes() << " " << node->op() << "(";
161 size_t count = 0;
162 for (auto& output : node->operands()) {
163 if (count > 0) {
164 ss << ", ";
165 }
166 ss << "%" << id_map.at(output.node);
167 if (output.node->num_outputs() > 1) {
168 ss << "." << output.index;
169 }
170 ++count;
171 }
172 ss << ")";
173 for (auto& tag : GetNodeTags(node)) {
174 ss << ", " << tag.name << "=" << tag.value;
175 }
176 return ss.str();
177}
178
179} // namespace
180
181std::string DumpUtil::ToDot(c10::ArrayRef<const Node*> nodes) {
182 auto post_order = Util::ComputePostOrder(nodes);
183 return PostOrderToDot(post_order, nodes);
184}
185
186std::string DumpUtil::PostOrderToDot(
187 c10::ArrayRef<const Node*> post_order,
188 c10::ArrayRef<const Node*> roots) {
189 std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots);
190 NodeIdMap id_map = GenerateIdMap(post_order);
191 std::stringstream ss;
192 ss << "digraph G {\n";
193 for (auto node : post_order) {
194 ss << " node" << id_map.at(node) << " ["
195 << GenerateDotNodeSpec(node, roots_ids) << "]\n";
196 }
197 for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) {
198 const Node* node = *it;
199 size_t id = id_map.at(node);
200 for (const auto i : c10::irange(node->operands().size())) {
201 const Output& output = node->operand(i);
202 ss << " node" << id_map.at(output.node) << " -> node" << id;
203 if (node->operands().size() > 1) {
204 ss << " [label=\"i=" << i;
205 if (output.node->num_outputs() > 1) {
206 ss << ",o=" << output.index;
207 }
208 ss << "\"]\n";
209 } else {
210 if (output.node->num_outputs() > 1) {
211 ss << " [label=\"o=" << output.index << "\"]";
212 }
213 ss << "\n";
214 }
215 }
216 }
217 ss << "}\n";
218 return ss.str();
219}
220
221std::string DumpUtil::ToText(c10::ArrayRef<const Node*> nodes) {
222 auto post_order = Util::ComputePostOrder(nodes);
223 return PostOrderToText(post_order, nodes);
224}
225
226std::string DumpUtil::PostOrderToText(
227 c10::ArrayRef<const Node*> post_order,
228 c10::ArrayRef<const Node*> roots) {
229 std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots);
230 NodeIdMap id_map = GenerateIdMap(post_order);
231 std::stringstream ss;
232 ss << "IR {\n";
233 for (auto node : post_order) {
234 auto opt_root_id = GetRootNodeId(node, roots_ids);
235 ss << " %" << id_map.at(node) << " = "
236 << GenerateTextNodeSpec(node, id_map);
237 if (opt_root_id) {
238 ss << ", ROOT=" << *opt_root_id;
239 }
240 ss << ", NodeType=" << typeid(*node).name();
241 ss << "\n";
242 }
243 ss << "}\n";
244 return ss.str();
245}
246
247std::string DumpUtil::ToBackend(
248 c10::ArrayRef<Value> values,
249 const BackendDevice& device) {
250 auto lowering_ctx = LoweringContext::Create("IrToBackend", device);
251 for (auto& ir_value : values) {
252 lowering_ctx->AddResult(ir_value);
253 }
254 auto computation = lowering_ctx->Build();
255 return getBackend()->GetComputationBackendText(computation);
256}
257
258} // namespace lazy
259} // namespace torch
260