1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Exporter/ONNXModelWriter.h"
18#include "glow/Graph/Utils.h"
19#include "glow/Runtime/RuntimeTypes.h"
20#include "glow/Support/ZipUtils.h"
21
22#include <stack>
23
24#include "miniz.h"
25#include <google/protobuf/io/coded_stream.h>
26#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
27
28using namespace glow::runtime;
29using google::protobuf::RepeatedPtrField;
30
31namespace glow {
32#ifdef FACEBOOK_INTERNAL
33extern const char *revisionHash;
34#endif /* FACEBOOK_INTERNAL */
35
36#define NUM_FLOAT_DIGS 30
37
38namespace {
39template <bool IsInteger, bool IsEnum, typename T> struct AttributeAssigner {
40 static void assign(ONNX_NAMESPACE::AttributeProto *attr, const T &container);
41};
42
43// Specialization for llvm::ArrayRef<T> container types
44template <typename T>
45struct AttributeAssigner<false, false, llvm::ArrayRef<T>> {
46 static void assign(ONNX_NAMESPACE::AttributeProto *attr,
47 const llvm::ArrayRef<T> &container) {
48 attr->set_type(ONNX_NAMESPACE::AttributeProto::INTS);
49 for (auto value : container) {
50 attr->add_ints(value);
51 }
52 }
53};
54
55// Specialization for string type
56template <> struct AttributeAssigner<false, false, std::string> {
57 static void assign(ONNX_NAMESPACE::AttributeProto *attr,
58 const std::string &container) {
59 attr->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
60 attr->set_s(container);
61 }
62};
63
64// Specialization for StringRef type
65template <> struct AttributeAssigner<false, false, llvm::StringRef> {
66 static void assign(ONNX_NAMESPACE::AttributeProto *attr,
67 const llvm::StringRef container) {
68 attr->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
69 attr->set_s(container.str());
70 }
71};
72
73// Specialization for vector of strings.
74template <> struct AttributeAssigner<false, false, std::vector<std::string>> {
75 static void assign(ONNX_NAMESPACE::AttributeProto *attr,
76 const std::vector<std::string> &container) {
77 attr->set_type(ONNX_NAMESPACE::AttributeProto::STRINGS);
78 for (auto &str : container) {
79 attr->add_strings(str);
80 }
81 }
82};
83
84// Specialization for float type
85template <> struct AttributeAssigner<false, false, float> {
86 static void assign(ONNX_NAMESPACE::AttributeProto *attr,
87 const float &container) {
88 attr->set_type(ONNX_NAMESPACE::AttributeProto::FLOAT);
89 attr->set_f(container);
90 }
91};
92
93// Specialization for NodeValueArrayRef.
94template <> struct AttributeAssigner<false, false, NodeValueArrayRef> {
95 static void assign(ONNX_NAMESPACE::AttributeProto *attr,
96 const NodeValueArrayRef &container) {
97 attr->set_type(ONNX_NAMESPACE::AttributeProto::STRINGS);
98 for (size_t i = 0, e = container.size(); i < e; i++) {
99 attr->add_strings(container[i].generateNodeOutputName(
100 /* stripResNoFor0thInput */ true));
101 }
102 }
103};
104
105// Specialization for llvm::ArrayRef<float>.
106template <> struct AttributeAssigner<false, false, llvm::ArrayRef<float>> {
107 static void assign(ONNX_NAMESPACE::AttributeProto *attr,
108 const llvm::ArrayRef<float> &container) {
109 attr->set_type(ONNX_NAMESPACE::AttributeProto::FLOATS);
110 for (auto value : container) {
111 attr->add_floats(value);
112 }
113 }
114};
115
116// Specialization for int type
117template <typename T> struct AttributeAssigner<true, false, T> {
118 static void assign(ONNX_NAMESPACE::AttributeProto *attr, const T &container) {
119 attr->set_type(ONNX_NAMESPACE::AttributeProto::INT);
120 attr->set_i(container);
121 }
122};
123
124// Specialization for enums.
125template <typename T> struct AttributeAssigner<false, true, T> {
126 static void assign(ONNX_NAMESPACE::AttributeProto *attr, const T &container) {
127 attr->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
128 std::string storage;
129 llvm::raw_string_ostream stream(storage);
130 stream << container;
131 attr->set_s(stream.str());
132 }
133};
134
135template <typename T>
136void addValueAttribute(ONNX_NAMESPACE::NodeProto *proto,
137 const std::string &name, const T &container) {
138 auto *attr = proto->add_attribute();
139 attr->set_name(name);
140 AttributeAssigner<std::numeric_limits<T>::is_integer, std::is_enum<T>::value,
141 T>::assign(attr, container);
142}
143
144/// Adds the type attributes from \p ty to \p proto. \p ioNum, \p isInput, and
145/// \p addPrefix are used to format the name of the attribute.
146void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, TypeRef ty,
147 unsigned ioNum, bool isInput,
148 const std::string &addPrefix = "") {
149 // Add ElemKind.
150 auto *elemKindAttr = proto->add_attribute();
151 elemKindAttr->set_name(
152 getTypeAttrID(ioNum, elemKindSignifier, isInput, addPrefix));
153 AttributeAssigner<false, false, llvm::StringRef>::assign(
154 elemKindAttr, ty->getElementName());
155
156 // Add Shape.
157 addValueAttribute(proto,
158 getTypeAttrID(ioNum, shapeSignifier, isInput, addPrefix),
159 ty->dims());
160
161 // Non-standard strides need to be serialized.
162 if (!ty->hasStandardStrides()) {
163 addValueAttribute(
164 proto, getTypeAttrID(ioNum, stridesSignifier, isInput, addPrefix),
165 ty->strides());
166 }
167
168 // Write out scale/offset if quantized ElemKind.
169 if (isQuantizedElemKind(ty->getElementType())) {
170 addValueAttribute(proto,
171 getTypeAttrID(ioNum, qScaleSignifier, isInput, addPrefix),
172 ty->getScale());
173 addValueAttribute(
174 proto, getTypeAttrID(ioNum, qOffsetSignifier, isInput, addPrefix),
175 ty->getOffset());
176 }
177}
178
179/// Adds the type attributes from \p NV to \p proto. \p ioNum, \p isInput, and
180/// \p addPrefix are used to format the name of the attribute.
181void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, const NodeValue &NV,
182 unsigned ioNum, bool isInput,
183 const std::string &addPrefix = "") {
184 addTypeAttributes(proto, NV.getType(), ioNum, isInput, addPrefix);
185}
186
187/// Add the type attributes from the \p ioNum number input or output (depending
188/// on \p isInput) of \p N to \p proto. This includes the ElemKind, the Shape,
189/// and scale/offset if ElemKind is quantized. Note that 'i' or 'o' along with
190/// \p ioNum is prefixed onto the specific attribute being appended, as ops may
191/// have multiple inputs/outputs.
192void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, const Node *N,
193 unsigned ioNum, bool isInput) {
194 NodeValue NV = isInput ? N->getNthInput(ioNum) : N->getNthResult(ioNum);
195 return addTypeAttributes(proto, NV, ioNum, isInput);
196}
197
198/// Helper function to recursively rewind Tile \p node.
199/// Optionally, if provided fills out the repeats \p repeats.
200/// Returns the first Tile in a chain of Tiles.
201const TileNode *unwindTile(const TileNode *node, std::vector<size_t> *repeats,
202 ReportedNodes &reporter) {
203 // unwind Tile
204 // Keep track of detected <axis, count> pairs.
205 std::vector<std::pair<unsigned_t, unsigned_t>> info;
206 const TileNode *tile = node;
207 while (tile) {
208 // Insert counts and axises in reverse order,
209 // cause rewind algorithm navigates from the bottom to the top.
210 info.insert(info.begin(), {tile->getAxis(), tile->getCount()});
211
212 if (const auto *TN = llvm::dyn_cast<TileNode>(tile->getInput().getNode())) {
213 reporter.insert(TN);
214 tile = TN;
215 } else {
216 break;
217 }
218 }
219
220 if (repeats) {
221 unsigned_t numDims = tile->getInput().dims().size();
222 // axis is in a normal case will have [0, 1, ..., numDims - 1] values,
223 // in extreme case []. Find missing indices and insert count 1.
224 auto aB = info.begin();
225
226 for (unsigned_t i = 0; i < numDims; ++i, ++aB) {
227 if (aB == info.end() || aB->first != i) {
228 aB = info.insert(aB, {i, 1});
229 }
230 }
231
232 for (size_t b = 0, e = info.size(); b < e; ++b) {
233 repeats->push_back(info[b].second);
234 }
235 }
236 return tile;
237}
238
239/// Writes all outputs from Node \p node to protobuf \p proto.
240void findOutputNames(const Node *node, ONNX_TRAITS::GraphProto &graph,
241 std::function<void(const std::string &name)> &&callback) {
242 // Check if user is SaveNode
243 std::set<unsigned> saveResNo;
244 std::vector<std::pair<const SaveNode *, unsigned>> saveOutputs;
245 std::vector<int> resultUsers(node->getNumResults(), 0);
246 for (const auto &use : node->getUsers()) {
247 const auto *user = use.getUser();
248 unsigned resNo = 0;
249 for (unsigned b = 0, e = user->getNumInputs(); b < e; ++b) {
250 auto UNV = user->getNthInput(b);
251 if (node == UNV.getNode()) {
252 resNo = UNV.getResNo();
253 resultUsers[resNo]++;
254 break;
255 }
256 }
257
258 if (user->getKind() == Kinded::Kind::SaveNodeKind) {
259 // Use the associated placeholder's name.
260 const SaveNode *SN = llvm::cast<SaveNode>(user);
261 saveOutputs.emplace_back(SN, resNo);
262 }
263 }
264
265 // If saveNode is the only user of a result, we can just use save name as
266 // output name. Otherwise, we have to insert a Identity node to relay this
267 // output to save output
268 for (const auto &p : saveOutputs) {
269 if (resultUsers[p.second] == 1) {
270 callback(p.first->getPlaceholder()->getName().str());
271 saveResNo.insert(p.second);
272 } else {
273 auto *proto = graph.add_node();
274 proto->set_name(node->getName().str() + "_copy_" +
275 std::to_string(p.second));
276 proto->set_op_type("Identity");
277 proto->add_input(p.second == 0 ? node->getName().str()
278 : (node->getName().str() + "_out_" +
279 std::to_string(p.second)));
280 proto->add_output(p.first->getPlaceholder()->getName().str());
281 }
282 }
283
284 // write the other outputs, if any
285 for (unsigned b = 0, e = node->getNumResults(); b < e; ++b) {
286 if (saveResNo.count(b)) {
287 continue;
288 }
289 if (b == 0) {
290 callback(node->getName().str());
291 } else {
292 callback(node->getName().str() + "_out_" + std::to_string(b));
293 }
294 }
295}
296
297/// Writes all outputs from Node \p node to protobuf \p proto.
298void outputsToProto(const Node *node, ONNX_TRAITS::GraphProto &graph,
299 ONNX_NAMESPACE::NodeProto *proto) {
300 findOutputNames(node, graph,
301 [&](const std::string &name) { proto->add_output(name); });
302}
303
304/// Writes all inputs from Node \p node to protobuf \p proto.
305void inputsToProto(const Node *node, ONNX_NAMESPACE::NodeProto *proto) {
306 for (unsigned b = 0, e = node->getNumInputs(); b < e; ++b) {
307 const auto NV = node->getNthInput(b);
308 auto resNo = NV.getResNo();
309 auto name = NV.getNode()->getName().str();
310 if (resNo) {
311 proto->add_input(name + "_out_" + std::to_string(b));
312 } else {
313 proto->add_input(name);
314 }
315 }
316}
317
318/// Write the output of the provided type only of node outputs.
319bool outputKindToProto(Kinded::Kind kind, const Node *node,
320 ONNX_TRAITS::GraphProto &graph,
321 ONNX_NAMESPACE::NodeProto *proto) {
322 bool found = false;
323 for (const auto &use : node->getUsers()) {
324 const auto *user = use.getUser();
325 if (user->getKind() == Kinded::Kind::SaveNodeKind) {
326 found = true;
327 const SaveNode *SN = llvm::cast<SaveNode>(user);
328 proto->add_output(SN->getPlaceholder()->getName().str());
329 break;
330 } else if (user->getKind() == kind) {
331 found = true;
332 outputsToProto(user, graph, proto);
333 }
334 }
335 return found;
336}
337
338/// Writes MatMul operators from Node \p node into
339/// provided graph protobuf \p graph, optionally reports intermediate nodes as
340/// visited, signaling that such nodes must be ignored, Depending on \p
341/// nodeKind, we can write either MatMul or BatchMatMul. \returns error.
342template <typename T>
343Error writeMatMulKind(const T *node, ONNX_TRAITS::GraphProto &graph,
344 const std::string &nodeKind) {
345 auto *proto = graph.add_node();
346 proto->set_name(node->getName().str());
347 proto->set_op_type(nodeKind);
348
349 Node *LHS = node->getLHS().getNode();
350 proto->add_input(LHS->getName().str());
351 Node *RHS = node->getRHS().getNode();
352 proto->add_input(RHS->getName().str());
353
354 outputsToProto(node, graph, proto);
355 return Error::success();
356}
357
358// Creates a Transpose Node as the Result of the \p node.
359// Reuses given \p proto pointer and create a new proto adding it to \p graph.
360// The permutation argument enables use of a different permutation. E.g. - for
361// tranposing 3D convolution with NCTHW2NTHWC.
362template <typename T>
363void writeTransposeResult(const T *node, ONNX_NAMESPACE::NodeProto *&proto,
364 ONNX_TRAITS::GraphProto &graph,
365 llvm::ArrayRef<unsigned_t> permutation = NCHW2NHWC) {
366 // Add dictionary entries.
367 llvm::ArrayRef<unsigned_t> container(permutation);
368 addValueAttribute(proto, "perm", container);
369 // Re-use proto for Transpose node.
370 auto newName = node->getName().str() + "_out_transpose";
371 proto->set_name(newName);
372 proto->set_op_type("Transpose");
373 proto->add_input(newName);
374
375 proto->add_output(node->getName().str());
376
377 // T node proto.
378 proto = graph.add_node();
379 proto->add_output(newName);
380}
381
382// Creates a Transpose Node as the Input \p input of the \p node.
383// Reuses given \p proto pointer and create a new proto adding it to \p graph.
384// The permutation argument enables use of a different permutation. E.g. - for
385// tranposing 3D convolution with NCTHW2NTHWC.
386void writeTransposeInput(const Node *node, const Node *input,
387 ONNX_NAMESPACE::NodeProto *proto,
388 ONNX_TRAITS::GraphProto &graph,
389 llvm::ArrayRef<unsigned_t> permutation = NHWC2NCHW) {
390 // Write "mirror" Transform input, i.e. NHWC2NCHW
391 auto newName =
392 node->getName().str() + "_" + input->getName().str() + "_in_transpose";
393 auto *transformProto = graph.add_node();
394 transformProto->set_name(newName);
395 transformProto->set_op_type("Transpose");
396
397 // Add dictionary entries.
398 llvm::ArrayRef<unsigned_t> container(permutation);
399 addValueAttribute(transformProto, "perm", container);
400 transformProto->add_input(input->getName().str());
401 transformProto->add_output(newName);
402 proto->add_input(newName);
403}
404
405/// Writes Arithmetic operators with name \p opName from Node \p node into
406/// provided graph protobuf \p graph. Arithmetic node may have been broadcasted,
407/// \p hasMultidirectionalBroadcast indicates the node can be multidirectional
408/// broadcast, if that's the case do not specify the axis or broadcast flag in
409/// protobuf, optionally reports intermediate nodes as visited, signaling that
410/// such nodes must be ignored, \returns error.
411template <typename T>
412Error writeArithmetic(const std::string &opName, const T *node,
413 ONNX_TRAITS::GraphProto &graph, ReportedNodes &reporter,
414 bool hasMultidirectionalBroadcast) {
415 auto *proto = graph.add_node();
416 proto->set_name(node->getName().str());
417 proto->set_op_type(opName);
418 outputsToProto(node, graph, proto);
419
420 auto LHS = node->getLHS();
421 if (const BroadcastNode *BN = llvm::dyn_cast<BroadcastNode>(LHS.getNode())) {
422 reporter.insert(BN);
423 LHS = BN->getInput();
424 }
425
426 int axis = -1;
427 auto RHS = node->getRHS();
428 if (const BroadcastNode *BN = llvm::dyn_cast<BroadcastNode>(RHS.getNode())) {
429 reporter.insert(BN);
430 RHS = BN->getInput();
431 axis = BN->getAxis();
432 }
433
434 proto->add_input(LHS.getNode()->getName().str());
435 proto->add_input(RHS.getNode()->getName().str());
436
437 // Check if the shapes of LHS and RHS are different and broadcast attribute is
438 // required.
439 if (LHS.dims() != RHS.dims() && !hasMultidirectionalBroadcast) {
440 addValueAttribute(proto, "axis", axis);
441 addValueAttribute(proto, "broadcast", 1UL);
442 }
443
444 return Error::success();
445}
446
447void tensorShapeFromInput(const std::string &name, TypeRef ty,
448 ONNX_NAMESPACE::ValueInfoProto *valueProto) {
449 valueProto->set_name(name);
450 auto *type = valueProto->mutable_type();
451 auto *tensorType = type->mutable_tensor_type();
452 tensorType->set_elem_type(ONNXModelWriter::convertType(*ty));
453 auto *tensorShape = tensorType->mutable_shape();
454 const auto &dims = ty->dims();
455 for (unsigned b = 0, e = dims.size(); b < e; ++b) {
456 auto *tensorDims = tensorShape->add_dim();
457 tensorDims->set_dim_value(dims[b]);
458 }
459}
460
461/// Creates the list Nodes in the reverse order of the required order for ONNX.
462class ReverseGraphWalker {
463 /// A post-order list of nodes.
464 std::vector<const Node *> reverseOrder_;
465 /// A set of visited nodes.
466 std::unordered_set<const Node *> visited_;
467
468 void visit(Function &F) {
469 // Write constants first, even they should be placed at the end of the list,
470 // we can visit them first, cause they will be written into ONNX
471 // initializers, not nodes.
472 for (const auto *C : F.findConstants()) {
473 reverseOrder_.push_back(C);
474 visited_.insert(C);
475 }
476 // Start visiting all root nodes, i.e. nodes that do not have any users.
477 for (auto &N : F.getNodes()) {
478 if (N.getNumUsers() == 0) {
479 visitIteratively(&N);
480 }
481 }
482 }
483
484 void visitIteratively(const Node *rootNode) {
485 std::stack<const Node *> st;
486 st.push(rootNode);
487 while (st.size()) {
488 auto *N = st.top();
489 st.pop();
490
491 // Check is node has been visited already.
492 if (visited_.count(N)) {
493 continue;
494 }
495
496 if (N->getKind() == Kinded::Kind::PlaceholderKind) {
497 // For Placeholder don't visit uses.
498 visited_.insert(N);
499 reverseOrder_.push_back(N);
500 continue;
501 }
502
503 // First visit all users, if any.
504 bool continueEarly = false;
505 for (const auto &use : N->getUsers()) {
506 const auto *UN = use.getUser();
507 // Check vacancy.
508 if (visited_.count(UN)) {
509 continue;
510 }
511 st.push(UN);
512 continueEarly = true;
513 }
514 if (continueEarly) {
515 continue;
516 }
517
518 // Visit current node, if it's still vacant.
519 if (!visited_.count(N)) {
520 visited_.insert(N);
521 reverseOrder_.push_back(N);
522 }
523
524 // And then visit inputs of the current node.
525 for (unsigned b = 0, e = N->getNumInputs(); b < e; ++b) {
526 auto *UN = N->getNthInput(b).getNode();
527 if (visited_.count(UN)) {
528 continue;
529 }
530 st.push(UN);
531 }
532
533 // Additionally visit the predicate input if it exists.
534 if (N->hasPredicate()) {
535 auto *UN = N->getPredicate().getNode();
536 if (!visited_.count(UN)) {
537 st.push(UN);
538 }
539 }
540 }
541 }
542
543public:
544 explicit ReverseGraphWalker(Function &F) { visit(F); }
545
546 llvm::ArrayRef<const Node *> getNodes() const { return reverseOrder_; }
547};
548
549template <typename T>
550static void addAttrToDocString(T *proto, const std::string &attrName,
551 llvm::StringRef attrVal) {
552 *(proto->mutable_doc_string()) += std::string(1, startChar) + attrName +
553 std::string(1, sepChar) + attrVal.str();
554}
555
556} // namespace
557
558Error ONNXModelWriter::insertLoaderNameUniqueOffsetMetadata(
559 llvm::StringMap<std::string> &extraMetadataProps,
560 const OriginNameToTQPMap &map) {
561 RETURN_ERR_IF_NOT(!extraMetadataProps.count("OriginNameToTQPMap"),
562 "Already had OriginNameToTQPMap");
563 std::string str;
564 for (const auto &nameTQP : map) {
565 str += nameTQP.first + offsetSepSig +
566 std::to_string(nameTQP.second.offset) + offsetEndSig;
567 }
568 extraMetadataProps.try_emplace(originNameToUniqueOffsetMappingSignifier, str);
569 return Error::success();
570}
571
572bool ONNXModelWriter::isIntermediatePHForDAG(const Placeholder *PH) {
573 if (!dagMode_) {
574 return false;
575 }
576
577 bool isInputPH = false, isOutputPH = false;
578 for (const auto &use : PH->getUsers()) {
579 const auto *userN = use.getUser();
580 // Only consider users from Functions in the DAG.
581 const Function *userF = userN->getParent();
582 if (!functionsFromDAG_.count(userF)) {
583 continue;
584 }
585 const bool currIsInputPH = isInput(PH, *userF);
586 const bool currIsOutputPH = isOutput(PH, *userF);
587 // Check this is not quantization profiling or training cases.
588 assert(
589 !(currIsInputPH && currIsOutputPH) &&
590 "Do not support PHs that are input and output to a single Function.");
591 if (currIsInputPH) {
592 isInputPH = true;
593 }
594 if (currIsOutputPH) {
595 isOutputPH = true;
596 }
597 }
598
599 // If the PH is both an input and an output for the Functions in the DAG then
600 // it must be an intermediate.
601 return isInputPH && isOutputPH;
602}
603
604/// Reverses the order of the nodes in \p nodes.
605static void
606reverseNodesOrder(RepeatedPtrField<ONNX_NAMESPACE::NodeProto> &nodes) {
607 for (size_t i = 0, n = nodes.size(); i < n / 2; ++i) {
608 nodes.SwapElements(i, n - i - 1);
609 }
610}
611
612bool ONNXModelWriter::isWritingConstFoldSubgraph() {
613 return graphProtoRoot_ != graphProto_;
614}
615
616Error ONNXModelWriter::writeConstantFoldingSubgraph(const Constant *C,
617 SaveNode *SN) {
618 Function *constFoldFunction = SN->getParent();
619
620 // If we already wrote out this Function we can return early.
621 if (!processedConstFoldFunctions_.insert(constFoldFunction).second) {
622 return Error::success();
623 }
624
625 // Create new constant folding Node, which we add the subgraph to. Always add
626 // to the root as these are all loaded before loading any ops.
627 auto *constFoldNodeProto = graphProtoRoot_->add_node();
628 constFoldNodeProto->set_op_type(constFoldSubgraphNodeName);
629 const char *constFoldNodeName = constFoldFunction->getName().data();
630 constFoldNodeProto->set_name(constFoldNodeName);
631
632 // Now add the constant folding subgraph to the node.
633 auto *constFoldNodeSubgraph = constFoldNodeProto->add_attribute();
634 constFoldNodeSubgraph->set_name("ConstFoldSubgraph");
635 constFoldNodeSubgraph->set_type(ONNX_NAMESPACE::AttributeProto::GRAPH);
636
637 // Temporarily swap in the constant folding function and the graph proto from
638 // the constant folding subgraph node so we can write into it.
639 ONNX_TRAITS::GraphProto *origGraphProto = graphProto_;
640 graphProto_ = constFoldNodeSubgraph->mutable_g();
641 Function *origF = F_;
642 F_ = constFoldFunction;
643 // Make sure to restore original state of the writer when exiting this scope.
644 ScopeGuard restoreOrigStateGuard([&]() {
645 graphProto_ = origGraphProto;
646 F_ = origF;
647 });
648
649 // Now that we have setup the constant folding Function and proto to write
650 // into, write the function in.
651 RETURN_IF_ERR(writeFunction());
652
653 // Now set the output of the ConstFoldSubgraph node. Only output is the
654 // Constant it generates. Note that there are no inputs, as Constant inputs
655 // are self-contained to this graph as initializers.
656 constFoldNodeProto->add_output(C->getName().data());
657 addTypeAttributes(constFoldNodeProto, SN->getOutput(), SaveNode::OutputIdx,
658 /* isInput */ false);
659
660 // Reverse the order of Nodes since we wrote them in reverse order.
661 reverseNodesOrder(*graphProto_->mutable_node());
662
663 return Error::success();
664}
665
666Error ONNXModelWriter::writeFunction() {
667 // Use pre order graph traversal.
668 // If node is constant with uses, turned into "input" and create a tensor
669 // If node is placeholder with uses, turned into "input" with tensor shape,
670 // except the case when placeholder has use as SaveNode.
671 // Otherwise call common operators method or special operators and write
672 // protobuf inputs from node inputs and protobuf outputs from uses.
673
674 ReverseGraphWalker visitor(*F_);
675 for (const auto *N : visitor.getNodes()) {
676 if (reportedNodes_.count(N)) {
677 continue;
678 }
679
680 const auto kind = N->getKind();
681 // Handle placeholders cases.
682 if (kind == Kinded::Kind::PlaceholderKind) {
683 const auto *PH = llvm::cast<Placeholder>(N);
684 if (!isInput(PH, *F_)) {
685 // Storage as an input to SaveNode - ignore it.
686 continue;
687 }
688 // Skip Placeholders that are only intermediates -- these are
689 // understood/recreated during reimporting based on an op's partition.
690 if (isIntermediatePHForDAG(PH)) {
691 continue;
692 }
693 // Write global input, output only tensor shape.
694 RETURN_IF_EXPECTED_IS_ERR(createProtoForIO(PH, /* isInput */ true));
695 } else if (kind == Kinded::Kind::ConstantKind) {
696 // Write global initializer, output tensor bytes.
697 const auto *C = llvm::cast<Constant>(N);
698
699 // Check if this constant came from constant folding that we recorded and
700 // want to serialize in the model. If so then we process it before the
701 // Constant itself so that it will be loaded first.
702 auto constFoldRecordIt = constFoldRecord_.find(C);
703 if (constFoldRecordIt != constFoldRecord_.end()) {
704 RETURN_IF_ERR(
705 writeConstantFoldingSubgraph(C, constFoldRecordIt->second));
706 }
707
708 // Note: Always add initializers to the root graph proto.
709 auto *tensorProto = addInitializer(*graphProtoRoot_);
710
711 // If we added a constant folding Node recording the constant folding that
712 // generated this initializer then point to it in the initializer.
713 if (constFoldRecordIt != constFoldRecord_.end()) {
714 SaveNode *SN = constFoldRecordIt->second;
715 // Point the original initializerProto to this Node so that it knows
716 // where to find the Function to replay the constant folding, along with
717 // the resNo needed from the Function.
718 auto *constFoldNodeNameProto = tensorProto->add_external_data();
719 constFoldNodeNameProto->set_key("ConstFoldNodeName");
720 constFoldNodeNameProto->set_value(SN->getParent()->getName().data());
721 auto *resNoProto = tensorProto->add_external_data();
722 resNoProto->set_key("ConstFoldResNo");
723 resNoProto->set_value(std::to_string(SN->getInput().getResNo()));
724 }
725
726 // When using useGlowCustomOps we always use generateNodeOutputName for
727 // all inputs and outputs.
728 tensorProto->set_name(useGlowCustomOps_
729 ? C->getOutput().generateNodeOutputName(
730 /* stripResNoFor0thInput */ true)
731 : C->getName().str());
732 writeTensor(C->getPayload(), tensorProto, useGlowCustomOps_,
733 includeConstantData_);
734 if (useGlowCustomOps_) {
735 // Also include the layout in the initializer to be loaded later.
736 addAttrToDocString(tensorProto, layoutSignifier, C->getLayout());
737 }
738 } else if (kind == Kinded::Kind::SaveNodeKind) {
739 // Save node case, find input and use its name as a global output,
740 // output only shape.
741 const SaveNode *SN = llvm::cast<SaveNode>(N);
742 const auto *PH = SN->getPlaceholder();
743
744 // If useGlowCustomOps then we need to add an Identity to map the name
745 // from generateNodeOutputName() to the name of the Placeholder.
746 if (useGlowCustomOps_) {
747 auto *proto = graphProto_->add_node();
748 proto->set_op_type("Identity");
749 proto->set_name(SN->getName().data());
750 proto->add_input(SN->getInput().generateNodeOutputName(
751 /* stripResNoFor0thInput */ true));
752 proto->add_output(PH->getName().data());
753 addTypeAttributes(proto, SN->getInput(), SaveNode::InputIdx,
754 /* isInput */ true);
755 addTypeAttributes(proto, SN->getOutput(), SaveNode::OutputIdx,
756 /* isInput */ false);
757 // If dumping a DAG then add partition names to each op that's written.
758 // Also only do so when not writing a const fold subgraph.
759 if (dagMode_ && !isWritingConstFoldSubgraph()) {
760 addValueAttribute(proto, "partitionName", F_->getName().str());
761 // Skip writing Placeholders that are only intermediates -- these are
762 // understood/recreated during reimporting based on an op's partition.
763 if (isIntermediatePHForDAG(PH)) {
764 addValueAttribute(proto, "isIntermediateOutputForDAG", true);
765 continue;
766 }
767 }
768 }
769
770 ONNX_NAMESPACE::ValueInfoProto *out;
771 ASSIGN_VALUE_OR_RETURN_ERR(out,
772 createProtoForIO(PH, /* isInput */ false));
773
774 // Use the doc string to specify the name that should be used for the
775 // SaveNode to ensure it's kept the same between export and import.
776 addAttrToDocString(out, saveNameSignifier, SN->getName());
777 } else if (useGlowCustomOps_) {
778 RETURN_IF_ERR(writeGlowCustomOperator(N, *graphProto_));
779 } else {
780 RETURN_IF_ERR(writeOperator(N, *graphProto_));
781 }
782 reportedNodes_.insert(N);
783 }
784
785 return Error::success();
786}
787
788void ONNXModelWriter::setupNewProto() {
789 modelProto_.set_ir_version(irVersion_);
790 modelProto_.set_producer_name(useGlowCustomOps_ ? "GlowONNXModelWriter"
791 : "ONNXModelWriter");
792 auto *opsetImportProto = modelProto_.add_opset_import();
793 opsetImportProto->set_version(opsetVersion_);
794 graphProto_ = modelProto_.mutable_graph();
795 graphProto_->set_name("glow");
796 graphProtoRoot_ = graphProto_;
797}
798
799static Error writeModelToString(const ::google::protobuf::Message &modelProto,
800 bool textMode, std::string *outputStringPtr) {
801 if (textMode) {
802 RETURN_ERR_IF_NOT(google::protobuf::TextFormat::PrintToString(
803 modelProto, outputStringPtr),
804 "Error writing to string");
805 } else {
806 ::google::protobuf::io::StringOutputStream zeroCopyOutput(outputStringPtr);
807 ::google::protobuf::io::CodedOutputStream codedOutput(&zeroCopyOutput);
808 modelProto.SerializeToCodedStream(&codedOutput);
809 RETURN_ERR_IF_NOT(!codedOutput.HadError(),
810 "Can't write to the output string",
811 ErrorValue::ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR);
812 }
813 return Error::success();
814}
815
816Error ONNXModelWriter::finalizeAndWriteProto(llvm::StringRef name) {
817 // Nodes have been added in a reverse order from SaveNode up to the inputs,
818 // we need to rearrange all nodes in the reverse order before serialization.
819 auto *nodes = graphProto_->mutable_node();
820 reverseNodesOrder(*nodes);
821
822 // useGlowCustomOps uses Identities differently than the normal writer, so
823 // we do not want to do this if so.
824 if (!useGlowCustomOps_) {
825 // We need to swap back Identity node with the next non-Identity node
826 // since we append Identity node to tap out the intermediate results
827 for (int i = 0, n = nodes->size(); i < n - 1; ++i) {
828 if (nodes->Get(i).op_type() == "Identity") {
829 int k = 1;
830 while (i + k < n) {
831 if (nodes->Get(i + k).op_type() != "Identity") {
832 break;
833 }
834 ++k;
835 }
836 nodes->SwapElements(i, i + k);
837 i += k;
838 }
839 }
840 } else {
841#ifdef FACEBOOK_INTERNAL
842 addMetadataProp("GlowRevisionHash", revisionHash);
843#endif /* FACEBOOK_INTERNAL */
844 }
845
846 // If we have loadedPHNames_, then we buffered the non-static PH IO protobuf
847 // in inputValueInfos_ and outputValueInfos_. Now we write it all out in order
848 // according to indices provided in loadedPHNames_.
849 if (loadedPHNames_) {
850 const bool ioNumMismatch =
851 (inputValueInfos_.size() + outputValueInfos_.size() !=
852 loadedPHNames_->size());
853
854 // If total number of inputs and outputs doesn't match the number of
855 // placeholders, then log an error message and let the next for-loop find
856 // the culprits.
857 if (ioNumMismatch) {
858 LOG(ERROR) << "Number of buffered inputs and outputs "
859 << (inputValueInfos_.size() + outputValueInfos_.size())
860 << " didn't match the number of loadedPHNames "
861 << loadedPHNames_->size();
862 }
863
864 // If we have the loaded PH names map, then we need to reorder the inputs
865 // and outputs to follow the same order as provided in the loadedPHNames_.
866 std::vector<const Placeholder *> orderedInputs(inputValueInfos_.size());
867 std::vector<const Placeholder *> orderedOutputs(outputValueInfos_.size());
868 for (const auto &pair : *loadedPHNames_) {
869 const Placeholder *PH = pair.first;
870 const unsigned orderIdx = pair.second.second;
871 if (inputValueInfos_.count(PH)) {
872 orderedInputs[orderIdx] = PH;
873 } else if (outputValueInfos_.count(PH)) {
874 orderedOutputs[orderIdx] = PH;
875 } else {
876 return MAKE_ERR("PH must either be in inputs or outputs: " +
877 PH->getName().str());
878 }
879 }
880
881 // If didn't find bad placeholders, then it must be some bad inputs/outputs.
882 if (ioNumMismatch) {
883 return MAKE_ERR("Found some inputs/outputs that don't have corresponding "
884 "placeholders");
885 }
886
887 // Now have IO in order matching loadedPHNames_, so finally write them out.
888 for (const Placeholder *PH : orderedInputs) {
889 auto *inputProto = graphProto_->add_input();
890 inputProto->MergeFrom(inputValueInfos_[PH]);
891 }
892 for (const Placeholder *PH : orderedOutputs) {
893 auto *outputProto = graphProto_->add_output();
894 outputProto->MergeFrom(outputValueInfos_[PH]);
895 }
896 }
897
898 if (zipMode_) {
899 RETURN_ERR_IF_NOT(
900 outputStringPtr_ == nullptr,
901 "OnnxModelWriter write to string for zip mode not supported");
902 const bool compressed = false;
903 ZipWriter zip(&ff_, name.str());
904 std::stringstream ss;
905 ss << initializers_.size() << "\n";
906 zip.writeRecord("weights", ss.str().c_str(), ss.str().size(), compressed);
907 std::string largeBuffer;
908 int i = 0;
909 // This part is probably quite inefficient as we are deserializing the
910 // protobuf to a char buffer and then put it to zip stream. I didn't dig
911 // enough to see if we can deserialize it into zip stream directly.
912 for (const auto &t : initializers_) {
913 std::stringstream nm;
914 nm << "weight_" << i++;
915 t.SerializeToString(&largeBuffer);
916 zip.writeRecord(nm.str(), largeBuffer.c_str(), largeBuffer.size(),
917 compressed);
918 }
919 if (textMode_) {
920 google::protobuf::TextFormat::PrintToString(modelProto_, &largeBuffer);
921 } else {
922 modelProto_.SerializeToString(&largeBuffer);
923 }
924 zip.writeRecord("model", largeBuffer.c_str(), largeBuffer.size(),
925 compressed);
926 zip.writeEndOfFile();
927 ff_.flush();
928 ff_.close();
929 return Error::success();
930 } else {
931 if (outputStringPtr_ != nullptr) {
932 return writeModelToString(modelProto_, textMode_, outputStringPtr_);
933 } else {
934 return writeModel(modelProto_, textMode_);
935 }
936 }
937}
938
939ONNXModelWriter::ONNXModelWriter(
940 const std::string &modelFilename, Function &F, size_t irVersion,
941 size_t opsetVersion, Error *errPtr, bool textMode, bool zipMode,
942 bool useGlowCustomOps, bool includeConstantData,
943 const llvm::StringMap<std::string> &extraMetadataProps,
944 const ConstantFoldingRecordMap &constFoldRecord,
945 const BackendSpecificNodeInfo &backendSpecificNodeInfo,
946 std::string *outputStringPtr)
947 : CommonOperatorWriter(modelFilename, &F, errPtr,
948 outputStringPtr == nullptr),
949 irVersion_(irVersion), opsetVersion_(opsetVersion), zipMode_(zipMode),
950 textMode_(textMode), includeConstantData_(includeConstantData),
951 extraMetadataProps_(extraMetadataProps),
952 useGlowCustomOps_(useGlowCustomOps), dagMode_(false),
953 constFoldRecord_(constFoldRecord),
954 backendSpecificNodeInfo_(backendSpecificNodeInfo),
955 loadedPHNames_(nullptr), staticPlaceholderTypes_(nullptr),
956 outputStringPtr_(outputStringPtr) {
957 // If errPtr already contains an error then don't continue with constructor.
958 if (errPtr && *errPtr) {
959 return;
960 }
961
962 // Lambda to setup the ONNXModelWriter and return any Errors that were raised.
963 auto setup = [&]() -> Error {
964 setupNewProto();
965 for (auto &prop : extraMetadataProps_) {
966 addMetadataProp(prop.getKey().str(), prop.second);
967 }
968
969 RETURN_IF_ERR(writeFunction());
970
971 return finalizeAndWriteProto(F_->getName());
972 };
973
974 if (errPtr) {
975 *errPtr = setup();
976 } else {
977 EXIT_ON_ERR(setup());
978 }
979}
980
981/// Collect nodes from the DAG from \p root in post order in \p postOrder.
982/// Gather visited nodes in \p visited.
983static void collectNodesPostOrder(const DAGNode *root,
984 std::unordered_set<const DAGNode *> &visited,
985 std::vector<const DAGNode *> &postOrder) {
986 if (root == nullptr) {
987 return;
988 }
989 visited.insert(root);
990 for (auto &c : root->children) {
991 if (visited.count(c) == 0) {
992 collectNodesPostOrder(c, visited, postOrder);
993 }
994 }
995 postOrder.push_back(root);
996}
997
998void ONNXModelWriter::addMetadataProp(const std::string &key,
999 const std::string &val) {
1000 auto *prop = modelProto_.add_metadata_props();
1001 prop->set_key(key);
1002 prop->set_value(val);
1003}
1004
1005Error ONNXModelWriter::writePartitionAndMetadataProps(
1006 Module &mod, llvm::ArrayRef<const DAGNode *> postOrder) {
1007 // Add number of partitions to proto.
1008 addMetadataProp("numPartitions", std::to_string(postOrder.size()));
1009
1010 for (size_t i = 0, e = postOrder.size(); i < e; i++) {
1011 const auto *dagNode = postOrder[i];
1012 F_ = mod.getFunction(dagNode->name);
1013 RETURN_ERR_IF_NOT(F_, "Function was not valid from DAGList");
1014
1015 // Write the nodes of the Function.
1016 RETURN_IF_ERR(writeFunction());
1017
1018 // Add to the proto the partition name and related meta info:
1019 const std::string partIdPrefix = getPartitionIdPrefix(i);
1020
1021 // name of partition:
1022 addMetadataProp(partIdPrefix + nameSignifier, dagNode->name);
1023
1024 // logicalDevices of partition:
1025 addMetadataProp(partIdPrefix + numLogicalDevicesSignifier,
1026 std::to_string(dagNode->logicalDevices.size()));
1027 for (size_t j = 0, f = dagNode->logicalDevices.size(); j < f; j++) {
1028 addMetadataProp(partIdPrefix + getLogicalDeviceSignfier(j),
1029 std::to_string(dagNode->logicalDevices[j]));
1030 }
1031
1032 // backendName of partition:
1033 addMetadataProp(partIdPrefix + backendNameSignifier, dagNode->backendName);
1034
1035 // size of partition:
1036 addMetadataProp(partIdPrefix + sizeSignifier,
1037 std::to_string(dagNode->size));
1038
1039 // backendHints.executionUnits of partition:
1040 addMetadataProp(partIdPrefix + executionUnitsSignifier,
1041 std::to_string(dagNode->backendHints.executionUnits));
1042
1043 // backendHints.SRAMPrioritization of partition not supported:
1044 assert(dagNode->backendHints.SRAMPrioritization.size() == 0 &&
1045 "Do not support SRAMPrioritization saving from DAGNode");
1046
1047 // backendSpecificOpts of partition:
1048 addMetadataProp(partIdPrefix + numBackendSpecificOptsSignifier,
1049 std::to_string(dagNode->backendSpecificOpts.size()));
1050 size_t j = 0;
1051 for (const auto &keyVal : dagNode->backendSpecificOpts) {
1052 addMetadataProp(partIdPrefix + getBackendSpecificOptKeySignifier(j),
1053 keyVal.first);
1054 addMetadataProp(partIdPrefix + getBackendSpecificOptValSignifier(j),
1055 keyVal.second);
1056 j += 1;
1057 }
1058
1059 // replicationCount of partition:
1060 addMetadataProp(partIdPrefix + replicationCountSignifier,
1061 std::to_string(dagNode->replicationCount));
1062 }
1063
1064 return Error::success();
1065}
1066
1067ONNXModelWriter::ONNXModelWriter(
1068 const std::string &modelFilename, DAGListTy &dagList, size_t irVersion,
1069 size_t opsetVersion, Error *errPtr, bool textMode, bool zipMode,
1070 bool includeConstantData,
1071 const llvm::StringMap<std::string> &extraMetadataProps,
1072 const ConstantFoldingRecordMap &constFoldRecord,
1073 const BackendSpecificNodeInfo &backendSpecificNodeInfo,
1074 const LoadedPlaceholderNameMap *loadedPHNames,
1075 const std::map<std::string, Type> *staticPlaceholderTypes,
1076 std::string *outputStringPtr)
1077 : CommonOperatorWriter(modelFilename, nullptr, errPtr,
1078 outputStringPtr == nullptr),
1079 irVersion_(irVersion), opsetVersion_(opsetVersion), zipMode_(zipMode),
1080 textMode_(textMode), includeConstantData_(includeConstantData),
1081 extraMetadataProps_(extraMetadataProps), useGlowCustomOps_(true),
1082 dagMode_(true), constFoldRecord_(constFoldRecord),
1083 backendSpecificNodeInfo_(backendSpecificNodeInfo),
1084 loadedPHNames_(loadedPHNames),
1085 staticPlaceholderTypes_(staticPlaceholderTypes),
1086 outputStringPtr_(outputStringPtr) {
1087 // If errPtr already contains an error then don't continue with constructor.
1088 if (errPtr && *errPtr) {
1089 return;
1090 }
1091
1092 // Lambda to setup the ONNXModelWriter and return any Errors that were raised.
1093 auto setup = [&]() -> Error {
1094 setupNewProto();
1095 for (auto &prop : extraMetadataProps_) {
1096 addMetadataProp(prop.getKey().str(), prop.second);
1097 }
1098
1099 RETURN_ERR_IF_NOT(dagList.size() == 1, "Expect only one DAG.");
1100 const auto &dag = *dagList.begin();
1101
1102 Module &mod = *dag.root->module;
1103
1104 // Iterate over the DAG in post-order; Nodes per Function are written in
1105 // reverse order and reversed at the end, so this follows suit.
1106 std::unordered_set<const DAGNode *> visited;
1107 std::vector<const DAGNode *> postOrder;
1108 collectNodesPostOrder(dag.root.get(), visited, postOrder);
1109 // Remove the root node from the list as we don't care about it.
1110 postOrder.pop_back();
1111 for (const DAGNode *dagNode : postOrder) {
1112 functionsFromDAG_.insert(mod.getFunction(dagNode->name));
1113 }
1114
1115 RETURN_IF_ERR(writePartitionAndMetadataProps(mod, postOrder));
1116
1117 return finalizeAndWriteProto(dag.root->name);
1118 };
1119
1120 if (errPtr) {
1121 *errPtr = setup();
1122 } else {
1123 EXIT_ON_ERR(setup());
1124 }
1125}
1126
1127ONNXModelWriter::TensorType *ONNXModelWriter::addInitializer(GraphType &g) {
1128 if (zipMode_) {
1129 initializers_.emplace_back();
1130 return &initializers_.back();
1131 } else {
1132 return g.add_initializer();
1133 }
1134}
1135
1136ONNXModelWriter::TensorType::DataType
1137ONNXModelWriter::convertType(const Type &glowType) {
1138 switch (glowType.getElementType()) {
1139 case ElemKind::FloatTy:
1140 return TensorType::FLOAT;
1141 case ElemKind::Float16Ty:
1142 return TensorType::FLOAT16;
1143 case ElemKind::BFloat16Ty:
1144 return TensorType::BFLOAT16;
1145 case ElemKind::Float64Ty:
1146 return TensorType::DOUBLE;
1147 case ElemKind::Int8QTy:
1148 return TensorType::INT8;
1149 case ElemKind::UInt8FusedQTy:
1150 case ElemKind::UInt8FusedFP16QTy:
1151 case ElemKind::UInt4FusedFP16QTy:
1152 case ElemKind::UInt4FusedQTy:
1153 case ElemKind::UInt8QTy:
1154 case ElemKind::UInt8ITy:
1155 return TensorType::UINT8;
1156 case ElemKind::Int16QTy:
1157 return TensorType::INT16;
1158 case ElemKind::Int32QTy:
1159 case ElemKind::Int32ITy:
1160 return TensorType::INT32;
1161 case ElemKind::Int64QTy:
1162 case ElemKind::Int64ITy:
1163 return TensorType::INT64;
1164 case ElemKind::BoolTy:
1165 return TensorType::BOOL;
1166 }
1167 LOG(DFATAL) << "Cannot reach here.";
1168 return TensorType::UNDEFINED; // Avoids a compilation warning.
1169}
1170
1171/// Add quantization parameters to the doc_string in \p out based on \p type.
1172template <typename T>
1173static void addQuantParamsToDocString(T *out, const Type &type) {
1174 addAttrToDocString(out, qScaleSignifier,
1175 strFormat("%.*f", NUM_FLOAT_DIGS, type.getScale()));
1176 addAttrToDocString(out, qOffsetSignifier, std::to_string(type.getOffset()));
1177}
1178
1179/// Add strides to the doc_string in \p out based on \p type.
1180template <typename T>
1181static void addStridesToDocString(T *out, const Type &type) {
1182 // Non-standard strides need to be serialized.
1183 if (type.hasStandardStrides()) {
1184 return;
1185 }
1186 const auto &strides = type.strides();
1187 std::string stridesStr;
1188 std::string delim;
1189 for (const auto &stride : strides) {
1190 stridesStr.append(delim);
1191 stridesStr.append(std::to_string(stride));
1192 delim = ",";
1193 }
1194 addAttrToDocString(out, stridesSignifier, stridesStr);
1195}
1196
1197void ONNXModelWriter::writeTensor(const Tensor &T, TensorType *out,
1198 bool useGlowCustomOps, bool includeData) {
1199 const auto &type = T.getType();
1200 out->set_data_type(convertType(type));
1201 const auto &dims = type.dims();
1202 for (unsigned b = 0, e = dims.size(); b < e; ++b) {
1203 out->add_dims(dims[b]);
1204 }
1205
1206 if (includeData) {
1207 out->set_raw_data(T.getUnsafePtr(), type.getSizeInBytes());
1208 }
1209
1210 if (useGlowCustomOps) {
1211 addAttrToDocString(out, elemKindSignifier, type.getElementName());
1212 addStridesToDocString(out, type);
1213 }
1214
1215 if (type.isQuantizedType()) {
1216 if (useGlowCustomOps) {
1217 addQuantParamsToDocString(out, type);
1218 } else {
1219 // Format is ElemKind:scale:offset.
1220 out->set_doc_string(strFormat("%s:%.*f:%d", type.getElementName().data(),
1221 NUM_FLOAT_DIGS, T.getType().getScale(),
1222 T.getType().getOffset()));
1223 }
1224 }
1225}
1226
1227Expected<ONNX_NAMESPACE::ValueInfoProto *>
1228ONNXModelWriter::createProtoForIO(const Placeholder *PH, bool isInput) {
1229 // If loadedPHNames_ then we have a specific order we need to write out IO
1230 // protos. If so, buffer non-static IO that is not part of a constant folding
1231 // subgraph into inputValueInfos_/outputValueInfos_ to later be written out in
1232 // order inside finalizeAndWriteProto() based on loadedPHNames_.
1233 ONNX_NAMESPACE::ValueInfoProto *valueProto;
1234 if (!loadedPHNames_ || isWritingConstFoldSubgraph() || PH->isStatic()) {
1235 valueProto = isInput ? graphProto_->add_input() : graphProto_->add_output();
1236 } else {
1237 valueProto = isInput ? &inputValueInfos_[PH] : &outputValueInfos_[PH];
1238 }
1239
1240 tensorShapeFromInput(PH->getName().str(), PH->getType(), valueProto);
1241
1242 if (useGlowCustomOps_) {
1243 // Write out any meta information we need to for the Placeholder.
1244 addStridesToDocString(valueProto, *PH->getType());
1245 addAttrToDocString(valueProto, staticSignifier,
1246 std::to_string(PH->isStatic()));
1247 addAttrToDocString(valueProto, trainableSignifier,
1248 std::to_string(PH->isTraining()));
1249 addAttrToDocString(valueProto, layoutSignifier, PH->getLayout());
1250 addAttrToDocString(valueProto, elemKindSignifier,
1251 PH->getType()->getElementName());
1252
1253 // If we're writing out a Placeholder from the original input Function, then
1254 // expect to find a corresponding input loaded PH name if they are
1255 // provided. This is expected when the PH is not static (as otherwise it's
1256 // input as a weight), when the Function being written isn't a constant
1257 // folding subgraph (then we have PHs that are used just to save the const
1258 // folding result), and when the PH isn't intermediate (then it's only
1259 // visible/used by Glow when executing partitioned DAGs).
1260 if (loadedPHNames_ && !PH->isStatic() && !isWritingConstFoldSubgraph() &&
1261 !isIntermediatePHForDAG(PH)) {
1262 auto it = loadedPHNames_->find(PH);
1263 RETURN_ERR_IF_NOT(it != loadedPHNames_->end(),
1264 "Did not find associated loader name for " +
1265 PH->getName().str() + " while writing Function " +
1266 F_->getName().str());
1267 addAttrToDocString(valueProto, loaderNameSignifier, it->second.first);
1268 }
1269
1270 // If we have a type that was used for loading a static Placeholder, then
1271 // serialize that type into a dummy node.
1272 if (staticPlaceholderTypes_ && PH->isStatic()) {
1273 auto it = staticPlaceholderTypes_->find(PH->getName().data());
1274 RETURN_ERR_IF_NOT(it != staticPlaceholderTypes_->end(),
1275 "Did not find associated type for static PH " +
1276 PH->getName().str() + " while writing Function " +
1277 F_->getName().str());
1278
1279 // Create new static PH dummy node that carries the type that the static
1280 // PH was loaded with. Note it has no inputs or outputs, howeverr there is
1281 // a type appended for the output idx, and the node has the same name as
1282 // the static PH to use when reloading.
1283 auto *staticPHDummyNodeProto = graphProto_->add_node();
1284 staticPHDummyNodeProto->set_op_type(staticPHDummyNodeName);
1285 staticPHDummyNodeProto->set_name(PH->getName().data());
1286
1287 // Set the output type to be the one we found in staticPlaceholderTypes_.
1288 addTypeAttributes(staticPHDummyNodeProto, &it->second, Storage::OutputIdx,
1289 /* isInput */ false);
1290 }
1291
1292 // Also include quantization params if necessary.
1293 if (PH->getType()->isQuantizedType()) {
1294 addQuantParamsToDocString(valueProto, *PH->getType());
1295 }
1296 }
1297 return valueProto;
1298}
1299
1300Error ONNXModelWriter::writeAllWithNode(const std::string &opName,
1301 const Node *node, GraphType &graph,
1302 NodeType *proto) {
1303 proto->set_name(node->getName().str());
1304 proto->set_op_type(opName);
1305 inputsToProto(node, proto);
1306 outputsToProto(node, graph, proto);
1307 return Error::success();
1308}
1309
1310Error ONNXModelWriter::writeAll(const std::string &opName, const Node *node,
1311 GraphType &graph) {
1312 return writeAllWithNode(opName, node, graph, graph.add_node());
1313}
1314
1315//===-----------------------------------------------------------------===//
1316// Operators Supported by ONNX
1317//===-----------------------------------------------------------------===//
1318Error ONNXModelWriter::writePad(const PadNode *node, GraphType &graph) {
1319 auto *proto = graph.add_node();
1320 // Add dictionary entries.
1321 switch (node->getMode()) {
1322 case PaddingMode::CONSTANT:
1323 addValueAttribute(proto, "mode", std::string("constant"));
1324 break;
1325 case PaddingMode::REFLECT:
1326 addValueAttribute(proto, "mode", std::string("reflect"));
1327 break;
1328 case PaddingMode::EDGE:
1329 addValueAttribute(proto, "mode", std::string("edge"));
1330 break;
1331 default:
1332 return MAKE_ERR("Pad: Invalid mode",
1333 ErrorValue::ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR);
1334 }
1335
1336 if (opsetVersion_ <= 10) {
1337 addValueAttribute(proto, "pads", node->getPads());
1338 float value = node->getValue();
1339 if (value != .0f) {
1340 addValueAttribute(proto, "value", value);
1341 }
1342 return writeAllWithNode("Pad", node, graph, proto);
1343 } else {
1344 proto->set_name(node->getName().str());
1345 proto->set_op_type("Pad");
1346 // Input for data.
1347 inputsToProto(node, proto);
1348
1349 // Input for pads.
1350 auto pads = node->getPads();
1351 Tensor oneDimTensorPads(ElemKind::Int64ITy, {(dim_t)pads.size()});
1352 auto oneDimTensorPadsH = oneDimTensorPads.getHandle<int64_t>();
1353 for (size_t b = 0, e = oneDimTensorPads.size(); b < e; ++b) {
1354 oneDimTensorPadsH.raw(b) = pads[b];
1355 }
1356 auto *tensorProto = addInitializer(graph);
1357 tensorProto->set_name(node->getName().str() + "_pads");
1358 writeTensor(oneDimTensorPads, tensorProto);
1359 proto->add_input(node->getName().str() + "_pads");
1360
1361 // Input for value.
1362 Tensor value(ElemKind::FloatTy, {1});
1363 auto valueH = value.getHandle();
1364 valueH.raw(0) = node->getValue();
1365 tensorProto = addInitializer(graph);
1366 tensorProto->set_name(node->getName().str() + "_value");
1367 writeTensor(value, tensorProto);
1368 proto->add_input(node->getName().str() + "_value");
1369 // Output
1370 outputsToProto(node, graph, proto);
1371 return Error::success();
1372 }
1373}
1374
1375Error ONNXModelWriter::writeConcat(const ConcatNode *node, GraphType &graph) {
1376 auto *proto = graph.add_node();
1377 // Add dictionary entries.
1378 addValueAttribute(proto, "axis", node->getDim());
1379
1380 return writeAllWithNode("Concat", node, graph, proto);
1381}
1382
1383Error ONNXModelWriter::writeTranspose(const TransposeNode *node,
1384 GraphType &graph) {
1385 // Some nodes create transpose for outputs.
1386 auto *input = node->getInput().getNode();
1387 if (llvm::dyn_cast<ConvolutionNode>(input) ||
1388 llvm::dyn_cast<Convolution3DNode>(input) ||
1389 llvm::dyn_cast<AvgPoolNode>(input) ||
1390 llvm::dyn_cast<MaxPoolNode>(input) ||
1391 llvm::dyn_cast<SpaceToDepthNode>(input)) {
1392 return Error::success();
1393 }
1394
1395 auto *proto = graph.add_node();
1396 // Add dictionary entries.
1397 addValueAttribute(proto, "perm", node->getShuffle());
1398
1399 return writeAllWithNode("Transpose", node, graph, proto);
1400}
1401
1402Error ONNXModelWriter::writeCollectRpnProposals(
1403 const CollectRpnProposalsNode *node, GraphType &graph) {
1404 return writeAllWithNode("CollectRpnProposals", node, graph, graph.add_node());
1405}
1406
1407Error ONNXModelWriter::writeFlip(const FlipNode *node, GraphType &graph) {
1408 auto *proto = graph.add_node();
1409 // Add dictionary entries.
1410 addValueAttribute(proto, "axis", node->getAxis());
1411
1412 return writeAllWithNode("Flip", node, graph, proto);
1413}
1414
1415Error ONNXModelWriter::writeAudioSpectrogram(const AudioSpectrogramNode *node,
1416 GraphType &graph) {
1417 auto *proto = graph.add_node();
1418
1419 addValueAttribute(proto, "window_size", node->getWindowSize());
1420 addValueAttribute(proto, "stride", node->getWindowStride());
1421 addValueAttribute(proto, "magnitude_squared", node->getMagnitudeSquared());
1422
1423 return writeAllWithNode("AudioSpectrogram", node, graph, proto);
1424}
1425
1426Error ONNXModelWriter::writeMFCC(const MFCCNode *node, GraphType &graph) {
1427 auto *proto = graph.add_node();
1428
1429 addValueAttribute(proto, "sample_rate", node->getSampleRate());
1430 addValueAttribute(proto, "lower_frequency_limit", node->getLowerFrequency());
1431 addValueAttribute(proto, "upper_frequency_limit", node->getUpperFrequency());
1432 addValueAttribute(proto, "filterbank_channel_count",
1433 node->getFilterBankCount());
1434 addValueAttribute(proto, "dct_coefficient_count", node->getNumCoefficients());
1435
1436 return writeAllWithNode("MFCC", node, graph, proto);
1437}
1438
1439Error ONNXModelWriter::writeROIAlign(const ROIAlignNode *node,
1440 GraphType &graph) {
1441 auto *proto = graph.add_node();
1442 switch (node->getMode()) {
1443 case PoolingMode::AVG:
1444 addValueAttribute(proto, "mode", std::string("avg"));
1445 break;
1446 case PoolingMode::MAX:
1447 addValueAttribute(proto, "mode", std::string("max"));
1448 break;
1449 }
1450 addValueAttribute(proto, "output_height", node->getOutputHeight());
1451 addValueAttribute(proto, "output_width", node->getOutputWidth());
1452 addValueAttribute(proto, "sampling_ratio", node->getSamplingRatio());
1453 addValueAttribute(proto, "spatial_scale", node->getSpatialScale());
1454 addValueAttribute(proto, "aligned", node->getAligned());
1455 addValueAttribute(proto, "rotated", node->getRotated());
1456 return writeAllWithNode("ROIAlign", node, graph, proto);
1457}
1458
1459Error ONNXModelWriter::writeBBoxTransform(const BBoxTransformNode *node,
1460 GraphType &graph) {
1461 auto *proto = graph.add_node();
1462 addValueAttribute(proto, "ApplyScale", node->getApplyScale());
1463 addValueAttribute(proto, "Rotated", node->getRotated());
1464 addValueAttribute(proto, "AngleBoundOn", node->getAngleBoundOn());
1465 addValueAttribute(proto, "AngleBoundLo", node->getAngleBoundLo());
1466 addValueAttribute(proto, "AngleBoundHi", node->getAngleBoundHi());
1467 addValueAttribute(proto, "ClipAngleThresh", node->getClipAngleThresh());
1468 return writeAllWithNode("BBoxTransform", node, graph, proto);
1469}
1470
1471Error ONNXModelWriter::writeConvolution(const ConvolutionNode *node,
1472 GraphType &graph) {
1473 // Loading convolution creates a sandwich with Transpose nodes for Input,
1474 // Weights, and Result. The lowering algorithm can remove Transpose nodes and
1475 // replace one set of nodes with another ones. When saving a graph to ONNX
1476 // format, keep in mind that when it will be loaded again a Transpose nodes
1477 // sandwich will be created again. The steps will be:
1478 // Remove Transpose nodes for Input and Weights, if such Transpose are not
1479 // found (they are supposed to be NCHW2NHWC then create a "mirror"
1480 // Transpose, i.e. NHWC2NCHW for correspondent Input or/and Weights.
1481 // The similar algorithm will be applied for Result. If Transpose NHWC2NCHW
1482 // node is found for Result user then remove it, otherwise create a "mirror"
1483 // Transpose, i.e. NCHW2NHWC.
1484 assert(node->getLayout() == NHWC && "can only write NHWC Convolutions");
1485
1486 // Delegate writing quantized Convs to writeTensorwiseQuantizedConvolution.
1487 if (isQuantizedElemKind(node->getInput().getElementType())) {
1488 return writeTensorwiseQuantizedConvolution(node, graph);
1489 }
1490
1491 auto *proto = graph.add_node();
1492
1493 // Use the output of transpose node.
1494 if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) {
1495 // Apparently Result Transpose has been removed, add NCHW2NHWC Transpose.
1496 writeTransposeResult(node, proto, graph);
1497 }
1498
1499 // Add dictionary entries.
1500 addValueAttribute(proto, "strides", node->getStrides());
1501 addValueAttribute(proto, "pads", node->getPads());
1502 addValueAttribute(proto, "group", node->getGroup());
1503 addValueAttribute(proto, "dilations", node->getDilation());
1504
1505 const Node *input = node->getInput().getNode();
1506 if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(input)) {
1507 proto->add_input(TN->getInput().getNode()->getName().str());
1508 reportedNodes_.insert(TN);
1509 } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(input)) {
1510 proto->add_input(RSN->getInput().getNode()->getName().str());
1511 reportedNodes_.insert(RSN);
1512 } else {
1513 writeTransposeInput(node, input, proto, graph);
1514 }
1515
1516 const Node *filter = node->getFilter().getNode();
1517 if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(filter)) {
1518 proto->add_input(TN->getInput().getNode()->getName().str());
1519 reportedNodes_.insert(TN);
1520 } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(filter)) {
1521 proto->add_input(RSN->getInput().getNode()->getName().str());
1522 reportedNodes_.insert(RSN);
1523 } else {
1524 writeTransposeInput(node, filter, proto, graph);
1525 }
1526
1527 proto->add_input(node->getBias().getNode()->getName().str());
1528
1529 proto->set_name(node->getName().str());
1530 proto->set_op_type("Conv");
1531
1532 return Error::success();
1533}
1534
1535Error ONNXModelWriter::writeTensorwiseQuantizedConvolution(
1536 const ConvolutionNode *node, GraphType &graph) {
1537 auto *proto = graph.add_node();
1538
1539 // Add dictionary entries.
1540 addValueAttribute(proto, "kernel_shape", node->getKernels());
1541 addValueAttribute(proto, "strides", node->getStrides());
1542 addValueAttribute(proto, "pads", node->getPads());
1543 addValueAttribute(proto, "group", node->getGroup());
1544 addValueAttribute(proto, "dilation", node->getDilation());
1545
1546 addValueAttribute(proto, "out_scale",
1547 node->getType(ConvolutionNode::ResultIdx)->getScale());
1548 addValueAttribute(proto, "out_offset",
1549 node->getType(ConvolutionNode::ResultIdx)->getOffset());
1550
1551 return writeAllWithNode("Conv", node, graph, proto);
1552}
1553
1554Error ONNXModelWriter::writeChannelwiseQuantizedConvolution(
1555 const ChannelwiseQuantizedConvolutionNode *node, GraphType &graph) {
1556 auto *proto = graph.add_node();
1557
1558 // Add dictionary entries.
1559 addValueAttribute(proto, "kernel_shape", node->getKernels());
1560 addValueAttribute(proto, "strides", node->getStrides());
1561 addValueAttribute(proto, "pads", node->getPads());
1562 addValueAttribute(proto, "group", node->getGroup());
1563
1564 addValueAttribute(
1565 proto, "out_scale",
1566 node->getType(ChannelwiseQuantizedConvolutionNode::ResultIdx)
1567 ->getScale());
1568 addValueAttribute(
1569 proto, "out_offset",
1570 node->getType(ChannelwiseQuantizedConvolutionNode::ResultIdx)
1571 ->getOffset());
1572
1573 return writeAllWithNode("ChannelwiseQuantizedConvolution", node, graph,
1574 proto);
1575}
1576
1577Error ONNXModelWriter::writeBatchedReduceMean(const BatchedReduceMeanNode *node,
1578 GraphType &graph) {
1579 auto *proto = graph.add_node();
1580 // Add dictionary entries.
1581 addValueAttribute(proto, "axes", node->getAxes());
1582
1583 proto->set_name(node->getName().str());
1584 proto->set_op_type("ReduceMean");
1585 inputsToProto(node, proto);
1586
1587 addValueAttribute(proto, "keepdims", 0);
1588 outputsToProto(node, graph, proto);
1589
1590 return Error::success();
1591}
1592
1593Error ONNXModelWriter::writeBatchedReduceAdd(const BatchedReduceAddNode *node,
1594 GraphType &graph) {
1595 auto *proto = graph.add_node();
1596 // Add dictionary entries.
1597 unsigned_t axis = node->getAxis();
1598 llvm::ArrayRef<unsigned_t> axes(axis);
1599 addValueAttribute(proto, "axes", axes);
1600
1601 proto->set_name(node->getName().str());
1602 proto->set_op_type("ReduceSum");
1603 inputsToProto(node, proto);
1604
1605 addValueAttribute(proto, "keepdims", 0);
1606 outputsToProto(node, graph, proto);
1607
1608 return Error::success();
1609}
1610
1611Error ONNXModelWriter::writeBatchedReduceSumSquare(
1612 const BatchedReduceSumSquareNode *node, GraphType &graph) {
1613 auto *proto = graph.add_node();
1614 // Add dictionary entries.
1615 unsigned_t axis = node->getAxis();
1616 llvm::ArrayRef<unsigned_t> axes(axis);
1617 addValueAttribute(proto, "axes", axes);
1618
1619 proto->set_name(node->getName().str());
1620 proto->set_op_type("ReduceSum");
1621 inputsToProto(node, proto);
1622
1623 addValueAttribute(proto, "keepdims", 0);
1624 outputsToProto(node, graph, proto);
1625
1626 return Error::success();
1627}
1628
1629Error ONNXModelWriter::writeBatchedReduceMax(const BatchedReduceMaxNode *node,
1630 GraphType &graph) {
1631 auto *proto = graph.add_node();
1632 // Find dictionary entries.
1633 addValueAttribute(proto, "axes", node->getAxes());
1634
1635 return writeAllWithNode("ReduceMax", node, graph, proto);
1636}
1637
1638Error ONNXModelWriter::writeBatchedReduceMin(const BatchedReduceMinNode *node,
1639 GraphType &graph) {
1640 auto *proto = graph.add_node();
1641 // Find dictionary entries.
1642 addValueAttribute(proto, "axes", node->getAxes());
1643
1644 return writeAllWithNode("ReduceMin", node, graph, proto);
1645}
1646
1647Error ONNXModelWriter::writeBatchedReduceProd(const BatchedReduceProdNode *node,
1648 GraphType &graph) {
1649 auto *proto = graph.add_node();
1650 // Add dictionary entries.
1651 unsigned_t axis = node->getAxis();
1652 llvm::ArrayRef<unsigned_t> axes(axis);
1653 addValueAttribute(proto, "axes", axes);
1654
1655 proto->set_name(node->getName().str());
1656 proto->set_op_type("ReduceProd");
1657 inputsToProto(node, proto);
1658
1659 addValueAttribute(proto, "keepdims", 0);
1660 outputsToProto(node, graph, proto);
1661
1662 return Error::success();
1663}
1664
1665Error ONNXModelWriter::writeBatchNormalization(
1666 const BatchNormalizationNode *node, GraphType &graph) {
1667 auto *proto = graph.add_node();
1668 // Add dictionary entries.
1669 addValueAttribute(proto, "epsilon", node->getEpsilon());
1670 addValueAttribute(proto, "momentum", node->getMomentum());
1671
1672 proto->set_name(node->getName().str());
1673 proto->set_op_type("BatchNormalization");
1674
1675 proto->add_input(node->getInput().getNode()->getName().str());
1676 proto->add_input(node->getScale().getNode()->getName().str());
1677 proto->add_input(node->getBias().getNode()->getName().str());
1678 proto->add_input(node->getMean().getNode()->getName().str());
1679 proto->add_input(node->getVar().getNode()->getName().str());
1680
1681 outputsToProto(node, graph, proto);
1682 return Error::success();
1683}
1684
1685Error ONNXModelWriter::writeInstanceNormalization(
1686 const InstanceNormalizationNode *node, GraphType &graph) {
1687 auto *proto = graph.add_node();
1688 // Add dictionary entries.
1689 addValueAttribute(proto, "epsilon", node->getEpsilon());
1690
1691 proto->set_name(node->getName().str());
1692 proto->set_op_type("InstanceNormalization");
1693
1694 proto->add_input(node->getInput().getNode()->getName().str());
1695 proto->add_input(node->getScale().getNode()->getName().str());
1696 proto->add_input(node->getBias().getNode()->getName().str());
1697
1698 outputsToProto(node, graph, proto);
1699 return Error::success();
1700}
1701
1702Error ONNXModelWriter::writeLayerNormalization(
1703 const LayerNormalizationNode *node, GraphType &graph) {
1704 auto *proto = graph.add_node();
1705 // Add dictionary entries.
1706 addValueAttribute(proto, "epsilon", node->getEpsilon());
1707
1708 proto->set_name(node->getName().str());
1709 proto->set_op_type("LayerNormalization");
1710
1711 proto->add_input(node->getInput().getNode()->getName().str());
1712 proto->add_input(node->getScale().getNode()->getName().str());
1713 proto->add_input(node->getBias().getNode()->getName().str());
1714
1715 outputsToProto(node, graph, proto);
1716 return Error::success();
1717}
1718
1719Error ONNXModelWriter::writeMeanVarNormalization(
1720 const MeanVarNormalizationNode *node, GraphType &graph) {
1721 auto *proto = graph.add_node();
1722 // Add dictionary entries.
1723 addValueAttribute(proto, "channel", node->getChannelIdx());
1724 addValueAttribute(proto, "momentum", node->getMomentum());
1725
1726 proto->set_name(node->getName().str());
1727 proto->set_op_type("MeanVarianceNormalization");
1728
1729 inputsToProto(node, proto);
1730 outputsToProto(node, graph, proto);
1731 return Error::success();
1732}
1733
1734Error ONNXModelWriter::writeSlice(const SliceNode *node, GraphType &graph) {
1735 auto *proto = graph.add_node();
1736 // Add dictionary entries.
1737 auto starts = node->getStart();
1738 auto outs = node->getResult().dims();
1739 RETURN_ERR_IF_NOT(starts.size() == outs.size(),
1740 "Mismatch starts and result dimensions.");
1741
1742 RETURN_IF_ERR(writeAllWithNode("Slice", node, graph, proto));
1743
1744 if (opsetVersion_ >= 10) {
1745 Tensor oneDimTensorStarts(ElemKind::Int64ITy, {(dim_t)starts.size()});
1746 auto handleStarts = oneDimTensorStarts.getHandle<int64_t>();
1747 Tensor oneDimTensorEnds(ElemKind::Int64ITy, {(dim_t)starts.size()});
1748 auto handleEnds = oneDimTensorEnds.getHandle<int64_t>();
1749
1750 for (size_t b = 0, e = starts.size(); b < e; ++b) {
1751 handleStarts.raw(b) = starts[b];
1752 handleEnds.raw(b) = outs[b] + starts[b];
1753 }
1754
1755 auto *tensorProto = addInitializer(graph);
1756 tensorProto->set_name(node->getName().str() + "_starts");
1757 writeTensor(oneDimTensorStarts, tensorProto, useGlowCustomOps_);
1758 proto->add_input(node->getName().str() + "_starts");
1759
1760 tensorProto = addInitializer(graph);
1761 tensorProto->set_name(node->getName().str() + "_ends");
1762 writeTensor(oneDimTensorEnds, tensorProto, useGlowCustomOps_);
1763 proto->add_input(node->getName().str() + "_ends");
1764 } else {
1765 auto *attrStarts = proto->add_attribute();
1766 attrStarts->set_name("starts");
1767 attrStarts->set_type(AttrType::INTS);
1768 auto *attrEnds = proto->add_attribute();
1769 attrEnds->set_name("ends");
1770 attrEnds->set_type(AttrType::INTS);
1771
1772 for (unsigned b = 0, e = starts.size(); b < e; ++b) {
1773 attrStarts->add_ints(starts[b]);
1774 attrEnds->add_ints(outs[b] + starts[b]);
1775 }
1776 }
1777 return Error::success();
1778}
1779
1780Error ONNXModelWriter::writePow(const PowNode *node, GraphType &graph) {
1781 auto *proto = graph.add_node();
1782 proto->set_name(node->getName().str());
1783 proto->add_input(node->getLHS().getNode()->getName().str());
1784 outputsToProto(node, graph, proto);
1785
1786 // Find exponent from splat node
1787 const auto *RHSN = node->getRHS().getNode();
1788 switch (RHSN->getKind()) {
1789 case Kinded::Kind::SplatNodeKind: {
1790 const auto *SN = llvm::cast<SplatNode>(RHSN);
1791 float value = SN->getValue();
1792 if (value == 0.5f) {
1793 proto->set_op_type("Sqrt");
1794 } else if (value == -1.0f) {
1795 proto->set_op_type("Reciprocal");
1796 } else if (value == 2.0f) {
1797 proto->set_op_type("Sqr");
1798 } else {
1799 return MAKE_ERR("Splat Node Value is invalid.");
1800 }
1801 break;
1802 }
1803 default:
1804 proto->add_input(RHSN->getName().str());
1805 break;
1806 }
1807
1808 reportedNodes_.insert(RHSN);
1809 return Error::success();
1810}
1811
1812Error ONNXModelWriter::writeTopK(const TopKNode *node, GraphType &graph) {
1813 auto *proto = graph.add_node();
1814
1815 Tensor scalar(ElemKind::Int64ITy, {1});
1816 auto handle = scalar.getHandle<int64_t>();
1817 handle.raw(0) = node->getK();
1818
1819 auto *tensorProto = addInitializer(graph);
1820 tensorProto->set_name("k");
1821 writeTensor(scalar, tensorProto, useGlowCustomOps_);
1822
1823 RETURN_IF_ERR(writeAllWithNode("TopK", node, graph, proto));
1824
1825 proto->add_input("k");
1826 return Error::success();
1827}
1828
1829Error ONNXModelWriter::writeArgMax(const ArgMaxNode *node, GraphType &graph) {
1830 auto *proto = graph.add_node();
1831
1832 Tensor axis(ElemKind::Int64ITy, {1});
1833 Tensor keepDims(ElemKind::BoolTy, {1});
1834 auto axisH = axis.getHandle<int64_t>();
1835 auto keepDimsH = keepDims.getHandle<int8_t>();
1836 axisH.raw(0) = node->getAxis();
1837 keepDimsH.raw(0) = node->getKeepDims();
1838
1839 auto *tensorProto = addInitializer(graph);
1840 tensorProto->set_name("axis");
1841 writeTensor(axis, tensorProto, useGlowCustomOps_);
1842
1843 tensorProto = addInitializer(graph);
1844 tensorProto->set_name("keepDims");
1845 writeTensor(keepDims, tensorProto, useGlowCustomOps_);
1846 RETURN_IF_ERR(writeAllWithNode("ArgMax", node, graph, proto));
1847
1848 return Error::success();
1849}
1850
1851Error ONNXModelWriter::writeArgMin(const ArgMinNode *node, GraphType &graph) {
1852 auto *proto = graph.add_node();
1853
1854 Tensor axis(ElemKind::Int64ITy, {1});
1855 Tensor keepDims(ElemKind::BoolTy, {1});
1856 auto axisH = axis.getHandle<int64_t>();
1857 auto keepDimsH = keepDims.getHandle<int8_t>();
1858 axisH.raw(0) = node->getAxis();
1859 keepDimsH.raw(0) = node->getKeepDims();
1860
1861 auto *tensorProto = addInitializer(graph);
1862 tensorProto->set_name("axis");
1863 writeTensor(axis, tensorProto, useGlowCustomOps_);
1864
1865 tensorProto = addInitializer(graph);
1866 tensorProto->set_name("keepDims");
1867 writeTensor(keepDims, tensorProto, useGlowCustomOps_);
1868 RETURN_IF_ERR(writeAllWithNode("ArgMin", node, graph, proto));
1869
1870 return Error::success();
1871}
1872
1873Error ONNXModelWriter::writePRelu(const PReluNode *node, GraphType &graph) {
1874 auto *proto = graph.add_node();
1875 proto->set_name(node->getName().str());
1876 proto->set_op_type("PRelu");
1877 proto->add_input(node->getInput().getNode()->getName().str());
1878
1879 const auto *slope = node->getSlope().getNode();
1880 if (const auto *BN = llvm::dyn_cast<BroadcastNode>(slope)) {
1881 proto->add_input(BN->getInput().getNode()->getName().str());
1882 reportedNodes_.insert(BN);
1883 } else if (const SplatNode *SN = llvm::dyn_cast<SplatNode>(slope)) {
1884 // Conversion a scalar to a tensor is required.
1885 Tensor scalar = {SN->getValue()};
1886 auto *tensorProto = addInitializer(graph);
1887 tensorProto->set_name(SN->getName().str());
1888 writeTensor(scalar, tensorProto, useGlowCustomOps_);
1889 proto->add_input(SN->getName().str());
1890 reportedNodes_.insert(SN);
1891 } else {
1892 return MAKE_ERR("Can't find Splat/Broadcast Node as part of PRelu Node.");
1893 }
1894
1895 outputsToProto(node, graph, proto);
1896 return Error::success();
1897}
1898
1899Error ONNXModelWriter::writeGather(const GatherNode *node, GraphType &graph) {
1900 auto *proto = graph.add_node();
1901 // Add dictionary entries.
1902 auto axis = node->getBatchDims();
1903
1904 if (axis != 0) {
1905 addValueAttribute(proto, "axis", axis);
1906 return writeAllWithNode("BatchGather", node, graph, proto);
1907 } else {
1908 return writeAllWithNode("Gather", node, graph, proto);
1909 }
1910}
1911
1912Error ONNXModelWriter::writeGatherElements(const GatherElementsNode *node,
1913 GraphType &graph) {
1914 auto *proto = graph.add_node();
1915 // Add dictionary entries.
1916 return writeAllWithNode("GatherElements", node, graph, proto);
1917}
1918
1919Error ONNXModelWriter::writeGatherND(const GatherNDNode *node,
1920 GraphType &graph) {
1921 auto *proto = graph.add_node();
1922 // Add dictionary entries.
1923 return writeAllWithNode("GatherND", node, graph, proto);
1924}
1925
1926Error ONNXModelWriter::writeMatMul(const MatMulNode *node, GraphType &graph) {
1927 return writeMatMulKind(node, graph, "MatMul");
1928}
1929
1930Error ONNXModelWriter::writeBatchMatMul(const BatchMatMulNode *node,
1931 GraphType &graph) {
1932 auto dimSize = node->getLHS().dims().size();
1933 if (dimSize == 2) {
1934 return writeMatMulKind(node, graph, "MatMul");
1935 } else {
1936 return writeMatMulKind(node, graph, "BatchMatMul");
1937 }
1938}
1939
1940Error ONNXModelWriter::writeReshape(const ReshapeNode *node, GraphType &graph) {
1941 auto *proto = graph.add_node();
1942
1943 // Converting arrayRef scale to a constant node
1944 auto dims = node->getDims();
1945 Tensor dimsTensor(ElemKind::Int64ITy, {(dim_t)dims.size()});
1946 auto handleDims = dimsTensor.getHandle<int64_t>();
1947 for (size_t b = 0, e = dims.size(); b < e; ++b) {
1948 handleDims.raw(b) = dims[b];
1949 }
1950
1951 auto *tensorProto = addInitializer(graph);
1952 tensorProto->set_name(node->getName().str() + "_shape");
1953 writeTensor(dimsTensor, tensorProto, useGlowCustomOps_);
1954
1955 RETURN_IF_ERR(writeAllWithNode("Reshape", node, graph, proto));
1956 proto->add_input(node->getName().str() + "_shape");
1957 return Error::success();
1958}
1959
1960Error ONNXModelWriter::writeBucketize(const BucketizeNode *node,
1961 GraphType &graph) {
1962 auto *proto = graph.add_node();
1963 // Add dictionary entries.
1964 addValueAttribute(proto, "boundaries", node->getBoundaries());
1965
1966 return writeAllWithNode("Bucketize", node, graph, proto);
1967}
1968
1969Error ONNXModelWriter::writeResizeNearest(const ResizeNearestNode *node,
1970 GraphType &graph) {
1971 auto *proto = graph.add_node();
1972 // Converting arrayRef scale to a constant node
1973 auto scale = node->getScale();
1974 Tensor scaleTensor(ElemKind::FloatTy, {(dim_t)scale.size()});
1975 auto handleScale = scaleTensor.getHandle<float>();
1976 for (size_t b = 0, e = scale.size(); b < e; ++b) {
1977 handleScale.raw(b) = scale[b];
1978 }
1979
1980 auto *tensorProto = addInitializer(graph);
1981 tensorProto->set_name(node->getName().str() + "_scale");
1982 writeTensor(scaleTensor, tensorProto, useGlowCustomOps_);
1983
1984 // Add dictionary entries.
1985 addValueAttribute(proto, "coordinate_transformation_mode",
1986 std::string("asymmetric"));
1987 addValueAttribute(proto, "mode", std::string("nearest"));
1988 addValueAttribute(proto, "nearest_mode", std::string("floor"));
1989
1990 RETURN_IF_ERR(writeAllWithNode("Resize", node, graph, proto));
1991 proto->add_input(node->getName().str() + "_scale");
1992 return Error::success();
1993}
1994
1995Error ONNXModelWriter::writeResizeBilinear(const ResizeBilinearNode *node,
1996 GraphType &graph) {
1997 auto *proto = graph.add_node();
1998 // Converting arrayRef scale to a constant node
1999 auto scale = node->getScale();
2000 Tensor scaleTensor(ElemKind::FloatTy, {(dim_t)scale.size()});
2001 auto handleScale = scaleTensor.getHandle<float>();
2002 for (size_t b = 0, e = scale.size(); b < e; ++b) {
2003 handleScale.raw(b) = scale[b];
2004 }
2005
2006 auto *tensorProto = addInitializer(graph);
2007 tensorProto->set_name(node->getName().str() + "_scale");
2008 writeTensor(scaleTensor, tensorProto, useGlowCustomOps_);
2009
2010 // Add dictionary entries.
2011 addValueAttribute(proto, "coordinate_transformation_mode",
2012 std::string("asymmetric"));
2013 addValueAttribute(proto, "mode", std::string("linear"));
2014
2015 RETURN_IF_ERR(writeAllWithNode("Resize", node, graph, proto));
2016 proto->add_input(node->getName().str() + "_scale");
2017 return Error::success();
2018}
2019
2020Error ONNXModelWriter::writeSoftMax(const SoftMaxNode *node, GraphType &graph) {
2021 auto *proto = graph.add_node();
2022 proto->set_name(node->getName().str());
2023 proto->set_op_type("Softmax");
2024 outputsToProto(node, graph, proto);
2025 // Find input from Reshape node
2026 proto->add_input(node->getInput().getNode()->getName().str());
2027
2028 // Mark selected input as visited.
2029 reportedNodes_.insert(node->getSelected().getNode());
2030 return Error::success();
2031}
2032
2033Error ONNXModelWriter::writeLogSoftMax(const LogSoftMaxNode *node,
2034 GraphType &graph) {
2035 auto *proto = graph.add_node();
2036 proto->set_name(node->getName().str());
2037 proto->set_op_type("LogSoftmax");
2038 outputsToProto(node, graph, proto);
2039 // Find input from Reshape node
2040 proto->add_input(node->getInput().getNode()->getName().str());
2041
2042 // Mark selected input as visited.
2043 reportedNodes_.insert(node->getSelected().getNode());
2044 return Error::success();
2045}
2046
2047Error ONNXModelWriter::writeReplaceNaN(const ReplaceNaNNode *node,
2048 GraphType &graph) {
2049 auto *proto = graph.add_node();
2050 // Add dictionary entries.
2051 float value = node->getValue();
2052 if (value != 0.0f) {
2053 addValueAttribute(proto, "value", value);
2054 }
2055 return writeAllWithNode("ReplaceNaN", node, graph, proto);
2056}
2057
2058Error ONNXModelWriter::writeGatherRanges(const GatherRangesNode *node,
2059 GraphType &graph) {
2060 auto *proto = graph.add_node();
2061 // Add dictionary entries.
2062 addValueAttribute(proto, "maxOutputSize", node->getOutput().dims()[0]);
2063
2064 return writeAllWithNode("GatherRanges", node, graph, proto);
2065}
2066
2067Error ONNXModelWriter::writeSparseToDenseMask(const SparseToDenseMaskNode *node,
2068 GraphType &graph) {
2069 auto *proto = graph.add_node();
2070 // Add dictionary entries.
2071 addValueAttribute(proto, "mask", node->getMask());
2072
2073 return writeAllWithNode("SparseToDenseMask", node, graph, proto);
2074}
2075
2076Error ONNXModelWriter::writeAdaptiveAvgPool(const AdaptiveAvgPoolNode *node,
2077 GraphType &graph) {
2078 auto *proto = graph.add_node();
2079
2080 // Add dictionary entries.
2081 const auto outShape = ShapeNHWC(node->getResult().dims());
2082 std::vector<size_t> output_size{outShape.h, outShape.w};
2083 addValueAttribute(proto, "output_size", llvm::makeArrayRef(output_size));
2084
2085 auto err = writeAllWithNode("AdaptiveAvgPool", node, graph, proto);
2086 return err;
2087}
2088
2089Error ONNXModelWriter::writeLocalResponseNormalization(
2090 const LocalResponseNormalizationNode *node, GraphType &graph) {
2091 auto *proto = graph.add_node();
2092 proto->set_name(node->getName().str());
2093 proto->set_op_type("LRN");
2094 outputsToProto(node, graph, proto);
2095 // Find input from Transpose node
2096 const TransposeNode *TN =
2097 llvm::dyn_cast<TransposeNode>(node->getInput().getNode());
2098 RETURN_ERR_IF_NOT(
2099 TN,
2100 "Can't find Transpose Node as part of LocalResponseNormalization Node.");
2101 proto->add_input(TN->getInput().getNode()->getName().str());
2102 reportedNodes_.insert(TN);
2103 // Add dictionary entries.
2104 addValueAttribute(proto, "size", 2 * node->getHalfWindowSize());
2105 addValueAttribute(proto, "alpha", node->getAlpha());
2106 addValueAttribute(proto, "beta", node->getBeta());
2107 addValueAttribute(proto, "bias", node->getK());
2108
2109 return Error::success();
2110}
2111
2112Error ONNXModelWriter::writeBatchBoxCox(const BatchBoxCoxNode *node,
2113 GraphType &graph) {
2114 auto *proto = graph.add_node();
2115 addValueAttribute(proto, "epsilon", node->getEpsilon());
2116 return writeAllWithNode("BatchBoxCox", node, graph, proto);
2117}
2118
2119//===-----------------------------------------------------------------===//
2120// Operators Supported by Glow only
2121//===-----------------------------------------------------------------===//
2122Error ONNXModelWriter::writeModulo(const ModuloNode *node, GraphType &graph) {
2123 auto *proto = graph.add_node();
2124 // Add dictionary entries.
2125 addValueAttribute(proto, "divisor", node->getDivisor());
2126 addValueAttribute(proto, "sign_follow_divisor",
2127 node->getSignFollowDivisor() ? 1 : 0);
2128
2129 return writeAllWithNode("Modulo", node, graph, proto);
2130}
2131
2132namespace {
2133template <typename T>
2134void writeTensorwiseQuantizedPool(const T *node, const std::string &op,
2135 ONNX_TRAITS::GraphProto &graph,
2136 ReportedNodes &) {
2137 assert(node->getLayout() == NHWC && "can only write NHWC Pools");
2138
2139 auto *proto = graph.add_node();
2140
2141 // Add dictionary entries.
2142 addValueAttribute(proto, "kernel_shape", node->getKernels());
2143 addValueAttribute(proto, "strides", node->getStrides());
2144 addValueAttribute(proto, "pads", node->getPads());
2145
2146 if (auto *APN = llvm::dyn_cast<AvgPoolNode>(node)) {
2147 addValueAttribute(proto, "count_include_pad", APN->getCountIncludePads());
2148 addValueAttribute(proto, "out_scale",
2149 APN->getType(AvgPoolNode::ResultIdx)->getScale());
2150 addValueAttribute(proto, "out_offset",
2151 APN->getType(AvgPoolNode::ResultIdx)->getOffset());
2152 } else if (auto *MPN = llvm::dyn_cast<MaxPoolNode>(node)) {
2153 addValueAttribute(proto, "out_scale",
2154 MPN->getType(MaxPoolNode::ResultIdx)->getScale());
2155 addValueAttribute(proto, "out_offset",
2156 MPN->getType(MaxPoolNode::ResultIdx)->getOffset());
2157 }
2158
2159 proto->add_input(node->getInput().getNode()->getName().str());
2160 outputsToProto(node, graph, proto);
2161
2162 proto->set_name(node->getName().str());
2163 proto->set_op_type(op);
2164}
2165
2166template <typename T>
2167void writePool(const T *node, const std::string &op,
2168 ONNX_TRAITS::GraphProto &graph, ReportedNodes &reporter) {
2169 // Delegate writing quantized pool ops to writeTensorwiseQuantizedPool.
2170 if (isQuantizedElemKind(node->getInput().getElementType())) {
2171 return writeTensorwiseQuantizedPool(node, op, graph, reporter);
2172 }
2173
2174 // Loading pools creates a sandwich with Transpose nodes for Input
2175 // and Result. The lowering algorithm can remove Transpose nodes and
2176 // replace one set of nodes with another ones. When saving a graph to ONNX
2177 // format, keep in mind that when it will be loaded again a Transpose nodes
2178 // sandwich will be created again. The steps will be:
2179 // Remove Transpose node for Input, if such Transpose is not
2180 // found (they are supposed to be NCHW2NHWC then create a "mirror"
2181 // Transpose, i.e. NHWC2NCHW for correspondent Input or/and Weights.
2182 // The similar algorithm will be applied for Result. If Transpose NHWC2NCHW
2183 // node is found for Result user then remove it, otherwise create a "mirror"
2184 // Transpose, i.e. NCHW2NHWC.
2185 assert((node->getLayout() == NHWC || node->getLayout() == NTHWC) &&
2186 "can only write NHWC (2D) or NTHWC (3D) Pool ops");
2187
2188 auto *proto = graph.add_node();
2189
2190 // Use the output of transpose node.
2191 if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) {
2192 // Apparently Result Transpose has been removed, add NCHW2NHWC Transpose.
2193 writeTransposeResult(node, proto, graph);
2194 }
2195
2196 // Add dictionary entries.
2197 addValueAttribute(proto, "kernel_shape", node->getKernels());
2198 addValueAttribute(proto, "strides", node->getStrides());
2199 addValueAttribute(proto, "pads", node->getPads());
2200
2201 if (auto *APN = llvm::dyn_cast<AvgPoolNode>(node)) {
2202 addValueAttribute(proto, "count_include_pad", APN->getCountIncludePads());
2203 }
2204
2205 const Node *input = node->getInput().getNode();
2206 if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(input)) {
2207 proto->add_input(TN->getInput().getNode()->getName().str());
2208 reporter.insert(TN);
2209 } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(input)) {
2210 proto->add_input(RSN->getInput().getNode()->getName().str());
2211 reporter.insert(RSN);
2212 } else {
2213 writeTransposeInput(node, input, proto, graph);
2214 }
2215
2216 proto->set_name(node->getName().str());
2217 proto->set_op_type(op);
2218}
2219} // namespace
2220
2221Error ONNXModelWriter::writeAvgPool(const AvgPoolNode *node, GraphType &graph) {
2222 writePool(node, "AveragePool", graph, reportedNodes_);
2223 return Error::success();
2224}
2225
2226Error ONNXModelWriter::writeMaxPool(const MaxPoolNode *node, GraphType &graph) {
2227 writePool(node, "MaxPool", graph, reportedNodes_);
2228 return Error::success();
2229}
2230
2231Error ONNXModelWriter::writeConvolution3D(const Convolution3DNode *node,
2232 GraphType &graph) {
2233 // Loading convolution creates a sandwich with Transpose nodes for Input,
2234 // Weights, and Result. The lowering algorithm can remove Transpose nodes and
2235 // replace one set of nodes with another ones. When saving a graph to ONNX
2236 // format, keep in mind that when it will be loaded again a Transpose nodes
2237 // sandwich will be created again. The steps will be:
2238 // Remove Transpose nodes for Input and Weights, if such Transpose are not
2239 // found (they are supposed to be NCTHW2NTHWC then create a "mirror"
2240 // Transpose, i.e. NTHWC2NCTHW for correspondent Input or/and Weights.
2241 // The similar algorithm will be applied for Result. If Transpose NTHWC2NCTHW
2242 // node is found for Result user then remove it, otherwise create a "mirror"
2243 // Transpose, i.e. NCTHW2NTHWC.
2244 // assert(node->getLayout() == NTHWC && "can only write NTHWC Convolutions");
2245
2246 // Delegate writing quantized Convs to writeTensorwiseQuantizedConvolution.
2247 if (isQuantizedElemKind(node->getInput().getElementType())) {
2248 return MAKE_ERR("Not implemented");
2249 // return writeTensorwiseQuantizedConvolution(node, graph);
2250 }
2251
2252 auto *proto = graph.add_node();
2253
2254 // Use the output of transpose node.
2255 if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) {
2256 // Apparently Result Transpose has been removed, add NCTHW2NTHWC Transpose.
2257 writeTransposeResult(node, proto, graph, NCTHW2NTHWC);
2258 }
2259
2260 // Add dictionary entries.
2261 addValueAttribute(proto, "kernel_shape", node->getKernels());
2262 addValueAttribute(proto, "strides", node->getStrides());
2263 addValueAttribute(proto, "pads", node->getPads());
2264 addValueAttribute(proto, "group", node->getGroup());
2265 // addValueAttribute(proto, "dilations", node->getDilation());
2266
2267 const Node *input = node->getInput().getNode();
2268 if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(input)) {
2269 proto->add_input(TN->getInput().getNode()->getName().str());
2270 reportedNodes_.insert(TN);
2271 } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(input)) {
2272 proto->add_input(RSN->getInput().getNode()->getName().str());
2273 reportedNodes_.insert(RSN);
2274 } else {
2275 writeTransposeInput(node, input, proto, graph, NTHWC2NCTHW);
2276 }
2277
2278 const Node *filter = node->getFilter().getNode();
2279 if (const TransposeNode *TN = llvm::dyn_cast<TransposeNode>(filter)) {
2280 proto->add_input(TN->getInput().getNode()->getName().str());
2281 reportedNodes_.insert(TN);
2282 } else if (const ReshapeNode *RSN = llvm::dyn_cast<ReshapeNode>(filter)) {
2283 proto->add_input(RSN->getInput().getNode()->getName().str());
2284 reportedNodes_.insert(RSN);
2285 } else {
2286 writeTransposeInput(node, filter, proto, graph, NTHWC2NCTHW);
2287 }
2288
2289 proto->add_input(node->getBias().getNode()->getName().str());
2290
2291 proto->set_name(node->getName().str());
2292 proto->set_op_type("Conv");
2293
2294 return Error::success();
2295}
2296
2297Error ONNXModelWriter::writeSpaceToDepth(const SpaceToDepthNode *node,
2298 GraphType &graph) {
2299 auto *proto = graph.add_node();
2300
2301 // Find input from Transpose node
2302 const TransposeNode *TN =
2303 llvm::dyn_cast<TransposeNode>(node->getInput().getNode());
2304 RETURN_ERR_IF_NOT(TN,
2305 "Can't find Transpose Node as part of SpaceToDepth Node.");
2306 proto->add_input(TN->getInput().getNode()->getName().str());
2307 reportedNodes_.insert(TN);
2308
2309 proto->set_name(node->getName().str());
2310 proto->set_op_type("SpaceToDepth");
2311 // Add dictionary entries.
2312 addValueAttribute(proto, "blocksize", node->getBlockSize());
2313
2314 // Use the output of transpose node, if any.
2315 if (!outputKindToProto(Kinded::Kind::TransposeNodeKind, node, graph, proto)) {
2316 outputsToProto(node, graph, proto);
2317 }
2318 return Error::success();
2319}
2320
2321Error ONNXModelWriter::writeChannelShuffle(const ChannelShuffleNode *node,
2322 GraphType &graph) {
2323 auto *proto = graph.add_node();
2324 // Add dictionary entries.
2325 addValueAttribute(proto, "group", node->getGroup());
2326 addValueAttribute(proto, "kernel", node->getKernel());
2327
2328 return writeAllWithNode("ChannelShuffle", node, graph, proto);
2329}
2330
2331Error ONNXModelWriter::writeQuantizationProfile(
2332 const QuantizationProfileNode *node, GraphType &graph) {
2333 auto *proto = graph.add_node();
2334 // Add dictionary entries.
2335 addValueAttribute(proto, "name", node->getProfiledNodeName());
2336 addValueAttribute(proto, "number", node->getProfiledOutputNumber());
2337
2338 return writeAllWithNode("QuantizationProfile", node, graph, proto);
2339}
2340
2341Error ONNXModelWriter::writeTraceEvent(const TraceEventNode *node,
2342 GraphType &graph) {
2343 auto *proto = graph.add_node();
2344 // Add dictionary entries.
2345 addValueAttribute(proto, "name", node->getEventName());
2346 addValueAttribute(proto, "type", node->getEventType());
2347 addValueAttribute(proto, "index", node->getIndex());
2348
2349 return writeAllWithNode("TraceEvent", node, graph, proto);
2350}
2351
2352Error ONNXModelWriter::writeInsertTensor(const InsertTensorNode *node,
2353 GraphType &graph) {
2354 auto *proto = graph.add_node();
2355 // Add dictionary entries.
2356 addValueAttribute(proto, "start", node->getStart());
2357 addValueAttribute(proto, "count", node->getCount());
2358 addValueAttribute(proto, "axis", node->getAxis());
2359
2360 return writeAllWithNode("InsertTensor", node, graph, proto);
2361}
2362
2363Error ONNXModelWriter::writeSplat(const SplatNode *node, GraphType &graph) {
2364 // Conversion a scalar to a tensor is required.
2365 Tensor tensor(ElemKind::FloatTy, node->getResult().dims());
2366 auto handle = tensor.getHandle<>();
2367 float value = node->getValue();
2368 for (size_t b = 0, e = tensor.size(); b < e; ++b) {
2369 handle.raw(b) = value;
2370 }
2371
2372 auto *tensorProto = addInitializer(graph);
2373
2374 findOutputNames(node, graph, [&](const std::string &name) {
2375 tensorProto->set_name(name);
2376 });
2377
2378 writeTensor(tensor, tensorProto, useGlowCustomOps_);
2379 reportedNodes_.insert(node);
2380
2381 return Error::success();
2382}
2383
2384Error ONNXModelWriter::writeTouch(const TouchNode *node, GraphType &graph) {
2385 auto *proto = graph.add_node();
2386 return writeAllWithNode("Touch", node, graph, proto);
2387}
2388
2389// Exporting arithmetic node which may involve broadcasting.
2390// Broadcast Node will be unwind.
2391#define ARITHMETIC_NODE_WRITER(ONNXNAME, GLOWNAME) \
2392 Error ONNXModelWriter::write##GLOWNAME(const GLOWNAME##Node *node, \
2393 GraphType &graph) { \
2394 return writeArithmetic(#ONNXNAME, node, graph, reportedNodes_, \
2395 hasMultidirectionalBroadcast(#ONNXNAME)); \
2396 }
2397
2398ARITHMETIC_NODE_WRITER(Add, Add);
2399ARITHMETIC_NODE_WRITER(Sub, Sub);
2400ARITHMETIC_NODE_WRITER(Mul, Mul);
2401ARITHMETIC_NODE_WRITER(Div, Div);
2402ARITHMETIC_NODE_WRITER(Equal, CmpEQ)
2403ARITHMETIC_NODE_WRITER(And, And)
2404ARITHMETIC_NODE_WRITER(Or, Or)
2405ARITHMETIC_NODE_WRITER(Xor, Xor)
2406ARITHMETIC_NODE_WRITER(Less, CmpLT)
2407
2408// Ops that Onnx doesn't have
2409ARITHMETIC_NODE_WRITER(CmpLTE, CmpLTE)
2410ARITHMETIC_NODE_WRITER(FloorDiv, FloorDiv);
2411ARITHMETIC_NODE_WRITER(Fmod, Fmod)
2412ARITHMETIC_NODE_WRITER(BitwiseAnd, BitwiseAnd)
2413ARITHMETIC_NODE_WRITER(BitwiseOr, BitwiseOr)
2414ARITHMETIC_NODE_WRITER(BitwiseXor, BitwiseXor)
2415#undef ARITHMETIC_NODE_WRITER
2416
2417// Default exporting algorithm.
2418#define DEF_ALL_WRITER_NODE(NAME) \
2419 Error ONNXModelWriter::write##NAME(const NAME##Node *node, \
2420 GraphType &graph) { \
2421 return writeAll(#NAME, node, graph); \
2422 }
2423
2424// ONNX nodes with default exporting algorithm.
2425DEF_ALL_WRITER_NODE(Not)
2426DEF_ALL_WRITER_NODE(Abs)
2427DEF_ALL_WRITER_NODE(Neg)
2428DEF_ALL_WRITER_NODE(Floor)
2429DEF_ALL_WRITER_NODE(Sign)
2430DEF_ALL_WRITER_NODE(Ceil)
2431DEF_ALL_WRITER_NODE(Round)
2432DEF_ALL_WRITER_NODE(Sqrt)
2433DEF_ALL_WRITER_NODE(Rsqrt)
2434DEF_ALL_WRITER_NODE(Reciprocal)
2435DEF_ALL_WRITER_NODE(Sin)
2436DEF_ALL_WRITER_NODE(Cos)
2437DEF_ALL_WRITER_NODE(LSTMUnit)
2438DEF_ALL_WRITER_NODE(DynamicQuantizedFullyConnected)
2439DEF_ALL_WRITER_NODE(DynamicRowwiseQuantizedFullyConnected)
2440DEF_ALL_WRITER_NODE(Erf)
2441DEF_ALL_WRITER_NODE(Min)
2442DEF_ALL_WRITER_NODE(Max)
2443DEF_ALL_WRITER_NODE(Log)
2444DEF_ALL_WRITER_NODE(Asin)
2445DEF_ALL_WRITER_NODE(Acos)
2446DEF_ALL_WRITER_NODE(Atan)
2447DEF_ALL_WRITER_NODE(Exp)
2448DEF_ALL_WRITER_NODE(Relu)
2449DEF_ALL_WRITER_NODE(LeakyRelu)
2450DEF_ALL_WRITER_NODE(Gelu)
2451DEF_ALL_WRITER_NODE(Tanh)
2452DEF_ALL_WRITER_NODE(IsNaN)
2453DEF_ALL_WRITER_NODE(Sigmoid)
2454DEF_ALL_WRITER_NODE(Swish)
2455DEF_ALL_WRITER_NODE(SoftPlus)
2456DEF_ALL_WRITER_NODE(LengthsSum)
2457DEF_ALL_WRITER_NODE(BatchOneHot)
2458DEF_ALL_WRITER_NODE(LengthsToRanges)
2459DEF_ALL_WRITER_NODE(SparseLengthsSum)
2460DEF_ALL_WRITER_NODE(SparseLengthsWeightedSum)
2461DEF_ALL_WRITER_NODE(EmbeddingBag)
2462DEF_ALL_WRITER_NODE(Embedding)
2463DEF_ALL_WRITER_NODE(BitwiseNot)
2464DEF_ALL_WRITER_NODE(GaussianFill)
2465DEF_ALL_WRITER_NODE(NonZero)
2466DEF_ALL_WRITER_NODE(BatchSparseToDense)
2467DEF_ALL_WRITER_NODE(FillExamplesWithIndicator)
2468
2469// Glow nodes with default exporting algorithm.
2470DEF_ALL_WRITER_NODE(CmpNEQ)
2471DEF_ALL_WRITER_NODE(BatchedAdd)
2472DEF_ALL_WRITER_NODE(BatchedMul)
2473DEF_ALL_WRITER_NODE(Dequantize)
2474DEF_ALL_WRITER_NODE(Regression)
2475DEF_ALL_WRITER_NODE(RowwiseQuantizedSparseLengthsWeightedSum)
2476DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsSum)
2477DEF_ALL_WRITER_NODE(EmbeddingBagByteRowwiseOffsets)
2478DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsWeightedSum)
2479DEF_ALL_WRITER_NODE(NonMaxSuppression)
2480DEF_ALL_WRITER_NODE(TFLiteDetectionPostProcess)
2481DEF_ALL_WRITER_NODE(HardSwish)
2482DEF_ALL_WRITER_NODE(ConvTranspose)
2483DEF_ALL_WRITER_NODE(Logit)
2484DEF_ALL_WRITER_NODE(Truncate)
2485DEF_ALL_WRITER_NODE(BatchedUnaryEmbeddingsBags)
2486DEF_ALL_WRITER_NODE(IntNBitSplitEmbeddingBags)
2487DEF_ALL_WRITER_NODE(IntNBitSplitEmbeddingWeightedBags)
2488
2489Error ONNXModelWriter::writeClip(const ClipNode *node, GraphType &graph) {
2490 auto *proto = graph.add_node();
2491 addValueAttribute(proto, "min", node->getMin());
2492 addValueAttribute(proto, "max", node->getMax());
2493 return writeAllWithNode("Clip", node, graph, proto);
2494}
2495
2496Error ONNXModelWriter::writeConvertTo(const ConvertToNode *node,
2497 GraphType &graph) {
2498 auto *proto = graph.add_node();
2499
2500 // Add dictionary entries.
2501 TensorType ttype;
2502 for (auto d : node->getResult().dims()) {
2503 ttype.add_dims(d);
2504 }
2505 ttype.set_data_type(convertType(*node->getResult().getType()));
2506 auto *attr = proto->add_attribute();
2507 attr->set_name("shape");
2508 attr->mutable_t()->CopyFrom(ttype);
2509
2510 return writeAllWithNode("ConvertTo", node, graph, proto);
2511}
2512
2513Error ONNXModelWriter::writeSelect(const SelectNode *node, GraphType &graph) {
2514 auto *proto = graph.add_node();
2515 // Add dictionary entries.
2516 addValueAttribute(proto, "shape", node->getResult().dims());
2517
2518 return writeAllWithNode("Select", node, graph, proto);
2519}
2520
2521Error ONNXModelWriter::writeQuantize(const QuantizeNode *node,
2522 GraphType &graph) {
2523 auto *proto = graph.add_node();
2524 auto outTy = node->getResult().getType();
2525
2526 // Add dictionary entries.
2527 addValueAttribute(proto, "scale", outTy->getScale());
2528 addValueAttribute(proto, "offset", outTy->getOffset());
2529 addValueAttribute(proto, "elem_kind", outTy->getElementName());
2530
2531 return writeAllWithNode("Quantize", node, graph, proto);
2532}
2533
2534Error ONNXModelWriter::writeIntLookupTable(const IntLookupTableNode *node,
2535 GraphType &graph) {
2536 auto *proto = graph.add_node();
2537 // Add dictionary entries.
2538 addValueAttribute(proto, "shape", node->getResult().dims());
2539 NodeValue mapping = node->getMapping();
2540 if (Constant *c = llvm::dyn_cast<Constant>(mapping.getNode())) {
2541 auto handle = c->getHandle<int8_t>();
2542 auto begin = &handle.raw(0);
2543 addValueAttribute(
2544 proto, "values",
2545 llvm::ArrayRef<int8_t>(begin, begin + handle.actualSize()));
2546 } else {
2547 return MAKE_ERR("Mapping must be a constant type.");
2548 }
2549
2550 return writeAllWithNode("IntLookupTable", node, graph, proto);
2551}
2552
2553Error ONNXModelWriter::writeLookupTable(const LookupTableNode *node,
2554 GraphType &graph) {
2555 auto *proto = graph.add_node();
2556 // Add dictionary entries.
2557 addValueAttribute(proto, "shape", node->getResult().dims());
2558 NodeValue table = node->getTable();
2559 if (Constant *c = llvm::dyn_cast<Constant>(table.getNode())) {
2560 auto handle = c->getHandle<int8_t>();
2561 auto begin = &handle.raw(0);
2562 addValueAttribute(
2563 proto, "values",
2564 llvm::ArrayRef<int8_t>(begin, begin + handle.actualSize()));
2565 } else {
2566 return MAKE_ERR("Mapping must be a constant type.");
2567 }
2568
2569 return writeAllWithNode("LookupTable", node, graph, proto);
2570}
2571
2572Error ONNXModelWriter::writeLengthsRangeFill(const LengthsRangeFillNode *node,
2573 GraphType &graph) {
2574 auto *proto = graph.add_node();
2575 // Add dictionary entries.
2576 addValueAttribute(proto, "size", node->getResult().dims()[0]);
2577
2578 return writeAllWithNode("LengthsRangeFill", node, graph, proto);
2579}
2580
2581Error ONNXModelWriter::writeRescaleQuantized(const RescaleQuantizedNode *node,
2582 GraphType &graph) {
2583 auto *proto = graph.add_node();
2584 auto outTy = node->getResult().getType();
2585 // Add dictionary entries.
2586 addValueAttribute(proto, "scale", outTy->getScale());
2587 addValueAttribute(proto, "offset", outTy->getOffset());
2588
2589 return writeAllWithNode("RescaleQuantized", node, graph, proto);
2590}
2591
2592Error ONNXModelWriter::writeGemm(const GemmNode *node, GraphType &graph) {
2593 auto *proto = graph.add_node();
2594 proto->set_name(node->getName().str());
2595 proto->set_op_type("Gemm");
2596
2597 proto->add_input(node->getA().getNode()->getName().str());
2598 proto->add_input(node->getB().getNode()->getName().str());
2599 if (node->getC().getNode()) {
2600 proto->add_input(node->getC().getNode()->getName().str());
2601 }
2602
2603 addValueAttribute(proto, "alpha", node->getAlpha());
2604 addValueAttribute(proto, "beta", node->getBeta());
2605 addValueAttribute(proto, "transA", node->getTransposeA());
2606 addValueAttribute(proto, "transB", node->getTransposeB());
2607
2608 outputsToProto(node, graph, proto);
2609 return Error::success();
2610}
2611
2612Error ONNXModelWriter::writeFullyConnected(const FullyConnectedNode *node,
2613 GraphType &graph) {
2614 auto *proto = graph.add_node();
2615 proto->set_name(node->getName().str());
2616 proto->set_op_type("FullyConnected");
2617
2618 if (node->getInput().dims().size() != 2) {
2619 return MAKE_ERR("Don't support input dim other than 2");
2620 }
2621 proto->add_input(node->getInput().getNode()->getName().str());
2622 proto->add_input(node->getWeights().getNode()->getName().str());
2623 proto->add_input(node->getBias().getNode()->getName().str());
2624 outputsToProto(node, graph, proto);
2625 return Error::success();
2626}
2627
2628Error ONNXModelWriter::writeVectorNorm(const VectorNormNode *node,
2629 GraphType &graph) {
2630 auto *proto = graph.add_node();
2631
2632 // Add dictionary entries.
2633 addValueAttribute(proto, "axis", node->getAxis());
2634
2635 proto->set_name(node->getName().str());
2636 proto->set_op_type("VectorNorm");
2637 inputsToProto(node, proto);
2638
2639 // currently support p = 2 (Frobenius or i2)
2640 addValueAttribute(proto, "p", node->getP());
2641
2642 outputsToProto(node, graph, proto);
2643
2644 return Error::success();
2645}
2646
2647Error ONNXModelWriter::writeRowwiseQuantizedFullyConnected(
2648 const RowwiseQuantizedFullyConnectedNode *node, GraphType &graph) {
2649 auto *proto = graph.add_node();
2650
2651 // Add dictionary entries.
2652 addValueAttribute(
2653 proto, "out_scale",
2654 node->getType(RowwiseQuantizedFullyConnectedNode::ResultIdx)->getScale());
2655 addValueAttribute(proto, "out_offset",
2656 node->getType(RowwiseQuantizedFullyConnectedNode::ResultIdx)
2657 ->getOffset());
2658
2659 return writeAllWithNode("RowwiseQuantizedFullyConnected", node, graph, proto);
2660}
2661
2662Error ONNXModelWriter::writeTile(const TileNode *node, GraphType &graph) {
2663 auto *proto = graph.add_node();
2664
2665 // unwind Tile
2666 std::vector<size_t> repeats;
2667 const TileNode *tile = unwindTile(node, &repeats, reportedNodes_);
2668
2669 proto->set_name("Tile");
2670 proto->set_op_type("Tile");
2671 // Use inputs from top tile.
2672 inputsToProto(tile, proto);
2673 // Use outputs from bottom tile.
2674 outputsToProto(node, graph, proto);
2675
2676 // Add node indices
2677 auto *indices = graph.add_node();
2678 indices->set_name("Constant");
2679 indices->set_op_type("Constant");
2680 indices->add_output(tile->getName().str() + "_indices");
2681
2682 unsigned_t numDims = tile->getInput().dims().size();
2683
2684 DCHECK(repeats.size() == numDims);
2685
2686 // Add Tensor type attribute.
2687 addValueAttribute(indices, "value", llvm::makeArrayRef(repeats));
2688 // Add indices as input to the Tile node
2689 proto->add_input(tile->getName().str() + "_indices");
2690
2691 return Error::success();
2692}
2693
2694Error ONNXModelWriter::writeCumSum(const CumSumNode *node, GraphType &graph) {
2695 auto *proto = graph.add_node();
2696 // Add dictionary entries.
2697 addValueAttribute(proto, "axis", 0);
2698 addValueAttribute(proto, "exclusive", node->getExclusive());
2699 addValueAttribute(proto, "reverse", node->getReverse());
2700
2701 return writeAllWithNode("CumSum", node, graph, proto);
2702}
2703
2704Error ONNXModelWriter::writeScatterData(const ScatterDataNode *node,
2705 GraphType &graph) {
2706 auto *proto = graph.add_node();
2707
2708 return writeAllWithNode("ScatterData", node, graph, proto);
2709}
2710
2711// Unsupported for export Glow nodes.
2712#define DEF_UNSUPPORTED_STORAGE(NAME) \
2713 Error ONNXModelWriter::write##NAME(const NAME *node, GraphType &) { \
2714 return writeUnexpectedKind(node); \
2715 }
2716
2717// Helper nodes.
2718DEF_UNSUPPORTED_STORAGE(Placeholder)
2719DEF_UNSUPPORTED_STORAGE(Constant)
2720DEF_UNSUPPORTED_STORAGE(Storage)
2721
2722// Unsupported for export Glow nodes.
2723#define DEF_UNSUPPORTED_NODE(NAME) \
2724 Error ONNXModelWriter::write##NAME(const NAME##Node *node, GraphType &) { \
2725 return writeUnexpectedKind(node); \
2726 }
2727
2728DEF_UNSUPPORTED_NODE(BatchedPairwiseDotProduct)
2729DEF_UNSUPPORTED_NODE(Broadcast)
2730DEF_UNSUPPORTED_NODE(SGD)
2731DEF_UNSUPPORTED_NODE(SparseLabelSplit)
2732// Artificial node.
2733DEF_UNSUPPORTED_NODE(Save)
2734DEF_UNSUPPORTED_NODE(ExternalFunctionCall)
2735// Gradient nodes.
2736DEF_UNSUPPORTED_NODE(AddGrad)
2737DEF_UNSUPPORTED_NODE(DivGrad)
2738DEF_UNSUPPORTED_NODE(MulGrad)
2739DEF_UNSUPPORTED_NODE(SubGrad)
2740DEF_UNSUPPORTED_NODE(ReluGrad)
2741DEF_UNSUPPORTED_NODE(TanhGrad)
2742DEF_UNSUPPORTED_NODE(AvgPoolGrad)
2743DEF_UNSUPPORTED_NODE(MaxPoolGrad)
2744DEF_UNSUPPORTED_NODE(SigmoidGrad)
2745DEF_UNSUPPORTED_NODE(SoftMaxGrad)
2746DEF_UNSUPPORTED_NODE(LogSoftMaxGrad)
2747DEF_UNSUPPORTED_NODE(RegressionGrad)
2748DEF_UNSUPPORTED_NODE(ConvolutionGrad)
2749DEF_UNSUPPORTED_NODE(CrossEntropyLoss)
2750DEF_UNSUPPORTED_NODE(Convolution3DGrad)
2751DEF_UNSUPPORTED_NODE(FullyConnectedGrad)
2752DEF_UNSUPPORTED_NODE(CrossEntropyLossGrad)
2753DEF_UNSUPPORTED_NODE(BatchNormalizationGrad)
2754DEF_UNSUPPORTED_NODE(SparseLengthsSumGrad)
2755DEF_UNSUPPORTED_NODE(SparseLengthsWeightedSumGrad)
2756DEF_UNSUPPORTED_NODE(SigmoidCrossEntropyWithLogits)
2757DEF_UNSUPPORTED_NODE(LocalResponseNormalizationGrad)
2758DEF_UNSUPPORTED_NODE(AdaptiveAvgPoolGrad)
2759DEF_UNSUPPORTED_NODE(BatchedPairwiseDotProductGrad)
2760
2761// Include backend-specific ONNX model writers.
2762#include "glow/ONNXModelWriterIncludes.h"
2763
2764Error ONNXModelWriter::writeGlowCustomOperator(const Node *node,
2765 GraphType &graph) {
2766 ONNX_NAMESPACE::NodeProto *opProto = nullptr;
2767
2768 switch (node->getKind()) {
2769#include "glow/AutoGenNodesExport.h"
2770 default:
2771 return MAKE_ERR(
2772 strFormat("Unhandled Node for export: %s", node->getName().data()));
2773 }
2774 RETURN_ERR_IF_NOT(opProto, "Did not have valid opProto.");
2775
2776 // If dumping a DAG then add partition names to each op that's written.
2777 if (dagMode_ && !isWritingConstFoldSubgraph()) {
2778 addValueAttribute(opProto, "partitionName", F_->getName().str());
2779 }
2780
2781 // Check if there is backendSpecificNodeInfo for node, and if so include it.
2782 auto itF = backendSpecificNodeInfo_.find(node->getParent());
2783 if (itF != backendSpecificNodeInfo_.end()) {
2784 auto itN = itF->second.find(node);
2785 if (itN != itF->second.end()) {
2786 // We found backend-specific node info, so add it to the opProto.
2787 for (const auto &optValPair : itN->second) {
2788 addValueAttribute(opProto,
2789 std::string(nodeOptSignifier) + "_" +
2790 optValPair.getKey().data(),
2791 optValPair.getValue());
2792 }
2793 }
2794 }
2795
2796 return Error::success();
2797}
2798
2799bool ONNXModelWriter::hasMultidirectionalBroadcast(
2800 const llvm::StringRef typeName) {
2801 // Before opset 7, broadcasting was unidirectional.
2802 if (opsetVersion_ > 6) {
2803 // List of ops that support multidirectional broadcast can be found at
2804 // https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
2805 if ((typeName == "Add") || (typeName == "Sub") || (typeName == "Mul") ||
2806 (typeName == "Div") || (typeName == "Equal") ||
2807 (typeName == "Greater") || (typeName == "Less") ||
2808 (typeName == "Max") || (typeName == "Mean") || (typeName == "Min") ||
2809 (typeName == "Or") || (typeName == "Pow") || (typeName == "Sum") ||
2810 (typeName == "Xor")) {
2811 return true;
2812 }
2813 }
2814 return false;
2815}
2816
2817} // namespace glow
2818