1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file stackvm_module.cc |
22 | */ |
23 | #include "stackvm_module.h" |
24 | |
25 | #include <dmlc/memory_io.h> |
26 | #include <tvm/runtime/module.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | #include <memory> |
30 | #include <unordered_map> |
31 | #include <utility> |
32 | |
33 | #include "../file_utils.h" |
34 | |
35 | namespace tvm { |
36 | namespace runtime { |
37 | |
38 | class StackVMModuleNode : public runtime::ModuleNode { |
39 | public: |
40 | const char* type_key() const final { return "stackvm" ; } |
41 | |
42 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
43 | if (name == runtime::symbol::tvm_module_main) { |
44 | return GetFunction(entry_func_, sptr_to_self); |
45 | } |
46 | auto it = fmap_.find(name); |
47 | if (it == fmap_.end()) return PackedFunc(); |
48 | const StackVM& vm = it->second; |
49 | // capture sptr_to_self to keep module node alive. |
50 | return PackedFunc( |
51 | [vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { vm.Run(args, this); }); |
52 | } |
53 | |
54 | std::string GetSource(const std::string& format) final { |
55 | std::ostringstream os; |
56 | for (const auto& kv : fmap_) { |
57 | os << "Function: " << kv.first << '\n'; |
58 | os << kv.second; |
59 | } |
60 | return os.str(); |
61 | } |
62 | |
63 | void SaveToFile(const std::string& file_name, const std::string& format) final { |
64 | std::string data, mblob; |
65 | dmlc::MemoryStringStream writer(&data); |
66 | dmlc::Stream* strm = &writer; |
67 | strm->Write(fmap_); |
68 | strm->Write(entry_func_); |
69 | // also save imports |
70 | uint64_t num_imports = static_cast<uint64_t>(imports_.size()); |
71 | strm->Write(num_imports); |
72 | |
73 | for (runtime::Module im : imports_) { |
74 | ICHECK_EQ(im->imports().size(), 0U) << "Only support simply one-level hierarchy" ; |
75 | std::string tkey = im->type_key(); |
76 | strm->Write(tkey); |
77 | im->SaveToBinary(strm); |
78 | } |
79 | SaveBinaryToFile(file_name, data); |
80 | } |
81 | |
82 | static Module Create(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) { |
83 | auto n = make_object<StackVMModuleNode>(); |
84 | n->fmap_ = std::move(fmap); |
85 | n->entry_func_ = std::move(entry_func); |
86 | return Module(n); |
87 | } |
88 | |
89 | static Module Load(dmlc::Stream* strm) { |
90 | std::unordered_map<std::string, StackVM> fmap; |
91 | std::string entry_func, data; |
92 | strm->Read(&fmap); |
93 | strm->Read(&entry_func); |
94 | auto n = make_object<StackVMModuleNode>(); |
95 | n->fmap_ = std::move(fmap); |
96 | n->entry_func_ = std::move(entry_func); |
97 | uint64_t num_imports; |
98 | strm->Read(&num_imports); |
99 | for (uint64_t i = 0; i < num_imports; ++i) { |
100 | std::string tkey; |
101 | ICHECK(strm->Read(&tkey)); |
102 | std::string loadkey = "runtime.module.loadbinary_" ; |
103 | std::string fkey = loadkey + tkey; |
104 | const PackedFunc* f = Registry::Get(fkey); |
105 | if (f == nullptr) { |
106 | std::string loaders = "" ; |
107 | for (auto name : Registry::ListNames()) { |
108 | if (name.rfind(loadkey, 0) == 0) { |
109 | if (loaders.size() > 0) { |
110 | loaders += ", " ; |
111 | } |
112 | loaders += name.substr(loadkey.size()); |
113 | } |
114 | } |
115 | ICHECK(f != nullptr) |
116 | << "Binary was created using " << tkey |
117 | << " but a loader of that name is not registered. Available loaders are " << loaders |
118 | << ". Perhaps you need to recompile with this runtime enabled." ; |
119 | } |
120 | Module m = (*f)(static_cast<void*>(strm)); |
121 | n->imports_.emplace_back(std::move(m)); |
122 | } |
123 | return Module(n); |
124 | } |
125 | |
126 | static Module LoadFromFile(std::string file_name, std::string format) { |
127 | std::string data; |
128 | LoadBinaryFromFile(file_name, &data); |
129 | dmlc::MemoryStringStream reader(&data); |
130 | return Load(&reader); |
131 | } |
132 | |
133 | private: |
134 | // internal function map |
135 | std::unordered_map<std::string, StackVM> fmap_; |
136 | // entry function. |
137 | std::string entry_func_; |
138 | }; |
139 | |
140 | Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) { |
141 | return StackVMModuleNode::Create(fmap, entry_func); |
142 | } |
143 | |
144 | TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm" ) |
145 | .set_body_typed(StackVMModuleNode::LoadFromFile); |
146 | |
147 | } // namespace runtime |
148 | } // namespace tvm |
149 | |