1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file node/serialization.cc
22 * \brief Utilities to serialize TVM AST/IR objects.
23 */
24#include <dmlc/json.h>
25#include <dmlc/memory_io.h>
26#include <tvm/ir/attrs.h>
27#include <tvm/node/reflection.h>
28#include <tvm/node/serialization.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/runtime/ndarray.h>
32#include <tvm/runtime/packed_func.h>
33#include <tvm/runtime/registry.h>
34
35#include <cctype>
36#include <map>
37#include <string>
38
39#include "../runtime/object_internal.h"
40#include "../support/base64.h"
41
42namespace tvm {
43
44inline std::string Type2String(const DataType& t) { return runtime::DLDataType2String(t); }
45
46inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); }
47
48inline std::string Base64Decode(std::string s) {
49 dmlc::MemoryStringStream mstrm(&s);
50 support::Base64InStream b64strm(&mstrm);
51 std::string output;
52 b64strm.InitPosition();
53 dmlc::Stream* strm = &b64strm;
54 strm->Read(&output);
55 return output;
56}
57
58inline std::string Base64Encode(std::string s) {
59 std::string blob;
60 dmlc::MemoryStringStream mstrm(&blob);
61 support::Base64OutStream b64strm(&mstrm);
62 dmlc::Stream* strm = &b64strm;
63 strm->Write(s);
64 b64strm.Finish();
65 return blob;
66}
67
68// indexer to index all the nodes
69class NodeIndexer : public AttrVisitor {
70 public:
71 std::unordered_map<Object*, size_t> node_index_{{nullptr, 0}};
72 std::vector<Object*> node_list_{nullptr};
73 std::unordered_map<DLTensor*, size_t> tensor_index_;
74 std::vector<DLTensor*> tensor_list_;
75 ReflectionVTable* reflection_ = ReflectionVTable::Global();
76
77 void Visit(const char* key, double* value) final {}
78 void Visit(const char* key, int64_t* value) final {}
79 void Visit(const char* key, uint64_t* value) final {}
80 void Visit(const char* key, int* value) final {}
81 void Visit(const char* key, bool* value) final {}
82 void Visit(const char* key, std::string* value) final {}
83 void Visit(const char* key, void** value) final {}
84 void Visit(const char* key, DataType* value) final {}
85
86 void Visit(const char* key, runtime::NDArray* value) final {
87 DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
88 if (tensor_index_.count(ptr)) return;
89 ICHECK_EQ(tensor_index_.size(), tensor_list_.size());
90 tensor_index_[ptr] = tensor_list_.size();
91 tensor_list_.push_back(ptr);
92 }
93
94 void Visit(const char* key, ObjectRef* value) final {
95 MakeIndex(const_cast<Object*>(value->get()));
96 }
97
98 void MakeNodeIndex(Object* node) {
99 if (node == nullptr) return;
100 ICHECK(node->IsInstance<Object>());
101
102 if (node_index_.count(node)) {
103 return;
104 }
105 ICHECK_EQ(node_index_.size(), node_list_.size());
106 node_index_[node] = node_list_.size();
107 node_list_.push_back(node);
108 }
109
110 // make index of all the children of node
111 void MakeIndex(Object* node) {
112 if (node == nullptr) return;
113 ICHECK(node->IsInstance<Object>());
114
115 if (node_index_.count(node)) {
116 return;
117 }
118 MakeNodeIndex(node);
119 if (node->IsInstance<ArrayNode>()) {
120 ArrayNode* n = static_cast<ArrayNode*>(node);
121 for (const auto& sp : *n) {
122 MakeIndex(const_cast<Object*>(sp.get()));
123 }
124 } else if (node->IsInstance<MapNode>()) {
125 MapNode* n = static_cast<MapNode*>(node);
126 bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) {
127 return v.first->template IsInstance<StringObj>();
128 });
129 if (is_str_map) {
130 for (const auto& kv : *n) {
131 MakeIndex(const_cast<Object*>(kv.second.get()));
132 }
133 } else {
134 for (const auto& kv : *n) {
135 MakeIndex(const_cast<Object*>(kv.first.get()));
136 MakeIndex(const_cast<Object*>(kv.second.get()));
137 }
138 }
139 } else if (node->IsInstance<relay::LetNode>()) {
140 auto pre_visit = [this](const relay::LetNode* op) {
141 MakeNodeIndex(const_cast<Object*>(static_cast<const Object*>(op)));
142 MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->var.get())));
143 MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->value.get())));
144 MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->span.get())));
145 MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->checked_type_.get())));
146 if (!op->body.as<relay::LetNode>()) {
147 MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->body.get())));
148 }
149 };
150 auto post_visit = [](const relay::LetNode* op) {};
151 if (!reflection_->GetReprBytes(node, nullptr)) {
152 relay::ExpandANormalForm(static_cast<relay::LetNode*>(node), pre_visit, post_visit);
153 }
154 } else {
155 // if the node already have repr bytes, no need to visit Attrs.
156 if (!reflection_->GetReprBytes(node, nullptr)) {
157 reflection_->VisitAttrs(node, this);
158 }
159 }
160 }
161};
162
163// use map so attributes are ordered.
164using AttrMap = std::map<std::string, std::string>;
165
166/*! \brief Node structure for json format. */
167struct JSONNode {
168 /*! \brief The type of key of the object. */
169 std::string type_key;
170 /*! \brief The str repr representation. */
171 std::string repr_bytes;
172 /*! \brief the attributes */
173 AttrMap attrs;
174 /*! \brief keys of a map. */
175 std::vector<std::string> keys;
176 /*! \brief values of a map or array. */
177 std::vector<size_t> data;
178 /*!
179 * \brief field member dependency.
180 * NOTE: This is an auxiliary data structure for loading, and it won't be serialized to json.
181 */
182 std::vector<size_t> fields;
183
184 void Save(dmlc::JSONWriter* writer) const {
185 writer->BeginObject();
186 writer->WriteObjectKeyValue("type_key", type_key);
187 if (repr_bytes.size() != 0) {
188 // choose to use str representation or base64, based on whether
189 // the byte representation is printable.
190 if (std::all_of(repr_bytes.begin(), repr_bytes.end(),
191 [](char ch) { return std::isprint(ch); })) {
192 writer->WriteObjectKeyValue("repr_str", repr_bytes);
193 } else {
194 writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes));
195 }
196 }
197 if (attrs.size() != 0) {
198 writer->WriteObjectKeyValue("attrs", attrs);
199 }
200 if (keys.size() != 0) {
201 writer->WriteObjectKeyValue("keys", keys);
202 }
203 if (data.size() != 0) {
204 writer->WriteObjectKeyValue("data", data);
205 }
206 writer->EndObject();
207 }
208
209 void Load(dmlc::JSONReader* reader) {
210 attrs.clear();
211 data.clear();
212 repr_bytes.clear();
213 type_key.clear();
214 std::string repr_b64, repr_str;
215 dmlc::JSONObjectReadHelper helper;
216 helper.DeclareOptionalField("type_key", &type_key);
217 helper.DeclareOptionalField("repr_b64", &repr_b64);
218 helper.DeclareOptionalField("repr_str", &repr_str);
219 helper.DeclareOptionalField("attrs", &attrs);
220 helper.DeclareOptionalField("keys", &keys);
221 helper.DeclareOptionalField("data", &data);
222 helper.ReadAllFields(reader);
223
224 if (repr_str.size() != 0) {
225 ICHECK_EQ(repr_b64.size(), 0U);
226 repr_bytes = std::move(repr_str);
227 } else if (repr_b64.size() != 0) {
228 repr_bytes = Base64Decode(repr_b64);
229 }
230 }
231};
232
233// Helper class to populate the json node
234// using the existing index.
235class JSONAttrGetter : public AttrVisitor {
236 public:
237 const std::unordered_map<Object*, size_t>* node_index_;
238 const std::unordered_map<DLTensor*, size_t>* tensor_index_;
239 JSONNode* node_;
240 ReflectionVTable* reflection_ = ReflectionVTable::Global();
241
242 void Visit(const char* key, double* value) final {
243 std::ostringstream s;
244 // Save 17 decimal digits for type <double> to avoid precision loss during loading JSON
245 s.precision(17);
246 s << (*value);
247 node_->attrs[key] = s.str();
248 }
249 void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); }
250 void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); }
251 void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); }
252 void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); }
253 void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; }
254 void Visit(const char* key, void** value) final {
255 LOG(FATAL) << "not allowed to serialize a pointer";
256 }
257 void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); }
258
259 void Visit(const char* key, runtime::NDArray* value) final {
260 node_->attrs[key] =
261 std::to_string(tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
262 }
263
264 void Visit(const char* key, ObjectRef* value) final {
265 node_->attrs[key] = std::to_string(node_index_->at(const_cast<Object*>(value->get())));
266 }
267
268 // Get the node
269 void Get(Object* node) {
270 if (node == nullptr) {
271 node_->type_key.clear();
272 return;
273 }
274 node_->type_key = node->GetTypeKey();
275 // do not need to print additional things once we have repr bytes.
276 if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return;
277
278 // populates the fields.
279 node_->attrs.clear();
280 node_->data.clear();
281
282 if (node->IsInstance<ArrayNode>()) {
283 ArrayNode* n = static_cast<ArrayNode*>(node);
284 for (size_t i = 0; i < n->size(); ++i) {
285 node_->data.push_back(node_index_->at(const_cast<Object*>(n->at(i).get())));
286 }
287 } else if (node->IsInstance<MapNode>()) {
288 MapNode* n = static_cast<MapNode*>(node);
289 bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) {
290 return v.first->template IsInstance<StringObj>();
291 });
292 if (is_str_map) {
293 for (const auto& kv : *n) {
294 node_->keys.push_back(Downcast<String>(kv.first));
295 node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
296 }
297 } else {
298 for (const auto& kv : *n) {
299 node_->data.push_back(node_index_->at(const_cast<Object*>(kv.first.get())));
300 node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
301 }
302 }
303 } else {
304 // recursively index normal object.
305 reflection_->VisitAttrs(node, this);
306 }
307 }
308};
309
310class FieldDependencyFinder : public AttrVisitor {
311 public:
312 JSONNode* jnode_;
313 ReflectionVTable* reflection_ = ReflectionVTable::Global();
314
315 std::string GetValue(const char* key) const {
316 auto it = jnode_->attrs.find(key);
317 if (it == jnode_->attrs.end()) {
318 LOG(FATAL) << "JSONReader: cannot find field " << key;
319 }
320 return it->second;
321 }
322 template <typename T>
323 void ParseValue(const char* key, T* value) const {
324 std::istringstream is(GetValue(key));
325 is >> *value;
326 if (is.fail()) {
327 LOG(FATAL) << "Wrong value format for field " << key;
328 }
329 }
330 void Visit(const char* key, double* value) final {}
331 void Visit(const char* key, int64_t* value) final {}
332 void Visit(const char* key, uint64_t* value) final {}
333 void Visit(const char* key, int* value) final {}
334 void Visit(const char* key, bool* value) final {}
335 void Visit(const char* key, std::string* value) final {}
336 void Visit(const char* key, void** value) final {}
337 void Visit(const char* key, DataType* value) final {}
338 void Visit(const char* key, runtime::NDArray* value) final {}
339 void Visit(const char* key, ObjectRef* value) final {
340 size_t index;
341 ParseValue(key, &index);
342 jnode_->fields.push_back(index);
343 }
344 void Find(Object* node, JSONNode* jnode) {
345 // Skip None
346 if (node == nullptr) {
347 return;
348 }
349 // Skip the objects that have their own string repr
350 if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node, nullptr)) {
351 return;
352 }
353 // Skip containers
354 if (jnode->type_key == ArrayNode::_type_key || jnode->type_key == MapNode::_type_key) {
355 return;
356 }
357 jnode_ = jnode;
358 reflection_->VisitAttrs(node, this);
359 }
360};
361
362// Helper class to set the attributes of a node
363// from given json node.
364class JSONAttrSetter : public AttrVisitor {
365 public:
366 const std::vector<ObjectPtr<Object>>* node_list_;
367 const std::vector<runtime::NDArray>* tensor_list_;
368 JSONNode* jnode_;
369
370 ReflectionVTable* reflection_ = ReflectionVTable::Global();
371
372 std::string GetValue(const char* key) const {
373 auto it = jnode_->attrs.find(key);
374 if (it == jnode_->attrs.end()) {
375 LOG(FATAL) << "JSONReader: cannot find field " << key;
376 }
377 return it->second;
378 }
379
380 void ParseDouble(const char* key, double* value) const {
381 std::istringstream is(GetValue(key));
382 if (is.str() == "inf") {
383 *value = std::numeric_limits<double>::infinity();
384 } else if (is.str() == "-inf") {
385 *value = -std::numeric_limits<double>::infinity();
386 } else {
387 is >> *value;
388 if (is.fail()) {
389 LOG(FATAL) << "Wrong value format for field " << key;
390 }
391 }
392 }
393
394 template <typename T>
395 void ParseValue(const char* key, T* value) const {
396 std::istringstream is(GetValue(key));
397 is >> *value;
398 if (is.fail()) {
399 LOG(FATAL) << "Wrong value format for field " << key;
400 }
401 }
402 void Visit(const char* key, double* value) final { ParseDouble(key, value); }
403 void Visit(const char* key, int64_t* value) final { ParseValue(key, value); }
404 void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); }
405 void Visit(const char* key, int* value) final { ParseValue(key, value); }
406 void Visit(const char* key, bool* value) final { ParseValue(key, value); }
407 void Visit(const char* key, std::string* value) final { *value = GetValue(key); }
408 void Visit(const char* key, void** value) final {
409 LOG(FATAL) << "not allowed to deserialize a pointer";
410 }
411 void Visit(const char* key, DataType* value) final {
412 std::string stype = GetValue(key);
413 *value = String2Type(stype);
414 }
415 void Visit(const char* key, runtime::NDArray* value) final {
416 size_t index;
417 ParseValue(key, &index);
418 ICHECK_LE(index, tensor_list_->size());
419 *value = tensor_list_->at(index);
420 }
421 void Visit(const char* key, ObjectRef* value) final {
422 size_t index;
423 ParseValue(key, &index);
424 ICHECK_LE(index, node_list_->size());
425 *value = ObjectRef(node_list_->at(index));
426 }
427 // set node to be current JSONNode
428 void Set(ObjectPtr<Object>* node, JSONNode* jnode) {
429 // Skip None
430 if (node->get() == nullptr) {
431 return;
432 }
433 // Skip the objects that have their own string repr
434 if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node->get(), nullptr)) {
435 return;
436 }
437 // handling Array
438 if (jnode->type_key == ArrayNode::_type_key) {
439 std::vector<ObjectRef> container;
440 for (auto index : jnode->data) {
441 container.push_back(ObjectRef(node_list_->at(index)));
442 }
443 Array<ObjectRef> array(container);
444 *node = runtime::ObjectInternal::MoveObjectPtr(&array);
445 return;
446 }
447 // handling Map
448 if (jnode->type_key == MapNode::_type_key) {
449 std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> container;
450 if (jnode->keys.empty()) {
451 ICHECK_EQ(jnode->data.size() % 2, 0U);
452 for (size_t i = 0; i < jnode->data.size(); i += 2) {
453 container[ObjectRef(node_list_->at(jnode->data[i]))] =
454 ObjectRef(node_list_->at(jnode->data[i + 1]));
455 }
456 } else {
457 ICHECK_EQ(jnode->data.size(), jnode->keys.size());
458 for (size_t i = 0; i < jnode->data.size(); ++i) {
459 container[String(jnode->keys[i])] = ObjectRef(node_list_->at(jnode->data[i]));
460 }
461 }
462 Map<ObjectRef, ObjectRef> map(container);
463 *node = runtime::ObjectInternal::MoveObjectPtr(&map);
464 return;
465 }
466 jnode_ = jnode;
467 reflection_->VisitAttrs(node->get(), this);
468 }
469};
470
471// json graph structure to store node
472struct JSONGraph {
473 // the root of the graph
474 size_t root;
475 // the nodes of the graph
476 std::vector<JSONNode> nodes;
477 // base64 b64ndarrays of arrays
478 std::vector<std::string> b64ndarrays;
479 // global attributes
480 AttrMap attrs;
481
482 void Save(dmlc::JSONWriter* writer) const {
483 writer->BeginObject();
484 writer->WriteObjectKeyValue("root", root);
485 writer->WriteObjectKeyValue("nodes", nodes);
486 writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays);
487 if (attrs.size() != 0) {
488 writer->WriteObjectKeyValue("attrs", attrs);
489 }
490 writer->EndObject();
491 }
492
493 void Load(dmlc::JSONReader* reader) {
494 attrs.clear();
495 dmlc::JSONObjectReadHelper helper;
496 helper.DeclareField("root", &root);
497 helper.DeclareField("nodes", &nodes);
498 helper.DeclareOptionalField("b64ndarrays", &b64ndarrays);
499 helper.DeclareOptionalField("attrs", &attrs);
500 helper.ReadAllFields(reader);
501 }
502
503 static JSONGraph Create(const ObjectRef& root) {
504 JSONGraph g;
505 NodeIndexer indexer;
506 indexer.MakeIndex(const_cast<Object*>(root.get()));
507 JSONAttrGetter getter;
508 getter.node_index_ = &indexer.node_index_;
509 getter.tensor_index_ = &indexer.tensor_index_;
510 for (Object* n : indexer.node_list_) {
511 JSONNode jnode;
512 getter.node_ = &jnode;
513 getter.Get(n);
514 g.nodes.emplace_back(std::move(jnode));
515 }
516 g.attrs["tvm_version"] = TVM_VERSION;
517 g.root = indexer.node_index_.at(const_cast<Object*>(root.get()));
518 // serialize tensor
519 for (DLTensor* tensor : indexer.tensor_list_) {
520 std::string blob;
521 dmlc::MemoryStringStream mstrm(&blob);
522 support::Base64OutStream b64strm(&mstrm);
523 runtime::SaveDLTensor(&b64strm, tensor);
524 b64strm.Finish();
525 g.b64ndarrays.emplace_back(std::move(blob));
526 }
527 return g;
528 }
529
530 std::vector<size_t> TopoSort() const {
531 size_t n_nodes = nodes.size();
532 std::vector<size_t> topo_order;
533 std::vector<size_t> in_degree(n_nodes, 0);
534 for (const JSONNode& jnode : nodes) {
535 for (size_t i : jnode.data) {
536 ++in_degree[i];
537 }
538 for (size_t i : jnode.fields) {
539 ++in_degree[i];
540 }
541 }
542 for (size_t i = 0; i < n_nodes; ++i) {
543 if (in_degree[i] == 0) {
544 topo_order.push_back(i);
545 }
546 }
547 for (size_t p = 0; p < topo_order.size(); ++p) {
548 const JSONNode& jnode = nodes[topo_order[p]];
549 for (size_t i : jnode.data) {
550 if (--in_degree[i] == 0) {
551 topo_order.push_back(i);
552 }
553 }
554 for (size_t i : jnode.fields) {
555 if (--in_degree[i] == 0) {
556 topo_order.push_back(i);
557 }
558 }
559 }
560 ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file";
561 std::reverse(std::begin(topo_order), std::end(topo_order));
562 return topo_order;
563 }
564};
565
566std::string SaveJSON(const ObjectRef& n) {
567 auto jgraph = JSONGraph::Create(n);
568 std::ostringstream os;
569 dmlc::JSONWriter writer(&os);
570 jgraph.Save(&writer);
571 return os.str();
572}
573
574ObjectRef LoadJSON(std::string json_str) {
575 ReflectionVTable* reflection = ReflectionVTable::Global();
576 JSONGraph jgraph;
577 {
578 // load in json graph.
579 std::istringstream is(json_str);
580 dmlc::JSONReader reader(&is);
581 jgraph.Load(&reader);
582 }
583 size_t n_nodes = jgraph.nodes.size();
584 std::vector<runtime::NDArray> tensors;
585 {
586 // load in tensors
587 for (const std::string& blob : jgraph.b64ndarrays) {
588 dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
589 support::Base64InStream b64strm(&mstrm);
590 b64strm.InitPosition();
591 runtime::NDArray temp;
592 ICHECK(temp.Load(&b64strm));
593 tensors.emplace_back(std::move(temp));
594 }
595 }
596 // Pass 1: create all non-container objects
597 std::vector<ObjectPtr<Object>> nodes(n_nodes, nullptr);
598 for (size_t i = 0; i < n_nodes; ++i) {
599 const JSONNode& jnode = jgraph.nodes[i];
600 if (jnode.type_key.length() != 0) {
601 nodes[i] = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
602 }
603 }
604 // Pass 2: figure out all field dependency
605 {
606 FieldDependencyFinder dep_finder;
607 for (size_t i = 0; i < n_nodes; ++i) {
608 dep_finder.Find(nodes[i].get(), &jgraph.nodes[i]);
609 }
610 }
611 // Pass 3: topo sort
612 std::vector<size_t> topo_order = jgraph.TopoSort();
613 // Pass 4: set all values
614 {
615 JSONAttrSetter setter;
616 setter.node_list_ = &nodes;
617 setter.tensor_list_ = &tensors;
618 for (size_t i : topo_order) {
619 setter.Set(&nodes[i], &jgraph.nodes[i]);
620 }
621 }
622 return ObjectRef(nodes.at(jgraph.root));
623}
624
625TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON);
626
627TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON);
628} // namespace tvm
629