1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file stackvm_module.cc
22 */
23#include "stackvm_module.h"
24
25#include <dmlc/memory_io.h>
26#include <tvm/runtime/module.h>
27#include <tvm/runtime/registry.h>
28
29#include <memory>
30#include <unordered_map>
31#include <utility>
32
33#include "../file_utils.h"
34
35namespace tvm {
36namespace runtime {
37
38class StackVMModuleNode : public runtime::ModuleNode {
39 public:
40 const char* type_key() const final { return "stackvm"; }
41
42 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
43 if (name == runtime::symbol::tvm_module_main) {
44 return GetFunction(entry_func_, sptr_to_self);
45 }
46 auto it = fmap_.find(name);
47 if (it == fmap_.end()) return PackedFunc();
48 const StackVM& vm = it->second;
49 // capture sptr_to_self to keep module node alive.
50 return PackedFunc(
51 [vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { vm.Run(args, this); });
52 }
53
54 std::string GetSource(const std::string& format) final {
55 std::ostringstream os;
56 for (const auto& kv : fmap_) {
57 os << "Function: " << kv.first << '\n';
58 os << kv.second;
59 }
60 return os.str();
61 }
62
63 void SaveToFile(const std::string& file_name, const std::string& format) final {
64 std::string data, mblob;
65 dmlc::MemoryStringStream writer(&data);
66 dmlc::Stream* strm = &writer;
67 strm->Write(fmap_);
68 strm->Write(entry_func_);
69 // also save imports
70 uint64_t num_imports = static_cast<uint64_t>(imports_.size());
71 strm->Write(num_imports);
72
73 for (runtime::Module im : imports_) {
74 ICHECK_EQ(im->imports().size(), 0U) << "Only support simply one-level hierarchy";
75 std::string tkey = im->type_key();
76 strm->Write(tkey);
77 im->SaveToBinary(strm);
78 }
79 SaveBinaryToFile(file_name, data);
80 }
81
82 static Module Create(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) {
83 auto n = make_object<StackVMModuleNode>();
84 n->fmap_ = std::move(fmap);
85 n->entry_func_ = std::move(entry_func);
86 return Module(n);
87 }
88
89 static Module Load(dmlc::Stream* strm) {
90 std::unordered_map<std::string, StackVM> fmap;
91 std::string entry_func, data;
92 strm->Read(&fmap);
93 strm->Read(&entry_func);
94 auto n = make_object<StackVMModuleNode>();
95 n->fmap_ = std::move(fmap);
96 n->entry_func_ = std::move(entry_func);
97 uint64_t num_imports;
98 strm->Read(&num_imports);
99 for (uint64_t i = 0; i < num_imports; ++i) {
100 std::string tkey;
101 ICHECK(strm->Read(&tkey));
102 std::string loadkey = "runtime.module.loadbinary_";
103 std::string fkey = loadkey + tkey;
104 const PackedFunc* f = Registry::Get(fkey);
105 if (f == nullptr) {
106 std::string loaders = "";
107 for (auto name : Registry::ListNames()) {
108 if (name.rfind(loadkey, 0) == 0) {
109 if (loaders.size() > 0) {
110 loaders += ", ";
111 }
112 loaders += name.substr(loadkey.size());
113 }
114 }
115 ICHECK(f != nullptr)
116 << "Binary was created using " << tkey
117 << " but a loader of that name is not registered. Available loaders are " << loaders
118 << ". Perhaps you need to recompile with this runtime enabled.";
119 }
120 Module m = (*f)(static_cast<void*>(strm));
121 n->imports_.emplace_back(std::move(m));
122 }
123 return Module(n);
124 }
125
126 static Module LoadFromFile(std::string file_name, std::string format) {
127 std::string data;
128 LoadBinaryFromFile(file_name, &data);
129 dmlc::MemoryStringStream reader(&data);
130 return Load(&reader);
131 }
132
133 private:
134 // internal function map
135 std::unordered_map<std::string, StackVM> fmap_;
136 // entry function.
137 std::string entry_func_;
138};
139
140Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) {
141 return StackVMModuleNode::Create(fmap, entry_func);
142}
143
144TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm")
145 .set_body_typed(StackVMModuleNode::LoadFromFile);
146
147} // namespace runtime
148} // namespace tvm
149