1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "onnx/defs/printer.h" |
6 | #include <iomanip> |
7 | #include "onnx/defs/tensor_proto_util.h" |
8 | |
9 | namespace ONNX_NAMESPACE { |
10 | |
11 | using MetaDataProp = StringStringEntryProto; |
12 | using MetaDataProps = google::protobuf::RepeatedPtrField<StringStringEntryProto>; |
13 | |
14 | class ProtoPrinter { |
15 | public: |
16 | ProtoPrinter(std::ostream& os) : output_(os) {} |
17 | |
18 | void print(const TensorShapeProto_Dimension& dim); |
19 | |
20 | void print(const TensorShapeProto& shape); |
21 | |
22 | void print(const TypeProto_Tensor& tensortype); |
23 | |
24 | void print(const TypeProto& type); |
25 | |
26 | void print(const TypeProto_Sequence& seqType); |
27 | |
28 | void print(const TypeProto_Map& mapType); |
29 | |
30 | void print(const TypeProto_Optional& optType); |
31 | |
32 | void print(const TypeProto_SparseTensor& sparseType); |
33 | |
34 | void print(const TensorProto& tensor); |
35 | |
36 | void print(const ValueInfoProto& value_info); |
37 | |
38 | void print(const ValueInfoList& vilist); |
39 | |
40 | void print(const AttributeProto& attr); |
41 | |
42 | void print(const AttrList& attrlist); |
43 | |
44 | void print(const NodeProto& node); |
45 | |
46 | void print(const NodeList& nodelist); |
47 | |
48 | void print(const GraphProto& graph); |
49 | |
50 | void print(const FunctionProto& fn); |
51 | |
52 | void print(const ModelProto& model); |
53 | |
54 | void print(const OperatorSetIdProto& opset); |
55 | |
56 | void print(const OpsetIdList& opsets); |
57 | |
58 | void print(const MetaDataProps& metadataprops) { |
59 | printSet("[" , ", " , "]" , metadataprops); |
60 | } |
61 | |
62 | void print(const MetaDataProp& metadata) { |
63 | printQuoted(metadata.key()); |
64 | output_ << ": " ; |
65 | printQuoted(metadata.value()); |
66 | } |
67 | |
68 | private: |
69 | template <typename T> |
70 | inline void print(T prim) { |
71 | output_ << prim; |
72 | } |
73 | |
74 | void printQuoted(const std::string& str) { |
75 | output_ << "\"" ; |
76 | for (const char* p = str.c_str(); *p; ++p) { |
77 | if ((*p == '\\') || (*p == '"')) |
78 | output_ << '\\'; |
79 | output_ << *p; |
80 | } |
81 | output_ << "\"" ; |
82 | } |
83 | |
84 | template <typename T> |
85 | inline void printKeyValuePair(KeyWordMap::KeyWord key, const T& val, bool addsep = true) { |
86 | if (addsep) |
87 | output_ << "," << std::endl; |
88 | output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": " ; |
89 | print(val); |
90 | } |
91 | |
92 | inline void printKeyValuePair(KeyWordMap::KeyWord key, const std::string& val) { |
93 | output_ << "," << std::endl; |
94 | output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": " ; |
95 | printQuoted(val); |
96 | } |
97 | |
98 | template <typename Collection> |
99 | inline void printSet(const char* open, const char* separator, const char* close, Collection coll) { |
100 | const char* sep = "" ; |
101 | output_ << open; |
102 | for (auto& elt : coll) { |
103 | output_ << sep; |
104 | print(elt); |
105 | sep = separator; |
106 | } |
107 | output_ << close; |
108 | } |
109 | |
110 | std::ostream& output_; |
111 | int indent_level = 3; |
112 | |
113 | void indent() { |
114 | indent_level += 3; |
115 | } |
116 | |
117 | void outdent() { |
118 | indent_level -= 3; |
119 | } |
120 | }; |
121 | |
122 | void ProtoPrinter::print(const TensorShapeProto_Dimension& dim) { |
123 | if (dim.has_dim_value()) |
124 | output_ << dim.dim_value(); |
125 | else if (dim.has_dim_param()) |
126 | output_ << dim.dim_param(); |
127 | else |
128 | output_ << "?" ; |
129 | } |
130 | |
131 | void ProtoPrinter::print(const TensorShapeProto& shape) { |
132 | printSet("[" , "," , "]" , shape.dim()); |
133 | } |
134 | |
135 | void ProtoPrinter::print(const TypeProto_Tensor& tensortype) { |
136 | output_ << PrimitiveTypeNameMap::ToString(tensortype.elem_type()); |
137 | if (tensortype.has_shape()) { |
138 | if (tensortype.shape().dim_size() > 0) |
139 | print(tensortype.shape()); |
140 | } else |
141 | output_ << "[]" ; |
142 | } |
143 | |
144 | void ProtoPrinter::print(const TypeProto_Sequence& seqType) { |
145 | output_ << "seq(" ; |
146 | print(seqType.elem_type()); |
147 | output_ << ")" ; |
148 | } |
149 | |
150 | void ProtoPrinter::print(const TypeProto_Map& mapType) { |
151 | output_ << "map(" << PrimitiveTypeNameMap::ToString(mapType.key_type()) << ", " ; |
152 | print(mapType.value_type()); |
153 | output_ << ")" ; |
154 | } |
155 | |
156 | void ProtoPrinter::print(const TypeProto_Optional& optType) { |
157 | output_ << "optional(" ; |
158 | print(optType.elem_type()); |
159 | output_ << ")" ; |
160 | } |
161 | |
162 | void ProtoPrinter::print(const TypeProto_SparseTensor& sparseType) { |
163 | output_ << "sparse_tensor(" << PrimitiveTypeNameMap::ToString(sparseType.elem_type()); |
164 | if (sparseType.has_shape()) { |
165 | if (sparseType.shape().dim_size() > 0) |
166 | print(sparseType.shape()); |
167 | } else |
168 | output_ << "[]" ; |
169 | |
170 | output_ << ")" ; |
171 | } |
172 | |
173 | void ProtoPrinter::print(const TypeProto& type) { |
174 | if (type.has_tensor_type()) |
175 | print(type.tensor_type()); |
176 | else if (type.has_sequence_type()) |
177 | print(type.sequence_type()); |
178 | else if (type.has_map_type()) |
179 | print(type.map_type()); |
180 | else if (type.has_optional_type()) |
181 | print(type.optional_type()); |
182 | else if (type.has_sparse_tensor_type()) |
183 | print(type.sparse_tensor_type()); |
184 | } |
185 | |
186 | void ProtoPrinter::print(const TensorProto& tensor) { |
187 | output_ << PrimitiveTypeNameMap::ToString(tensor.data_type()); |
188 | if (tensor.dims_size() > 0) |
189 | printSet("[" , "," , "]" , tensor.dims()); |
190 | |
191 | if (!tensor.name().empty()) { |
192 | output_ << " " << tensor.name(); |
193 | } |
194 | // TODO: does not yet handle all types or externally stored data. |
195 | if (tensor.has_raw_data()) { |
196 | switch (static_cast<TensorProto::DataType>(tensor.data_type())) { |
197 | case TensorProto::DataType::TensorProto_DataType_INT32: |
198 | printSet(" {" , "," , "}" , ParseData<int32_t>(&tensor)); |
199 | break; |
200 | case TensorProto::DataType::TensorProto_DataType_INT64: |
201 | printSet(" {" , "," , "}" , ParseData<int64_t>(&tensor)); |
202 | break; |
203 | case TensorProto::DataType::TensorProto_DataType_FLOAT: |
204 | printSet(" {" , "," , "}" , ParseData<float>(&tensor)); |
205 | break; |
206 | case TensorProto::DataType::TensorProto_DataType_DOUBLE: |
207 | printSet(" {" , "," , "}" , ParseData<double>(&tensor)); |
208 | break; |
209 | default: |
210 | output_ << "..." ; // ParseData not instantiated for other types. |
211 | break; |
212 | } |
213 | } else { |
214 | switch (static_cast<TensorProto::DataType>(tensor.data_type())) { |
215 | case TensorProto::DataType::TensorProto_DataType_INT8: |
216 | case TensorProto::DataType::TensorProto_DataType_INT16: |
217 | case TensorProto::DataType::TensorProto_DataType_INT32: |
218 | case TensorProto::DataType::TensorProto_DataType_UINT8: |
219 | case TensorProto::DataType::TensorProto_DataType_UINT16: |
220 | case TensorProto::DataType::TensorProto_DataType_BOOL: |
221 | printSet(" {" , "," , "}" , tensor.int32_data()); |
222 | break; |
223 | case TensorProto::DataType::TensorProto_DataType_INT64: |
224 | printSet(" {" , "," , "}" , tensor.int64_data()); |
225 | break; |
226 | case TensorProto::DataType::TensorProto_DataType_UINT32: |
227 | case TensorProto::DataType::TensorProto_DataType_UINT64: |
228 | printSet(" {" , "," , "}" , tensor.uint64_data()); |
229 | break; |
230 | case TensorProto::DataType::TensorProto_DataType_FLOAT: |
231 | printSet(" {" , "," , "}" , tensor.float_data()); |
232 | break; |
233 | case TensorProto::DataType::TensorProto_DataType_DOUBLE: |
234 | printSet(" {" , "," , "}" , tensor.double_data()); |
235 | break; |
236 | case TensorProto::DataType::TensorProto_DataType_STRING: { |
237 | const char* sep = "{" ; |
238 | for (auto& elt : tensor.string_data()) { |
239 | output_ << sep; |
240 | printQuoted(elt); |
241 | sep = ", " ; |
242 | } |
243 | output_ << "}" ; |
244 | break; |
245 | } |
246 | default: |
247 | break; |
248 | } |
249 | } |
250 | } |
251 | |
252 | void ProtoPrinter::print(const ValueInfoProto& value_info) { |
253 | print(value_info.type()); |
254 | output_ << " " << value_info.name(); |
255 | } |
256 | |
257 | void ProtoPrinter::print(const ValueInfoList& vilist) { |
258 | printSet("(" , ", " , ")" , vilist); |
259 | } |
260 | |
261 | void ProtoPrinter::print(const AttributeProto& attr) { |
262 | // Special case of attr-ref: |
263 | if (attr.has_ref_attr_name()) { |
264 | output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = @" << attr.ref_attr_name(); |
265 | return; |
266 | } |
267 | // General case: |
268 | output_ << attr.name() << " = " ; |
269 | switch (attr.type()) { |
270 | case AttributeProto_AttributeType_INT: |
271 | output_ << attr.i(); |
272 | break; |
273 | case AttributeProto_AttributeType_INTS: |
274 | printSet("[" , ", " , "]" , attr.ints()); |
275 | break; |
276 | case AttributeProto_AttributeType_FLOAT: |
277 | output_ << attr.f(); |
278 | break; |
279 | case AttributeProto_AttributeType_FLOATS: |
280 | printSet("[" , ", " , "]" , attr.floats()); |
281 | break; |
282 | case AttributeProto_AttributeType_STRING: |
283 | output_ << "\"" << attr.s() << "\"" ; |
284 | break; |
285 | case AttributeProto_AttributeType_STRINGS: { |
286 | const char* sep = "[" ; |
287 | for (auto& elt : attr.strings()) { |
288 | output_ << sep << "\"" << elt << "\"" ; |
289 | sep = ", " ; |
290 | } |
291 | output_ << "]" ; |
292 | break; |
293 | } |
294 | case AttributeProto_AttributeType_GRAPH: |
295 | indent(); |
296 | print(attr.g()); |
297 | outdent(); |
298 | break; |
299 | case AttributeProto_AttributeType_GRAPHS: |
300 | indent(); |
301 | printSet("[" , ", " , "]" , attr.graphs()); |
302 | outdent(); |
303 | break; |
304 | case AttributeProto_AttributeType_TENSOR: |
305 | print(attr.t()); |
306 | break; |
307 | case AttributeProto_AttributeType_TENSORS: |
308 | printSet("[" , ", " , "]" , attr.tensors()); |
309 | break; |
310 | default: |
311 | break; |
312 | } |
313 | } |
314 | |
315 | void ProtoPrinter::print(const AttrList& attrlist) { |
316 | printSet(" <" , ", " , ">" , attrlist); |
317 | } |
318 | |
319 | void ProtoPrinter::print(const NodeProto& node) { |
320 | output_ << std::setw(indent_level) << ' '; |
321 | printSet("" , ", " , "" , node.output()); |
322 | output_ << " = " ; |
323 | if (node.domain() != "" ) |
324 | output_ << node.domain() << "." ; |
325 | output_ << node.op_type(); |
326 | bool has_subgraph = false; |
327 | for (auto attr : node.attribute()) |
328 | if (attr.has_g() || (attr.graphs_size() > 0)) |
329 | has_subgraph = true; |
330 | if ((!has_subgraph) && (node.attribute_size() > 0)) |
331 | print(node.attribute()); |
332 | printSet(" (" , ", " , ")" , node.input()); |
333 | if ((has_subgraph) && (node.attribute_size() > 0)) |
334 | print(node.attribute()); |
335 | output_ << "\n" ; |
336 | } |
337 | |
338 | void ProtoPrinter::print(const NodeList& nodelist) { |
339 | output_ << "{\n" ; |
340 | for (auto& node : nodelist) { |
341 | print(node); |
342 | } |
343 | if (indent_level > 3) |
344 | output_ << std::setw(indent_level - 3) << " " ; |
345 | output_ << "}" ; |
346 | } |
347 | |
348 | void ProtoPrinter::print(const GraphProto& graph) { |
349 | output_ << graph.name() << " " << graph.input() << " => " << graph.output() << " " ; |
350 | print(graph.node()); |
351 | } |
352 | |
353 | void ProtoPrinter::print(const ModelProto& model) { |
354 | output_ << "<\n" ; |
355 | printKeyValuePair(KeyWordMap::KeyWord::IR_VERSION, model.ir_version(), false); |
356 | printKeyValuePair(KeyWordMap::KeyWord::OPSET_IMPORT, model.opset_import()); |
357 | if (model.has_producer_name()) |
358 | printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_NAME, model.producer_name()); |
359 | if (model.has_producer_version()) |
360 | printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_VERSION, model.producer_version()); |
361 | if (model.has_domain()) |
362 | printKeyValuePair(KeyWordMap::KeyWord::DOMAIN_KW, model.domain()); |
363 | if (model.has_model_version()) |
364 | printKeyValuePair(KeyWordMap::KeyWord::MODEL_VERSION, model.model_version()); |
365 | if (model.has_doc_string()) |
366 | printKeyValuePair(KeyWordMap::KeyWord::DOC_STRING, model.doc_string()); |
367 | if (model.metadata_props_size() > 0) |
368 | printKeyValuePair(KeyWordMap::KeyWord::METADATA_PROPS, model.metadata_props()); |
369 | output_ << std::endl << ">" << std::endl; |
370 | |
371 | print(model.graph()); |
372 | for (const auto& fn : model.functions()) { |
373 | output_ << std::endl; |
374 | print(fn); |
375 | } |
376 | } |
377 | |
378 | void ProtoPrinter::print(const OperatorSetIdProto& opset) { |
379 | output_ << "\"" << opset.domain() << "\" : " << opset.version(); |
380 | } |
381 | |
382 | void ProtoPrinter::print(const OpsetIdList& opsets) { |
383 | printSet("[" , ", " , "]" , opsets); |
384 | } |
385 | |
386 | void ProtoPrinter::print(const FunctionProto& fn) { |
387 | output_ << "<\n" ; |
388 | output_ << " " |
389 | << "domain: \"" << fn.domain() << "\",\n" ; |
390 | output_ << " " |
391 | << "opset_import: " ; |
392 | printSet("[" , "," , "]" , fn.opset_import()); |
393 | output_ << "\n>\n" ; |
394 | output_ << fn.name() << " " ; |
395 | if (fn.attribute_size() > 0) |
396 | printSet("<" , "," , ">" , fn.attribute()); |
397 | printSet("(" , ", " , ")" , fn.input()); |
398 | output_ << " => " ; |
399 | printSet("(" , ", " , ")" , fn.output()); |
400 | output_ << "\n" ; |
401 | print(fn.node()); |
402 | } |
403 | |
404 | #define DEF_OP(T) \ |
405 | std::ostream& operator<<(std::ostream& os, const T& proto) { \ |
406 | ProtoPrinter printer(os); \ |
407 | printer.print(proto); \ |
408 | return os; \ |
409 | }; |
410 | |
411 | DEF_OP(TensorShapeProto_Dimension) |
412 | |
413 | DEF_OP(TensorShapeProto) |
414 | |
415 | DEF_OP(TypeProto_Tensor) |
416 | |
417 | DEF_OP(TypeProto) |
418 | |
419 | DEF_OP(TensorProto) |
420 | |
421 | DEF_OP(ValueInfoProto) |
422 | |
423 | DEF_OP(ValueInfoList) |
424 | |
425 | DEF_OP(AttributeProto) |
426 | |
427 | DEF_OP(AttrList) |
428 | |
429 | DEF_OP(NodeProto) |
430 | |
431 | DEF_OP(NodeList) |
432 | |
433 | DEF_OP(GraphProto) |
434 | |
435 | DEF_OP(FunctionProto) |
436 | |
437 | DEF_OP(ModelProto) |
438 | |
439 | } // namespace ONNX_NAMESPACE |