1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file node/serialization.cc |
22 | * \brief Utilities to serialize TVM AST/IR objects. |
23 | */ |
24 | #include <dmlc/json.h> |
25 | #include <dmlc/memory_io.h> |
26 | #include <tvm/ir/attrs.h> |
27 | #include <tvm/node/reflection.h> |
28 | #include <tvm/node/serialization.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/runtime/ndarray.h> |
32 | #include <tvm/runtime/packed_func.h> |
33 | #include <tvm/runtime/registry.h> |
34 | |
35 | #include <cctype> |
36 | #include <map> |
37 | #include <string> |
38 | |
39 | #include "../runtime/object_internal.h" |
40 | #include "../support/base64.h" |
41 | |
42 | namespace tvm { |
43 | |
44 | inline std::string Type2String(const DataType& t) { return runtime::DLDataType2String(t); } |
45 | |
46 | inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); } |
47 | |
48 | inline std::string Base64Decode(std::string s) { |
49 | dmlc::MemoryStringStream mstrm(&s); |
50 | support::Base64InStream b64strm(&mstrm); |
51 | std::string output; |
52 | b64strm.InitPosition(); |
53 | dmlc::Stream* strm = &b64strm; |
54 | strm->Read(&output); |
55 | return output; |
56 | } |
57 | |
58 | inline std::string Base64Encode(std::string s) { |
59 | std::string blob; |
60 | dmlc::MemoryStringStream mstrm(&blob); |
61 | support::Base64OutStream b64strm(&mstrm); |
62 | dmlc::Stream* strm = &b64strm; |
63 | strm->Write(s); |
64 | b64strm.Finish(); |
65 | return blob; |
66 | } |
67 | |
68 | // indexer to index all the nodes |
69 | class NodeIndexer : public AttrVisitor { |
70 | public: |
71 | std::unordered_map<Object*, size_t> node_index_{{nullptr, 0}}; |
72 | std::vector<Object*> node_list_{nullptr}; |
73 | std::unordered_map<DLTensor*, size_t> tensor_index_; |
74 | std::vector<DLTensor*> tensor_list_; |
75 | ReflectionVTable* reflection_ = ReflectionVTable::Global(); |
76 | |
77 | void Visit(const char* key, double* value) final {} |
78 | void Visit(const char* key, int64_t* value) final {} |
79 | void Visit(const char* key, uint64_t* value) final {} |
80 | void Visit(const char* key, int* value) final {} |
81 | void Visit(const char* key, bool* value) final {} |
82 | void Visit(const char* key, std::string* value) final {} |
83 | void Visit(const char* key, void** value) final {} |
84 | void Visit(const char* key, DataType* value) final {} |
85 | |
86 | void Visit(const char* key, runtime::NDArray* value) final { |
87 | DLTensor* ptr = const_cast<DLTensor*>((*value).operator->()); |
88 | if (tensor_index_.count(ptr)) return; |
89 | ICHECK_EQ(tensor_index_.size(), tensor_list_.size()); |
90 | tensor_index_[ptr] = tensor_list_.size(); |
91 | tensor_list_.push_back(ptr); |
92 | } |
93 | |
94 | void Visit(const char* key, ObjectRef* value) final { |
95 | MakeIndex(const_cast<Object*>(value->get())); |
96 | } |
97 | |
98 | void MakeNodeIndex(Object* node) { |
99 | if (node == nullptr) return; |
100 | ICHECK(node->IsInstance<Object>()); |
101 | |
102 | if (node_index_.count(node)) { |
103 | return; |
104 | } |
105 | ICHECK_EQ(node_index_.size(), node_list_.size()); |
106 | node_index_[node] = node_list_.size(); |
107 | node_list_.push_back(node); |
108 | } |
109 | |
110 | // make index of all the children of node |
111 | void MakeIndex(Object* node) { |
112 | if (node == nullptr) return; |
113 | ICHECK(node->IsInstance<Object>()); |
114 | |
115 | if (node_index_.count(node)) { |
116 | return; |
117 | } |
118 | MakeNodeIndex(node); |
119 | if (node->IsInstance<ArrayNode>()) { |
120 | ArrayNode* n = static_cast<ArrayNode*>(node); |
121 | for (const auto& sp : *n) { |
122 | MakeIndex(const_cast<Object*>(sp.get())); |
123 | } |
124 | } else if (node->IsInstance<MapNode>()) { |
125 | MapNode* n = static_cast<MapNode*>(node); |
126 | bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { |
127 | return v.first->template IsInstance<StringObj>(); |
128 | }); |
129 | if (is_str_map) { |
130 | for (const auto& kv : *n) { |
131 | MakeIndex(const_cast<Object*>(kv.second.get())); |
132 | } |
133 | } else { |
134 | for (const auto& kv : *n) { |
135 | MakeIndex(const_cast<Object*>(kv.first.get())); |
136 | MakeIndex(const_cast<Object*>(kv.second.get())); |
137 | } |
138 | } |
139 | } else if (node->IsInstance<relay::LetNode>()) { |
140 | auto pre_visit = [this](const relay::LetNode* op) { |
141 | MakeNodeIndex(const_cast<Object*>(static_cast<const Object*>(op))); |
142 | MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->var.get()))); |
143 | MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->value.get()))); |
144 | MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->span.get()))); |
145 | MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->checked_type_.get()))); |
146 | if (!op->body.as<relay::LetNode>()) { |
147 | MakeIndex(const_cast<Object*>(static_cast<const Object*>(op->body.get()))); |
148 | } |
149 | }; |
150 | auto post_visit = [](const relay::LetNode* op) {}; |
151 | if (!reflection_->GetReprBytes(node, nullptr)) { |
152 | relay::ExpandANormalForm(static_cast<relay::LetNode*>(node), pre_visit, post_visit); |
153 | } |
154 | } else { |
155 | // if the node already have repr bytes, no need to visit Attrs. |
156 | if (!reflection_->GetReprBytes(node, nullptr)) { |
157 | reflection_->VisitAttrs(node, this); |
158 | } |
159 | } |
160 | } |
161 | }; |
162 | |
163 | // use map so attributes are ordered. |
164 | using AttrMap = std::map<std::string, std::string>; |
165 | |
166 | /*! \brief Node structure for json format. */ |
167 | struct JSONNode { |
168 | /*! \brief The type of key of the object. */ |
169 | std::string type_key; |
170 | /*! \brief The str repr representation. */ |
171 | std::string repr_bytes; |
172 | /*! \brief the attributes */ |
173 | AttrMap attrs; |
174 | /*! \brief keys of a map. */ |
175 | std::vector<std::string> keys; |
176 | /*! \brief values of a map or array. */ |
177 | std::vector<size_t> data; |
178 | /*! |
179 | * \brief field member dependency. |
180 | * NOTE: This is an auxiliary data structure for loading, and it won't be serialized to json. |
181 | */ |
182 | std::vector<size_t> fields; |
183 | |
184 | void Save(dmlc::JSONWriter* writer) const { |
185 | writer->BeginObject(); |
186 | writer->WriteObjectKeyValue("type_key" , type_key); |
187 | if (repr_bytes.size() != 0) { |
188 | // choose to use str representation or base64, based on whether |
189 | // the byte representation is printable. |
190 | if (std::all_of(repr_bytes.begin(), repr_bytes.end(), |
191 | [](char ch) { return std::isprint(ch); })) { |
192 | writer->WriteObjectKeyValue("repr_str" , repr_bytes); |
193 | } else { |
194 | writer->WriteObjectKeyValue("repr_b64" , Base64Encode(repr_bytes)); |
195 | } |
196 | } |
197 | if (attrs.size() != 0) { |
198 | writer->WriteObjectKeyValue("attrs" , attrs); |
199 | } |
200 | if (keys.size() != 0) { |
201 | writer->WriteObjectKeyValue("keys" , keys); |
202 | } |
203 | if (data.size() != 0) { |
204 | writer->WriteObjectKeyValue("data" , data); |
205 | } |
206 | writer->EndObject(); |
207 | } |
208 | |
209 | void Load(dmlc::JSONReader* reader) { |
210 | attrs.clear(); |
211 | data.clear(); |
212 | repr_bytes.clear(); |
213 | type_key.clear(); |
214 | std::string repr_b64, repr_str; |
215 | dmlc::JSONObjectReadHelper helper; |
216 | helper.DeclareOptionalField("type_key" , &type_key); |
217 | helper.DeclareOptionalField("repr_b64" , &repr_b64); |
218 | helper.DeclareOptionalField("repr_str" , &repr_str); |
219 | helper.DeclareOptionalField("attrs" , &attrs); |
220 | helper.DeclareOptionalField("keys" , &keys); |
221 | helper.DeclareOptionalField("data" , &data); |
222 | helper.ReadAllFields(reader); |
223 | |
224 | if (repr_str.size() != 0) { |
225 | ICHECK_EQ(repr_b64.size(), 0U); |
226 | repr_bytes = std::move(repr_str); |
227 | } else if (repr_b64.size() != 0) { |
228 | repr_bytes = Base64Decode(repr_b64); |
229 | } |
230 | } |
231 | }; |
232 | |
233 | // Helper class to populate the json node |
234 | // using the existing index. |
235 | class JSONAttrGetter : public AttrVisitor { |
236 | public: |
237 | const std::unordered_map<Object*, size_t>* node_index_; |
238 | const std::unordered_map<DLTensor*, size_t>* tensor_index_; |
239 | JSONNode* node_; |
240 | ReflectionVTable* reflection_ = ReflectionVTable::Global(); |
241 | |
242 | void Visit(const char* key, double* value) final { |
243 | std::ostringstream s; |
244 | // Save 17 decimal digits for type <double> to avoid precision loss during loading JSON |
245 | s.precision(17); |
246 | s << (*value); |
247 | node_->attrs[key] = s.str(); |
248 | } |
249 | void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); } |
250 | void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); } |
251 | void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); } |
252 | void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); } |
253 | void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; } |
254 | void Visit(const char* key, void** value) final { |
255 | LOG(FATAL) << "not allowed to serialize a pointer" ; |
256 | } |
257 | void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } |
258 | |
259 | void Visit(const char* key, runtime::NDArray* value) final { |
260 | node_->attrs[key] = |
261 | std::to_string(tensor_index_->at(const_cast<DLTensor*>((*value).operator->()))); |
262 | } |
263 | |
264 | void Visit(const char* key, ObjectRef* value) final { |
265 | node_->attrs[key] = std::to_string(node_index_->at(const_cast<Object*>(value->get()))); |
266 | } |
267 | |
268 | // Get the node |
269 | void Get(Object* node) { |
270 | if (node == nullptr) { |
271 | node_->type_key.clear(); |
272 | return; |
273 | } |
274 | node_->type_key = node->GetTypeKey(); |
275 | // do not need to print additional things once we have repr bytes. |
276 | if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return; |
277 | |
278 | // populates the fields. |
279 | node_->attrs.clear(); |
280 | node_->data.clear(); |
281 | |
282 | if (node->IsInstance<ArrayNode>()) { |
283 | ArrayNode* n = static_cast<ArrayNode*>(node); |
284 | for (size_t i = 0; i < n->size(); ++i) { |
285 | node_->data.push_back(node_index_->at(const_cast<Object*>(n->at(i).get()))); |
286 | } |
287 | } else if (node->IsInstance<MapNode>()) { |
288 | MapNode* n = static_cast<MapNode*>(node); |
289 | bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { |
290 | return v.first->template IsInstance<StringObj>(); |
291 | }); |
292 | if (is_str_map) { |
293 | for (const auto& kv : *n) { |
294 | node_->keys.push_back(Downcast<String>(kv.first)); |
295 | node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get()))); |
296 | } |
297 | } else { |
298 | for (const auto& kv : *n) { |
299 | node_->data.push_back(node_index_->at(const_cast<Object*>(kv.first.get()))); |
300 | node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get()))); |
301 | } |
302 | } |
303 | } else { |
304 | // recursively index normal object. |
305 | reflection_->VisitAttrs(node, this); |
306 | } |
307 | } |
308 | }; |
309 | |
310 | class FieldDependencyFinder : public AttrVisitor { |
311 | public: |
312 | JSONNode* jnode_; |
313 | ReflectionVTable* reflection_ = ReflectionVTable::Global(); |
314 | |
315 | std::string GetValue(const char* key) const { |
316 | auto it = jnode_->attrs.find(key); |
317 | if (it == jnode_->attrs.end()) { |
318 | LOG(FATAL) << "JSONReader: cannot find field " << key; |
319 | } |
320 | return it->second; |
321 | } |
322 | template <typename T> |
323 | void ParseValue(const char* key, T* value) const { |
324 | std::istringstream is(GetValue(key)); |
325 | is >> *value; |
326 | if (is.fail()) { |
327 | LOG(FATAL) << "Wrong value format for field " << key; |
328 | } |
329 | } |
330 | void Visit(const char* key, double* value) final {} |
331 | void Visit(const char* key, int64_t* value) final {} |
332 | void Visit(const char* key, uint64_t* value) final {} |
333 | void Visit(const char* key, int* value) final {} |
334 | void Visit(const char* key, bool* value) final {} |
335 | void Visit(const char* key, std::string* value) final {} |
336 | void Visit(const char* key, void** value) final {} |
337 | void Visit(const char* key, DataType* value) final {} |
338 | void Visit(const char* key, runtime::NDArray* value) final {} |
339 | void Visit(const char* key, ObjectRef* value) final { |
340 | size_t index; |
341 | ParseValue(key, &index); |
342 | jnode_->fields.push_back(index); |
343 | } |
344 | void Find(Object* node, JSONNode* jnode) { |
345 | // Skip None |
346 | if (node == nullptr) { |
347 | return; |
348 | } |
349 | // Skip the objects that have their own string repr |
350 | if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node, nullptr)) { |
351 | return; |
352 | } |
353 | // Skip containers |
354 | if (jnode->type_key == ArrayNode::_type_key || jnode->type_key == MapNode::_type_key) { |
355 | return; |
356 | } |
357 | jnode_ = jnode; |
358 | reflection_->VisitAttrs(node, this); |
359 | } |
360 | }; |
361 | |
362 | // Helper class to set the attributes of a node |
363 | // from given json node. |
364 | class JSONAttrSetter : public AttrVisitor { |
365 | public: |
366 | const std::vector<ObjectPtr<Object>>* node_list_; |
367 | const std::vector<runtime::NDArray>* tensor_list_; |
368 | JSONNode* jnode_; |
369 | |
370 | ReflectionVTable* reflection_ = ReflectionVTable::Global(); |
371 | |
372 | std::string GetValue(const char* key) const { |
373 | auto it = jnode_->attrs.find(key); |
374 | if (it == jnode_->attrs.end()) { |
375 | LOG(FATAL) << "JSONReader: cannot find field " << key; |
376 | } |
377 | return it->second; |
378 | } |
379 | |
380 | void ParseDouble(const char* key, double* value) const { |
381 | std::istringstream is(GetValue(key)); |
382 | if (is.str() == "inf" ) { |
383 | *value = std::numeric_limits<double>::infinity(); |
384 | } else if (is.str() == "-inf" ) { |
385 | *value = -std::numeric_limits<double>::infinity(); |
386 | } else { |
387 | is >> *value; |
388 | if (is.fail()) { |
389 | LOG(FATAL) << "Wrong value format for field " << key; |
390 | } |
391 | } |
392 | } |
393 | |
394 | template <typename T> |
395 | void ParseValue(const char* key, T* value) const { |
396 | std::istringstream is(GetValue(key)); |
397 | is >> *value; |
398 | if (is.fail()) { |
399 | LOG(FATAL) << "Wrong value format for field " << key; |
400 | } |
401 | } |
402 | void Visit(const char* key, double* value) final { ParseDouble(key, value); } |
403 | void Visit(const char* key, int64_t* value) final { ParseValue(key, value); } |
404 | void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); } |
405 | void Visit(const char* key, int* value) final { ParseValue(key, value); } |
406 | void Visit(const char* key, bool* value) final { ParseValue(key, value); } |
407 | void Visit(const char* key, std::string* value) final { *value = GetValue(key); } |
408 | void Visit(const char* key, void** value) final { |
409 | LOG(FATAL) << "not allowed to deserialize a pointer" ; |
410 | } |
411 | void Visit(const char* key, DataType* value) final { |
412 | std::string stype = GetValue(key); |
413 | *value = String2Type(stype); |
414 | } |
415 | void Visit(const char* key, runtime::NDArray* value) final { |
416 | size_t index; |
417 | ParseValue(key, &index); |
418 | ICHECK_LE(index, tensor_list_->size()); |
419 | *value = tensor_list_->at(index); |
420 | } |
421 | void Visit(const char* key, ObjectRef* value) final { |
422 | size_t index; |
423 | ParseValue(key, &index); |
424 | ICHECK_LE(index, node_list_->size()); |
425 | *value = ObjectRef(node_list_->at(index)); |
426 | } |
427 | // set node to be current JSONNode |
428 | void Set(ObjectPtr<Object>* node, JSONNode* jnode) { |
429 | // Skip None |
430 | if (node->get() == nullptr) { |
431 | return; |
432 | } |
433 | // Skip the objects that have their own string repr |
434 | if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node->get(), nullptr)) { |
435 | return; |
436 | } |
437 | // handling Array |
438 | if (jnode->type_key == ArrayNode::_type_key) { |
439 | std::vector<ObjectRef> container; |
440 | for (auto index : jnode->data) { |
441 | container.push_back(ObjectRef(node_list_->at(index))); |
442 | } |
443 | Array<ObjectRef> array(container); |
444 | *node = runtime::ObjectInternal::MoveObjectPtr(&array); |
445 | return; |
446 | } |
447 | // handling Map |
448 | if (jnode->type_key == MapNode::_type_key) { |
449 | std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> container; |
450 | if (jnode->keys.empty()) { |
451 | ICHECK_EQ(jnode->data.size() % 2, 0U); |
452 | for (size_t i = 0; i < jnode->data.size(); i += 2) { |
453 | container[ObjectRef(node_list_->at(jnode->data[i]))] = |
454 | ObjectRef(node_list_->at(jnode->data[i + 1])); |
455 | } |
456 | } else { |
457 | ICHECK_EQ(jnode->data.size(), jnode->keys.size()); |
458 | for (size_t i = 0; i < jnode->data.size(); ++i) { |
459 | container[String(jnode->keys[i])] = ObjectRef(node_list_->at(jnode->data[i])); |
460 | } |
461 | } |
462 | Map<ObjectRef, ObjectRef> map(container); |
463 | *node = runtime::ObjectInternal::MoveObjectPtr(&map); |
464 | return; |
465 | } |
466 | jnode_ = jnode; |
467 | reflection_->VisitAttrs(node->get(), this); |
468 | } |
469 | }; |
470 | |
471 | // json graph structure to store node |
472 | struct JSONGraph { |
473 | // the root of the graph |
474 | size_t root; |
475 | // the nodes of the graph |
476 | std::vector<JSONNode> nodes; |
477 | // base64 b64ndarrays of arrays |
478 | std::vector<std::string> b64ndarrays; |
479 | // global attributes |
480 | AttrMap attrs; |
481 | |
482 | void Save(dmlc::JSONWriter* writer) const { |
483 | writer->BeginObject(); |
484 | writer->WriteObjectKeyValue("root" , root); |
485 | writer->WriteObjectKeyValue("nodes" , nodes); |
486 | writer->WriteObjectKeyValue("b64ndarrays" , b64ndarrays); |
487 | if (attrs.size() != 0) { |
488 | writer->WriteObjectKeyValue("attrs" , attrs); |
489 | } |
490 | writer->EndObject(); |
491 | } |
492 | |
493 | void Load(dmlc::JSONReader* reader) { |
494 | attrs.clear(); |
495 | dmlc::JSONObjectReadHelper helper; |
496 | helper.DeclareField("root" , &root); |
497 | helper.DeclareField("nodes" , &nodes); |
498 | helper.DeclareOptionalField("b64ndarrays" , &b64ndarrays); |
499 | helper.DeclareOptionalField("attrs" , &attrs); |
500 | helper.ReadAllFields(reader); |
501 | } |
502 | |
503 | static JSONGraph Create(const ObjectRef& root) { |
504 | JSONGraph g; |
505 | NodeIndexer indexer; |
506 | indexer.MakeIndex(const_cast<Object*>(root.get())); |
507 | JSONAttrGetter getter; |
508 | getter.node_index_ = &indexer.node_index_; |
509 | getter.tensor_index_ = &indexer.tensor_index_; |
510 | for (Object* n : indexer.node_list_) { |
511 | JSONNode jnode; |
512 | getter.node_ = &jnode; |
513 | getter.Get(n); |
514 | g.nodes.emplace_back(std::move(jnode)); |
515 | } |
516 | g.attrs["tvm_version" ] = TVM_VERSION; |
517 | g.root = indexer.node_index_.at(const_cast<Object*>(root.get())); |
518 | // serialize tensor |
519 | for (DLTensor* tensor : indexer.tensor_list_) { |
520 | std::string blob; |
521 | dmlc::MemoryStringStream mstrm(&blob); |
522 | support::Base64OutStream b64strm(&mstrm); |
523 | runtime::SaveDLTensor(&b64strm, tensor); |
524 | b64strm.Finish(); |
525 | g.b64ndarrays.emplace_back(std::move(blob)); |
526 | } |
527 | return g; |
528 | } |
529 | |
530 | std::vector<size_t> TopoSort() const { |
531 | size_t n_nodes = nodes.size(); |
532 | std::vector<size_t> topo_order; |
533 | std::vector<size_t> in_degree(n_nodes, 0); |
534 | for (const JSONNode& jnode : nodes) { |
535 | for (size_t i : jnode.data) { |
536 | ++in_degree[i]; |
537 | } |
538 | for (size_t i : jnode.fields) { |
539 | ++in_degree[i]; |
540 | } |
541 | } |
542 | for (size_t i = 0; i < n_nodes; ++i) { |
543 | if (in_degree[i] == 0) { |
544 | topo_order.push_back(i); |
545 | } |
546 | } |
547 | for (size_t p = 0; p < topo_order.size(); ++p) { |
548 | const JSONNode& jnode = nodes[topo_order[p]]; |
549 | for (size_t i : jnode.data) { |
550 | if (--in_degree[i] == 0) { |
551 | topo_order.push_back(i); |
552 | } |
553 | } |
554 | for (size_t i : jnode.fields) { |
555 | if (--in_degree[i] == 0) { |
556 | topo_order.push_back(i); |
557 | } |
558 | } |
559 | } |
560 | ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file" ; |
561 | std::reverse(std::begin(topo_order), std::end(topo_order)); |
562 | return topo_order; |
563 | } |
564 | }; |
565 | |
566 | std::string SaveJSON(const ObjectRef& n) { |
567 | auto jgraph = JSONGraph::Create(n); |
568 | std::ostringstream os; |
569 | dmlc::JSONWriter writer(&os); |
570 | jgraph.Save(&writer); |
571 | return os.str(); |
572 | } |
573 | |
574 | ObjectRef LoadJSON(std::string json_str) { |
575 | ReflectionVTable* reflection = ReflectionVTable::Global(); |
576 | JSONGraph jgraph; |
577 | { |
578 | // load in json graph. |
579 | std::istringstream is(json_str); |
580 | dmlc::JSONReader reader(&is); |
581 | jgraph.Load(&reader); |
582 | } |
583 | size_t n_nodes = jgraph.nodes.size(); |
584 | std::vector<runtime::NDArray> tensors; |
585 | { |
586 | // load in tensors |
587 | for (const std::string& blob : jgraph.b64ndarrays) { |
588 | dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob)); |
589 | support::Base64InStream b64strm(&mstrm); |
590 | b64strm.InitPosition(); |
591 | runtime::NDArray temp; |
592 | ICHECK(temp.Load(&b64strm)); |
593 | tensors.emplace_back(std::move(temp)); |
594 | } |
595 | } |
596 | // Pass 1: create all non-container objects |
597 | std::vector<ObjectPtr<Object>> nodes(n_nodes, nullptr); |
598 | for (size_t i = 0; i < n_nodes; ++i) { |
599 | const JSONNode& jnode = jgraph.nodes[i]; |
600 | if (jnode.type_key.length() != 0) { |
601 | nodes[i] = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); |
602 | } |
603 | } |
604 | // Pass 2: figure out all field dependency |
605 | { |
606 | FieldDependencyFinder dep_finder; |
607 | for (size_t i = 0; i < n_nodes; ++i) { |
608 | dep_finder.Find(nodes[i].get(), &jgraph.nodes[i]); |
609 | } |
610 | } |
611 | // Pass 3: topo sort |
612 | std::vector<size_t> topo_order = jgraph.TopoSort(); |
613 | // Pass 4: set all values |
614 | { |
615 | JSONAttrSetter setter; |
616 | setter.node_list_ = &nodes; |
617 | setter.tensor_list_ = &tensors; |
618 | for (size_t i : topo_order) { |
619 | setter.Set(&nodes[i], &jgraph.nodes[i]); |
620 | } |
621 | } |
622 | return ObjectRef(nodes.at(jgraph.root)); |
623 | } |
624 | |
625 | TVM_REGISTER_GLOBAL("node.SaveJSON" ).set_body_typed(SaveJSON); |
626 | |
627 | TVM_REGISTER_GLOBAL("node.LoadJSON" ).set_body_typed(LoadJSON); |
628 | } // namespace tvm |
629 | |