1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "onnx/defs/printer.h"
6#include <iomanip>
7#include "onnx/defs/tensor_proto_util.h"
8
9namespace ONNX_NAMESPACE {
10
11using MetaDataProp = StringStringEntryProto;
12using MetaDataProps = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
13
14class ProtoPrinter {
15 public:
16 ProtoPrinter(std::ostream& os) : output_(os) {}
17
18 void print(const TensorShapeProto_Dimension& dim);
19
20 void print(const TensorShapeProto& shape);
21
22 void print(const TypeProto_Tensor& tensortype);
23
24 void print(const TypeProto& type);
25
26 void print(const TypeProto_Sequence& seqType);
27
28 void print(const TypeProto_Map& mapType);
29
30 void print(const TypeProto_Optional& optType);
31
32 void print(const TypeProto_SparseTensor& sparseType);
33
34 void print(const TensorProto& tensor);
35
36 void print(const ValueInfoProto& value_info);
37
38 void print(const ValueInfoList& vilist);
39
40 void print(const AttributeProto& attr);
41
42 void print(const AttrList& attrlist);
43
44 void print(const NodeProto& node);
45
46 void print(const NodeList& nodelist);
47
48 void print(const GraphProto& graph);
49
50 void print(const FunctionProto& fn);
51
52 void print(const ModelProto& model);
53
54 void print(const OperatorSetIdProto& opset);
55
56 void print(const OpsetIdList& opsets);
57
58 void print(const MetaDataProps& metadataprops) {
59 printSet("[", ", ", "]", metadataprops);
60 }
61
62 void print(const MetaDataProp& metadata) {
63 printQuoted(metadata.key());
64 output_ << ": ";
65 printQuoted(metadata.value());
66 }
67
68 private:
69 template <typename T>
70 inline void print(T prim) {
71 output_ << prim;
72 }
73
74 void printQuoted(const std::string& str) {
75 output_ << "\"";
76 for (const char* p = str.c_str(); *p; ++p) {
77 if ((*p == '\\') || (*p == '"'))
78 output_ << '\\';
79 output_ << *p;
80 }
81 output_ << "\"";
82 }
83
84 template <typename T>
85 inline void printKeyValuePair(KeyWordMap::KeyWord key, const T& val, bool addsep = true) {
86 if (addsep)
87 output_ << "," << std::endl;
88 output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
89 print(val);
90 }
91
92 inline void printKeyValuePair(KeyWordMap::KeyWord key, const std::string& val) {
93 output_ << "," << std::endl;
94 output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
95 printQuoted(val);
96 }
97
98 template <typename Collection>
99 inline void printSet(const char* open, const char* separator, const char* close, Collection coll) {
100 const char* sep = "";
101 output_ << open;
102 for (auto& elt : coll) {
103 output_ << sep;
104 print(elt);
105 sep = separator;
106 }
107 output_ << close;
108 }
109
110 std::ostream& output_;
111 int indent_level = 3;
112
113 void indent() {
114 indent_level += 3;
115 }
116
117 void outdent() {
118 indent_level -= 3;
119 }
120};
121
122void ProtoPrinter::print(const TensorShapeProto_Dimension& dim) {
123 if (dim.has_dim_value())
124 output_ << dim.dim_value();
125 else if (dim.has_dim_param())
126 output_ << dim.dim_param();
127 else
128 output_ << "?";
129}
130
131void ProtoPrinter::print(const TensorShapeProto& shape) {
132 printSet("[", ",", "]", shape.dim());
133}
134
135void ProtoPrinter::print(const TypeProto_Tensor& tensortype) {
136 output_ << PrimitiveTypeNameMap::ToString(tensortype.elem_type());
137 if (tensortype.has_shape()) {
138 if (tensortype.shape().dim_size() > 0)
139 print(tensortype.shape());
140 } else
141 output_ << "[]";
142}
143
144void ProtoPrinter::print(const TypeProto_Sequence& seqType) {
145 output_ << "seq(";
146 print(seqType.elem_type());
147 output_ << ")";
148}
149
150void ProtoPrinter::print(const TypeProto_Map& mapType) {
151 output_ << "map(" << PrimitiveTypeNameMap::ToString(mapType.key_type()) << ", ";
152 print(mapType.value_type());
153 output_ << ")";
154}
155
156void ProtoPrinter::print(const TypeProto_Optional& optType) {
157 output_ << "optional(";
158 print(optType.elem_type());
159 output_ << ")";
160}
161
162void ProtoPrinter::print(const TypeProto_SparseTensor& sparseType) {
163 output_ << "sparse_tensor(" << PrimitiveTypeNameMap::ToString(sparseType.elem_type());
164 if (sparseType.has_shape()) {
165 if (sparseType.shape().dim_size() > 0)
166 print(sparseType.shape());
167 } else
168 output_ << "[]";
169
170 output_ << ")";
171}
172
173void ProtoPrinter::print(const TypeProto& type) {
174 if (type.has_tensor_type())
175 print(type.tensor_type());
176 else if (type.has_sequence_type())
177 print(type.sequence_type());
178 else if (type.has_map_type())
179 print(type.map_type());
180 else if (type.has_optional_type())
181 print(type.optional_type());
182 else if (type.has_sparse_tensor_type())
183 print(type.sparse_tensor_type());
184}
185
186void ProtoPrinter::print(const TensorProto& tensor) {
187 output_ << PrimitiveTypeNameMap::ToString(tensor.data_type());
188 if (tensor.dims_size() > 0)
189 printSet("[", ",", "]", tensor.dims());
190
191 if (!tensor.name().empty()) {
192 output_ << " " << tensor.name();
193 }
194 // TODO: does not yet handle all types or externally stored data.
195 if (tensor.has_raw_data()) {
196 switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
197 case TensorProto::DataType::TensorProto_DataType_INT32:
198 printSet(" {", ",", "}", ParseData<int32_t>(&tensor));
199 break;
200 case TensorProto::DataType::TensorProto_DataType_INT64:
201 printSet(" {", ",", "}", ParseData<int64_t>(&tensor));
202 break;
203 case TensorProto::DataType::TensorProto_DataType_FLOAT:
204 printSet(" {", ",", "}", ParseData<float>(&tensor));
205 break;
206 case TensorProto::DataType::TensorProto_DataType_DOUBLE:
207 printSet(" {", ",", "}", ParseData<double>(&tensor));
208 break;
209 default:
210 output_ << "..."; // ParseData not instantiated for other types.
211 break;
212 }
213 } else {
214 switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
215 case TensorProto::DataType::TensorProto_DataType_INT8:
216 case TensorProto::DataType::TensorProto_DataType_INT16:
217 case TensorProto::DataType::TensorProto_DataType_INT32:
218 case TensorProto::DataType::TensorProto_DataType_UINT8:
219 case TensorProto::DataType::TensorProto_DataType_UINT16:
220 case TensorProto::DataType::TensorProto_DataType_BOOL:
221 printSet(" {", ",", "}", tensor.int32_data());
222 break;
223 case TensorProto::DataType::TensorProto_DataType_INT64:
224 printSet(" {", ",", "}", tensor.int64_data());
225 break;
226 case TensorProto::DataType::TensorProto_DataType_UINT32:
227 case TensorProto::DataType::TensorProto_DataType_UINT64:
228 printSet(" {", ",", "}", tensor.uint64_data());
229 break;
230 case TensorProto::DataType::TensorProto_DataType_FLOAT:
231 printSet(" {", ",", "}", tensor.float_data());
232 break;
233 case TensorProto::DataType::TensorProto_DataType_DOUBLE:
234 printSet(" {", ",", "}", tensor.double_data());
235 break;
236 case TensorProto::DataType::TensorProto_DataType_STRING: {
237 const char* sep = "{";
238 for (auto& elt : tensor.string_data()) {
239 output_ << sep;
240 printQuoted(elt);
241 sep = ", ";
242 }
243 output_ << "}";
244 break;
245 }
246 default:
247 break;
248 }
249 }
250}
251
252void ProtoPrinter::print(const ValueInfoProto& value_info) {
253 print(value_info.type());
254 output_ << " " << value_info.name();
255}
256
257void ProtoPrinter::print(const ValueInfoList& vilist) {
258 printSet("(", ", ", ")", vilist);
259}
260
261void ProtoPrinter::print(const AttributeProto& attr) {
262 // Special case of attr-ref:
263 if (attr.has_ref_attr_name()) {
264 output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = @" << attr.ref_attr_name();
265 return;
266 }
267 // General case:
268 output_ << attr.name() << " = ";
269 switch (attr.type()) {
270 case AttributeProto_AttributeType_INT:
271 output_ << attr.i();
272 break;
273 case AttributeProto_AttributeType_INTS:
274 printSet("[", ", ", "]", attr.ints());
275 break;
276 case AttributeProto_AttributeType_FLOAT:
277 output_ << attr.f();
278 break;
279 case AttributeProto_AttributeType_FLOATS:
280 printSet("[", ", ", "]", attr.floats());
281 break;
282 case AttributeProto_AttributeType_STRING:
283 output_ << "\"" << attr.s() << "\"";
284 break;
285 case AttributeProto_AttributeType_STRINGS: {
286 const char* sep = "[";
287 for (auto& elt : attr.strings()) {
288 output_ << sep << "\"" << elt << "\"";
289 sep = ", ";
290 }
291 output_ << "]";
292 break;
293 }
294 case AttributeProto_AttributeType_GRAPH:
295 indent();
296 print(attr.g());
297 outdent();
298 break;
299 case AttributeProto_AttributeType_GRAPHS:
300 indent();
301 printSet("[", ", ", "]", attr.graphs());
302 outdent();
303 break;
304 case AttributeProto_AttributeType_TENSOR:
305 print(attr.t());
306 break;
307 case AttributeProto_AttributeType_TENSORS:
308 printSet("[", ", ", "]", attr.tensors());
309 break;
310 default:
311 break;
312 }
313}
314
315void ProtoPrinter::print(const AttrList& attrlist) {
316 printSet(" <", ", ", ">", attrlist);
317}
318
319void ProtoPrinter::print(const NodeProto& node) {
320 output_ << std::setw(indent_level) << ' ';
321 printSet("", ", ", "", node.output());
322 output_ << " = ";
323 if (node.domain() != "")
324 output_ << node.domain() << ".";
325 output_ << node.op_type();
326 bool has_subgraph = false;
327 for (auto attr : node.attribute())
328 if (attr.has_g() || (attr.graphs_size() > 0))
329 has_subgraph = true;
330 if ((!has_subgraph) && (node.attribute_size() > 0))
331 print(node.attribute());
332 printSet(" (", ", ", ")", node.input());
333 if ((has_subgraph) && (node.attribute_size() > 0))
334 print(node.attribute());
335 output_ << "\n";
336}
337
338void ProtoPrinter::print(const NodeList& nodelist) {
339 output_ << "{\n";
340 for (auto& node : nodelist) {
341 print(node);
342 }
343 if (indent_level > 3)
344 output_ << std::setw(indent_level - 3) << " ";
345 output_ << "}";
346}
347
348void ProtoPrinter::print(const GraphProto& graph) {
349 output_ << graph.name() << " " << graph.input() << " => " << graph.output() << " ";
350 print(graph.node());
351}
352
353void ProtoPrinter::print(const ModelProto& model) {
354 output_ << "<\n";
355 printKeyValuePair(KeyWordMap::KeyWord::IR_VERSION, model.ir_version(), false);
356 printKeyValuePair(KeyWordMap::KeyWord::OPSET_IMPORT, model.opset_import());
357 if (model.has_producer_name())
358 printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_NAME, model.producer_name());
359 if (model.has_producer_version())
360 printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_VERSION, model.producer_version());
361 if (model.has_domain())
362 printKeyValuePair(KeyWordMap::KeyWord::DOMAIN_KW, model.domain());
363 if (model.has_model_version())
364 printKeyValuePair(KeyWordMap::KeyWord::MODEL_VERSION, model.model_version());
365 if (model.has_doc_string())
366 printKeyValuePair(KeyWordMap::KeyWord::DOC_STRING, model.doc_string());
367 if (model.metadata_props_size() > 0)
368 printKeyValuePair(KeyWordMap::KeyWord::METADATA_PROPS, model.metadata_props());
369 output_ << std::endl << ">" << std::endl;
370
371 print(model.graph());
372 for (const auto& fn : model.functions()) {
373 output_ << std::endl;
374 print(fn);
375 }
376}
377
378void ProtoPrinter::print(const OperatorSetIdProto& opset) {
379 output_ << "\"" << opset.domain() << "\" : " << opset.version();
380}
381
382void ProtoPrinter::print(const OpsetIdList& opsets) {
383 printSet("[", ", ", "]", opsets);
384}
385
386void ProtoPrinter::print(const FunctionProto& fn) {
387 output_ << "<\n";
388 output_ << " "
389 << "domain: \"" << fn.domain() << "\",\n";
390 output_ << " "
391 << "opset_import: ";
392 printSet("[", ",", "]", fn.opset_import());
393 output_ << "\n>\n";
394 output_ << fn.name() << " ";
395 if (fn.attribute_size() > 0)
396 printSet("<", ",", ">", fn.attribute());
397 printSet("(", ", ", ")", fn.input());
398 output_ << " => ";
399 printSet("(", ", ", ")", fn.output());
400 output_ << "\n";
401 print(fn.node());
402}
403
404#define DEF_OP(T) \
405 std::ostream& operator<<(std::ostream& os, const T& proto) { \
406 ProtoPrinter printer(os); \
407 printer.print(proto); \
408 return os; \
409 };
410
411DEF_OP(TensorShapeProto_Dimension)
412
413DEF_OP(TensorShapeProto)
414
415DEF_OP(TypeProto_Tensor)
416
417DEF_OP(TypeProto)
418
419DEF_OP(TensorProto)
420
421DEF_OP(ValueInfoProto)
422
423DEF_OP(ValueInfoList)
424
425DEF_OP(AttributeProto)
426
427DEF_OP(AttrList)
428
429DEF_OP(NodeProto)
430
431DEF_OP(NodeList)
432
433DEF_OP(GraphProto)
434
435DEF_OP(FunctionProto)
436
437DEF_OP(ModelProto)
438
439} // namespace ONNX_NAMESPACE