1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// ATTENTION: The code in this file is highly EXPERIMENTAL.
6// Adventurous users should note that the APIs will probably change.
7
8#include "onnx/common/ir_pb_converter.h"
9#include <sstream>
10
11namespace ONNX_NAMESPACE {
12
13// Part 1: convert ONNX Protobuf to IR
14std::unique_ptr<Graph> graphProtoToGraph(const GraphProto& gp, bool nested, const int ir_version = IR_VERSION);
15
16Tensor tensorProtoToTensor(const ONNX_NAMESPACE::TensorProto& tp) {
17 Tensor ret;
18
19 ret.sizes().reserve(tp.dims_size());
20 for (int i = 0; i < tp.dims_size(); i++) {
21 ret.sizes().push_back(tp.dims(i));
22 }
23
24 ret.elem_type() = tp.data_type();
25 switch (tp.data_type()) {
26 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
27 case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: {
28 ret.floats().reserve(tp.float_data_size());
29 for (int i = 0; i < tp.float_data_size(); i++) {
30 ret.floats().push_back(tp.float_data(i));
31 }
32 break;
33 }
34 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
35 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
36 case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
37 case ONNX_NAMESPACE::TensorProto_DataType_INT8:
38 case ONNX_NAMESPACE::TensorProto_DataType_INT16:
39 case ONNX_NAMESPACE::TensorProto_DataType_INT32:
40 case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
41 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
42 ret.int32s().reserve(tp.int32_data_size());
43 for (int i = 0; i < tp.int32_data_size(); i++) {
44 ret.int32s().push_back(tp.int32_data(i));
45 }
46 break;
47 }
48 case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
49 ret.int64s().reserve(tp.int64_data_size());
50 for (int i = 0; i < tp.int64_data_size(); i++) {
51 ret.int64s().push_back(tp.int64_data(i));
52 }
53 break;
54 }
55 case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
56 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
57 ret.uint64s().reserve(tp.uint64_data_size());
58 for (int i = 0; i < tp.uint64_data_size(); i++) {
59 ret.uint64s().push_back(tp.uint64_data(i));
60 }
61 break;
62 }
63 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
64 case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: {
65 ret.doubles().reserve(tp.double_data_size());
66 for (int i = 0; i < tp.double_data_size(); i++) {
67 ret.doubles().push_back(tp.double_data(i));
68 }
69 break;
70 }
71 case ONNX_NAMESPACE::TensorProto_DataType_STRING: {
72 ret.strings().reserve(tp.string_data_size());
73 for (int i = 0; i < tp.string_data_size(); i++) {
74 ret.strings().push_back(tp.string_data(i));
75 }
76 break;
77 }
78 case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
79 fail_convert("Unknown tensor data type");
80 }
81
82 // The only way to know if we should be using raw_data or
83 // <type>_data is to look at which of them is size zero.
84 if (tp.has_raw_data()) {
85 ret.set_raw_data(tp.raw_data());
86 }
87
88 if (tp.has_name()) {
89 ret.setName(tp.name());
90 }
91 if (tp.has_segment()) {
92 ret.set_segment_begin_and_end(tp.segment().begin(), tp.segment().end());
93 }
94 return ret;
95}
96
97void convertAttribute(const ONNX_NAMESPACE::AttributeProto& ap, Node* n, const int ir_version = IR_VERSION) {
98 Symbol sym = Symbol(ap.name());
99 switch (ap.type()) {
100 case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
101 n->f_(sym, ap.f());
102 break;
103 case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS: {
104 std::vector<double> floats;
105 floats.reserve(ap.floats_size());
106 for (int i = 0; i < ap.floats_size(); i++) {
107 floats.push_back(ap.floats(i));
108 }
109 n->fs_(sym, std::move(floats));
110 break;
111 }
112 case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
113 n->i_(sym, ap.i());
114 break;
115 case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS: {
116 std::vector<int64_t> ints;
117 ints.reserve(ap.ints_size());
118 for (int i = 0; i < ap.ints_size(); i++) {
119 ints.push_back(ap.ints(i));
120 }
121 n->is_(sym, std::move(ints));
122 break;
123 }
124 case ONNX_NAMESPACE::AttributeProto_AttributeType_STRING:
125 n->s_(sym, ap.s());
126 break;
127 case ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS: {
128 std::vector<std::string> strings;
129 strings.reserve(ap.strings_size());
130 for (int i = 0; i < ap.strings_size(); i++) {
131 strings.push_back(ap.strings(i));
132 }
133 n->ss_(sym, std::move(strings));
134 break;
135 }
136 case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR:
137 n->t_(sym, tensorProtoToTensor(ap.t()));
138 break;
139 case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS: {
140 std::vector<Tensor> tensors;
141 tensors.reserve(ap.tensors_size());
142 for (int i = 0; i < ap.tensors_size(); i++) {
143 tensors.push_back(tensorProtoToTensor(ap.tensors(i)));
144 }
145 n->ts_(sym, std::move(tensors));
146 break;
147 }
148 case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO:
149 n->tp_(sym, ap.tp());
150 break;
151 case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS: {
152 std::vector<TypeProto> types;
153 types.reserve(ap.type_protos_size());
154 for (int i = 0; i < ap.type_protos_size(); i++) {
155 types.push_back(ap.type_protos(i));
156 }
157 n->tps_(sym, std::move(types));
158 break;
159 }
160 case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH:
161 n->g_(sym, graphProtoToGraph(ap.g(), true, ir_version));
162 break;
163 case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS: {
164 std::vector<std::shared_ptr<Graph>> graphs;
165 graphs.reserve(ap.graphs_size());
166 for (int i = 0; i < ap.graphs_size(); i++) {
167 graphs.push_back(graphProtoToGraph(ap.graphs(i), true, ir_version));
168 }
169 n->gs_(sym, std::move(graphs));
170 break;
171 }
172 case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR:
173 case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS:
174 fail_convert("Sparse tensors not supported.");
175 break;
176 case ONNX_NAMESPACE::AttributeProto_AttributeType_UNDEFINED:
177 fail_convert("Unknown tensor data type");
178 break;
179 }
180}
181
182void convertAttributes(ONNX_NAMESPACE::NodeProto& np, Node* n, const int ir_version = IR_VERSION) {
183 for (int i = 0; i < np.attribute_size(); i++) {
184 convertAttribute(np.attribute(i), n, ir_version);
185 }
186}
187
188std::vector<Dimension> tensorShapeProtoToDimensions(const ONNX_NAMESPACE::TensorShapeProto& tsp) {
189 std::vector<Dimension> dims;
190 dims.reserve(tsp.dim_size());
191 for (int i = 0; i < tsp.dim_size(); i++) {
192 if (tsp.dim(i).has_dim_value()) {
193 dims.emplace_back(tsp.dim(i).dim_value());
194 } else if (tsp.dim(i).has_dim_param()) {
195 dims.emplace_back(tsp.dim(i).dim_param());
196 } else {
197 // a dimension that has neither dim_value nor dim_param set
198 // represents an unknown dimension unrelated to other unknown dimensions.
199 dims.emplace_back();
200 }
201 }
202 return dims;
203}
204
205void createDummyValue(
206 std::unique_ptr<Graph>& g,
207 const std::string& name,
208 std::unordered_map<std::string, Value*>& value_by_name_of) {
209 auto* undef = g->create(kCaptured, 1);
210 g->appendNode(undef);
211 undef->outputs()[0]->setUniqueName(name);
212 value_by_name_of[name] = undef->outputs()[0];
213}
214
215std::unique_ptr<Graph> graphProtoToGraph(const ONNX_NAMESPACE::GraphProto& gp, bool nested, const int ir_version) {
216 std::unique_ptr<Graph> g(new Graph());
217
218 if (gp.has_name()) {
219 g->setName(gp.name());
220 }
221 if (gp.has_doc_string()) {
222 g->setDocString(gp.doc_string());
223 }
224
225 // Values are created (as in `new Value(..)`) by the Node that
226 // outputs them. Therefore we initialize the Nodes and Values in
227 // several stages.
228 //
229 // 1) add all input (to the graph) Values, owned by the sentinel Param node
230 // 2) add all Nodes and their output Values, but don't intialize inputs
231 // 3) initialize inputs of all Nodes
232 // 4) initialize inputs of the Return sentinel node
233 // 5) fill in type info for graph outputs, and register them as outputs
234 // 6) fill in type info for Values from the value_info list in the graph
235
236 // In ONNX proto land, Values are just strings. We are going to make
237 // objects out of them, and equal strings must be mapped to the same
238 // Value object.
239 std::unordered_map<std::string, Value*> value_by_name_of;
240
241 // We initialize Node inputs in a separate pass from the Nodes
242 // themselves. To do so, we need to have access to the names of the
243 // inputs.
244 std::unordered_map<Node*, std::vector<std::string>> inputs_by_node;
245
246 {
247 // ONNX represents optional arguments in two ways
248 // - they are simply not provided
249 // - OR the empty string is passed as the input name
250 // This is to handle that second case, which needs a dummy node to
251 // be representable in the graph IR.
252 auto* n = g->create(kUndefined, 1);
253 g->appendNode(n);
254 n->outputs()[0]->setUniqueName("");
255 value_by_name_of[""] = n->outputs()[0];
256 }
257
258 for (int i = 0; i < gp.input_size(); i++) {
259 const auto& vip = gp.input(i);
260 auto v = g->addInput();
261 const auto& tensor_type = vip.type().tensor_type();
262 if (tensor_type.has_elem_type()) {
263 v->setElemType(tensor_type.elem_type());
264 }
265 if (tensor_type.has_shape()) {
266 v->setSizes(tensorShapeProtoToDimensions(tensor_type.shape()));
267 }
268 v->setUniqueName(vip.name());
269 value_by_name_of[vip.name()] = v;
270 }
271
272 // initializers should be added before all nodes,
273 // otherwise getNextUnique() may conflicts with an existing initializer name.
274 for (int i = 0; i < gp.initializer_size(); ++i) {
275 auto init = tensorProtoToTensor(gp.initializer(i));
276 // If ir_version >= 4, initializer does not have to be included in input
277 // Create a Value from initializer by addInitializerNode if name does not exist in input
278 // and save it into value_by_name_of for later use (node input)
279 if (ir_version >= 4 && value_by_name_of.count(init.name()) == 0) {
280 value_by_name_of[init.name()] = g->addInitializerAndCreateValue(init);
281 } else {
282 // If ir_version < 4 or the initializer exists in input
283 // Simply add initializer without creating new value
284 // which means it will prioritize input value over initializer value if both exist
285 g->addInitializer(init);
286 }
287 }
288
289 for (int i = 0; i < gp.node_size(); i++) {
290 auto np = gp.node(i);
291 auto* n = g->create(Symbol(np.op_type()), /* num_outputs = */ np.output_size());
292 g->appendNode(n);
293 for (int j = 0; j < np.output_size(); j++) {
294 auto out = n->outputs()[j];
295 // we don't know the real type here, so that's done in a later pass
296 out->setElemType(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
297 out->setUniqueName(np.output(j));
298 value_by_name_of[np.output(j)] = out;
299 }
300 convertAttributes(np, n, ir_version);
301 std::vector<std::string> inputs;
302 inputs.reserve(np.input_size());
303 for (int j = 0; j < np.input_size(); j++) {
304 inputs.push_back(np.input(j));
305 }
306 inputs_by_node[n] = inputs;
307 if (np.has_doc_string()) {
308 n->setDocString(np.doc_string());
309 }
310 if (np.has_name()) {
311 n->setName(np.name());
312 }
313 if (np.has_domain()) {
314 n->setDomain(np.domain());
315 }
316 }
317
318 for (auto n : g->nodes()) {
319 auto search = inputs_by_node.find(n);
320 if (search == inputs_by_node.end()) {
321 continue;
322 }
323 for (const auto& input : search->second) {
324 if (!value_by_name_of.count(input) && nested) {
325 // Undefined reference to an input in a nested block. This may be a
326 // captured value. Create a dummy node that we ignore later.
327 createDummyValue(g, input, value_by_name_of);
328 }
329
330 if (!value_by_name_of.count(input)) {
331 std::ostringstream msg;
332 msg << "Input " << input << " is undefined!";
333 ONNX_THROW_EX(std::out_of_range(msg.str()));
334 }
335 n->addInput(value_by_name_of.at(input));
336 }
337 }
338
339 for (int i = 0; i < gp.output_size(); i++) {
340 if (!value_by_name_of.count(gp.output(i).name()) && nested) {
341 // Same captured value logic as above. We can consider outputs of a
342 // graph to be "inputs" of a dummy "output" node. The same lexical
343 // scoping rules are valid here, thus we need to add a dummy node
344 // in the case of an undefined reference
345 createDummyValue(g, gp.output(i).name(), value_by_name_of);
346 }
347 const auto& output_tensor_type = gp.output(i).type().tensor_type();
348 if (output_tensor_type.has_elem_type()) {
349 value_by_name_of[gp.output(i).name()]->setElemType(output_tensor_type.elem_type());
350 }
351 if (output_tensor_type.has_shape()) {
352 value_by_name_of[gp.output(i).name()]->setSizes(tensorShapeProtoToDimensions(output_tensor_type.shape()));
353 }
354 g->registerOutput(value_by_name_of[gp.output(i).name()]);
355 }
356
357 for (int i = 0; i < gp.value_info_size(); i++) {
358 const auto& tensor_type = gp.value_info(i).type().tensor_type();
359 if (!value_by_name_of.count(gp.value_info(i).name())) {
360 // Ideally the model should not have a value_info whose name does not exist in the graph (unused); simply skip it
361 continue;
362 }
363 if (tensor_type.has_elem_type()) {
364 value_by_name_of[gp.value_info(i).name()]->setElemType(tensor_type.elem_type());
365 }
366 if (tensor_type.has_shape()) {
367 value_by_name_of[gp.value_info(i).name()]->setSizes(tensorShapeProtoToDimensions(tensor_type.shape()));
368 }
369 }
370
371 return g;
372}
373
374std::unique_ptr<Graph> ImportModelProto(const ModelProto& mp) {
375 if (!mp.has_ir_version()) {
376 return nullptr;
377 } else if (mp.ir_version() <= 1) {
378 // ir_version=1 is not supported and ir_version=0 is illegal
379 return nullptr;
380 }
381
382 std::unique_ptr<Graph> g(graphProtoToGraph(mp.graph(), false, mp.ir_version()));
383 for (int i = 0; i < mp.opset_import_size(); i++) {
384 OpSetID new_opset_version(mp.opset_import(i).domain(), mp.opset_import(i).version());
385 g->forSelfAndEachSubGraph(
386 [&new_opset_version](Graph* graph) { graph->opset_versions_mutable().emplace_back(new_opset_version); });
387 }
388 return g;
389}
390
391// Part 2: convert IR to ONNX Protobuf
392std::string value_name(Value* n) {
393 return n->uniqueName();
394}
395
396void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g);
397
398void encodeTensor(ONNX_NAMESPACE::TensorProto* p, const Tensor& tensor) {
399 if (tensor.hasName()) {
400 p->set_name(tensor.name());
401 }
402 if (tensor.is_segment()) {
403 ONNX_NAMESPACE::TensorProto_Segment segment;
404 segment.set_begin(tensor.segment_begin());
405 segment.set_end(tensor.segment_end());
406 p->mutable_segment()->CopyFrom(segment);
407 }
408 for (auto d : tensor.sizes()) {
409 p->add_dims(d);
410 }
411 p->set_data_type(tensor.elem_type());
412 switch (tensor.elem_type()) {
413 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
414 case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: {
415 for (float x : tensor.floats()) {
416 p->add_float_data(x);
417 }
418 break;
419 }
420 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
421 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
422 case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
423 case ONNX_NAMESPACE::TensorProto_DataType_INT8:
424 case ONNX_NAMESPACE::TensorProto_DataType_INT16:
425 case ONNX_NAMESPACE::TensorProto_DataType_INT32:
426 case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
427 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
428 for (int32_t x : tensor.int32s()) {
429 p->add_int32_data(x);
430 }
431 break;
432 }
433 case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
434 for (int64_t x : tensor.int64s()) {
435 p->add_int64_data(x);
436 }
437 break;
438 }
439 case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
440 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
441 for (uint64_t x : tensor.uint64s()) {
442 p->add_uint64_data(x);
443 }
444 break;
445 }
446 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
447 case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: {
448 for (double x : tensor.doubles()) {
449 p->add_double_data(x);
450 }
451 break;
452 }
453 case ONNX_NAMESPACE::TensorProto_DataType_STRING: {
454 for (const std::string& x : tensor.strings()) {
455 p->add_string_data(x);
456 }
457 break;
458 }
459 case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
460 fail_convert("Unknown tensor data type");
461 }
462 if (tensor.is_raw_data()) {
463 p->set_raw_data(tensor.raw());
464 }
465}
466
467void addAttribute(ONNX_NAMESPACE::NodeProto* n_p, Node* n, Symbol name) {
468 auto attr = n_p->add_attribute();
469 attr->set_name(name.toString());
470 switch (n->kindOf(name)) {
471 case AttributeKind::f: {
472 attr->set_f(static_cast<float>(n->f(name)));
473 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT);
474 } break;
475 case AttributeKind::fs: {
476 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS);
477 for (auto& v : n->fs(name))
478 attr->add_floats(static_cast<float>(v));
479 } break;
480 case AttributeKind::i: {
481 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
482 attr->set_i(n->i(name));
483 } break;
484 case AttributeKind::is: {
485 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS);
486 for (auto& v : n->is(name))
487 attr->add_ints(v);
488 } break;
489 case AttributeKind::s: {
490 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING);
491 attr->set_s(n->s(name));
492 } break;
493 case AttributeKind::ss: {
494 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS);
495 for (auto& v : n->ss(name))
496 attr->add_strings(v);
497 } break;
498 case AttributeKind::t: {
499 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR);
500 auto t = attr->mutable_t();
501 encodeTensor(t, n->t(name));
502 } break;
503 case AttributeKind::ts: {
504 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS);
505 for (auto& v : n->ts(name)) {
506 auto t = attr->add_tensors();
507 encodeTensor(t, v);
508 }
509 } break;
510 case AttributeKind::g: {
511 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH);
512 auto g = attr->mutable_g();
513 encodeGraph(g, n->g(name));
514 } break;
515 case AttributeKind::gs: {
516 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS);
517 for (auto& v : n->gs(name)) {
518 auto g = attr->add_graphs();
519 encodeGraph(g, v);
520 }
521 } break;
522 case AttributeKind::tp: {
523 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO);
524 auto tp = attr->mutable_tp();
525 tp->CopyFrom(n->tp(name));
526 } break;
527 case AttributeKind::tps: {
528 attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS);
529 for (auto& v : n->tps(name)) {
530 auto tp = attr->add_type_protos();
531 tp->CopyFrom(v);
532 }
533 } break;
534 }
535}
536
537void encodeTypeProtoTensorType(ONNX_NAMESPACE::TypeProto_Tensor* tensor_type, Value* n) {
538 if (n->elemType() != 0) {
539 tensor_type->set_elem_type(n->elemType());
540 }
541 if (n->has_sizes()) {
542 ONNX_NAMESPACE::TensorShapeProto* shape = tensor_type->mutable_shape();
543 for (const Dimension& d : n->sizes()) {
544 auto dim = shape->add_dim();
545 if (!d.is_unknown) {
546 if (d.is_int) {
547 dim->set_dim_value(d.dim);
548 } else {
549 dim->set_dim_param(d.param);
550 }
551 }
552 }
553 }
554}
555
556void encodeValueInfo(ONNX_NAMESPACE::ValueInfoProto* v, Value* n) {
557 v->set_name(value_name(n));
558 if (n->elemType() != 0 || n->has_sizes()) {
559 ONNX_NAMESPACE::TypeProto* t = v->mutable_type();
560 ONNX_NAMESPACE::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
561 encodeTypeProtoTensorType(tensor_type, n);
562 }
563}
564
565void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g) {
566 ONNX_ASSERT(p_g != nullptr);
567
568 if (g->has_name()) {
569 p_g->set_name(g->name());
570 }
571
572 if (g->has_doc_string()) {
573 p_g->set_doc_string(g->docString());
574 }
575
576 for (auto input : g->inputs()) {
577 ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_input();
578 encodeValueInfo(v, input);
579 }
580 for (auto output : g->outputs()) {
581 ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_output();
582 encodeValueInfo(v, output);
583 }
584
585 std::unordered_set<Value*> graph_outputs(g->outputs().begin(), g->outputs().end());
586
587 for (auto node : g->nodes()) {
588 if (node->kind() == kUndefined || node->kind() == kCaptured) {
589 // Undefined nodes are used to represent optional inputs that are not
590 // provided.
591 continue;
592 }
593 auto p_n = p_g->add_node();
594 for (auto input : node->inputs()) {
595 if (input->node()->kind() == kUndefined) {
596 p_n->add_input("");
597 } else {
598 p_n->add_input(value_name(input));
599 }
600 }
601 for (auto output : node->outputs()) {
602 p_n->add_output(value_name(output));
603 // only save it if
604 // - it has actual information worth saving
605 // - it's not already saved in the graph outputs value info
606 if (graph_outputs.find(output) != graph_outputs.end()) {
607 continue;
608 }
609 if (output->elemType() == TensorProto_DataType_UNDEFINED && output->sizes().empty()) {
610 continue;
611 }
612 ValueInfoProto* v = p_g->add_value_info();
613 encodeValueInfo(v, output);
614 }
615 p_n->set_op_type(node->kind().toString());
616 for (auto attr_name : node->attributeNames()) {
617 addAttribute(p_n, node, attr_name);
618 }
619 if (node->has_doc_string()) {
620 p_n->set_doc_string(node->docString());
621 }
622 if (node->has_name()) {
623 p_n->set_name(node->name());
624 }
625 if (node->has_domain()) {
626 p_n->set_domain(node->domain());
627 }
628 }
629
630 auto num_initializers = g->initializers().size();
631 for (unsigned int i = 0; i < num_initializers; i++) {
632 auto p = p_g->add_initializer();
633 p->set_name(g->initializer_names()[i]);
634 encodeTensor(p, g->initializers()[i]);
635 }
636}
637
638void ExportModelProto(ModelProto* p_m, const std::shared_ptr<Graph>& g) {
639 GraphProto* p_g = p_m->mutable_graph();
640 encodeGraph(p_g, g);
641 // Add new opset_versions
642 p_m->clear_opset_import();
643 for (const OpSetID& opset : g->opset_versions_mutable()) {
644 OperatorSetIdProto* opset_version_output = p_m->add_opset_import();
645 opset_version_output->set_domain(opset.domain());
646 opset_version_output->set_version(opset.version());
647 }
648}
649
650ModelProto PrepareOutput(const ModelProto& mp_in) {
651 ModelProto mp_out{};
652
653 if (mp_in.has_ir_version()) {
654 mp_out.set_ir_version(mp_in.ir_version());
655 }
656 if (mp_in.has_producer_name()) {
657 mp_out.set_producer_name(mp_in.producer_name());
658 }
659 if (mp_in.has_producer_version()) {
660 mp_out.set_producer_version(mp_in.producer_version());
661 }
662 if (mp_in.has_domain()) {
663 mp_out.set_domain(mp_in.domain());
664 }
665 if (mp_in.has_model_version()) {
666 mp_out.set_model_version(mp_in.model_version());
667 }
668 if (mp_in.has_doc_string()) {
669 mp_out.set_doc_string(mp_in.doc_string());
670 }
671 for (int i = 0; i < mp_in.opset_import_size(); i++) {
672 auto& oi_in = mp_in.opset_import(i);
673 auto* oi_out = mp_out.add_opset_import();
674 if (oi_in.has_domain()) {
675 oi_out->set_domain(oi_in.domain());
676 }
677 if (oi_in.has_version()) {
678 oi_out->set_version(oi_in.version());
679 }
680 }
681 for (int i = 0; i < mp_in.metadata_props_size(); i++) {
682 auto& pp_in = mp_in.metadata_props(i);
683 auto* pp_out = mp_out.add_metadata_props();
684 if (pp_in.has_key()) {
685 pp_out->set_key(pp_in.key());
686 }
687 if (pp_in.has_value()) {
688 pp_out->set_value(pp_in.value());
689 }
690 }
691
692 return mp_out;
693}
694
695void assertNonNull(const std::shared_ptr<Graph>& g) {
696 ONNX_ASSERTM(
697 g.get() != nullptr,
698 "Warning: onnx version converter is unable to parse input model. "
699 "(The IR version of the ONNX model may be too old.)");
700}
701
702} // namespace ONNX_NAMESPACE
703