1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Exporter/ONNXModelWriter.h" |
18 | #include "glow/Graph/Utils.h" |
19 | #include "glow/Runtime/RuntimeTypes.h" |
20 | #include "glow/Support/ZipUtils.h" |
21 | |
22 | #include <stack> |
23 | |
24 | #include "miniz.h" |
25 | #include <google/protobuf/io/coded_stream.h> |
26 | #include <google/protobuf/io/zero_copy_stream_impl_lite.h> |
27 | |
28 | using namespace glow::runtime; |
29 | using google::protobuf::RepeatedPtrField; |
30 | |
31 | namespace glow { |
32 | #ifdef FACEBOOK_INTERNAL |
33 | extern const char *revisionHash; |
34 | #endif /* FACEBOOK_INTERNAL */ |
35 | |
36 | #define NUM_FLOAT_DIGS 30 |
37 | |
38 | namespace { |
39 | template <bool IsInteger, bool IsEnum, typename T> struct AttributeAssigner { |
40 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, const T &container); |
41 | }; |
42 | |
43 | // Specialization for llvm::ArrayRef<T> container types |
44 | template <typename T> |
45 | struct AttributeAssigner<false, false, llvm::ArrayRef<T>> { |
46 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, |
47 | const llvm::ArrayRef<T> &container) { |
48 | attr->set_type(ONNX_NAMESPACE::AttributeProto::INTS); |
49 | for (auto value : container) { |
50 | attr->add_ints(value); |
51 | } |
52 | } |
53 | }; |
54 | |
55 | // Specialization for string type |
56 | template <> struct AttributeAssigner<false, false, std::string> { |
57 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, |
58 | const std::string &container) { |
59 | attr->set_type(ONNX_NAMESPACE::AttributeProto::STRING); |
60 | attr->set_s(container); |
61 | } |
62 | }; |
63 | |
64 | // Specialization for StringRef type |
65 | template <> struct AttributeAssigner<false, false, llvm::StringRef> { |
66 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, |
67 | const llvm::StringRef container) { |
68 | attr->set_type(ONNX_NAMESPACE::AttributeProto::STRING); |
69 | attr->set_s(container.str()); |
70 | } |
71 | }; |
72 | |
73 | // Specialization for vector of strings. |
74 | template <> struct AttributeAssigner<false, false, std::vector<std::string>> { |
75 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, |
76 | const std::vector<std::string> &container) { |
77 | attr->set_type(ONNX_NAMESPACE::AttributeProto::STRINGS); |
78 | for (auto &str : container) { |
79 | attr->add_strings(str); |
80 | } |
81 | } |
82 | }; |
83 | |
84 | // Specialization for float type |
85 | template <> struct AttributeAssigner<false, false, float> { |
86 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, |
87 | const float &container) { |
88 | attr->set_type(ONNX_NAMESPACE::AttributeProto::FLOAT); |
89 | attr->set_f(container); |
90 | } |
91 | }; |
92 | |
93 | // Specialization for NodeValueArrayRef. |
94 | template <> struct AttributeAssigner<false, false, NodeValueArrayRef> { |
95 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, |
96 | const NodeValueArrayRef &container) { |
97 | attr->set_type(ONNX_NAMESPACE::AttributeProto::STRINGS); |
98 | for (size_t i = 0, e = container.size(); i < e; i++) { |
99 | attr->add_strings(container[i].generateNodeOutputName( |
100 | /* stripResNoFor0thInput */ true)); |
101 | } |
102 | } |
103 | }; |
104 | |
105 | // Specialization for llvm::ArrayRef<float>. |
106 | template <> struct AttributeAssigner<false, false, llvm::ArrayRef<float>> { |
107 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, |
108 | const llvm::ArrayRef<float> &container) { |
109 | attr->set_type(ONNX_NAMESPACE::AttributeProto::FLOATS); |
110 | for (auto value : container) { |
111 | attr->add_floats(value); |
112 | } |
113 | } |
114 | }; |
115 | |
116 | // Specialization for int type |
117 | template <typename T> struct AttributeAssigner<true, false, T> { |
118 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, const T &container) { |
119 | attr->set_type(ONNX_NAMESPACE::AttributeProto::INT); |
120 | attr->set_i(container); |
121 | } |
122 | }; |
123 | |
124 | // Specialization for enums. |
125 | template <typename T> struct AttributeAssigner<false, true, T> { |
126 | static void assign(ONNX_NAMESPACE::AttributeProto *attr, const T &container) { |
127 | attr->set_type(ONNX_NAMESPACE::AttributeProto::STRING); |
128 | std::string storage; |
129 | llvm::raw_string_ostream stream(storage); |
130 | stream << container; |
131 | attr->set_s(stream.str()); |
132 | } |
133 | }; |
134 | |
135 | template <typename T> |
136 | void addValueAttribute(ONNX_NAMESPACE::NodeProto *proto, |
137 | const std::string &name, const T &container) { |
138 | auto *attr = proto->add_attribute(); |
139 | attr->set_name(name); |
140 | AttributeAssigner<std::numeric_limits<T>::is_integer, std::is_enum<T>::value, |
141 | T>::assign(attr, container); |
142 | } |
143 | |
144 | /// Adds the type attributes from \p ty to \p proto. \p ioNum, \p isInput, and |
145 | /// \p addPrefix are used to format the name of the attribute. |
146 | void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, TypeRef ty, |
147 | unsigned ioNum, bool isInput, |
148 | const std::string &addPrefix = "" ) { |
149 | // Add ElemKind. |
150 | auto *elemKindAttr = proto->add_attribute(); |
151 | elemKindAttr->set_name( |
152 | getTypeAttrID(ioNum, elemKindSignifier, isInput, addPrefix)); |
153 | AttributeAssigner<false, false, llvm::StringRef>::assign( |
154 | elemKindAttr, ty->getElementName()); |
155 | |
156 | // Add Shape. |
157 | addValueAttribute(proto, |
158 | getTypeAttrID(ioNum, shapeSignifier, isInput, addPrefix), |
159 | ty->dims()); |
160 | |
161 | // Non-standard strides need to be serialized. |
162 | if (!ty->hasStandardStrides()) { |
163 | addValueAttribute( |
164 | proto, getTypeAttrID(ioNum, stridesSignifier, isInput, addPrefix), |
165 | ty->strides()); |
166 | } |
167 | |
168 | // Write out scale/offset if quantized ElemKind. |
169 | if (isQuantizedElemKind(ty->getElementType())) { |
170 | addValueAttribute(proto, |
171 | getTypeAttrID(ioNum, qScaleSignifier, isInput, addPrefix), |
172 | ty->getScale()); |
173 | addValueAttribute( |
174 | proto, getTypeAttrID(ioNum, qOffsetSignifier, isInput, addPrefix), |
175 | ty->getOffset()); |
176 | } |
177 | } |
178 | |
179 | /// Adds the type attributes from \p NV to \p proto. \p ioNum, \p isInput, and |
180 | /// \p addPrefix are used to format the name of the attribute. |
181 | void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, const NodeValue &NV, |
182 | unsigned ioNum, bool isInput, |
183 | const std::string &addPrefix = "" ) { |
184 | addTypeAttributes(proto, NV.getType(), ioNum, isInput, addPrefix); |
185 | } |
186 | |
187 | /// Add the type attributes from the \p ioNum number input or output (depending |
188 | /// on \p isInput) of \p N to \p proto. This includes the ElemKind, the Shape, |
189 | /// and scale/offset if ElemKind is quantized. Note that 'i' or 'o' along with |
190 | /// \p ioNum is prefixed onto the specific attribute being appended, as ops may |
191 | /// have multiple inputs/outputs. |
192 | void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, const Node *N, |
193 | unsigned ioNum, bool isInput) { |
194 | NodeValue NV = isInput ? N->getNthInput(ioNum) : N->getNthResult(ioNum); |
195 | return addTypeAttributes(proto, NV, ioNum, isInput); |
196 | } |
197 | |
198 | /// Helper function to recursively rewind Tile \p node. |
199 | /// Optionally, if provided fills out the repeats \p repeats. |
200 | /// Returns the first Tile in a chain of Tiles. |
201 | const TileNode *unwindTile(const TileNode *node, std::vector<size_t> *repeats, |
202 | ReportedNodes &reporter) { |
203 | // unwind Tile |
204 | // Keep track of detected <axis, count> pairs. |
205 | std::vector<std::pair<unsigned_t, unsigned_t>> info; |
206 | const TileNode *tile = node; |
207 | while (tile) { |
208 | // Insert counts and axises in reverse order, |
209 | // cause rewind algorithm navigates from the bottom to the top. |
210 | info.insert(info.begin(), {tile->getAxis(), tile->getCount()}); |
211 | |
212 | if (const auto *TN = llvm::dyn_cast<TileNode>(tile->getInput().getNode())) { |
213 | reporter.insert(TN); |
214 | tile = TN; |
215 | } else { |
216 | break; |
217 | } |
218 | } |
219 | |
220 | if (repeats) { |
221 | unsigned_t numDims = tile->getInput().dims().size(); |
222 | // axis is in a normal case will have [0, 1, ..., numDims - 1] values, |
223 | // in extreme case []. Find missing indices and insert count 1. |
224 | auto aB = info.begin(); |
225 | |
226 | for (unsigned_t i = 0; i < numDims; ++i, ++aB) { |
227 | if (aB == info.end() || aB->first != i) { |
228 | aB = info.insert(aB, {i, 1}); |
229 | } |
230 | } |
231 | |
232 | for (size_t b = 0, e = info.size(); b < e; ++b) { |
233 | repeats->push_back(info[b].second); |
234 | } |
235 | } |
236 | return tile; |
237 | } |
238 | |
239 | /// Writes all outputs from Node \p node to protobuf \p proto. |
240 | void findOutputNames(const Node *node, ONNX_TRAITS::GraphProto &graph, |
241 | std::function<void(const std::string &name)> &&callback) { |
242 | // Check if user is SaveNode |
243 | std::set<unsigned> saveResNo; |
244 | std::vector<std::pair<const SaveNode *, unsigned>> saveOutputs; |
245 | std::vector<int> resultUsers(node->getNumResults(), 0); |
246 | for (const auto &use : node->getUsers()) { |
247 | const auto *user = use.getUser(); |
248 | unsigned resNo = 0; |
249 | for (unsigned b = 0, e = user->getNumInputs(); b < e; ++b) { |
250 | auto UNV = user->getNthInput(b); |
251 | if (node == UNV.getNode()) { |
252 | resNo = UNV.getResNo(); |
253 | resultUsers[resNo]++; |
254 | break; |
255 | } |
256 | } |
257 | |
258 | if (user->getKind() == Kinded::Kind::SaveNodeKind) { |
259 | // Use the associated placeholder's name. |
260 | const SaveNode *SN = llvm::cast<SaveNode>(user); |
261 | saveOutputs.emplace_back(SN, resNo); |
262 | } |
263 | } |
264 | |
265 | // If saveNode is the only user of a result, we can just use save name as |
266 | // output name. Otherwise, we have to insert a Identity node to relay this |
267 | // output to save output |
268 | for (const auto &p : saveOutputs) { |
269 | if (resultUsers[p.second] == 1) { |
270 | callback(p.first->getPlaceholder()->getName().str()); |
271 | saveResNo.insert(p.second); |
272 | } else { |
273 | auto *proto = graph.add_node(); |
274 | proto->set_name(node->getName().str() + "_copy_" + |
275 | std::to_string(p.second)); |
276 | proto->set_op_type("Identity" ); |
277 | proto->add_input(p.second == 0 ? node->getName().str() |
278 | : (node->getName().str() + "_out_" + |
279 | std::to_string(p.second))); |
280 | proto->add_output(p.first->getPlaceholder()->getName().str()); |
281 | } |
282 | } |
283 | |
284 | // write the other outputs, if any |
285 | for (unsigned b = 0, e = node->getNumResults(); b < e; ++b) { |
286 | if (saveResNo.count(b)) { |
287 | continue; |
288 | } |
289 | if (b == 0) { |
290 | callback(node->getName().str()); |
291 | } else { |
292 | callback(node->getName().str() + "_out_" + std::to_string(b)); |
293 | } |
294 | } |
295 | } |
296 | |
297 | /// Writes all outputs from Node \p node to protobuf \p proto. |
298 | void outputsToProto(const Node *node, ONNX_TRAITS::GraphProto &graph, |
299 | ONNX_NAMESPACE::NodeProto *proto) { |
300 | findOutputNames(node, graph, |
301 | [&](const std::string &name) { proto->add_output(name); }); |
302 | } |
303 | |
304 | /// Writes all inputs from Node \p node to protobuf \p proto. |
305 | void inputsToProto(const Node *node, ONNX_NAMESPACE::NodeProto *proto) { |
306 | for (unsigned b = 0, e = node->getNumInputs(); b < e; ++b) { |
307 | const auto NV = node->getNthInput(b); |
308 | auto resNo = NV.getResNo(); |
309 | auto name = NV.getNode()->getName().str(); |
310 | if (resNo) { |
311 | proto->add_input(name + "_out_" + std::to_string(b)); |
312 | } else { |
313 | proto->add_input(name); |
314 | } |
315 | } |
316 | } |
317 | |
318 | /// Write the output of the provided type only of node outputs. |
319 | bool outputKindToProto(Kinded::Kind kind, const Node *node, |
320 | ONNX_TRAITS::GraphProto &graph, |
321 | ONNX_NAMESPACE::NodeProto *proto) { |
322 | bool found = false; |
323 | for (const auto &use : node->getUsers()) { |
324 | const auto *user = use.getUser(); |
325 | if (user->getKind() == Kinded::Kind::SaveNodeKind) { |
326 | found = true; |
327 | const SaveNode *SN = llvm::cast<SaveNode>(user); |
328 | proto->add_output(SN->getPlaceholder()->getName().str()); |
329 | break; |
330 | } else if (user->getKind() == kind) { |
331 | found = true; |
332 | outputsToProto(user, graph, proto); |
333 | } |
334 | } |
335 | return found; |
336 | } |
337 | |
338 | /// Writes MatMul operators from Node \p node into |
339 | /// provided graph protobuf \p graph, optionally reports intermediate nodes as |
340 | /// visited, signaling that such nodes must be ignored, Depending on \p |
341 | /// nodeKind, we can write either MatMul or BatchMatMul. \returns error. |
342 | template <typename T> |
343 | Error writeMatMulKind(const T *node, ONNX_TRAITS::GraphProto &graph, |
344 | const std::string &nodeKind) { |
345 | auto *proto = graph.add_node(); |
346 | proto->set_name(node->getName().str()); |
347 | proto->set_op_type(nodeKind); |
348 | |
349 | Node *LHS = node->getLHS().getNode(); |
350 | proto->add_input(LHS->getName().str()); |
351 | Node *RHS = node->getRHS().getNode(); |
352 | proto->add_input(RHS->getName().str()); |
353 | |
354 | outputsToProto(node, graph, proto); |
355 | return Error::success(); |
356 | } |
357 | |
358 | // Creates a Transpose Node as the Result of the \p node. |
359 | // Reuses given \p proto pointer and create a new proto adding it to \p graph. |
360 | // The permutation argument enables use of a different permutation. E.g. - for |
361 | // tranposing 3D convolution with NCTHW2NTHWC. |
362 | template <typename T> |
363 | void writeTransposeResult(const T *node, ONNX_NAMESPACE::NodeProto *&proto, |
364 | ONNX_TRAITS::GraphProto &graph, |
365 | llvm::ArrayRef<unsigned_t> permutation = NCHW2NHWC) { |
366 | // Add dictionary entries. |
367 | llvm::ArrayRef<unsigned_t> container(permutation); |
368 | addValueAttribute(proto, "perm" , container); |
369 | // Re-use proto for Transpose node. |
370 | auto newName = node->getName().str() + "_out_transpose" ; |
371 | proto->set_name(newName); |
372 | proto->set_op_type("Transpose" ); |
373 | proto->add_input(newName); |
374 | |
375 | proto->add_output(node->getName().str()); |
376 | |
377 | // T node proto. |
378 | proto = graph.add_node(); |
379 | proto->add_output(newName); |
380 | } |
381 | |
382 | // Creates a Transpose Node as the Input \p input of the \p node. |
383 | // Reuses given \p proto pointer and create a new proto adding it to \p graph. |
384 | // The permutation argument enables use of a different permutation. E.g. - for |
385 | // tranposing 3D convolution with NCTHW2NTHWC. |
386 | void writeTransposeInput(const Node *node, const Node *input, |
387 | ONNX_NAMESPACE::NodeProto *proto, |
388 | ONNX_TRAITS::GraphProto &graph, |
389 | llvm::ArrayRef<unsigned_t> permutation = NHWC2NCHW) { |
390 | // Write "mirror" Transform input, i.e. NHWC2NCHW |
391 | auto newName = |
392 | node->getName().str() + "_" + input->getName().str() + "_in_transpose" ; |
393 | auto *transformProto = graph.add_node(); |
394 | transformProto->set_name(newName); |
395 | transformProto->set_op_type("Transpose" ); |
396 | |
397 | // Add dictionary entries. |
398 | llvm::ArrayRef<unsigned_t> container(permutation); |
399 | addValueAttribute(transformProto, "perm" , container); |
400 | transformProto->add_input(input->getName().str()); |
401 | transformProto->add_output(newName); |
402 | proto->add_input(newName); |
403 | } |
404 | |
405 | /// Writes Arithmetic operators with name \p opName from Node \p node into |
406 | /// provided graph protobuf \p graph. Arithmetic node may have been broadcasted, |
407 | /// \p hasMultidirectionalBroadcast indicates the node can be multidirectional |
408 | /// broadcast, if that's the case do not specify the axis or broadcast flag in |
409 | /// protobuf, optionally reports intermediate nodes as visited, signaling that |
410 | /// such nodes must be ignored, \returns error. |
411 | template <typename T> |
412 | Error writeArithmetic(const std::string &opName, const T *node, |
413 | ONNX_TRAITS::GraphProto &graph, ReportedNodes &reporter, |
414 | bool hasMultidirectionalBroadcast) { |
415 | auto *proto = graph.add_node(); |
416 | proto->set_name(node->getName().str()); |
417 | proto->set_op_type(opName); |
418 | outputsToProto(node, graph, proto); |
419 | |
420 | auto LHS = node->getLHS(); |
421 | if (const BroadcastNode *BN = llvm::dyn_cast<BroadcastNode>(LHS.getNode())) { |
422 | reporter.insert(BN); |
423 | LHS = BN->getInput(); |
424 | } |
425 | |
426 | int axis = -1; |
427 | auto RHS = node->getRHS(); |
428 | if (const BroadcastNode *BN = llvm::dyn_cast<BroadcastNode>(RHS.getNode())) { |
429 | reporter.insert(BN); |
430 | RHS = BN->getInput(); |
431 | axis = BN->getAxis(); |
432 | } |
433 | |
434 | proto->add_input(LHS.getNode()->getName().str()); |
435 | proto->add_input(RHS.getNode()->getName().str()); |
436 | |
437 | // Check if the shapes of LHS and RHS are different and broadcast attribute is |
438 | // required. |
439 | if (LHS.dims() != RHS.dims() && !hasMultidirectionalBroadcast) { |
440 | addValueAttribute(proto, "axis" , axis); |
441 | addValueAttribute(proto, "broadcast" , 1UL); |
442 | } |
443 | |
444 | return Error::success(); |
445 | } |
446 | |
447 | void tensorShapeFromInput(const std::string &name, TypeRef ty, |
448 | ONNX_NAMESPACE::ValueInfoProto *valueProto) { |
449 | valueProto->set_name(name); |
450 | auto *type = valueProto->mutable_type(); |
451 | auto *tensorType = type->mutable_tensor_type(); |
452 | tensorType->set_elem_type(ONNXModelWriter::convertType(*ty)); |
453 | auto *tensorShape = tensorType->mutable_shape(); |
454 | const auto &dims = ty->dims(); |
455 | for (unsigned b = 0, e = dims.size(); b < e; ++b) { |
456 | auto *tensorDims = tensorShape->add_dim(); |
457 | tensorDims->set_dim_value(dims[b]); |
458 | } |
459 | } |
460 | |
461 | /// Creates the list Nodes in the reverse order of the required order for ONNX. |
462 | class ReverseGraphWalker { |
463 | /// A post-order list of nodes. |
464 | std::vector<const Node *> reverseOrder_; |
465 | /// A set of visited nodes. |
466 | std::unordered_set<const Node *> visited_; |
467 | |
468 | void visit(Function &F) { |
469 | // Write constants first, even they should be placed at the end of the list, |
470 | // we can visit them first, cause they will be written into ONNX |
471 | // initializers, not nodes. |
472 | for (const auto *C : F.findConstants()) { |
473 | reverseOrder_.push_back(C); |
474 | visited_.insert(C); |
475 | } |
476 | // Start visiting all root nodes, i.e. nodes that do not have any users. |
477 | for (auto &N : F.getNodes()) { |
478 | if (N.getNumUsers() == 0) { |
479 | visitIteratively(&N); |
480 | } |
481 | } |
482 | } |
483 | |
484 | void visitIteratively(const Node *rootNode) { |
485 | std::stack<const Node *> st; |
486 | st.push(rootNode); |
487 | while (st.size()) { |
488 | auto *N = st.top(); |
489 | st.pop(); |
490 | |
491 | // Check is node has been visited already. |
492 | if (visited_.count(N)) { |
493 | continue; |
494 | } |
495 | |
496 | if (N->getKind() == Kinded::Kind::PlaceholderKind) { |
497 | // For Placeholder don't visit uses. |
498 | visited_.insert(N); |
499 | reverseOrder_.push_back(N); |
500 | continue; |
501 | } |
502 | |
503 | // First visit all users, if any. |
504 | bool continueEarly = false; |
505 | for (const auto &use : N->getUsers()) { |
506 | const auto *UN = use.getUser(); |
507 | // Check vacancy. |
508 | if (visited_.count(UN)) { |
509 | continue; |
510 | } |
511 | st.push(UN); |
512 | continueEarly = true; |
513 | } |
514 | if (continueEarly) { |
515 | continue; |
516 | } |
517 | |
518 | // Visit current node, if it's still vacant. |
519 | if (!visited_.count(N)) { |
520 | visited_.insert(N); |
521 | reverseOrder_.push_back(N); |
522 | } |
523 | |
524 | // And then visit inputs of the current node. |
525 | for (unsigned b = 0, e = N->getNumInputs(); b < e; ++b) { |
526 | auto *UN = N->getNthInput(b).getNode(); |
527 | if (visited_.count(UN)) { |
528 | continue; |
529 | } |
530 | st.push(UN); |
531 | } |
532 | |
533 | // Additionally visit the predicate input if it exists. |
534 | if (N->hasPredicate()) { |
535 | auto *UN = N->getPredicate().getNode(); |
536 | if (!visited_.count(UN)) { |
537 | st.push(UN); |
538 | } |
539 | } |
540 | } |
541 | } |
542 | |
543 | public: |
544 | explicit ReverseGraphWalker(Function &F) { visit(F); } |
545 | |
546 | llvm::ArrayRef<const Node *> getNodes() const { return reverseOrder_; } |
547 | }; |
548 | |
549 | template <typename T> |
550 | static void addAttrToDocString(T *proto, const std::string &attrName, |
551 | llvm::StringRef attrVal) { |
552 | *(proto->mutable_doc_string()) += std::string(1, startChar) + attrName + |
553 | std::string(1, sepChar) + attrVal.str(); |
554 | } |
555 | |
556 | } // namespace |
557 | |
558 | Error ONNXModelWriter::insertLoaderNameUniqueOffsetMetadata( |
559 | llvm::StringMap<std::string> &, |
560 | const OriginNameToTQPMap &map) { |
561 | RETURN_ERR_IF_NOT(!extraMetadataProps.count("OriginNameToTQPMap" ), |
562 | "Already had OriginNameToTQPMap" ); |
563 | std::string str; |
564 | for (const auto &nameTQP : map) { |
565 | str += nameTQP.first + offsetSepSig + |
566 | std::to_string(nameTQP.second.offset) + offsetEndSig; |
567 | } |
568 | extraMetadataProps.try_emplace(originNameToUniqueOffsetMappingSignifier, str); |
569 | return Error::success(); |
570 | } |
571 | |
572 | bool ONNXModelWriter::isIntermediatePHForDAG(const Placeholder *PH) { |
573 | if (!dagMode_) { |
574 | return false; |
575 | } |
576 | |
577 | bool isInputPH = false, isOutputPH = false; |
578 | for (const auto &use : PH->getUsers()) { |
579 | const auto *userN = use.getUser(); |
580 | // Only consider users from Functions in the DAG. |
581 | const Function *userF = userN->getParent(); |
582 | if (!functionsFromDAG_.count(userF)) { |
583 | continue; |
584 | } |
585 | const bool currIsInputPH = isInput(PH, *userF); |
586 | const bool currIsOutputPH = isOutput(PH, *userF); |
587 | // Check this is not quantization profiling or training cases. |
588 | assert( |
589 | !(currIsInputPH && currIsOutputPH) && |
590 | "Do not support PHs that are input and output to a single Function." ); |
591 | if (currIsInputPH) { |
592 | isInputPH = true; |
593 | } |
594 | if (currIsOutputPH) { |
595 | isOutputPH = true; |
596 | } |
597 | } |
598 | |
599 | // If the PH is both an input and an output for the Functions in the DAG then |
600 | // it must be an intermediate. |
601 | return isInputPH && isOutputPH; |
602 | } |
603 | |
604 | /// Reverses the order of the nodes in \p nodes. |
605 | static void |
606 | reverseNodesOrder(RepeatedPtrField<ONNX_NAMESPACE::NodeProto> &nodes) { |
607 | for (size_t i = 0, n = nodes.size(); i < n / 2; ++i) { |
608 | nodes.SwapElements(i, n - i - 1); |
609 | } |
610 | } |
611 | |
612 | bool ONNXModelWriter::isWritingConstFoldSubgraph() { |
613 | return graphProtoRoot_ != graphProto_; |
614 | } |
615 | |
616 | Error ONNXModelWriter::writeConstantFoldingSubgraph(const Constant *C, |
617 | SaveNode *SN) { |
618 | Function *constFoldFunction = SN->getParent(); |
619 | |
620 | // If we already wrote out this Function we can return early. |
621 | if (!processedConstFoldFunctions_.insert(constFoldFunction).second) { |
622 | return Error::success(); |
623 | } |
624 | |
625 | // Create new constant folding Node, which we add the subgraph to. Always add |
626 | // to the root as these are all loaded before loading any ops. |
627 | auto *constFoldNodeProto = graphProtoRoot_->add_node(); |
628 | constFoldNodeProto->set_op_type(constFoldSubgraphNodeName); |
629 | const char *constFoldNodeName = constFoldFunction->getName().data(); |
630 | constFoldNodeProto->set_name(constFoldNodeName); |
631 | |
632 | // Now add the constant folding subgraph to the node. |
633 | auto *constFoldNodeSubgraph = constFoldNodeProto->add_attribute(); |
634 | constFoldNodeSubgraph->set_name("ConstFoldSubgraph" ); |
635 | constFoldNodeSubgraph->set_type(ONNX_NAMESPACE::AttributeProto::GRAPH); |
636 | |
637 | // Temporarily swap in the constant folding function and the graph proto from |
638 | // the constant folding subgraph node so we can write into it. |
639 | ONNX_TRAITS::GraphProto *origGraphProto = graphProto_; |
640 | graphProto_ = constFoldNodeSubgraph->mutable_g(); |
641 | Function *origF = F_; |
642 | F_ = constFoldFunction; |
643 | // Make sure to restore original state of the writer when exiting this scope. |
644 | ScopeGuard restoreOrigStateGuard([&]() { |
645 | graphProto_ = origGraphProto; |
646 | F_ = origF; |
647 | }); |
648 | |
649 | // Now that we have setup the constant folding Function and proto to write |
650 | // into, write the function in. |
651 | RETURN_IF_ERR(writeFunction()); |
652 | |
653 | // Now set the output of the ConstFoldSubgraph node. Only output is the |
654 | // Constant it generates. Note that there are no inputs, as Constant inputs |
655 | // are self-contained to this graph as initializers. |
656 | constFoldNodeProto->add_output(C->getName().data()); |
657 | addTypeAttributes(constFoldNodeProto, SN->getOutput(), SaveNode::OutputIdx, |
658 | /* isInput */ false); |
659 | |
660 | // Reverse the order of Nodes since we wrote them in reverse order. |
661 | reverseNodesOrder(*graphProto_->mutable_node()); |
662 | |
663 | return Error::success(); |
664 | } |
665 | |
666 | Error ONNXModelWriter::writeFunction() { |
667 | // Use pre order graph traversal. |
668 | // If node is constant with uses, turned into "input" and create a tensor |
669 | // If node is placeholder with uses, turned into "input" with tensor shape, |
670 | // except the case when placeholder has use as SaveNode. |
671 | // Otherwise call common operators method or special operators and write |
672 | // protobuf inputs from node inputs and protobuf outputs from uses. |
673 | |
674 | ReverseGraphWalker visitor(*F_); |
675 | for (const auto *N : visitor.getNodes()) { |
676 | if (reportedNodes_.count(N)) { |
677 | continue; |
678 | } |
679 | |
680 | const auto kind = N->getKind(); |
681 | // Handle placeholders cases. |
682 | if (kind == Kinded::Kind::PlaceholderKind) { |
683 | const auto *PH = llvm::cast<Placeholder>(N); |
684 | if (!isInput(PH, *F_)) { |
685 | // Storage as an input to SaveNode - ignore it. |
686 | continue; |
687 | } |
688 | // Skip Placeholders that are only intermediates -- these are |
689 | // understood/recreated during reimporting based on an op's partition. |
690 | if (isIntermediatePHForDAG(PH)) { |
691 | continue; |
692 | } |
693 | // Write global input, output only tensor shape. |
694 | RETURN_IF_EXPECTED_IS_ERR(createProtoForIO(PH, /* isInput */ true)); |
695 | } else if (kind == Kinded::Kind::ConstantKind) { |
696 | // Write global initializer, output tensor bytes. |
697 | const auto *C = llvm::cast<Constant>(N); |
698 | |
699 | // Check if this constant came from constant folding that we recorded and |
700 | // want to serialize in the model. If so then we process it before the |
701 | // Constant itself so that it will be loaded first. |
702 | auto constFoldRecordIt = constFoldRecord_.find(C); |
703 | if (constFoldRecordIt != constFoldRecord_.end()) { |
704 | RETURN_IF_ERR( |
705 | writeConstantFoldingSubgraph(C, constFoldRecordIt->second)); |
706 | } |
707 | |
708 | // Note: Always add initializers to the root graph proto. |
709 | auto *tensorProto = addInitializer(*graphProtoRoot_); |
710 | |
711 | // If we added a constant folding Node recording the constant folding that |
712 | // generated this initializer then point to it in the initializer. |
713 | if (constFoldRecordIt != constFoldRecord_.end()) { |
714 | SaveNode *SN = constFoldRecordIt->second; |
715 | // Point the original initializerProto to this Node so that it knows |
716 | // where to find the Function to replay the constant folding, along with |
717 | // the resNo needed from the Function. |
718 | auto *constFoldNodeNameProto = tensorProto->add_external_data(); |
719 | constFoldNodeNameProto->set_key("ConstFoldNodeName" ); |
720 | constFoldNodeNameProto->set_value(SN->getParent()->getName().data()); |
721 | auto *resNoProto = tensorProto->add_external_data(); |
722 | resNoProto->set_key("ConstFoldResNo" ); |
723 | resNoProto->set_value(std::to_string(SN->getInput().getResNo())); |
724 | } |
725 | |
726 | // When using useGlowCustomOps we always use generateNodeOutputName for |
727 | // all inputs and outputs. |
728 | tensorProto->set_name(useGlowCustomOps_ |
729 | ? C->getOutput().generateNodeOutputName( |
730 | /* stripResNoFor0thInput */ true) |
731 | : C->getName().str()); |
732 | writeTensor(C->getPayload(), tensorProto, useGlowCustomOps_, |
733 | includeConstantData_); |
734 | if (useGlowCustomOps_) { |
735 | // Also include the layout in the initializer to be loaded later. |
736 | addAttrToDocString(tensorProto, layoutSignifier, C->getLayout()); |
737 | } |
738 | } else if (kind == Kinded::Kind::SaveNodeKind) { |
739 | // Save node case, find input and use its name as a global output, |
740 | // output only shape. |
741 | const SaveNode *SN = llvm::cast<SaveNode>(N); |
742 | const auto *PH = SN->getPlaceholder(); |
743 | |
744 | // If useGlowCustomOps then we need to add an Identity to map the name |
745 | // from generateNodeOutputName() to the name of the Placeholder. |
746 | if (useGlowCustomOps_) { |
747 | auto *proto = graphProto_->add_node(); |
748 | proto->set_op_type("Identity" ); |
749 | proto->set_name(SN->getName().data()); |
750 | proto->add_input(SN->getInput().generateNodeOutputName( |
751 | /* stripResNoFor0thInput */ true)); |
752 | proto->add_output(PH->getName().data()); |
753 | addTypeAttributes(proto, SN->getInput(), SaveNode::InputIdx, |
754 | /* isInput */ true); |
755 | addTypeAttributes(proto, SN->getOutput(), SaveNode::OutputIdx, |
756 | /* isInput */ false); |
757 | // If dumping a DAG then add partition names to each op that's written. |
758 | // Also only do so when not writing a const fold subgraph. |
759 | if (dagMode_ && !isWritingConstFoldSubgraph()) { |
760 | addValueAttribute(proto, "partitionName" , F_->getName().str()); |
761 | // Skip writing Placeholders that are only intermediates -- these are |
762 | // understood/recreated during reimporting based on an op's partition. |
763 | if (isIntermediatePHForDAG(PH)) { |
764 | addValueAttribute(proto, "isIntermediateOutputForDAG" , true); |
765 | continue; |
766 | } |
767 | } |
768 | } |
769 | |
770 | ONNX_NAMESPACE::ValueInfoProto *out; |
771 | ASSIGN_VALUE_OR_RETURN_ERR(out, |
772 | createProtoForIO(PH, /* isInput */ false)); |
773 | |
774 | // Use the doc string to specify the name that should be used for the |
775 | // SaveNode to ensure it's kept the same between export and import. |
776 | addAttrToDocString(out, saveNameSignifier, SN->getName()); |
777 | } else if (useGlowCustomOps_) { |
778 | RETURN_IF_ERR(writeGlowCustomOperator(N, *graphProto_)); |
779 | } else { |
780 | RETURN_IF_ERR(writeOperator(N, *graphProto_)); |
781 | } |
782 | reportedNodes_.insert(N); |
783 | } |
784 | |
785 | return Error::success(); |
786 | } |
787 | |
788 | void ONNXModelWriter::setupNewProto() { |
789 | modelProto_.set_ir_version(irVersion_); |
790 | modelProto_.set_producer_name(useGlowCustomOps_ ? "GlowONNXModelWriter" |
791 | : "ONNXModelWriter" ); |
792 | auto *opsetImportProto = modelProto_.add_opset_import(); |
793 | opsetImportProto->set_version(opsetVersion_); |
794 | graphProto_ = modelProto_.mutable_graph(); |
795 | graphProto_->set_name("glow" ); |
796 | graphProtoRoot_ = graphProto_; |
797 | } |
798 | |
799 | static Error writeModelToString(const ::google::protobuf::Message &modelProto, |
800 | bool textMode, std::string *outputStringPtr) { |
801 | if (textMode) { |
802 | RETURN_ERR_IF_NOT(google::protobuf::TextFormat::PrintToString( |
803 | modelProto, outputStringPtr), |
804 | "Error writing to string" ); |
805 | } else { |
806 | ::google::protobuf::io::StringOutputStream zeroCopyOutput(outputStringPtr); |
807 | ::google::protobuf::io::CodedOutputStream codedOutput(&zeroCopyOutput); |
808 | modelProto.SerializeToCodedStream(&codedOutput); |
809 | RETURN_ERR_IF_NOT(!codedOutput.HadError(), |
810 | "Can't write to the output string" , |
811 | ErrorValue::ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR); |
812 | } |
813 | return Error::success(); |
814 | } |
815 | |
816 | Error ONNXModelWriter::finalizeAndWriteProto(llvm::StringRef name) { |
817 | // Nodes have been added in a reverse order from SaveNode up to the inputs, |
818 | // we need to rearrange all nodes in the reverse order before serialization. |
819 | auto *nodes = graphProto_->mutable_node(); |
820 | reverseNodesOrder(*nodes); |
821 | |
822 | // useGlowCustomOps uses Identities differently than the normal writer, so |
823 | // we do not want to do this if so. |
824 | if (!useGlowCustomOps_) { |
825 | // We need to swap back Identity node with the next non-Identity node |
826 | // since we append Identity node to tap out the intermediate results |
827 | for (int i = 0, n = nodes->size(); i < n - 1; ++i) { |
828 | if (nodes->Get(i).op_type() == "Identity" ) { |
829 | int k = 1; |
830 | while (i + k < n) { |
831 | if (nodes->Get(i + k).op_type() != "Identity" ) { |
832 | break; |
833 | } |
834 | ++k; |
835 | } |
836 | nodes->SwapElements(i, i + k); |
837 | i += k; |
838 | } |
839 | } |
840 | } else { |
841 | #ifdef FACEBOOK_INTERNAL |
842 | addMetadataProp("GlowRevisionHash" , revisionHash); |
843 | #endif /* FACEBOOK_INTERNAL */ |
844 | } |
845 | |
846 | // If we have loadedPHNames_, then we buffered the non-static PH IO protobuf |
847 | // in inputValueInfos_ and outputValueInfos_. Now we write it all out in order |
848 | // according to indices provided in loadedPHNames_. |
849 | if (loadedPHNames_) { |
850 | const bool ioNumMismatch = |
851 | (inputValueInfos_.size() + outputValueInfos_.size() != |
852 | loadedPHNames_->size()); |
853 | |
854 | // If total number of inputs and outputs doesn't match the number of |
855 | // placeholders, then log an error message and let the next for-loop find |
856 | // the culprits. |
857 | if (ioNumMismatch) { |
858 | LOG(ERROR) << "Number of buffered inputs and outputs " |
859 | << (inputValueInfos_.size() + outputValueInfos_.size()) |
860 | << " didn't match the number of loadedPHNames " |
861 | << loadedPHNames_->size(); |
862 | } |
863 | |
864 | // If we have the loaded PH names map, then we need to reorder the inputs |
865 | // and outputs to follow the same order as provided in the loadedPHNames_. |
866 | std::vector<const Placeholder *> orderedInputs(inputValueInfos_.size()); |
867 | std::vector<const Placeholder *> orderedOutputs(outputValueInfos_.size()); |
868 | for (const auto &pair : *loadedPHNames_) { |
869 | const Placeholder *PH = pair.first; |
870 | const unsigned orderIdx = pair.second.second; |
871 | if (inputValueInfos_.count(PH)) { |
872 | orderedInputs[orderIdx] = PH; |
873 | } else if (outputValueInfos_.count(PH)) { |
874 | orderedOutputs[orderIdx] = PH; |
875 | } else { |
876 | return MAKE_ERR("PH must either be in inputs or outputs: " + |
877 | PH->getName().str()); |
878 | } |
879 | } |
880 | |
881 | // If didn't find bad placeholders, then it must be some bad inputs/outputs. |
882 | if (ioNumMismatch) { |
883 | return MAKE_ERR("Found some inputs/outputs that don't have corresponding " |
884 | "placeholders" ); |
885 | } |
886 | |
887 | // Now have IO in order matching loadedPHNames_, so finally write them out. |
888 | for (const Placeholder *PH : orderedInputs) { |
889 | auto *inputProto = graphProto_->add_input(); |
890 | inputProto->MergeFrom(inputValueInfos_[PH]); |
891 | } |
892 | for (const Placeholder *PH : orderedOutputs) { |
893 | auto *outputProto = graphProto_->add_output(); |
894 | outputProto->MergeFrom(outputValueInfos_[PH]); |
895 | } |
896 | } |
897 | |
898 | if (zipMode_) { |
899 | RETURN_ERR_IF_NOT( |
900 | outputStringPtr_ == nullptr, |
901 | "OnnxModelWriter write to string for zip mode not supported" ); |
902 | const bool compressed = false; |
903 | ZipWriter zip(&ff_, name.str()); |
904 | std::stringstream ss; |
905 | ss << initializers_.size() << "\n" ; |
906 | zip.writeRecord("weights" , ss.str().c_str(), ss.str().size(), compressed); |
907 | std::string largeBuffer; |
908 | int i = 0; |
909 | // This part is probably quite inefficient as we are deserializing the |
910 | // protobuf to a char buffer and then put it to zip stream. I didn't dig |
911 | // enough to see if we can deserialize it into zip stream directly. |
912 | for (const auto &t : initializers_) { |
913 | std::stringstream nm; |
914 | nm << "weight_" << i++; |
915 | t.SerializeToString(&largeBuffer); |
916 | zip.writeRecord(nm.str(), largeBuffer.c_str(), largeBuffer.size(), |
917 | compressed); |
918 | } |
919 | if (textMode_) { |
920 | google::protobuf::TextFormat::PrintToString(modelProto_, &largeBuffer); |
921 | } else { |
922 | modelProto_.SerializeToString(&largeBuffer); |
923 | } |
924 | zip.writeRecord("model" , largeBuffer.c_str(), largeBuffer.size(), |
925 | compressed); |
926 | zip.writeEndOfFile(); |
927 | ff_.flush(); |
928 | ff_.close(); |
929 | return Error::success(); |
930 | } else { |
931 | if (outputStringPtr_ != nullptr) { |
932 | return writeModelToString(modelProto_, textMode_, outputStringPtr_); |
933 | } else { |
934 | return writeModel(modelProto_, textMode_); |
935 | } |
936 | } |
937 | } |
938 | |
939 | ONNXModelWriter::ONNXModelWriter( |
940 | const std::string &modelFilename, Function &F, size_t irVersion, |
941 | size_t opsetVersion, Error *errPtr, bool textMode, bool zipMode, |
942 | bool useGlowCustomOps, bool includeConstantData, |
943 | const llvm::StringMap<std::string> &, |
944 | const ConstantFoldingRecordMap &constFoldRecord, |
945 | const BackendSpecificNodeInfo &backendSpecificNodeInfo, |
946 | std::string *outputStringPtr) |
947 | : CommonOperatorWriter(modelFilename, &F, errPtr, |
948 | outputStringPtr == nullptr), |
949 | irVersion_(irVersion), opsetVersion_(opsetVersion), zipMode_(zipMode), |
950 | textMode_(textMode), includeConstantData_(includeConstantData), |
951 | extraMetadataProps_(extraMetadataProps), |
952 | useGlowCustomOps_(useGlowCustomOps), dagMode_(false), |
953 | constFoldRecord_(constFoldRecord), |
954 | backendSpecificNodeInfo_(backendSpecificNodeInfo), |
955 | loadedPHNames_(nullptr), staticPlaceholderTypes_(nullptr), |
956 | outputStringPtr_(outputStringPtr) { |
957 | // If errPtr already contains an error then don't continue with constructor. |
958 | if (errPtr && *errPtr) { |
959 | return; |
960 | } |
961 | |
962 | // Lambda to setup the ONNXModelWriter and return any Errors that were raised. |
963 | auto setup = [&]() -> Error { |
964 | setupNewProto(); |
965 | for (auto &prop : extraMetadataProps_) { |
966 | addMetadataProp(prop.getKey().str(), prop.second); |
967 | } |
968 | |
969 | RETURN_IF_ERR(writeFunction()); |
970 | |
971 | return finalizeAndWriteProto(F_->getName()); |
972 | }; |
973 | |
974 | if (errPtr) { |
975 | *errPtr = setup(); |
976 | } else { |
977 | EXIT_ON_ERR(setup()); |
978 | } |
979 | } |
980 | |
981 | /// Collect nodes from the DAG from \p root in post order in \p postOrder. |
982 | /// Gather visited nodes in \p visited. |
983 | static void collectNodesPostOrder(const DAGNode *root, |
984 | std::unordered_set<const DAGNode *> &visited, |
985 | std::vector<const DAGNode *> &postOrder) { |
986 | if (root == nullptr) { |
987 | return; |
988 | } |
989 | visited.insert(root); |
990 | for (auto &c : root->children) { |
991 | if (visited.count(c) == 0) { |
992 | collectNodesPostOrder(c, visited, postOrder); |
993 | } |
994 | } |
995 | postOrder.push_back(root); |
996 | } |
997 | |
998 | void ONNXModelWriter::addMetadataProp(const std::string &key, |
999 | const std::string &val) { |
1000 | auto *prop = modelProto_.add_metadata_props(); |
1001 | prop->set_key(key); |
1002 | prop->set_value(val); |
1003 | } |
1004 | |
1005 | Error ONNXModelWriter::writePartitionAndMetadataProps( |
1006 | Module &mod, llvm::ArrayRef<const DAGNode *> postOrder) { |
1007 | // Add number of partitions to proto. |
1008 | addMetadataProp("numPartitions" , std::to_string(postOrder.size())); |
1009 | |
1010 | for (size_t i = 0, e = postOrder.size(); i < e; i++) { |
1011 | const auto *dagNode = postOrder[i]; |
1012 | F_ = mod.getFunction(dagNode->name); |
1013 | RETURN_ERR_IF_NOT(F_, "Function was not valid from DAGList" ); |
1014 | |
1015 | // Write the nodes of the Function. |
1016 | RETURN_IF_ERR(writeFunction()); |
1017 | |
1018 | // Add to the proto the partition name and related meta info: |
1019 | const std::string partIdPrefix = getPartitionIdPrefix(i); |
1020 | |
1021 | // name of partition: |
1022 | addMetadataProp(partIdPrefix + nameSignifier, dagNode->name); |
1023 | |
1024 | // logicalDevices of partition: |
1025 | addMetadataProp(partIdPrefix + numLogicalDevicesSignifier, |
1026 | std::to_string(dagNode->logicalDevices.size())); |
1027 | for (size_t j = 0, f = dagNode->logicalDevices.size(); j < f; j++) { |
1028 | addMetadataProp(partIdPrefix + getLogicalDeviceSignfier(j), |
1029 | std::to_string(dagNode->logicalDevices[j])); |
1030 | } |
1031 | |
1032 | // backendName of partition: |
1033 | addMetadataProp(partIdPrefix + backendNameSignifier, dagNode->backendName); |
1034 | |
1035 | // size of partition: |
1036 | addMetadataProp(partIdPrefix + sizeSignifier, |
1037 | std::to_string(dagNode->size)); |
1038 | |
1039 | // backendHints.executionUnits of partition: |
1040 | addMetadataProp(partIdPrefix + executionUnitsSignifier, |
1041 | std::to_string(dagNode->backendHints.executionUnits)); |
1042 | |
1043 | // backendHints.SRAMPrioritization of partition not supported: |
1044 | assert(dagNode->backendHints.SRAMPrioritization.size() == 0 && |
1045 | "Do not support SRAMPrioritization saving from DAGNode" ); |
1046 | |
1047 | // backendSpecificOpts of partition: |
1048 | addMetadataProp(partIdPrefix + numBackendSpecificOptsSignifier, |
1049 | std::to_string(dagNode->backendSpecificOpts.size())); |
1050 | size_t j = 0; |
1051 | for (const auto &keyVal : dagNode->backendSpecificOpts) { |
1052 | addMetadataProp(partIdPrefix + getBackendSpecificOptKeySignifier(j), |
1053 | keyVal.first); |
1054 | addMetadataProp(partIdPrefix + getBackendSpecificOptValSignifier(j), |
1055 | keyVal.second); |
1056 | j += 1; |
1057 | } |
1058 | |
1059 | // replicationCount of partition: |
1060 | addMetadataProp(partIdPrefix + replicationCountSignifier, |
1061 | std::to_string(dagNode->replicationCount)); |
1062 | } |
1063 | |
1064 | return Error::success(); |
1065 | } |
1066 | |
1067 | ONNXModelWriter::ONNXModelWriter( |
1068 | const std::string &modelFilename, DAGListTy &dagList, size_t irVersion, |
1069 | size_t opsetVersion, Error *errPtr, bool textMode, bool zipMode, |
1070 | bool includeConstantData, |
1071 | const llvm::StringMap<std::string> &, |
1072 | const ConstantFoldingRecordMap &constFoldRecord, |
1073 | const BackendSpecificNodeInfo &backendSpecificNodeInfo, |
1074 | const LoadedPlaceholderNameMap *loadedPHNames, |
1075 | const std::map<std::string, Type> *staticPlaceholderTypes, |
1076 | std::string *outputStringPtr) |
1077 | : CommonOperatorWriter(modelFilename, nullptr, errPtr, |
1078 | outputStringPtr == nullptr), |
1079 | irVersion_(irVersion), opsetVersion_(opsetVersion), zipMode_(zipMode), |
1080 | textMode_(textMode), includeConstantData_(includeConstantData), |
1081 | extraMetadataProps_(extraMetadataProps), useGlowCustomOps_(true), |
1082 | dagMode_(true), constFoldRecord_(constFoldRecord), |
1083 | backendSpecificNodeInfo_(backendSpecificNodeInfo), |
1084 | loadedPHNames_(loadedPHNames), |
1085 | staticPlaceholderTypes_(staticPlaceholderTypes), |
1086 | outputStringPtr_(outputStringPtr) { |
1087 | // If errPtr already contains an error then don't continue with constructor. |
1088 | if (errPtr && *errPtr) { |
1089 | return; |
1090 | } |
1091 | |
1092 | // Lambda to setup the ONNXModelWriter and return any Errors that were raised. |
1093 | auto setup = [&]() -> Error { |
1094 | setupNewProto(); |
1095 | for (auto &prop : extraMetadataProps_) { |
1096 | addMetadataProp(prop.getKey().str(), prop.second); |
1097 | } |
1098 | |
1099 | RETURN_ERR_IF_NOT(dagList.size() == 1, "Expect only one DAG." ); |
1100 | const auto &dag = *dagList.begin(); |
1101 | |
1102 | Module &mod = *dag.root->module; |
1103 | |
1104 | // Iterate over the DAG in post-order; Nodes per Function are written in |
1105 | // reverse order and reversed at the end, so this follows suit. |
1106 | std::unordered_set<const DAGNode *> visited; |
1107 | std::vector<const DAGNode *> postOrder; |
1108 | collectNodesPostOrder(dag.root.get(), visited, postOrder); |
1109 | // Remove the root node from the list as we don't care about it. |
1110 | postOrder.pop_back(); |
1111 | for (const DAGNode *dagNode : postOrder) { |
1112 | functionsFromDAG_.insert(mod.getFunction(dagNode->name)); |
1113 | } |
1114 | |
1115 | RETURN_IF_ERR(writePartitionAndMetadataProps(mod, postOrder)); |
1116 | |
1117 | return finalizeAndWriteProto(dag.root->name); |
1118 | }; |
1119 | |
1120 | if (errPtr) { |
1121 | *errPtr = setup(); |
1122 | } else { |
1123 | EXIT_ON_ERR(setup()); |
1124 | } |
1125 | } |
1126 | |
1127 | ONNXModelWriter::TensorType *ONNXModelWriter::addInitializer(GraphType &g) { |
1128 | if (zipMode_) { |
1129 | initializers_.emplace_back(); |
1130 | return &initializers_.back(); |
1131 | } else { |
1132 | return g.add_initializer(); |
1133 | } |
1134 | } |
1135 | |
1136 | ONNXModelWriter::TensorType::DataType |
1137 | ONNXModelWriter::convertType(const Type &glowType) { |
1138 | switch (glowType.getElementType()) { |
1139 | case ElemKind::FloatTy: |
1140 | return TensorType::FLOAT; |
1141 | case ElemKind::Float16Ty: |
1142 | return TensorType::FLOAT16; |
1143 | case ElemKind::BFloat16Ty: |
1144 | return TensorType::BFLOAT16; |
1145 | case ElemKind::Float64Ty: |
1146 | return TensorType::DOUBLE; |
1147 | case ElemKind::Int8QTy: |
1148 | return TensorType::INT8; |
1149 | case ElemKind::UInt8FusedQTy: |
1150 | case ElemKind::UInt8FusedFP16QTy: |
1151 | case ElemKind::UInt4FusedFP16QTy: |
1152 | case ElemKind::UInt4FusedQTy: |
1153 | case ElemKind::UInt8QTy: |
1154 | case ElemKind::UInt8ITy: |
1155 | return TensorType::UINT8; |
1156 | case ElemKind::Int16QTy: |
1157 | return TensorType::INT16; |
1158 | case ElemKind::Int32QTy: |
1159 | case ElemKind::Int32ITy: |
1160 | return TensorType::INT32; |
1161 | case ElemKind::Int64QTy: |
1162 | case ElemKind::Int64ITy: |
1163 | return TensorType::INT64; |
1164 | case ElemKind::BoolTy: |
1165 | return TensorType::BOOL; |
1166 | } |
1167 | LOG(DFATAL) << "Cannot reach here." ; |
1168 | return TensorType::UNDEFINED; // Avoids a compilation warning. |
1169 | } |
1170 | |
1171 | /// Add quantization parameters to the doc_string in \p out based on \p type. |
1172 | template <typename T> |
1173 | static void addQuantParamsToDocString(T *out, const Type &type) { |
1174 | addAttrToDocString(out, qScaleSignifier, |
1175 | strFormat("%.*f" , NUM_FLOAT_DIGS, type.getScale())); |
1176 | addAttrToDocString(out, qOffsetSignifier, std::to_string(type.getOffset())); |
1177 | } |
1178 | |
1179 | /// Add strides to the doc_string in \p out based on \p type. |
1180 | template <typename T> |
1181 | static void addStridesToDocString(T *out, const Type &type) { |
1182 | // Non-standard strides need to be serialized. |
1183 | if (type.hasStandardStrides()) { |
1184 | return; |
1185 | } |
1186 | const auto &strides = type.strides(); |
1187 | std::string stridesStr; |
1188 | std::string delim; |
1189 | for (const auto &stride : strides) { |
1190 | stridesStr.append(delim); |
1191 | stridesStr.append(std::to_string(stride)); |
1192 | delim = "," ; |
1193 | } |
1194 | addAttrToDocString(out, stridesSignifier, stridesStr); |
1195 | } |
1196 | |
1197 | void ONNXModelWriter::writeTensor(const Tensor &T, TensorType *out, |
1198 | bool useGlowCustomOps, bool includeData) { |
1199 | const auto &type = T.getType(); |
1200 | out->set_data_type(convertType(type)); |
1201 | const auto &dims = type.dims(); |
1202 | for (unsigned b = 0, e = dims.size(); b < e; ++b) { |
1203 | out->add_dims(dims[b]); |
1204 | } |
1205 | |
1206 | if (includeData) { |
1207 | out->set_raw_data(T.getUnsafePtr(), type.getSizeInBytes()); |
1208 | } |
1209 | |
1210 | if (useGlowCustomOps) { |
1211 | addAttrToDocString(out, elemKindSignifier, type.getElementName()); |
1212 | addStridesToDocString(out, type); |
1213 | } |
1214 | |
1215 | if (type.isQuantizedType()) { |
1216 | if (useGlowCustomOps) { |
1217 | addQuantParamsToDocString(out, type); |
1218 | } else { |
1219 | // Format is ElemKind:scale:offset. |
1220 | out->set_doc_string(strFormat("%s:%.*f:%d" , type.getElementName().data(), |
1221 | NUM_FLOAT_DIGS, T.getType().getScale(), |
1222 | T.getType().getOffset())); |
1223 | } |
1224 | } |
1225 | } |
1226 | |
1227 | Expected<ONNX_NAMESPACE::ValueInfoProto *> |
1228 | ONNXModelWriter::createProtoForIO(const Placeholder *PH, bool isInput) { |
1229 | // If loadedPHNames_ then we have a specific order we need to write out IO |
1230 | // protos. If so, buffer non-static IO that is not part of a constant folding |
1231 | // subgraph into inputValueInfos_/outputValueInfos_ to later be written out in |
1232 | // order inside finalizeAndWriteProto() based on loadedPHNames_. |
1233 | ONNX_NAMESPACE::ValueInfoProto *valueProto; |
1234 | if (!loadedPHNames_ || isWritingConstFoldSubgraph() || PH->isStatic()) { |
1235 | valueProto = isInput ? graphProto_->add_input() : graphProto_->add_output(); |
1236 | } else { |
1237 | valueProto = isInput ? &inputValueInfos_[PH] : &outputValueInfos_[PH]; |
1238 | } |
1239 | |
1240 | tensorShapeFromInput(PH->getName().str(), PH->getType(), valueProto); |
1241 | |
1242 | if (useGlowCustomOps_) { |
1243 | // Write out any meta information we need to for the Placeholder. |
1244 | addStridesToDocString(valueProto, *PH->getType()); |
1245 | addAttrToDocString(valueProto, staticSignifier, |
1246 | std::to_string(PH->isStatic())); |
1247 | addAttrToDocString(valueProto, trainableSignifier, |
1248 | std::to_string(PH->isTraining())); |
1249 | addAttrToDocString(valueProto, layoutSignifier, PH->getLayout()); |
1250 | addAttrToDocString(valueProto, elemKindSignifier, |
1251 | PH->getType()->getElementName()); |
1252 | |
1253 | // If we're writing out a Placeholder from the original input Function, then |
1254 | // expect to find a corresponding input loaded PH name if they are |
1255 | // provided. This is expected when the PH is not static (as otherwise it's |
1256 | // input as a weight), when the Function being written isn't a constant |
1257 | // folding subgraph (then we have PHs that are used just to save the const |
1258 | // folding result), and when the PH isn't intermediate (then it's only |
1259 | // visible/used by Glow when executing partitioned DAGs). |
1260 | if (loadedPHNames_ && !PH->isStatic() && !isWritingConstFoldSubgraph() && |
1261 | !isIntermediatePHForDAG(PH)) { |
1262 | auto it = loadedPHNames_->find(PH); |
1263 | RETURN_ERR_IF_NOT(it != loadedPHNames_->end(), |
1264 | "Did not find associated loader name for " + |
1265 | PH->getName().str() + " while writing Function " + |
1266 | F_->getName().str()); |
1267 | addAttrToDocString(valueProto, loaderNameSignifier, it->second.first); |
1268 | } |
1269 | |
1270 | // If we have a type that was used for loading a static Placeholder, then |
1271 | // serialize that type into a dummy node. |
1272 | if (staticPlaceholderTypes_ && PH->isStatic()) { |
1273 | auto it = staticPlaceholderTypes_->find(PH->getName().data()); |
1274 | RETURN_ERR_IF_NOT(it != staticPlaceholderTypes_->end(), |
1275 | "Did not find associated type for static PH " + |
1276 | PH->getName().str() + " while writing Function " + |
1277 | F_->getName().str()); |
1278 | |
1279 | // Create new static PH dummy node that carries the type that the static |
1280 | // PH was loaded with. Note it has no inputs or outputs, howeverr there is |
1281 | // a type appended for the output idx, and the node has the same name as |
1282 | // the static PH to use when reloading. |
1283 | auto *staticPHDummyNodeProto = graphProto_->add_node(); |
1284 | staticPHDummyNodeProto->set_op_type(staticPHDummyNodeName); |
1285 | staticPHDummyNodeProto->set_name(PH->getName().data()); |
1286 | |
1287 | // Set the output type to be the one we found in staticPlaceholderTypes_. |
1288 | addTypeAttributes(staticPHDummyNodeProto, &it->second, Storage::OutputIdx, |
1289 | /* isInput */ false); |
1290 | } |
1291 | |
1292 | // Also include quantization params if necessary. |
1293 | if (PH->getType()->isQuantizedType()) { |
1294 | addQuantParamsToDocString(valueProto, *PH->getType()); |
1295 | } |
1296 | } |
1297 | return valueProto; |
1298 | } |
1299 | |
1300 | Error ONNXModelWriter::writeAllWithNode(const std::string &opName, |
1301 | const Node *node, GraphType &graph, |
1302 | NodeType *proto) { |
1303 | proto->set_name(node->getName().str()); |
1304 | proto->set_op_type(opName); |
1305 | inputsToProto(node, proto); |
1306 | outputsToProto(node, graph, proto); |
1307 | return Error::success(); |
1308 | } |
1309 | |
1310 | Error ONNXModelWriter::writeAll(const std::string &opName, const Node *node, |
1311 | GraphType &graph) { |
1312 | return writeAllWithNode(opName, node, graph, graph.add_node()); |
1313 | } |
1314 | |
1315 | //===-----------------------------------------------------------------===// |
1316 | // Operators Supported by ONNX |
1317 | //===-----------------------------------------------------------------===// |
1318 | Error ONNXModelWriter::writePad(const PadNode *node, GraphType &graph) { |
1319 | auto *proto = graph.add_node(); |
1320 | // Add dictionary entries. |
1321 | switch (node->getMode()) { |
1322 | case PaddingMode::CONSTANT: |
1323 | addValueAttribute(proto, "mode" , std::string("constant" )); |
1324 | break; |
1325 | case PaddingMode::REFLECT: |
1326 | addValueAttribute(proto, "mode" , std::string("reflect" )); |
1327 | break; |
1328 | case PaddingMode::EDGE: |
1329 | addValueAttribute(proto, "mode" , std::string("edge" )); |
1330 | break; |
1331 | default: |
1332 | return MAKE_ERR("Pad: Invalid mode" , |
1333 | ErrorValue::ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR); |
1334 | } |
1335 | |
1336 | if (opsetVersion_ <= 10) { |
1337 | addValueAttribute(proto, "pads" , node->getPads()); |
1338 | float value = node->getValue(); |
1339 | if (value != .0f) { |
1340 | addValueAttribute(proto, "value" , value); |
1341 | } |
1342 | return writeAllWithNode("Pad" , node, graph, proto); |
1343 | } else { |
1344 | proto->set_name(node->getName().str()); |
1345 | proto->set_op_type("Pad" ); |
1346 | // Input for data. |
1347 | inputsToProto(node, proto); |
1348 | |
1349 | // Input for pads. |
1350 | auto pads = node->getPads(); |
1351 | Tensor oneDimTensorPads(ElemKind::Int64ITy, {(dim_t)pads.size()}); |
1352 | auto oneDimTensorPadsH = oneDimTensorPads.getHandle<int64_t>(); |
1353 | for (size_t b = 0, e = oneDimTensorPads.size(); b < e; ++b) { |
1354 | oneDimTensorPadsH.raw(b) = pads[b]; |
1355 | } |
1356 | auto *tensorProto = addInitializer(graph); |
1357 | tensorProto->set_name(node->getName().str() + "_pads" ); |
1358 | writeTensor(oneDimTensorPads, tensorProto); |
1359 | proto->add_input(node->getName().str() + "_pads" ); |
1360 | |
1361 | // Input for value. |
1362 | Tensor value(ElemKind::FloatTy, {1}); |
1363 | auto valueH = value.getHandle(); |
1364 | valueH.raw(0) = node->getValue(); |
1365 | tensorProto = addInitializer(graph); |
1366 | tensorProto->set_name(node->getName().str() + "_value" ); |
1367 | writeTensor(value, tensorProto); |
1368 | proto->add_input(node->getName().str() + "_value" ); |
1369 | // Output |
1370 | outputsToProto(node, graph, proto); |
1371 | return Error::success(); |
1372 | } |
1373 | } |
1374 | |
1375 | Error ONNXModelWriter::writeConcat(const ConcatNode *node, GraphType &graph) { |
1376 | auto *proto = graph.add_node(); |
1377 | // Add dictionary entries. |
1378 | addValueAttribute(proto, "axis" , node->getDim()); |
1379 | |
1380 | return writeAllWithNode("Concat" , node, graph, proto); |
1381 | } |
1382 | |
1383 | Error ONNXModelWriter::writeTranspose(const TransposeNode *node, |
1384 | GraphType &graph) { |
1385 | // Some nodes create transpose for outputs. |
1386 | auto *input = node->getInput().getNode(); |
1387 | if (llvm::dyn_cast<ConvolutionNode>(input) || |
1388 | llvm::dyn_cast<Convolution3DNode>(input) || |
1389 | llvm::dyn_cast<AvgPoolNode>(input) || |
1390 | llvm::dyn_cast<MaxPoolNode>(input) || |
1391 | llvm::dyn_cast<SpaceToDepthNode>(input)) { |
1392 | return Error::success(); |
1393 | } |
1394 | |
1395 | auto *proto = graph.add_node(); |
1396 | // Add dictionary entries. |
1397 | addValueAttribute(proto, "perm" , node->getShuffle()); |
1398 | |
1399 | return writeAllWithNode("Transpose" , node, graph, proto); |
1400 | } |
1401 | |
1402 | Error ONNXModelWriter::writeCollectRpnProposals( |
1403 | const CollectRpnProposalsNode *node, GraphType &graph) { |
1404 | return writeAllWithNode("CollectRpnProposals" , node, graph, graph.add_node()); |
1405 | } |
1406 | |
1407 | Error ONNXModelWriter::writeFlip(const FlipNode *node, GraphType &graph) { |
1408 | auto *proto = graph.add_node(); |
1409 | // Add dictionary entries. |
1410 | addValueAttribute(proto, "axis" , node->getAxis()); |
1411 | |
1412 | return writeAllWithNode("Flip" , node, graph, proto); |
1413 | } |
1414 | |
1415 | Error ONNXModelWriter::writeAudioSpectrogram(const AudioSpectrogramNode *node, |
1416 | GraphType &graph) { |
1417 | auto *proto = graph.add_node(); |
1418 | |
1419 | addValueAttribute(proto, "window_size" , node->getWindowSize()); |
1420 | addValueAttribute(proto, "stride" , node->getWindowStride()); |
1421 | addValueAttribute(proto, "magnitude_squared" , node->getMagnitudeSquared()); |
1422 | |
1423 | return writeAllWithNode("AudioSpectrogram" , node, graph, proto); |
1424 | } |
1425 | |
1426 | Error ONNXModelWriter::writeMFCC(const MFCCNode *node, GraphType &graph) { |
1427 | auto *proto = graph.add_node(); |
1428 | |
1429 | addValueAttribute(proto, "sample_rate" , node->getSampleRate()); |
1430 | addValueAttribute(proto, "lower_frequency_limit" , node->getLowerFrequency()); |
1431 | addValueAttribute(proto, "upper_frequency_limit" , node->getUpperFrequency()); |
1432 | addValueAttribute(proto, "filterbank_channel_count" , |
1433 | node->getFilterBankCount()); |
1434 | addValueAttribute(proto, "dct_coefficient_count" , node->getNumCoefficients()); |
1435 | |
1436 | return writeAllWithNode("MFCC" , node, graph, proto); |
1437 | } |
1438 | |
1439 | Error ONNXModelWriter::writeROIAlign(const ROIAlignNode *node, |
1440 | GraphType &graph) { |
1441 | auto *proto = graph.add_node(); |
1442 | switch (node->getMode()) { |
1443 | case PoolingMode::AVG: |
1444 | addValueAttribute(proto, "mode" , std::string("avg" )); |
1445 | break; |
1446 | case PoolingMode::MAX: |
1447 | addValueAttribute(proto, "mode" , std::string("max" )); |
1448 | break; |
1449 | } |
1450 | addValueAttribute(proto, "output_height" , node->getOutputHeight()); |
1451 | addValueAttribute(proto, "output_width" , node->getOutputWidth()); |
1452 | addValueAttribute(proto, "sampling_ratio" , node->getSamplingRatio()); |
1453 | addValueAttribute(proto, "spatial_scale" , node->getSpatialScale()); |
1454 | addValueAttribute(proto, "aligned" , node->getAligned()); |
1455 | addValueAttribute(proto, "rotated" , node->getRotated()); |
1456 | return writeAllWithNode("ROIAlign" , node, graph, proto); |
1457 | } |
1458 | |
1459 | Error ONNXModelWriter::writeBBoxTransform(const BBoxTransformNode *node, |
1460 | GraphType &graph) { |
1461 | auto *proto = graph.add_node(); |
1462 | addValueAttribute(proto, "ApplyScale" , node->getApplyScale()); |
1463 | addValueAttribute(proto, "Rotated" , node->getRotated()); |
1464 | addValueAttribute(proto, "AngleBoundOn" , node->getAngleBoundOn()); |
1465 | addValueAttribute(proto, "AngleBoundLo" , node->getAngleBoundLo()); |
1466 | addValueAttribute(proto, "AngleBoundHi" , node->getAngleBoundHi()); |
1467 | addValueAttribute(proto, "ClipAngleThresh" , node->getClipAngleThresh()); |
1468 | return writeAllWithNode("BBoxTransform" , node, graph, proto); |
1469 | } |
1470 | |
1471 | Error ONNXModelWriter::writeConvolution(const ConvolutionNode *node, |
1472 | GraphType &graph) { |
1473 | // Loading convolution creates a sandwich with Transpose nodes for Input, |
1474 | // Weights, and Result. The lowering algorithm can remove Transpose nodes and |
1475 | // replace one set of nodes with another ones. When saving a graph to ONNX |
1476 | // format, keep in mind that when it will be loaded again a Transpose nodes |
1477 | // sandwich will be created again. The steps will be: |
1478 | // Remove Transpose nodes for Input and Weights, if such Transpose are not |
1479 | // found (they are supposed to be NCHW2NHWC then create a "mirror" |
1480 | // Transpose, i.e. NHWC2NCHW for correspondent Input or/and Weights. |
1481 | // The similar algorithm will be applied for Result. If Transpose NHWC2NCHW |
1482 | // node is found for Result user then remove it, otherwise create a "mirror" |
1483 | // Transpose, i.e. NCHW2NHWC. |
1484 | assert(node->getLayout() == NHWC && "can only write NHWC Convolutions" ); |
1485 | |
1486 | // Delegate writing quantized Convs to writeTensorwiseQuantizedConvolution. |
1487 | if (isQuantizedElemKind(node->getInput().getElementType())) { |
1488 | return writeTensorwiseQuantizedConvolution(node, graph); |
1489 | } |
1490 | |
1491 | auto *proto = graph.add_node(); |
1492 | |
1493 | // Use the output of transpose node. |
1494 | if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) { |
1495 | // Apparently Result Transpose has been removed, add NCHW2NHWC Transpose. |
1496 | writeTransposeResult(node, proto, graph); |
1497 | } |
1498 | |
1499 | // Add dictionary entries. |
1500 | addValueAttribute(proto, "strides" , node->getStrides()); |
1501 | addValueAttribute(proto, "pads" , node->getPads()); |
1502 | addValueAttribute(proto, "group" , node->getGroup()); |
1503 | addValueAttribute(proto, "dilations" , node->getDilation()); |
1504 | |
1505 | const Node *input = node->getInput().getNode(); |
1506 | if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(input)) { |
1507 | proto->add_input(TN->getInput().getNode()->getName().str()); |
1508 | reportedNodes_.insert(TN); |
1509 | } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(input)) { |
1510 | proto->add_input(RSN->getInput().getNode()->getName().str()); |
1511 | reportedNodes_.insert(RSN); |
1512 | } else { |
1513 | writeTransposeInput(node, input, proto, graph); |
1514 | } |
1515 | |
1516 | const Node *filter = node->getFilter().getNode(); |
1517 | if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(filter)) { |
1518 | proto->add_input(TN->getInput().getNode()->getName().str()); |
1519 | reportedNodes_.insert(TN); |
1520 | } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(filter)) { |
1521 | proto->add_input(RSN->getInput().getNode()->getName().str()); |
1522 | reportedNodes_.insert(RSN); |
1523 | } else { |
1524 | writeTransposeInput(node, filter, proto, graph); |
1525 | } |
1526 | |
1527 | proto->add_input(node->getBias().getNode()->getName().str()); |
1528 | |
1529 | proto->set_name(node->getName().str()); |
1530 | proto->set_op_type("Conv" ); |
1531 | |
1532 | return Error::success(); |
1533 | } |
1534 | |
1535 | Error ONNXModelWriter::writeTensorwiseQuantizedConvolution( |
1536 | const ConvolutionNode *node, GraphType &graph) { |
1537 | auto *proto = graph.add_node(); |
1538 | |
1539 | // Add dictionary entries. |
1540 | addValueAttribute(proto, "kernel_shape" , node->getKernels()); |
1541 | addValueAttribute(proto, "strides" , node->getStrides()); |
1542 | addValueAttribute(proto, "pads" , node->getPads()); |
1543 | addValueAttribute(proto, "group" , node->getGroup()); |
1544 | addValueAttribute(proto, "dilation" , node->getDilation()); |
1545 | |
1546 | addValueAttribute(proto, "out_scale" , |
1547 | node->getType(ConvolutionNode::ResultIdx)->getScale()); |
1548 | addValueAttribute(proto, "out_offset" , |
1549 | node->getType(ConvolutionNode::ResultIdx)->getOffset()); |
1550 | |
1551 | return writeAllWithNode("Conv" , node, graph, proto); |
1552 | } |
1553 | |
1554 | Error ONNXModelWriter::writeChannelwiseQuantizedConvolution( |
1555 | const ChannelwiseQuantizedConvolutionNode *node, GraphType &graph) { |
1556 | auto *proto = graph.add_node(); |
1557 | |
1558 | // Add dictionary entries. |
1559 | addValueAttribute(proto, "kernel_shape" , node->getKernels()); |
1560 | addValueAttribute(proto, "strides" , node->getStrides()); |
1561 | addValueAttribute(proto, "pads" , node->getPads()); |
1562 | addValueAttribute(proto, "group" , node->getGroup()); |
1563 | |
1564 | addValueAttribute( |
1565 | proto, "out_scale" , |
1566 | node->getType(ChannelwiseQuantizedConvolutionNode::ResultIdx) |
1567 | ->getScale()); |
1568 | addValueAttribute( |
1569 | proto, "out_offset" , |
1570 | node->getType(ChannelwiseQuantizedConvolutionNode::ResultIdx) |
1571 | ->getOffset()); |
1572 | |
1573 | return writeAllWithNode("ChannelwiseQuantizedConvolution" , node, graph, |
1574 | proto); |
1575 | } |
1576 | |
1577 | Error ONNXModelWriter::writeBatchedReduceMean(const BatchedReduceMeanNode *node, |
1578 | GraphType &graph) { |
1579 | auto *proto = graph.add_node(); |
1580 | // Add dictionary entries. |
1581 | addValueAttribute(proto, "axes" , node->getAxes()); |
1582 | |
1583 | proto->set_name(node->getName().str()); |
1584 | proto->set_op_type("ReduceMean" ); |
1585 | inputsToProto(node, proto); |
1586 | |
1587 | addValueAttribute(proto, "keepdims" , 0); |
1588 | outputsToProto(node, graph, proto); |
1589 | |
1590 | return Error::success(); |
1591 | } |
1592 | |
1593 | Error ONNXModelWriter::writeBatchedReduceAdd(const BatchedReduceAddNode *node, |
1594 | GraphType &graph) { |
1595 | auto *proto = graph.add_node(); |
1596 | // Add dictionary entries. |
1597 | unsigned_t axis = node->getAxis(); |
1598 | llvm::ArrayRef<unsigned_t> axes(axis); |
1599 | addValueAttribute(proto, "axes" , axes); |
1600 | |
1601 | proto->set_name(node->getName().str()); |
1602 | proto->set_op_type("ReduceSum" ); |
1603 | inputsToProto(node, proto); |
1604 | |
1605 | addValueAttribute(proto, "keepdims" , 0); |
1606 | outputsToProto(node, graph, proto); |
1607 | |
1608 | return Error::success(); |
1609 | } |
1610 | |
1611 | Error ONNXModelWriter::writeBatchedReduceSumSquare( |
1612 | const BatchedReduceSumSquareNode *node, GraphType &graph) { |
1613 | auto *proto = graph.add_node(); |
1614 | // Add dictionary entries. |
1615 | unsigned_t axis = node->getAxis(); |
1616 | llvm::ArrayRef<unsigned_t> axes(axis); |
1617 | addValueAttribute(proto, "axes" , axes); |
1618 | |
1619 | proto->set_name(node->getName().str()); |
1620 | proto->set_op_type("ReduceSum" ); |
1621 | inputsToProto(node, proto); |
1622 | |
1623 | addValueAttribute(proto, "keepdims" , 0); |
1624 | outputsToProto(node, graph, proto); |
1625 | |
1626 | return Error::success(); |
1627 | } |
1628 | |
1629 | Error ONNXModelWriter::writeBatchedReduceMax(const BatchedReduceMaxNode *node, |
1630 | GraphType &graph) { |
1631 | auto *proto = graph.add_node(); |
1632 | // Find dictionary entries. |
1633 | addValueAttribute(proto, "axes" , node->getAxes()); |
1634 | |
1635 | return writeAllWithNode("ReduceMax" , node, graph, proto); |
1636 | } |
1637 | |
1638 | Error ONNXModelWriter::writeBatchedReduceMin(const BatchedReduceMinNode *node, |
1639 | GraphType &graph) { |
1640 | auto *proto = graph.add_node(); |
1641 | // Find dictionary entries. |
1642 | addValueAttribute(proto, "axes" , node->getAxes()); |
1643 | |
1644 | return writeAllWithNode("ReduceMin" , node, graph, proto); |
1645 | } |
1646 | |
1647 | Error ONNXModelWriter::writeBatchedReduceProd(const BatchedReduceProdNode *node, |
1648 | GraphType &graph) { |
1649 | auto *proto = graph.add_node(); |
1650 | // Add dictionary entries. |
1651 | unsigned_t axis = node->getAxis(); |
1652 | llvm::ArrayRef<unsigned_t> axes(axis); |
1653 | addValueAttribute(proto, "axes" , axes); |
1654 | |
1655 | proto->set_name(node->getName().str()); |
1656 | proto->set_op_type("ReduceProd" ); |
1657 | inputsToProto(node, proto); |
1658 | |
1659 | addValueAttribute(proto, "keepdims" , 0); |
1660 | outputsToProto(node, graph, proto); |
1661 | |
1662 | return Error::success(); |
1663 | } |
1664 | |
1665 | Error ONNXModelWriter::writeBatchNormalization( |
1666 | const BatchNormalizationNode *node, GraphType &graph) { |
1667 | auto *proto = graph.add_node(); |
1668 | // Add dictionary entries. |
1669 | addValueAttribute(proto, "epsilon" , node->getEpsilon()); |
1670 | addValueAttribute(proto, "momentum" , node->getMomentum()); |
1671 | |
1672 | proto->set_name(node->getName().str()); |
1673 | proto->set_op_type("BatchNormalization" ); |
1674 | |
1675 | proto->add_input(node->getInput().getNode()->getName().str()); |
1676 | proto->add_input(node->getScale().getNode()->getName().str()); |
1677 | proto->add_input(node->getBias().getNode()->getName().str()); |
1678 | proto->add_input(node->getMean().getNode()->getName().str()); |
1679 | proto->add_input(node->getVar().getNode()->getName().str()); |
1680 | |
1681 | outputsToProto(node, graph, proto); |
1682 | return Error::success(); |
1683 | } |
1684 | |
1685 | Error ONNXModelWriter::writeInstanceNormalization( |
1686 | const InstanceNormalizationNode *node, GraphType &graph) { |
1687 | auto *proto = graph.add_node(); |
1688 | // Add dictionary entries. |
1689 | addValueAttribute(proto, "epsilon" , node->getEpsilon()); |
1690 | |
1691 | proto->set_name(node->getName().str()); |
1692 | proto->set_op_type("InstanceNormalization" ); |
1693 | |
1694 | proto->add_input(node->getInput().getNode()->getName().str()); |
1695 | proto->add_input(node->getScale().getNode()->getName().str()); |
1696 | proto->add_input(node->getBias().getNode()->getName().str()); |
1697 | |
1698 | outputsToProto(node, graph, proto); |
1699 | return Error::success(); |
1700 | } |
1701 | |
1702 | Error ONNXModelWriter::writeLayerNormalization( |
1703 | const LayerNormalizationNode *node, GraphType &graph) { |
1704 | auto *proto = graph.add_node(); |
1705 | // Add dictionary entries. |
1706 | addValueAttribute(proto, "epsilon" , node->getEpsilon()); |
1707 | |
1708 | proto->set_name(node->getName().str()); |
1709 | proto->set_op_type("LayerNormalization" ); |
1710 | |
1711 | proto->add_input(node->getInput().getNode()->getName().str()); |
1712 | proto->add_input(node->getScale().getNode()->getName().str()); |
1713 | proto->add_input(node->getBias().getNode()->getName().str()); |
1714 | |
1715 | outputsToProto(node, graph, proto); |
1716 | return Error::success(); |
1717 | } |
1718 | |
1719 | Error ONNXModelWriter::writeMeanVarNormalization( |
1720 | const MeanVarNormalizationNode *node, GraphType &graph) { |
1721 | auto *proto = graph.add_node(); |
1722 | // Add dictionary entries. |
1723 | addValueAttribute(proto, "channel" , node->getChannelIdx()); |
1724 | addValueAttribute(proto, "momentum" , node->getMomentum()); |
1725 | |
1726 | proto->set_name(node->getName().str()); |
1727 | proto->set_op_type("MeanVarianceNormalization" ); |
1728 | |
1729 | inputsToProto(node, proto); |
1730 | outputsToProto(node, graph, proto); |
1731 | return Error::success(); |
1732 | } |
1733 | |
1734 | Error ONNXModelWriter::writeSlice(const SliceNode *node, GraphType &graph) { |
1735 | auto *proto = graph.add_node(); |
1736 | // Add dictionary entries. |
1737 | auto starts = node->getStart(); |
1738 | auto outs = node->getResult().dims(); |
1739 | RETURN_ERR_IF_NOT(starts.size() == outs.size(), |
1740 | "Mismatch starts and result dimensions." ); |
1741 | |
1742 | RETURN_IF_ERR(writeAllWithNode("Slice" , node, graph, proto)); |
1743 | |
1744 | if (opsetVersion_ >= 10) { |
1745 | Tensor oneDimTensorStarts(ElemKind::Int64ITy, {(dim_t)starts.size()}); |
1746 | auto handleStarts = oneDimTensorStarts.getHandle<int64_t>(); |
1747 | Tensor oneDimTensorEnds(ElemKind::Int64ITy, {(dim_t)starts.size()}); |
1748 | auto handleEnds = oneDimTensorEnds.getHandle<int64_t>(); |
1749 | |
1750 | for (size_t b = 0, e = starts.size(); b < e; ++b) { |
1751 | handleStarts.raw(b) = starts[b]; |
1752 | handleEnds.raw(b) = outs[b] + starts[b]; |
1753 | } |
1754 | |
1755 | auto *tensorProto = addInitializer(graph); |
1756 | tensorProto->set_name(node->getName().str() + "_starts" ); |
1757 | writeTensor(oneDimTensorStarts, tensorProto, useGlowCustomOps_); |
1758 | proto->add_input(node->getName().str() + "_starts" ); |
1759 | |
1760 | tensorProto = addInitializer(graph); |
1761 | tensorProto->set_name(node->getName().str() + "_ends" ); |
1762 | writeTensor(oneDimTensorEnds, tensorProto, useGlowCustomOps_); |
1763 | proto->add_input(node->getName().str() + "_ends" ); |
1764 | } else { |
1765 | auto *attrStarts = proto->add_attribute(); |
1766 | attrStarts->set_name("starts" ); |
1767 | attrStarts->set_type(AttrType::INTS); |
1768 | auto *attrEnds = proto->add_attribute(); |
1769 | attrEnds->set_name("ends" ); |
1770 | attrEnds->set_type(AttrType::INTS); |
1771 | |
1772 | for (unsigned b = 0, e = starts.size(); b < e; ++b) { |
1773 | attrStarts->add_ints(starts[b]); |
1774 | attrEnds->add_ints(outs[b] + starts[b]); |
1775 | } |
1776 | } |
1777 | return Error::success(); |
1778 | } |
1779 | |
1780 | Error ONNXModelWriter::writePow(const PowNode *node, GraphType &graph) { |
1781 | auto *proto = graph.add_node(); |
1782 | proto->set_name(node->getName().str()); |
1783 | proto->add_input(node->getLHS().getNode()->getName().str()); |
1784 | outputsToProto(node, graph, proto); |
1785 | |
1786 | // Find exponent from splat node |
1787 | const auto *RHSN = node->getRHS().getNode(); |
1788 | switch (RHSN->getKind()) { |
1789 | case Kinded::Kind::SplatNodeKind: { |
1790 | const auto *SN = llvm::cast<SplatNode>(RHSN); |
1791 | float value = SN->getValue(); |
1792 | if (value == 0.5f) { |
1793 | proto->set_op_type("Sqrt" ); |
1794 | } else if (value == -1.0f) { |
1795 | proto->set_op_type("Reciprocal" ); |
1796 | } else if (value == 2.0f) { |
1797 | proto->set_op_type("Sqr" ); |
1798 | } else { |
1799 | return MAKE_ERR("Splat Node Value is invalid." ); |
1800 | } |
1801 | break; |
1802 | } |
1803 | default: |
1804 | proto->add_input(RHSN->getName().str()); |
1805 | break; |
1806 | } |
1807 | |
1808 | reportedNodes_.insert(RHSN); |
1809 | return Error::success(); |
1810 | } |
1811 | |
1812 | Error ONNXModelWriter::writeTopK(const TopKNode *node, GraphType &graph) { |
1813 | auto *proto = graph.add_node(); |
1814 | |
1815 | Tensor scalar(ElemKind::Int64ITy, {1}); |
1816 | auto handle = scalar.getHandle<int64_t>(); |
1817 | handle.raw(0) = node->getK(); |
1818 | |
1819 | auto *tensorProto = addInitializer(graph); |
1820 | tensorProto->set_name("k" ); |
1821 | writeTensor(scalar, tensorProto, useGlowCustomOps_); |
1822 | |
1823 | RETURN_IF_ERR(writeAllWithNode("TopK" , node, graph, proto)); |
1824 | |
1825 | proto->add_input("k" ); |
1826 | return Error::success(); |
1827 | } |
1828 | |
1829 | Error ONNXModelWriter::writeArgMax(const ArgMaxNode *node, GraphType &graph) { |
1830 | auto *proto = graph.add_node(); |
1831 | |
1832 | Tensor axis(ElemKind::Int64ITy, {1}); |
1833 | Tensor keepDims(ElemKind::BoolTy, {1}); |
1834 | auto axisH = axis.getHandle<int64_t>(); |
1835 | auto keepDimsH = keepDims.getHandle<int8_t>(); |
1836 | axisH.raw(0) = node->getAxis(); |
1837 | keepDimsH.raw(0) = node->getKeepDims(); |
1838 | |
1839 | auto *tensorProto = addInitializer(graph); |
1840 | tensorProto->set_name("axis" ); |
1841 | writeTensor(axis, tensorProto, useGlowCustomOps_); |
1842 | |
1843 | tensorProto = addInitializer(graph); |
1844 | tensorProto->set_name("keepDims" ); |
1845 | writeTensor(keepDims, tensorProto, useGlowCustomOps_); |
1846 | RETURN_IF_ERR(writeAllWithNode("ArgMax" , node, graph, proto)); |
1847 | |
1848 | return Error::success(); |
1849 | } |
1850 | |
1851 | Error ONNXModelWriter::writeArgMin(const ArgMinNode *node, GraphType &graph) { |
1852 | auto *proto = graph.add_node(); |
1853 | |
1854 | Tensor axis(ElemKind::Int64ITy, {1}); |
1855 | Tensor keepDims(ElemKind::BoolTy, {1}); |
1856 | auto axisH = axis.getHandle<int64_t>(); |
1857 | auto keepDimsH = keepDims.getHandle<int8_t>(); |
1858 | axisH.raw(0) = node->getAxis(); |
1859 | keepDimsH.raw(0) = node->getKeepDims(); |
1860 | |
1861 | auto *tensorProto = addInitializer(graph); |
1862 | tensorProto->set_name("axis" ); |
1863 | writeTensor(axis, tensorProto, useGlowCustomOps_); |
1864 | |
1865 | tensorProto = addInitializer(graph); |
1866 | tensorProto->set_name("keepDims" ); |
1867 | writeTensor(keepDims, tensorProto, useGlowCustomOps_); |
1868 | RETURN_IF_ERR(writeAllWithNode("ArgMin" , node, graph, proto)); |
1869 | |
1870 | return Error::success(); |
1871 | } |
1872 | |
1873 | Error ONNXModelWriter::writePRelu(const PReluNode *node, GraphType &graph) { |
1874 | auto *proto = graph.add_node(); |
1875 | proto->set_name(node->getName().str()); |
1876 | proto->set_op_type("PRelu" ); |
1877 | proto->add_input(node->getInput().getNode()->getName().str()); |
1878 | |
1879 | const auto *slope = node->getSlope().getNode(); |
1880 | if (const auto *BN = llvm::dyn_cast<BroadcastNode>(slope)) { |
1881 | proto->add_input(BN->getInput().getNode()->getName().str()); |
1882 | reportedNodes_.insert(BN); |
1883 | } else if (const SplatNode *SN = llvm::dyn_cast<SplatNode>(slope)) { |
1884 | // Conversion a scalar to a tensor is required. |
1885 | Tensor scalar = {SN->getValue()}; |
1886 | auto *tensorProto = addInitializer(graph); |
1887 | tensorProto->set_name(SN->getName().str()); |
1888 | writeTensor(scalar, tensorProto, useGlowCustomOps_); |
1889 | proto->add_input(SN->getName().str()); |
1890 | reportedNodes_.insert(SN); |
1891 | } else { |
1892 | return MAKE_ERR("Can't find Splat/Broadcast Node as part of PRelu Node." ); |
1893 | } |
1894 | |
1895 | outputsToProto(node, graph, proto); |
1896 | return Error::success(); |
1897 | } |
1898 | |
1899 | Error ONNXModelWriter::writeGather(const GatherNode *node, GraphType &graph) { |
1900 | auto *proto = graph.add_node(); |
1901 | // Add dictionary entries. |
1902 | auto axis = node->getBatchDims(); |
1903 | |
1904 | if (axis != 0) { |
1905 | addValueAttribute(proto, "axis" , axis); |
1906 | return writeAllWithNode("BatchGather" , node, graph, proto); |
1907 | } else { |
1908 | return writeAllWithNode("Gather" , node, graph, proto); |
1909 | } |
1910 | } |
1911 | |
1912 | Error ONNXModelWriter::writeGatherElements(const GatherElementsNode *node, |
1913 | GraphType &graph) { |
1914 | auto *proto = graph.add_node(); |
1915 | // Add dictionary entries. |
1916 | return writeAllWithNode("GatherElements" , node, graph, proto); |
1917 | } |
1918 | |
1919 | Error ONNXModelWriter::writeGatherND(const GatherNDNode *node, |
1920 | GraphType &graph) { |
1921 | auto *proto = graph.add_node(); |
1922 | // Add dictionary entries. |
1923 | return writeAllWithNode("GatherND" , node, graph, proto); |
1924 | } |
1925 | |
1926 | Error ONNXModelWriter::writeMatMul(const MatMulNode *node, GraphType &graph) { |
1927 | return writeMatMulKind(node, graph, "MatMul" ); |
1928 | } |
1929 | |
1930 | Error ONNXModelWriter::writeBatchMatMul(const BatchMatMulNode *node, |
1931 | GraphType &graph) { |
1932 | auto dimSize = node->getLHS().dims().size(); |
1933 | if (dimSize == 2) { |
1934 | return writeMatMulKind(node, graph, "MatMul" ); |
1935 | } else { |
1936 | return writeMatMulKind(node, graph, "BatchMatMul" ); |
1937 | } |
1938 | } |
1939 | |
1940 | Error ONNXModelWriter::writeReshape(const ReshapeNode *node, GraphType &graph) { |
1941 | auto *proto = graph.add_node(); |
1942 | |
1943 | // Converting arrayRef scale to a constant node |
1944 | auto dims = node->getDims(); |
1945 | Tensor dimsTensor(ElemKind::Int64ITy, {(dim_t)dims.size()}); |
1946 | auto handleDims = dimsTensor.getHandle<int64_t>(); |
1947 | for (size_t b = 0, e = dims.size(); b < e; ++b) { |
1948 | handleDims.raw(b) = dims[b]; |
1949 | } |
1950 | |
1951 | auto *tensorProto = addInitializer(graph); |
1952 | tensorProto->set_name(node->getName().str() + "_shape" ); |
1953 | writeTensor(dimsTensor, tensorProto, useGlowCustomOps_); |
1954 | |
1955 | RETURN_IF_ERR(writeAllWithNode("Reshape" , node, graph, proto)); |
1956 | proto->add_input(node->getName().str() + "_shape" ); |
1957 | return Error::success(); |
1958 | } |
1959 | |
1960 | Error ONNXModelWriter::writeBucketize(const BucketizeNode *node, |
1961 | GraphType &graph) { |
1962 | auto *proto = graph.add_node(); |
1963 | // Add dictionary entries. |
1964 | addValueAttribute(proto, "boundaries" , node->getBoundaries()); |
1965 | |
1966 | return writeAllWithNode("Bucketize" , node, graph, proto); |
1967 | } |
1968 | |
1969 | Error ONNXModelWriter::writeResizeNearest(const ResizeNearestNode *node, |
1970 | GraphType &graph) { |
1971 | auto *proto = graph.add_node(); |
1972 | // Converting arrayRef scale to a constant node |
1973 | auto scale = node->getScale(); |
1974 | Tensor scaleTensor(ElemKind::FloatTy, {(dim_t)scale.size()}); |
1975 | auto handleScale = scaleTensor.getHandle<float>(); |
1976 | for (size_t b = 0, e = scale.size(); b < e; ++b) { |
1977 | handleScale.raw(b) = scale[b]; |
1978 | } |
1979 | |
1980 | auto *tensorProto = addInitializer(graph); |
1981 | tensorProto->set_name(node->getName().str() + "_scale" ); |
1982 | writeTensor(scaleTensor, tensorProto, useGlowCustomOps_); |
1983 | |
1984 | // Add dictionary entries. |
1985 | addValueAttribute(proto, "coordinate_transformation_mode" , |
1986 | std::string("asymmetric" )); |
1987 | addValueAttribute(proto, "mode" , std::string("nearest" )); |
1988 | addValueAttribute(proto, "nearest_mode" , std::string("floor" )); |
1989 | |
1990 | RETURN_IF_ERR(writeAllWithNode("Resize" , node, graph, proto)); |
1991 | proto->add_input(node->getName().str() + "_scale" ); |
1992 | return Error::success(); |
1993 | } |
1994 | |
1995 | Error ONNXModelWriter::writeResizeBilinear(const ResizeBilinearNode *node, |
1996 | GraphType &graph) { |
1997 | auto *proto = graph.add_node(); |
1998 | // Converting arrayRef scale to a constant node |
1999 | auto scale = node->getScale(); |
2000 | Tensor scaleTensor(ElemKind::FloatTy, {(dim_t)scale.size()}); |
2001 | auto handleScale = scaleTensor.getHandle<float>(); |
2002 | for (size_t b = 0, e = scale.size(); b < e; ++b) { |
2003 | handleScale.raw(b) = scale[b]; |
2004 | } |
2005 | |
2006 | auto *tensorProto = addInitializer(graph); |
2007 | tensorProto->set_name(node->getName().str() + "_scale" ); |
2008 | writeTensor(scaleTensor, tensorProto, useGlowCustomOps_); |
2009 | |
2010 | // Add dictionary entries. |
2011 | addValueAttribute(proto, "coordinate_transformation_mode" , |
2012 | std::string("asymmetric" )); |
2013 | addValueAttribute(proto, "mode" , std::string("linear" )); |
2014 | |
2015 | RETURN_IF_ERR(writeAllWithNode("Resize" , node, graph, proto)); |
2016 | proto->add_input(node->getName().str() + "_scale" ); |
2017 | return Error::success(); |
2018 | } |
2019 | |
2020 | Error ONNXModelWriter::writeSoftMax(const SoftMaxNode *node, GraphType &graph) { |
2021 | auto *proto = graph.add_node(); |
2022 | proto->set_name(node->getName().str()); |
2023 | proto->set_op_type("Softmax" ); |
2024 | outputsToProto(node, graph, proto); |
2025 | // Find input from Reshape node |
2026 | proto->add_input(node->getInput().getNode()->getName().str()); |
2027 | |
2028 | // Mark selected input as visited. |
2029 | reportedNodes_.insert(node->getSelected().getNode()); |
2030 | return Error::success(); |
2031 | } |
2032 | |
2033 | Error ONNXModelWriter::writeLogSoftMax(const LogSoftMaxNode *node, |
2034 | GraphType &graph) { |
2035 | auto *proto = graph.add_node(); |
2036 | proto->set_name(node->getName().str()); |
2037 | proto->set_op_type("LogSoftmax" ); |
2038 | outputsToProto(node, graph, proto); |
2039 | // Find input from Reshape node |
2040 | proto->add_input(node->getInput().getNode()->getName().str()); |
2041 | |
2042 | // Mark selected input as visited. |
2043 | reportedNodes_.insert(node->getSelected().getNode()); |
2044 | return Error::success(); |
2045 | } |
2046 | |
2047 | Error ONNXModelWriter::writeReplaceNaN(const ReplaceNaNNode *node, |
2048 | GraphType &graph) { |
2049 | auto *proto = graph.add_node(); |
2050 | // Add dictionary entries. |
2051 | float value = node->getValue(); |
2052 | if (value != 0.0f) { |
2053 | addValueAttribute(proto, "value" , value); |
2054 | } |
2055 | return writeAllWithNode("ReplaceNaN" , node, graph, proto); |
2056 | } |
2057 | |
2058 | Error ONNXModelWriter::writeGatherRanges(const GatherRangesNode *node, |
2059 | GraphType &graph) { |
2060 | auto *proto = graph.add_node(); |
2061 | // Add dictionary entries. |
2062 | addValueAttribute(proto, "maxOutputSize" , node->getOutput().dims()[0]); |
2063 | |
2064 | return writeAllWithNode("GatherRanges" , node, graph, proto); |
2065 | } |
2066 | |
2067 | Error ONNXModelWriter::writeSparseToDenseMask(const SparseToDenseMaskNode *node, |
2068 | GraphType &graph) { |
2069 | auto *proto = graph.add_node(); |
2070 | // Add dictionary entries. |
2071 | addValueAttribute(proto, "mask" , node->getMask()); |
2072 | |
2073 | return writeAllWithNode("SparseToDenseMask" , node, graph, proto); |
2074 | } |
2075 | |
2076 | Error ONNXModelWriter::writeAdaptiveAvgPool(const AdaptiveAvgPoolNode *node, |
2077 | GraphType &graph) { |
2078 | auto *proto = graph.add_node(); |
2079 | |
2080 | // Add dictionary entries. |
2081 | const auto outShape = ShapeNHWC(node->getResult().dims()); |
2082 | std::vector<size_t> output_size{outShape.h, outShape.w}; |
2083 | addValueAttribute(proto, "output_size" , llvm::makeArrayRef(output_size)); |
2084 | |
2085 | auto err = writeAllWithNode("AdaptiveAvgPool" , node, graph, proto); |
2086 | return err; |
2087 | } |
2088 | |
2089 | Error ONNXModelWriter::writeLocalResponseNormalization( |
2090 | const LocalResponseNormalizationNode *node, GraphType &graph) { |
2091 | auto *proto = graph.add_node(); |
2092 | proto->set_name(node->getName().str()); |
2093 | proto->set_op_type("LRN" ); |
2094 | outputsToProto(node, graph, proto); |
2095 | // Find input from Transpose node |
2096 | const TransposeNode *TN = |
2097 | llvm::dyn_cast<TransposeNode>(node->getInput().getNode()); |
2098 | RETURN_ERR_IF_NOT( |
2099 | TN, |
2100 | "Can't find Transpose Node as part of LocalResponseNormalization Node." ); |
2101 | proto->add_input(TN->getInput().getNode()->getName().str()); |
2102 | reportedNodes_.insert(TN); |
2103 | // Add dictionary entries. |
2104 | addValueAttribute(proto, "size" , 2 * node->getHalfWindowSize()); |
2105 | addValueAttribute(proto, "alpha" , node->getAlpha()); |
2106 | addValueAttribute(proto, "beta" , node->getBeta()); |
2107 | addValueAttribute(proto, "bias" , node->getK()); |
2108 | |
2109 | return Error::success(); |
2110 | } |
2111 | |
2112 | Error ONNXModelWriter::writeBatchBoxCox(const BatchBoxCoxNode *node, |
2113 | GraphType &graph) { |
2114 | auto *proto = graph.add_node(); |
2115 | addValueAttribute(proto, "epsilon" , node->getEpsilon()); |
2116 | return writeAllWithNode("BatchBoxCox" , node, graph, proto); |
2117 | } |
2118 | |
2119 | //===-----------------------------------------------------------------===// |
2120 | // Operators Supported by Glow only |
2121 | //===-----------------------------------------------------------------===// |
2122 | Error ONNXModelWriter::writeModulo(const ModuloNode *node, GraphType &graph) { |
2123 | auto *proto = graph.add_node(); |
2124 | // Add dictionary entries. |
2125 | addValueAttribute(proto, "divisor" , node->getDivisor()); |
2126 | addValueAttribute(proto, "sign_follow_divisor" , |
2127 | node->getSignFollowDivisor() ? 1 : 0); |
2128 | |
2129 | return writeAllWithNode("Modulo" , node, graph, proto); |
2130 | } |
2131 | |
2132 | namespace { |
2133 | template <typename T> |
2134 | void writeTensorwiseQuantizedPool(const T *node, const std::string &op, |
2135 | ONNX_TRAITS::GraphProto &graph, |
2136 | ReportedNodes &) { |
2137 | assert(node->getLayout() == NHWC && "can only write NHWC Pools" ); |
2138 | |
2139 | auto *proto = graph.add_node(); |
2140 | |
2141 | // Add dictionary entries. |
2142 | addValueAttribute(proto, "kernel_shape" , node->getKernels()); |
2143 | addValueAttribute(proto, "strides" , node->getStrides()); |
2144 | addValueAttribute(proto, "pads" , node->getPads()); |
2145 | |
2146 | if (auto *APN = llvm::dyn_cast<AvgPoolNode>(node)) { |
2147 | addValueAttribute(proto, "count_include_pad" , APN->getCountIncludePads()); |
2148 | addValueAttribute(proto, "out_scale" , |
2149 | APN->getType(AvgPoolNode::ResultIdx)->getScale()); |
2150 | addValueAttribute(proto, "out_offset" , |
2151 | APN->getType(AvgPoolNode::ResultIdx)->getOffset()); |
2152 | } else if (auto *MPN = llvm::dyn_cast<MaxPoolNode>(node)) { |
2153 | addValueAttribute(proto, "out_scale" , |
2154 | MPN->getType(MaxPoolNode::ResultIdx)->getScale()); |
2155 | addValueAttribute(proto, "out_offset" , |
2156 | MPN->getType(MaxPoolNode::ResultIdx)->getOffset()); |
2157 | } |
2158 | |
2159 | proto->add_input(node->getInput().getNode()->getName().str()); |
2160 | outputsToProto(node, graph, proto); |
2161 | |
2162 | proto->set_name(node->getName().str()); |
2163 | proto->set_op_type(op); |
2164 | } |
2165 | |
2166 | template <typename T> |
2167 | void writePool(const T *node, const std::string &op, |
2168 | ONNX_TRAITS::GraphProto &graph, ReportedNodes &reporter) { |
2169 | // Delegate writing quantized pool ops to writeTensorwiseQuantizedPool. |
2170 | if (isQuantizedElemKind(node->getInput().getElementType())) { |
2171 | return writeTensorwiseQuantizedPool(node, op, graph, reporter); |
2172 | } |
2173 | |
2174 | // Loading pools creates a sandwich with Transpose nodes for Input |
2175 | // and Result. The lowering algorithm can remove Transpose nodes and |
2176 | // replace one set of nodes with another ones. When saving a graph to ONNX |
2177 | // format, keep in mind that when it will be loaded again a Transpose nodes |
2178 | // sandwich will be created again. The steps will be: |
2179 | // Remove Transpose node for Input, if such Transpose is not |
2180 | // found (they are supposed to be NCHW2NHWC then create a "mirror" |
2181 | // Transpose, i.e. NHWC2NCHW for correspondent Input or/and Weights. |
2182 | // The similar algorithm will be applied for Result. If Transpose NHWC2NCHW |
2183 | // node is found for Result user then remove it, otherwise create a "mirror" |
2184 | // Transpose, i.e. NCHW2NHWC. |
2185 | assert((node->getLayout() == NHWC || node->getLayout() == NTHWC) && |
2186 | "can only write NHWC (2D) or NTHWC (3D) Pool ops" ); |
2187 | |
2188 | auto *proto = graph.add_node(); |
2189 | |
2190 | // Use the output of transpose node. |
2191 | if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) { |
2192 | // Apparently Result Transpose has been removed, add NCHW2NHWC Transpose. |
2193 | writeTransposeResult(node, proto, graph); |
2194 | } |
2195 | |
2196 | // Add dictionary entries. |
2197 | addValueAttribute(proto, "kernel_shape" , node->getKernels()); |
2198 | addValueAttribute(proto, "strides" , node->getStrides()); |
2199 | addValueAttribute(proto, "pads" , node->getPads()); |
2200 | |
2201 | if (auto *APN = llvm::dyn_cast<AvgPoolNode>(node)) { |
2202 | addValueAttribute(proto, "count_include_pad" , APN->getCountIncludePads()); |
2203 | } |
2204 | |
2205 | const Node *input = node->getInput().getNode(); |
2206 | if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(input)) { |
2207 | proto->add_input(TN->getInput().getNode()->getName().str()); |
2208 | reporter.insert(TN); |
2209 | } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(input)) { |
2210 | proto->add_input(RSN->getInput().getNode()->getName().str()); |
2211 | reporter.insert(RSN); |
2212 | } else { |
2213 | writeTransposeInput(node, input, proto, graph); |
2214 | } |
2215 | |
2216 | proto->set_name(node->getName().str()); |
2217 | proto->set_op_type(op); |
2218 | } |
2219 | } // namespace |
2220 | |
2221 | Error ONNXModelWriter::writeAvgPool(const AvgPoolNode *node, GraphType &graph) { |
2222 | writePool(node, "AveragePool" , graph, reportedNodes_); |
2223 | return Error::success(); |
2224 | } |
2225 | |
2226 | Error ONNXModelWriter::writeMaxPool(const MaxPoolNode *node, GraphType &graph) { |
2227 | writePool(node, "MaxPool" , graph, reportedNodes_); |
2228 | return Error::success(); |
2229 | } |
2230 | |
2231 | Error ONNXModelWriter::writeConvolution3D(const Convolution3DNode *node, |
2232 | GraphType &graph) { |
2233 | // Loading convolution creates a sandwich with Transpose nodes for Input, |
2234 | // Weights, and Result. The lowering algorithm can remove Transpose nodes and |
2235 | // replace one set of nodes with another ones. When saving a graph to ONNX |
2236 | // format, keep in mind that when it will be loaded again a Transpose nodes |
2237 | // sandwich will be created again. The steps will be: |
2238 | // Remove Transpose nodes for Input and Weights, if such Transpose are not |
2239 | // found (they are supposed to be NCTHW2NTHWC then create a "mirror" |
2240 | // Transpose, i.e. NTHWC2NCTHW for correspondent Input or/and Weights. |
2241 | // The similar algorithm will be applied for Result. If Transpose NTHWC2NCTHW |
2242 | // node is found for Result user then remove it, otherwise create a "mirror" |
2243 | // Transpose, i.e. NCTHW2NTHWC. |
2244 | // assert(node->getLayout() == NTHWC && "can only write NTHWC Convolutions"); |
2245 | |
2246 | // Delegate writing quantized Convs to writeTensorwiseQuantizedConvolution. |
2247 | if (isQuantizedElemKind(node->getInput().getElementType())) { |
2248 | return MAKE_ERR("Not implemented" ); |
2249 | // return writeTensorwiseQuantizedConvolution(node, graph); |
2250 | } |
2251 | |
2252 | auto *proto = graph.add_node(); |
2253 | |
2254 | // Use the output of transpose node. |
2255 | if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) { |
2256 | // Apparently Result Transpose has been removed, add NCTHW2NTHWC Transpose. |
2257 | writeTransposeResult(node, proto, graph, NCTHW2NTHWC); |
2258 | } |
2259 | |
2260 | // Add dictionary entries. |
2261 | addValueAttribute(proto, "kernel_shape" , node->getKernels()); |
2262 | addValueAttribute(proto, "strides" , node->getStrides()); |
2263 | addValueAttribute(proto, "pads" , node->getPads()); |
2264 | addValueAttribute(proto, "group" , node->getGroup()); |
2265 | // addValueAttribute(proto, "dilations", node->getDilation()); |
2266 | |
2267 | const Node *input = node->getInput().getNode(); |
2268 | if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(input)) { |
2269 | proto->add_input(TN->getInput().getNode()->getName().str()); |
2270 | reportedNodes_.insert(TN); |
2271 | } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(input)) { |
2272 | proto->add_input(RSN->getInput().getNode()->getName().str()); |
2273 | reportedNodes_.insert(RSN); |
2274 | } else { |
2275 | writeTransposeInput(node, input, proto, graph, NTHWC2NCTHW); |
2276 | } |
2277 | |
2278 | const Node *filter = node->getFilter().getNode(); |
2279 | if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(filter)) { |
2280 | proto->add_input(TN->getInput().getNode()->getName().str()); |
2281 | reportedNodes_.insert(TN); |
2282 | } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(filter)) { |
2283 | proto->add_input(RSN->getInput().getNode()->getName().str()); |
2284 | reportedNodes_.insert(RSN); |
2285 | } else { |
2286 | writeTransposeInput(node, filter, proto, graph, NTHWC2NCTHW); |
2287 | } |
2288 | |
2289 | proto->add_input(node->getBias().getNode()->getName().str()); |
2290 | |
2291 | proto->set_name(node->getName().str()); |
2292 | proto->set_op_type("Conv" ); |
2293 | |
2294 | return Error::success(); |
2295 | } |
2296 | |
2297 | Error ONNXModelWriter::writeSpaceToDepth(const SpaceToDepthNode *node, |
2298 | GraphType &graph) { |
2299 | auto *proto = graph.add_node(); |
2300 | |
2301 | // Find input from Transpose node |
2302 | const TransposeNode *TN = |
2303 | llvm::dyn_cast<TransposeNode>(node->getInput().getNode()); |
2304 | RETURN_ERR_IF_NOT(TN, |
2305 | "Can't find Transpose Node as part of SpaceToDepth Node." ); |
2306 | proto->add_input(TN->getInput().getNode()->getName().str()); |
2307 | reportedNodes_.insert(TN); |
2308 | |
2309 | proto->set_name(node->getName().str()); |
2310 | proto->set_op_type("SpaceToDepth" ); |
2311 | // Add dictionary entries. |
2312 | addValueAttribute(proto, "blocksize" , node->getBlockSize()); |
2313 | |
2314 | // Use the output of transpose node, if any. |
2315 | if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) { |
2316 | outputsToProto(node, graph, proto); |
2317 | } |
2318 | return Error::success(); |
2319 | } |
2320 | |
2321 | Error ONNXModelWriter::writeChannelShuffle(const ChannelShuffleNode *node, |
2322 | GraphType &graph) { |
2323 | auto *proto = graph.add_node(); |
2324 | // Add dictionary entries. |
2325 | addValueAttribute(proto, "group" , node->getGroup()); |
2326 | addValueAttribute(proto, "kernel" , node->getKernel()); |
2327 | |
2328 | return writeAllWithNode("ChannelShuffle" , node, graph, proto); |
2329 | } |
2330 | |
2331 | Error ONNXModelWriter::writeQuantizationProfile( |
2332 | const QuantizationProfileNode *node, GraphType &graph) { |
2333 | auto *proto = graph.add_node(); |
2334 | // Add dictionary entries. |
2335 | addValueAttribute(proto, "name" , node->getProfiledNodeName()); |
2336 | addValueAttribute(proto, "number" , node->getProfiledOutputNumber()); |
2337 | |
2338 | return writeAllWithNode("QuantizationProfile" , node, graph, proto); |
2339 | } |
2340 | |
2341 | Error ONNXModelWriter::writeTraceEvent(const TraceEventNode *node, |
2342 | GraphType &graph) { |
2343 | auto *proto = graph.add_node(); |
2344 | // Add dictionary entries. |
2345 | addValueAttribute(proto, "name" , node->getEventName()); |
2346 | addValueAttribute(proto, "type" , node->getEventType()); |
2347 | addValueAttribute(proto, "index" , node->getIndex()); |
2348 | |
2349 | return writeAllWithNode("TraceEvent" , node, graph, proto); |
2350 | } |
2351 | |
2352 | Error ONNXModelWriter::writeInsertTensor(const InsertTensorNode *node, |
2353 | GraphType &graph) { |
2354 | auto *proto = graph.add_node(); |
2355 | // Add dictionary entries. |
2356 | addValueAttribute(proto, "start" , node->getStart()); |
2357 | addValueAttribute(proto, "count" , node->getCount()); |
2358 | addValueAttribute(proto, "axis" , node->getAxis()); |
2359 | |
2360 | return writeAllWithNode("InsertTensor" , node, graph, proto); |
2361 | } |
2362 | |
2363 | Error ONNXModelWriter::writeSplat(const SplatNode *node, GraphType &graph) { |
2364 | // Conversion a scalar to a tensor is required. |
2365 | Tensor tensor(ElemKind::FloatTy, node->getResult().dims()); |
2366 | auto handle = tensor.getHandle<>(); |
2367 | float value = node->getValue(); |
2368 | for (size_t b = 0, e = tensor.size(); b < e; ++b) { |
2369 | handle.raw(b) = value; |
2370 | } |
2371 | |
2372 | auto *tensorProto = addInitializer(graph); |
2373 | |
2374 | findOutputNames(node, graph, [&](const std::string &name) { |
2375 | tensorProto->set_name(name); |
2376 | }); |
2377 | |
2378 | writeTensor(tensor, tensorProto, useGlowCustomOps_); |
2379 | reportedNodes_.insert(node); |
2380 | |
2381 | return Error::success(); |
2382 | } |
2383 | |
2384 | Error ONNXModelWriter::writeTouch(const TouchNode *node, GraphType &graph) { |
2385 | auto *proto = graph.add_node(); |
2386 | return writeAllWithNode("Touch" , node, graph, proto); |
2387 | } |
2388 | |
2389 | // Exporting arithmetic node which may involve broadcasting. |
2390 | // Broadcast Node will be unwind. |
2391 | #define ARITHMETIC_NODE_WRITER(ONNXNAME, GLOWNAME) \ |
2392 | Error ONNXModelWriter::write##GLOWNAME(const GLOWNAME##Node *node, \ |
2393 | GraphType &graph) { \ |
2394 | return writeArithmetic(#ONNXNAME, node, graph, reportedNodes_, \ |
2395 | hasMultidirectionalBroadcast(#ONNXNAME)); \ |
2396 | } |
2397 | |
2398 | ARITHMETIC_NODE_WRITER(Add, Add); |
2399 | ARITHMETIC_NODE_WRITER(Sub, Sub); |
2400 | ARITHMETIC_NODE_WRITER(Mul, Mul); |
2401 | ARITHMETIC_NODE_WRITER(Div, Div); |
2402 | ARITHMETIC_NODE_WRITER(Equal, CmpEQ) |
2403 | ARITHMETIC_NODE_WRITER(And, And) |
2404 | ARITHMETIC_NODE_WRITER(Or, Or) |
2405 | ARITHMETIC_NODE_WRITER(Xor, Xor) |
2406 | ARITHMETIC_NODE_WRITER(Less, CmpLT) |
2407 | |
2408 | // Ops that Onnx doesn't have |
2409 | ARITHMETIC_NODE_WRITER(CmpLTE, CmpLTE) |
2410 | ARITHMETIC_NODE_WRITER(FloorDiv, FloorDiv); |
2411 | ARITHMETIC_NODE_WRITER(Fmod, Fmod) |
2412 | ARITHMETIC_NODE_WRITER(BitwiseAnd, BitwiseAnd) |
2413 | ARITHMETIC_NODE_WRITER(BitwiseOr, BitwiseOr) |
2414 | ARITHMETIC_NODE_WRITER(BitwiseXor, BitwiseXor) |
2415 | #undef ARITHMETIC_NODE_WRITER |
2416 | |
2417 | // Default exporting algorithm. |
2418 | #define DEF_ALL_WRITER_NODE(NAME) \ |
2419 | Error ONNXModelWriter::write##NAME(const NAME##Node *node, \ |
2420 | GraphType &graph) { \ |
2421 | return writeAll(#NAME, node, graph); \ |
2422 | } |
2423 | |
2424 | // ONNX nodes with default exporting algorithm. |
2425 | DEF_ALL_WRITER_NODE(Not) |
2426 | DEF_ALL_WRITER_NODE(Abs) |
2427 | DEF_ALL_WRITER_NODE(Neg) |
2428 | DEF_ALL_WRITER_NODE(Floor) |
2429 | DEF_ALL_WRITER_NODE(Sign) |
2430 | DEF_ALL_WRITER_NODE(Ceil) |
2431 | DEF_ALL_WRITER_NODE(Round) |
2432 | DEF_ALL_WRITER_NODE(Sqrt) |
2433 | DEF_ALL_WRITER_NODE(Rsqrt) |
2434 | DEF_ALL_WRITER_NODE(Reciprocal) |
2435 | DEF_ALL_WRITER_NODE(Sin) |
2436 | DEF_ALL_WRITER_NODE(Cos) |
2437 | DEF_ALL_WRITER_NODE(LSTMUnit) |
2438 | DEF_ALL_WRITER_NODE(DynamicQuantizedFullyConnected) |
2439 | DEF_ALL_WRITER_NODE(DynamicRowwiseQuantizedFullyConnected) |
2440 | DEF_ALL_WRITER_NODE(Erf) |
2441 | DEF_ALL_WRITER_NODE(Min) |
2442 | DEF_ALL_WRITER_NODE(Max) |
2443 | DEF_ALL_WRITER_NODE(Log) |
2444 | DEF_ALL_WRITER_NODE(Asin) |
2445 | DEF_ALL_WRITER_NODE(Acos) |
2446 | DEF_ALL_WRITER_NODE(Atan) |
2447 | DEF_ALL_WRITER_NODE(Exp) |
2448 | DEF_ALL_WRITER_NODE(Relu) |
2449 | DEF_ALL_WRITER_NODE(LeakyRelu) |
2450 | DEF_ALL_WRITER_NODE(Gelu) |
2451 | DEF_ALL_WRITER_NODE(Tanh) |
2452 | DEF_ALL_WRITER_NODE(IsNaN) |
2453 | DEF_ALL_WRITER_NODE(Sigmoid) |
2454 | DEF_ALL_WRITER_NODE(Swish) |
2455 | DEF_ALL_WRITER_NODE(SoftPlus) |
2456 | DEF_ALL_WRITER_NODE(LengthsSum) |
2457 | DEF_ALL_WRITER_NODE(BatchOneHot) |
2458 | DEF_ALL_WRITER_NODE(LengthsToRanges) |
2459 | DEF_ALL_WRITER_NODE(SparseLengthsSum) |
2460 | DEF_ALL_WRITER_NODE(SparseLengthsWeightedSum) |
2461 | DEF_ALL_WRITER_NODE(EmbeddingBag) |
2462 | DEF_ALL_WRITER_NODE(Embedding) |
2463 | DEF_ALL_WRITER_NODE(BitwiseNot) |
2464 | DEF_ALL_WRITER_NODE(GaussianFill) |
2465 | DEF_ALL_WRITER_NODE(NonZero) |
2466 | DEF_ALL_WRITER_NODE(BatchSparseToDense) |
2467 | DEF_ALL_WRITER_NODE(FillExamplesWithIndicator) |
2468 | |
2469 | // Glow nodes with default exporting algorithm. |
2470 | DEF_ALL_WRITER_NODE(CmpNEQ) |
2471 | DEF_ALL_WRITER_NODE(BatchedAdd) |
2472 | DEF_ALL_WRITER_NODE(BatchedMul) |
2473 | DEF_ALL_WRITER_NODE(Dequantize) |
2474 | DEF_ALL_WRITER_NODE(Regression) |
2475 | DEF_ALL_WRITER_NODE(RowwiseQuantizedSparseLengthsWeightedSum) |
2476 | DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsSum) |
2477 | DEF_ALL_WRITER_NODE(EmbeddingBagByteRowwiseOffsets) |
2478 | DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsWeightedSum) |
2479 | DEF_ALL_WRITER_NODE(NonMaxSuppression) |
2480 | DEF_ALL_WRITER_NODE(TFLiteDetectionPostProcess) |
2481 | DEF_ALL_WRITER_NODE(HardSwish) |
2482 | DEF_ALL_WRITER_NODE(ConvTranspose) |
2483 | DEF_ALL_WRITER_NODE(Logit) |
2484 | DEF_ALL_WRITER_NODE(Truncate) |
2485 | DEF_ALL_WRITER_NODE(BatchedUnaryEmbeddingsBags) |
2486 | DEF_ALL_WRITER_NODE(IntNBitSplitEmbeddingBags) |
2487 | DEF_ALL_WRITER_NODE(IntNBitSplitEmbeddingWeightedBags) |
2488 | |
2489 | Error ONNXModelWriter::writeClip(const ClipNode *node, GraphType &graph) { |
2490 | auto *proto = graph.add_node(); |
2491 | addValueAttribute(proto, "min" , node->getMin()); |
2492 | addValueAttribute(proto, "max" , node->getMax()); |
2493 | return writeAllWithNode("Clip" , node, graph, proto); |
2494 | } |
2495 | |
2496 | Error ONNXModelWriter::writeConvertTo(const ConvertToNode *node, |
2497 | GraphType &graph) { |
2498 | auto *proto = graph.add_node(); |
2499 | |
2500 | // Add dictionary entries. |
2501 | TensorType ttype; |
2502 | for (auto d : node->getResult().dims()) { |
2503 | ttype.add_dims(d); |
2504 | } |
2505 | ttype.set_data_type(convertType(*node->getResult().getType())); |
2506 | auto *attr = proto->add_attribute(); |
2507 | attr->set_name("shape" ); |
2508 | attr->mutable_t()->CopyFrom(ttype); |
2509 | |
2510 | return writeAllWithNode("ConvertTo" , node, graph, proto); |
2511 | } |
2512 | |
2513 | Error ONNXModelWriter::writeSelect(const SelectNode *node, GraphType &graph) { |
2514 | auto *proto = graph.add_node(); |
2515 | // Add dictionary entries. |
2516 | addValueAttribute(proto, "shape" , node->getResult().dims()); |
2517 | |
2518 | return writeAllWithNode("Select" , node, graph, proto); |
2519 | } |
2520 | |
2521 | Error ONNXModelWriter::writeQuantize(const QuantizeNode *node, |
2522 | GraphType &graph) { |
2523 | auto *proto = graph.add_node(); |
2524 | auto outTy = node->getResult().getType(); |
2525 | |
2526 | // Add dictionary entries. |
2527 | addValueAttribute(proto, "scale" , outTy->getScale()); |
2528 | addValueAttribute(proto, "offset" , outTy->getOffset()); |
2529 | addValueAttribute(proto, "elem_kind" , outTy->getElementName()); |
2530 | |
2531 | return writeAllWithNode("Quantize" , node, graph, proto); |
2532 | } |
2533 | |
2534 | Error ONNXModelWriter::writeIntLookupTable(const IntLookupTableNode *node, |
2535 | GraphType &graph) { |
2536 | auto *proto = graph.add_node(); |
2537 | // Add dictionary entries. |
2538 | addValueAttribute(proto, "shape" , node->getResult().dims()); |
2539 | NodeValue mapping = node->getMapping(); |
2540 | if (Constant *c = llvm::dyn_cast<Constant>(mapping.getNode())) { |
2541 | auto handle = c->getHandle<int8_t>(); |
2542 | auto begin = &handle.raw(0); |
2543 | addValueAttribute( |
2544 | proto, "values" , |
2545 | llvm::ArrayRef<int8_t>(begin, begin + handle.actualSize())); |
2546 | } else { |
2547 | return MAKE_ERR("Mapping must be a constant type." ); |
2548 | } |
2549 | |
2550 | return writeAllWithNode("IntLookupTable" , node, graph, proto); |
2551 | } |
2552 | |
2553 | Error ONNXModelWriter::writeLookupTable(const LookupTableNode *node, |
2554 | GraphType &graph) { |
2555 | auto *proto = graph.add_node(); |
2556 | // Add dictionary entries. |
2557 | addValueAttribute(proto, "shape" , node->getResult().dims()); |
2558 | NodeValue table = node->getTable(); |
2559 | if (Constant *c = llvm::dyn_cast<Constant>(table.getNode())) { |
2560 | auto handle = c->getHandle<int8_t>(); |
2561 | auto begin = &handle.raw(0); |
2562 | addValueAttribute( |
2563 | proto, "values" , |
2564 | llvm::ArrayRef<int8_t>(begin, begin + handle.actualSize())); |
2565 | } else { |
2566 | return MAKE_ERR("Mapping must be a constant type." ); |
2567 | } |
2568 | |
2569 | return writeAllWithNode("LookupTable" , node, graph, proto); |
2570 | } |
2571 | |
2572 | Error ONNXModelWriter::writeLengthsRangeFill(const LengthsRangeFillNode *node, |
2573 | GraphType &graph) { |
2574 | auto *proto = graph.add_node(); |
2575 | // Add dictionary entries. |
2576 | addValueAttribute(proto, "size" , node->getResult().dims()[0]); |
2577 | |
2578 | return writeAllWithNode("LengthsRangeFill" , node, graph, proto); |
2579 | } |
2580 | |
2581 | Error ONNXModelWriter::writeRescaleQuantized(const RescaleQuantizedNode *node, |
2582 | GraphType &graph) { |
2583 | auto *proto = graph.add_node(); |
2584 | auto outTy = node->getResult().getType(); |
2585 | // Add dictionary entries. |
2586 | addValueAttribute(proto, "scale" , outTy->getScale()); |
2587 | addValueAttribute(proto, "offset" , outTy->getOffset()); |
2588 | |
2589 | return writeAllWithNode("RescaleQuantized" , node, graph, proto); |
2590 | } |
2591 | |
2592 | Error ONNXModelWriter::writeGemm(const GemmNode *node, GraphType &graph) { |
2593 | auto *proto = graph.add_node(); |
2594 | proto->set_name(node->getName().str()); |
2595 | proto->set_op_type("Gemm" ); |
2596 | |
2597 | proto->add_input(node->getA().getNode()->getName().str()); |
2598 | proto->add_input(node->getB().getNode()->getName().str()); |
2599 | if (node->getC().getNode()) { |
2600 | proto->add_input(node->getC().getNode()->getName().str()); |
2601 | } |
2602 | |
2603 | addValueAttribute(proto, "alpha" , node->getAlpha()); |
2604 | addValueAttribute(proto, "beta" , node->getBeta()); |
2605 | addValueAttribute(proto, "transA" , node->getTransposeA()); |
2606 | addValueAttribute(proto, "transB" , node->getTransposeB()); |
2607 | |
2608 | outputsToProto(node, graph, proto); |
2609 | return Error::success(); |
2610 | } |
2611 | |
2612 | Error ONNXModelWriter::writeFullyConnected(const FullyConnectedNode *node, |
2613 | GraphType &graph) { |
2614 | auto *proto = graph.add_node(); |
2615 | proto->set_name(node->getName().str()); |
2616 | proto->set_op_type("FullyConnected" ); |
2617 | |
2618 | if (node->getInput().dims().size() != 2) { |
2619 | return MAKE_ERR("Don't support input dim other than 2" ); |
2620 | } |
2621 | proto->add_input(node->getInput().getNode()->getName().str()); |
2622 | proto->add_input(node->getWeights().getNode()->getName().str()); |
2623 | proto->add_input(node->getBias().getNode()->getName().str()); |
2624 | outputsToProto(node, graph, proto); |
2625 | return Error::success(); |
2626 | } |
2627 | |
2628 | Error ONNXModelWriter::writeVectorNorm(const VectorNormNode *node, |
2629 | GraphType &graph) { |
2630 | auto *proto = graph.add_node(); |
2631 | |
2632 | // Add dictionary entries. |
2633 | addValueAttribute(proto, "axis" , node->getAxis()); |
2634 | |
2635 | proto->set_name(node->getName().str()); |
2636 | proto->set_op_type("VectorNorm" ); |
2637 | inputsToProto(node, proto); |
2638 | |
2639 | // currently support p = 2 (Frobenius or i2) |
2640 | addValueAttribute(proto, "p" , node->getP()); |
2641 | |
2642 | outputsToProto(node, graph, proto); |
2643 | |
2644 | return Error::success(); |
2645 | } |
2646 | |
2647 | Error ONNXModelWriter::writeRowwiseQuantizedFullyConnected( |
2648 | const RowwiseQuantizedFullyConnectedNode *node, GraphType &graph) { |
2649 | auto *proto = graph.add_node(); |
2650 | |
2651 | // Add dictionary entries. |
2652 | addValueAttribute( |
2653 | proto, "out_scale" , |
2654 | node->getType(RowwiseQuantizedFullyConnectedNode::ResultIdx)->getScale()); |
2655 | addValueAttribute(proto, "out_offset" , |
2656 | node->getType(RowwiseQuantizedFullyConnectedNode::ResultIdx) |
2657 | ->getOffset()); |
2658 | |
2659 | return writeAllWithNode("RowwiseQuantizedFullyConnected" , node, graph, proto); |
2660 | } |
2661 | |
2662 | Error ONNXModelWriter::writeTile(const TileNode *node, GraphType &graph) { |
2663 | auto *proto = graph.add_node(); |
2664 | |
2665 | // unwind Tile |
2666 | std::vector<size_t> repeats; |
2667 | const TileNode *tile = unwindTile(node, &repeats, reportedNodes_); |
2668 | |
2669 | proto->set_name("Tile" ); |
2670 | proto->set_op_type("Tile" ); |
2671 | // Use inputs from top tile. |
2672 | inputsToProto(tile, proto); |
2673 | // Use outputs from bottom tile. |
2674 | outputsToProto(node, graph, proto); |
2675 | |
2676 | // Add node indices |
2677 | auto *indices = graph.add_node(); |
2678 | indices->set_name("Constant" ); |
2679 | indices->set_op_type("Constant" ); |
2680 | indices->add_output(tile->getName().str() + "_indices" ); |
2681 | |
2682 | unsigned_t numDims = tile->getInput().dims().size(); |
2683 | |
2684 | DCHECK(repeats.size() == numDims); |
2685 | |
2686 | // Add Tensor type attribute. |
2687 | addValueAttribute(indices, "value" , llvm::makeArrayRef(repeats)); |
2688 | // Add indices as input to the Tile node |
2689 | proto->add_input(tile->getName().str() + "_indices" ); |
2690 | |
2691 | return Error::success(); |
2692 | } |
2693 | |
2694 | Error ONNXModelWriter::writeCumSum(const CumSumNode *node, GraphType &graph) { |
2695 | auto *proto = graph.add_node(); |
2696 | // Add dictionary entries. |
2697 | addValueAttribute(proto, "axis" , 0); |
2698 | addValueAttribute(proto, "exclusive" , node->getExclusive()); |
2699 | addValueAttribute(proto, "reverse" , node->getReverse()); |
2700 | |
2701 | return writeAllWithNode("CumSum" , node, graph, proto); |
2702 | } |
2703 | |
2704 | Error ONNXModelWriter::writeScatterData(const ScatterDataNode *node, |
2705 | GraphType &graph) { |
2706 | auto *proto = graph.add_node(); |
2707 | |
2708 | return writeAllWithNode("ScatterData" , node, graph, proto); |
2709 | } |
2710 | |
2711 | // Unsupported for export Glow nodes. |
2712 | #define DEF_UNSUPPORTED_STORAGE(NAME) \ |
2713 | Error ONNXModelWriter::write##NAME(const NAME *node, GraphType &) { \ |
2714 | return writeUnexpectedKind(node); \ |
2715 | } |
2716 | |
2717 | // Helper nodes. |
2718 | DEF_UNSUPPORTED_STORAGE(Placeholder) |
2719 | DEF_UNSUPPORTED_STORAGE(Constant) |
2720 | DEF_UNSUPPORTED_STORAGE(Storage) |
2721 | |
2722 | // Unsupported for export Glow nodes. |
2723 | #define DEF_UNSUPPORTED_NODE(NAME) \ |
2724 | Error ONNXModelWriter::write##NAME(const NAME##Node *node, GraphType &) { \ |
2725 | return writeUnexpectedKind(node); \ |
2726 | } |
2727 | |
2728 | DEF_UNSUPPORTED_NODE(BatchedPairwiseDotProduct) |
2729 | DEF_UNSUPPORTED_NODE(Broadcast) |
2730 | DEF_UNSUPPORTED_NODE(SGD) |
2731 | DEF_UNSUPPORTED_NODE(SparseLabelSplit) |
2732 | // Artificial node. |
2733 | DEF_UNSUPPORTED_NODE(Save) |
2734 | DEF_UNSUPPORTED_NODE(ExternalFunctionCall) |
2735 | // Gradient nodes. |
2736 | DEF_UNSUPPORTED_NODE(AddGrad) |
2737 | DEF_UNSUPPORTED_NODE(DivGrad) |
2738 | DEF_UNSUPPORTED_NODE(MulGrad) |
2739 | DEF_UNSUPPORTED_NODE(SubGrad) |
2740 | DEF_UNSUPPORTED_NODE(ReluGrad) |
2741 | DEF_UNSUPPORTED_NODE(TanhGrad) |
2742 | DEF_UNSUPPORTED_NODE(AvgPoolGrad) |
2743 | DEF_UNSUPPORTED_NODE(MaxPoolGrad) |
2744 | DEF_UNSUPPORTED_NODE(SigmoidGrad) |
2745 | DEF_UNSUPPORTED_NODE(SoftMaxGrad) |
2746 | DEF_UNSUPPORTED_NODE(LogSoftMaxGrad) |
2747 | DEF_UNSUPPORTED_NODE(RegressionGrad) |
2748 | DEF_UNSUPPORTED_NODE(ConvolutionGrad) |
2749 | DEF_UNSUPPORTED_NODE(CrossEntropyLoss) |
2750 | DEF_UNSUPPORTED_NODE(Convolution3DGrad) |
2751 | DEF_UNSUPPORTED_NODE(FullyConnectedGrad) |
2752 | DEF_UNSUPPORTED_NODE(CrossEntropyLossGrad) |
2753 | DEF_UNSUPPORTED_NODE(BatchNormalizationGrad) |
2754 | DEF_UNSUPPORTED_NODE(SparseLengthsSumGrad) |
2755 | DEF_UNSUPPORTED_NODE(SparseLengthsWeightedSumGrad) |
2756 | DEF_UNSUPPORTED_NODE(SigmoidCrossEntropyWithLogits) |
2757 | DEF_UNSUPPORTED_NODE(LocalResponseNormalizationGrad) |
2758 | DEF_UNSUPPORTED_NODE(AdaptiveAvgPoolGrad) |
2759 | DEF_UNSUPPORTED_NODE(BatchedPairwiseDotProductGrad) |
2760 | |
2761 | // Include backend-specific ONNX model writers. |
2762 | #include "glow/ONNXModelWriterIncludes.h" |
2763 | |
2764 | Error ONNXModelWriter::writeGlowCustomOperator(const Node *node, |
2765 | GraphType &graph) { |
2766 | ONNX_NAMESPACE::NodeProto *opProto = nullptr; |
2767 | |
2768 | switch (node->getKind()) { |
2769 | #include "glow/AutoGenNodesExport.h" |
2770 | default: |
2771 | return MAKE_ERR( |
2772 | strFormat("Unhandled Node for export: %s" , node->getName().data())); |
2773 | } |
2774 | RETURN_ERR_IF_NOT(opProto, "Did not have valid opProto." ); |
2775 | |
2776 | // If dumping a DAG then add partition names to each op that's written. |
2777 | if (dagMode_ && !isWritingConstFoldSubgraph()) { |
2778 | addValueAttribute(opProto, "partitionName" , F_->getName().str()); |
2779 | } |
2780 | |
2781 | // Check if there is backendSpecificNodeInfo for node, and if so include it. |
2782 | auto itF = backendSpecificNodeInfo_.find(node->getParent()); |
2783 | if (itF != backendSpecificNodeInfo_.end()) { |
2784 | auto itN = itF->second.find(node); |
2785 | if (itN != itF->second.end()) { |
2786 | // We found backend-specific node info, so add it to the opProto. |
2787 | for (const auto &optValPair : itN->second) { |
2788 | addValueAttribute(opProto, |
2789 | std::string(nodeOptSignifier) + "_" + |
2790 | optValPair.getKey().data(), |
2791 | optValPair.getValue()); |
2792 | } |
2793 | } |
2794 | } |
2795 | |
2796 | return Error::success(); |
2797 | } |
2798 | |
2799 | bool ONNXModelWriter::hasMultidirectionalBroadcast( |
2800 | const llvm::StringRef typeName) { |
2801 | // Before opset 7, broadcasting was unidirectional. |
2802 | if (opsetVersion_ > 6) { |
2803 | // List of ops that support multidirectional broadcast can be found at |
2804 | // https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md |
2805 | if ((typeName == "Add" ) || (typeName == "Sub" ) || (typeName == "Mul" ) || |
2806 | (typeName == "Div" ) || (typeName == "Equal" ) || |
2807 | (typeName == "Greater" ) || (typeName == "Less" ) || |
2808 | (typeName == "Max" ) || (typeName == "Mean" ) || (typeName == "Min" ) || |
2809 | (typeName == "Or" ) || (typeName == "Pow" ) || (typeName == "Sum" ) || |
2810 | (typeName == "Xor" )) { |
2811 | return true; |
2812 | } |
2813 | } |
2814 | return false; |
2815 | } |
2816 | |
2817 | } // namespace glow |
2818 | |