1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file file_utils.cc |
22 | */ |
23 | #include "file_utils.h" |
24 | |
25 | #include <dmlc/json.h> |
26 | #include <dmlc/memory_io.h> |
27 | #include <tvm/runtime/logging.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/runtime/serializer.h> |
30 | |
31 | #include <fstream> |
32 | #include <unordered_map> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace runtime { |
37 | |
38 | void FunctionInfo::Save(dmlc::JSONWriter* writer) const { |
39 | std::vector<std::string> sarg_types(arg_types.size()); |
40 | for (size_t i = 0; i < arg_types.size(); ++i) { |
41 | sarg_types[i] = DLDataType2String(arg_types[i]); |
42 | } |
43 | writer->BeginObject(); |
44 | writer->WriteObjectKeyValue("name" , name); |
45 | writer->WriteObjectKeyValue("arg_types" , sarg_types); |
46 | writer->WriteObjectKeyValue("launch_param_tags" , launch_param_tags); |
47 | writer->EndObject(); |
48 | } |
49 | |
50 | void FunctionInfo::Load(dmlc::JSONReader* reader) { |
51 | dmlc::JSONObjectReadHelper helper; |
52 | std::vector<std::string> sarg_types; |
53 | helper.DeclareField("name" , &name); |
54 | helper.DeclareField("arg_types" , &sarg_types); |
55 | helper.DeclareOptionalField("launch_param_tags" , &launch_param_tags); |
56 | helper.DeclareOptionalField("thread_axis_tags" , |
57 | &launch_param_tags); // for backward compatibility |
58 | helper.ReadAllFields(reader); |
59 | arg_types.resize(sarg_types.size()); |
60 | for (size_t i = 0; i < arg_types.size(); ++i) { |
61 | arg_types[i] = String2DLDataType(sarg_types[i]); |
62 | } |
63 | } |
64 | |
65 | void FunctionInfo::Save(dmlc::Stream* writer) const { |
66 | writer->Write(name); |
67 | writer->Write(arg_types); |
68 | writer->Write(launch_param_tags); |
69 | } |
70 | |
71 | bool FunctionInfo::Load(dmlc::Stream* reader) { |
72 | if (!reader->Read(&name)) return false; |
73 | if (!reader->Read(&arg_types)) return false; |
74 | if (!reader->Read(&launch_param_tags)) return false; |
75 | return true; |
76 | } |
77 | |
78 | std::string GetFileFormat(const std::string& file_name, const std::string& format) { |
79 | std::string fmt = format; |
80 | if (fmt.length() == 0) { |
81 | size_t pos = file_name.find_last_of("." ); |
82 | if (pos != std::string::npos) { |
83 | return file_name.substr(pos + 1, file_name.length() - pos - 1); |
84 | } else { |
85 | return "" ; |
86 | } |
87 | } else { |
88 | return format; |
89 | } |
90 | } |
91 | |
92 | std::string GetCacheDir() { |
93 | char* env_cache_dir; |
94 | if ((env_cache_dir = getenv("TVM_CACHE_DIR" ))) return env_cache_dir; |
95 | if ((env_cache_dir = getenv("XDG_CACHE_HOME" ))) { |
96 | return std::string(env_cache_dir) + "/tvm" ; |
97 | } |
98 | if ((env_cache_dir = getenv("HOME" ))) { |
99 | return std::string(env_cache_dir) + "/.cache/tvm" ; |
100 | } |
101 | return "." ; |
102 | } |
103 | |
104 | std::string GetFileBasename(const std::string& file_name) { |
105 | size_t last_slash = file_name.find_last_of("/" ); |
106 | if (last_slash == std::string::npos) return file_name; |
107 | return file_name.substr(last_slash + 1); |
108 | } |
109 | |
110 | std::string GetMetaFilePath(const std::string& file_name) { |
111 | size_t pos = file_name.find_last_of("." ); |
112 | if (pos != std::string::npos) { |
113 | return file_name.substr(0, pos) + ".tvm_meta.json" ; |
114 | } else { |
115 | return file_name + ".tvm_meta.json" ; |
116 | } |
117 | } |
118 | |
119 | void LoadBinaryFromFile(const std::string& file_name, std::string* data) { |
120 | std::ifstream fs(file_name, std::ios::in | std::ios::binary); |
121 | ICHECK(!fs.fail()) << "Cannot open " << file_name; |
122 | // get its size: |
123 | fs.seekg(0, std::ios::end); |
124 | size_t size = static_cast<size_t>(fs.tellg()); |
125 | fs.seekg(0, std::ios::beg); |
126 | data->resize(size); |
127 | fs.read(&(*data)[0], size); |
128 | } |
129 | |
130 | void SaveBinaryToFile(const std::string& file_name, const std::string& data) { |
131 | std::ofstream fs(file_name, std::ios::out | std::ios::binary); |
132 | ICHECK(!fs.fail()) << "Cannot open " << file_name; |
133 | fs.write(&data[0], data.length()); |
134 | } |
135 | |
136 | void SaveMetaDataToFile(const std::string& file_name, |
137 | const std::unordered_map<std::string, FunctionInfo>& fmap) { |
138 | std::string version = "0.1.0" ; |
139 | std::ofstream fs(file_name.c_str()); |
140 | ICHECK(!fs.fail()) << "Cannot open file " << file_name; |
141 | dmlc::JSONWriter writer(&fs); |
142 | writer.BeginObject(); |
143 | writer.WriteObjectKeyValue("tvm_version" , version); |
144 | writer.WriteObjectKeyValue("func_info" , fmap); |
145 | writer.EndObject(); |
146 | fs.close(); |
147 | } |
148 | |
149 | void LoadMetaDataFromFile(const std::string& file_name, |
150 | std::unordered_map<std::string, FunctionInfo>* fmap) { |
151 | std::ifstream fs(file_name.c_str()); |
152 | ICHECK(!fs.fail()) << "Cannot open file " << file_name; |
153 | std::string version; |
154 | dmlc::JSONReader reader(&fs); |
155 | dmlc::JSONObjectReadHelper helper; |
156 | helper.DeclareField("tvm_version" , &version); |
157 | helper.DeclareField("func_info" , fmap); |
158 | helper.ReadAllFields(&reader); |
159 | fs.close(); |
160 | } |
161 | |
162 | void RemoveFile(const std::string& file_name) { |
163 | // FIXME: This doesn't check the return code. |
164 | std::remove(file_name.c_str()); |
165 | } |
166 | |
167 | void CopyFile(const std::string& src_file_name, const std::string& dest_file_name) { |
168 | std::ifstream src(src_file_name, std::ios::binary); |
169 | ICHECK(src) << "Unable to open source file '" << src_file_name << "'" ; |
170 | |
171 | std::ofstream dest(dest_file_name, std::ios::binary | std::ios::trunc); |
172 | ICHECK(dest) << "Unable to destination source file '" << src_file_name << "'" ; |
173 | |
174 | dest << src.rdbuf(); |
175 | |
176 | src.close(); |
177 | dest.close(); |
178 | |
179 | ICHECK(dest) << "File-copy operation failed." |
180 | << " src='" << src_file_name << "'" |
181 | << " dest='" << dest_file_name << "'" ; |
182 | } |
183 | |
184 | Map<String, NDArray> LoadParams(const std::string& param_blob) { |
185 | dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob)); |
186 | return LoadParams(&strm); |
187 | } |
188 | Map<String, NDArray> LoadParams(dmlc::Stream* strm) { |
189 | Map<String, NDArray> params; |
190 | uint64_t , reserved; |
191 | ICHECK(strm->Read(&header)) << "Invalid parameters file format" ; |
192 | ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format" ; |
193 | ICHECK(strm->Read(&reserved)) << "Invalid parameters file format" ; |
194 | |
195 | std::vector<std::string> names; |
196 | ICHECK(strm->Read(&names)) << "Invalid parameters file format" ; |
197 | uint64_t sz; |
198 | strm->Read(&sz); |
199 | size_t size = static_cast<size_t>(sz); |
200 | ICHECK(size == names.size()) << "Invalid parameters file format" ; |
201 | for (size_t i = 0; i < size; ++i) { |
202 | // The data_entry is allocated on device, NDArray.load always load the array into CPU. |
203 | NDArray temp; |
204 | temp.Load(strm); |
205 | params.Set(names[i], temp); |
206 | } |
207 | return params; |
208 | } |
209 | |
210 | void SaveParams(dmlc::Stream* strm, const Map<String, NDArray>& params) { |
211 | std::vector<std::string> names; |
212 | std::vector<const DLTensor*> arrays; |
213 | for (auto& p : params) { |
214 | names.push_back(p.first); |
215 | arrays.push_back(p.second.operator->()); |
216 | } |
217 | |
218 | uint64_t = kTVMNDArrayListMagic, reserved = 0; |
219 | strm->Write(header); |
220 | strm->Write(reserved); |
221 | strm->Write(names); |
222 | { |
223 | uint64_t sz = static_cast<uint64_t>(arrays.size()); |
224 | strm->Write(sz); |
225 | for (size_t i = 0; i < sz; ++i) { |
226 | tvm::runtime::SaveDLTensor(strm, arrays[i]); |
227 | } |
228 | } |
229 | } |
230 | |
231 | std::string SaveParams(const Map<String, NDArray>& params) { |
232 | std::string bytes; |
233 | dmlc::MemoryStringStream strm(&bytes); |
234 | dmlc::Stream* fo = &strm; |
235 | SaveParams(fo, params); |
236 | return bytes; |
237 | } |
238 | |
239 | TVM_REGISTER_GLOBAL("runtime.SaveParams" ).set_body_typed([](const Map<String, NDArray>& params) { |
240 | std::string s = ::tvm::runtime::SaveParams(params); |
241 | // copy return array so it is owned by the ret value |
242 | TVMRetValue rv; |
243 | rv = TVMByteArray{s.data(), s.size()}; |
244 | return rv; |
245 | }); |
246 | TVM_REGISTER_GLOBAL("runtime.LoadParams" ).set_body_typed([](const String& s) { |
247 | return ::tvm::runtime::LoadParams(s); |
248 | }); |
249 | |
250 | } // namespace runtime |
251 | } // namespace tvm |
252 | |