1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file file_utils.cc
22 */
23#include "file_utils.h"
24
25#include <dmlc/json.h>
26#include <dmlc/memory_io.h>
27#include <tvm/runtime/logging.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/runtime/serializer.h>
30
31#include <fstream>
32#include <unordered_map>
33#include <vector>
34
35namespace tvm {
36namespace runtime {
37
38void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
39 std::vector<std::string> sarg_types(arg_types.size());
40 for (size_t i = 0; i < arg_types.size(); ++i) {
41 sarg_types[i] = DLDataType2String(arg_types[i]);
42 }
43 writer->BeginObject();
44 writer->WriteObjectKeyValue("name", name);
45 writer->WriteObjectKeyValue("arg_types", sarg_types);
46 writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
47 writer->EndObject();
48}
49
50void FunctionInfo::Load(dmlc::JSONReader* reader) {
51 dmlc::JSONObjectReadHelper helper;
52 std::vector<std::string> sarg_types;
53 helper.DeclareField("name", &name);
54 helper.DeclareField("arg_types", &sarg_types);
55 helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
56 helper.DeclareOptionalField("thread_axis_tags",
57 &launch_param_tags); // for backward compatibility
58 helper.ReadAllFields(reader);
59 arg_types.resize(sarg_types.size());
60 for (size_t i = 0; i < arg_types.size(); ++i) {
61 arg_types[i] = String2DLDataType(sarg_types[i]);
62 }
63}
64
65void FunctionInfo::Save(dmlc::Stream* writer) const {
66 writer->Write(name);
67 writer->Write(arg_types);
68 writer->Write(launch_param_tags);
69}
70
71bool FunctionInfo::Load(dmlc::Stream* reader) {
72 if (!reader->Read(&name)) return false;
73 if (!reader->Read(&arg_types)) return false;
74 if (!reader->Read(&launch_param_tags)) return false;
75 return true;
76}
77
78std::string GetFileFormat(const std::string& file_name, const std::string& format) {
79 std::string fmt = format;
80 if (fmt.length() == 0) {
81 size_t pos = file_name.find_last_of(".");
82 if (pos != std::string::npos) {
83 return file_name.substr(pos + 1, file_name.length() - pos - 1);
84 } else {
85 return "";
86 }
87 } else {
88 return format;
89 }
90}
91
92std::string GetCacheDir() {
93 char* env_cache_dir;
94 if ((env_cache_dir = getenv("TVM_CACHE_DIR"))) return env_cache_dir;
95 if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) {
96 return std::string(env_cache_dir) + "/tvm";
97 }
98 if ((env_cache_dir = getenv("HOME"))) {
99 return std::string(env_cache_dir) + "/.cache/tvm";
100 }
101 return ".";
102}
103
104std::string GetFileBasename(const std::string& file_name) {
105 size_t last_slash = file_name.find_last_of("/");
106 if (last_slash == std::string::npos) return file_name;
107 return file_name.substr(last_slash + 1);
108}
109
110std::string GetMetaFilePath(const std::string& file_name) {
111 size_t pos = file_name.find_last_of(".");
112 if (pos != std::string::npos) {
113 return file_name.substr(0, pos) + ".tvm_meta.json";
114 } else {
115 return file_name + ".tvm_meta.json";
116 }
117}
118
119void LoadBinaryFromFile(const std::string& file_name, std::string* data) {
120 std::ifstream fs(file_name, std::ios::in | std::ios::binary);
121 ICHECK(!fs.fail()) << "Cannot open " << file_name;
122 // get its size:
123 fs.seekg(0, std::ios::end);
124 size_t size = static_cast<size_t>(fs.tellg());
125 fs.seekg(0, std::ios::beg);
126 data->resize(size);
127 fs.read(&(*data)[0], size);
128}
129
130void SaveBinaryToFile(const std::string& file_name, const std::string& data) {
131 std::ofstream fs(file_name, std::ios::out | std::ios::binary);
132 ICHECK(!fs.fail()) << "Cannot open " << file_name;
133 fs.write(&data[0], data.length());
134}
135
136void SaveMetaDataToFile(const std::string& file_name,
137 const std::unordered_map<std::string, FunctionInfo>& fmap) {
138 std::string version = "0.1.0";
139 std::ofstream fs(file_name.c_str());
140 ICHECK(!fs.fail()) << "Cannot open file " << file_name;
141 dmlc::JSONWriter writer(&fs);
142 writer.BeginObject();
143 writer.WriteObjectKeyValue("tvm_version", version);
144 writer.WriteObjectKeyValue("func_info", fmap);
145 writer.EndObject();
146 fs.close();
147}
148
149void LoadMetaDataFromFile(const std::string& file_name,
150 std::unordered_map<std::string, FunctionInfo>* fmap) {
151 std::ifstream fs(file_name.c_str());
152 ICHECK(!fs.fail()) << "Cannot open file " << file_name;
153 std::string version;
154 dmlc::JSONReader reader(&fs);
155 dmlc::JSONObjectReadHelper helper;
156 helper.DeclareField("tvm_version", &version);
157 helper.DeclareField("func_info", fmap);
158 helper.ReadAllFields(&reader);
159 fs.close();
160}
161
162void RemoveFile(const std::string& file_name) {
163 // FIXME: This doesn't check the return code.
164 std::remove(file_name.c_str());
165}
166
167void CopyFile(const std::string& src_file_name, const std::string& dest_file_name) {
168 std::ifstream src(src_file_name, std::ios::binary);
169 ICHECK(src) << "Unable to open source file '" << src_file_name << "'";
170
171 std::ofstream dest(dest_file_name, std::ios::binary | std::ios::trunc);
172 ICHECK(dest) << "Unable to destination source file '" << src_file_name << "'";
173
174 dest << src.rdbuf();
175
176 src.close();
177 dest.close();
178
179 ICHECK(dest) << "File-copy operation failed."
180 << " src='" << src_file_name << "'"
181 << " dest='" << dest_file_name << "'";
182}
183
184Map<String, NDArray> LoadParams(const std::string& param_blob) {
185 dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
186 return LoadParams(&strm);
187}
188Map<String, NDArray> LoadParams(dmlc::Stream* strm) {
189 Map<String, NDArray> params;
190 uint64_t header, reserved;
191 ICHECK(strm->Read(&header)) << "Invalid parameters file format";
192 ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
193 ICHECK(strm->Read(&reserved)) << "Invalid parameters file format";
194
195 std::vector<std::string> names;
196 ICHECK(strm->Read(&names)) << "Invalid parameters file format";
197 uint64_t sz;
198 strm->Read(&sz);
199 size_t size = static_cast<size_t>(sz);
200 ICHECK(size == names.size()) << "Invalid parameters file format";
201 for (size_t i = 0; i < size; ++i) {
202 // The data_entry is allocated on device, NDArray.load always load the array into CPU.
203 NDArray temp;
204 temp.Load(strm);
205 params.Set(names[i], temp);
206 }
207 return params;
208}
209
210void SaveParams(dmlc::Stream* strm, const Map<String, NDArray>& params) {
211 std::vector<std::string> names;
212 std::vector<const DLTensor*> arrays;
213 for (auto& p : params) {
214 names.push_back(p.first);
215 arrays.push_back(p.second.operator->());
216 }
217
218 uint64_t header = kTVMNDArrayListMagic, reserved = 0;
219 strm->Write(header);
220 strm->Write(reserved);
221 strm->Write(names);
222 {
223 uint64_t sz = static_cast<uint64_t>(arrays.size());
224 strm->Write(sz);
225 for (size_t i = 0; i < sz; ++i) {
226 tvm::runtime::SaveDLTensor(strm, arrays[i]);
227 }
228 }
229}
230
231std::string SaveParams(const Map<String, NDArray>& params) {
232 std::string bytes;
233 dmlc::MemoryStringStream strm(&bytes);
234 dmlc::Stream* fo = &strm;
235 SaveParams(fo, params);
236 return bytes;
237}
238
239TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map<String, NDArray>& params) {
240 std::string s = ::tvm::runtime::SaveParams(params);
241 // copy return array so it is owned by the ret value
242 TVMRetValue rv;
243 rv = TVMByteArray{s.data(), s.size()};
244 return rv;
245});
246TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) {
247 return ::tvm::runtime::LoadParams(s);
248});
249
250} // namespace runtime
251} // namespace tvm
252