1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Base/TensorSerialization.h"
18#ifdef WITH_PNG
19#include "glow/Base/Image.h"
20#endif
21#include "glow/Graph/Graph.h"
22
23#include "llvm/Support/CommandLine.h"
24#include <fstream>
25
26using namespace glow;
27
28namespace glow {
29
30/// Helper method to dump the tensor content into a text file.
31template <class ElemTy>
32static void dumpTensorToTextFileImpl(Tensor &tensor, llvm::StringRef filename,
33 std::ofstream &fs) {
34 Handle<ElemTy> handle = tensor.getHandle<ElemTy>();
35 for (dim_t idx = 0, e = handle.actualSize(); idx < e; idx++) {
36 fs << (double)handle.raw(idx) << ", ";
37 }
38}
39
40/// Helper method to load the tensor content from a text file.
41template <class ElemTy>
42static void loadTensorFromTextFileImpl(Tensor &tensor, llvm::StringRef filename,
43 std::ifstream &fs) {
44 Handle<ElemTy> handle = tensor.getHandle<ElemTy>();
45 char ch;
46 double val;
47 for (dim_t idx = 0, e = handle.actualSize(); idx < e; idx++) {
48 // Load tensor value.
49 CHECK(fs >> val) << "Error loading text file '" << filename.data()
50 << "'! Only " << idx
51 << " values were given for loading a tensor "
52 << "with " << e << " elements!";
53 handle.raw(idx) = val;
54 // Check delimiter.
55 CHECK(fs >> ch) << "Error loading text file '" << filename.data()
56 << "'! Delimiter character ',' not found!";
57 if (idx < e - 1) {
58 CHECK(ch == ',')
59 << "Error loading text file '" << filename.data()
60 << "'! Delimiter character is expected to be ',' but character '"
61 << ch << "' was found!";
62 }
63 }
64 CHECK(!(fs >> val)) << "Error loading text file '" << filename.data()
65 << "'! Too many values given for loading a tensor with "
66 << handle.actualSize() << " elements!";
67}
68
69#ifdef WITH_PNG
70
71/// Helper method to load tensor files into the model input tensor.
72static void loadTensorFromFileWithType(Tensor &T, llvm::StringRef filename,
73 ImageLayout imageLayout) {
74 std::ifstream infile(filename.str().c_str());
75 CHECK(infile.is_open()) << "Error opening file '" << filename.data() << "'!";
76 std::string line;
77 ShapeVector dims;
78
79 CHECK(std::getline(infile, line)) << "Failed to read 1st line";
80 std::stringstream ss(line);
81 for (dim_t i = 0; i < 4; i++) {
82 int val;
83 CHECK(ss >> val) << "Failed to read dimension " << i;
84 dims.push_back(val);
85 }
86 T.reset(ElemKind::FloatTy, dims);
87 // Now read the tensor.
88 CHECK(std::getline(infile, line)) << "Failed to read 2nd line";
89 auto H = T.getHandle<>();
90 std::stringstream ss2(line);
91 for (dim_t i = 0, e = H.size(); i < e; i++) {
92 float val;
93 CHECK(ss2 >> val) << "Error loading file " << filename.data()
94 << " @ element " << i;
95 H.raw(i) = val;
96 }
97 // Convert to requested layout (tensor blob is in NCHW by default).
98 if (imageLayout == ImageLayout::NHWC) {
99 Tensor transposed;
100 T.transpose(&transposed, NCHW2NHWC);
101 T = std::move(transposed);
102 }
103}
104
105/// Set default tensor loader.
106static InputTensorFileLoaderFn inputTensorFileLoader_ =
107 loadTensorFromFileWithType;
108
109#endif // WITH_PNG
110
111} // namespace glow
112
113#ifdef WITH_PNG
114
115void glow::registerInputTensorFileLoader(InputTensorFileLoaderFn loader) {
116 inputTensorFileLoader_ = loader;
117}
118
119#endif // WITH_PNG
120
121void glow::dumpTensorToBinaryFile(const Tensor &tensor,
122 llvm::StringRef filename,
123 const TensorSerializationOptions &opts) {
124 std::ofstream fs;
125 fs.open(filename.data(), std::ios::binary);
126 CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!";
127 dumpTensorToBinaryFile(tensor, fs, opts);
128 fs.close();
129}
130
131void glow::dumpTensorToBinaryFile(const Tensor &tensor, std::ofstream &fs,
132 const TensorSerializationOptions &opts) {
133 CHECK(tensor.getUnsafePtr())
134 << "Tensor not initialized before dumping to binary file!";
135 // Dump tensor type.
136 if (opts.withType) {
137 std::string typeStr = tensor.getType().toString();
138 fs.write(typeStr.c_str(), typeStr.size());
139 }
140 // Dump tensor data.
141 fs.write(tensor.getUnsafePtr(), tensor.getSizeInBytes());
142}
143
144void glow::loadTensorFromBinaryFile(Tensor &tensor, llvm::StringRef filename,
145 const TensorSerializationOptions &opts) {
146 std::ifstream fs;
147 fs.open(filename.data(), std::ios::binary);
148 CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!";
149 // Load tensor type.
150 size_t headerSize = 0;
151 if (opts.withType) {
152 std::string typeStr;
153 char ch;
154 do {
155 CHECK(fs.read(&ch, 1))
156 << "Error loading binary file '" << filename.data()
157 << "'! Tensor type delimiter character '>' not found!";
158 typeStr += ch;
159 } while (ch != '>');
160 tensor.reset(Type::fromString(typeStr));
161 headerSize = typeStr.size();
162 } else {
163 CHECK(tensor.getUnsafePtr())
164 << "Tensor not initialized before loading from raw binary file!";
165 }
166 // Verify file data size matches tensor size in bytes.
167 size_t tensorSize = tensor.getSizeInBytes();
168 fs.seekg(0, std::ios::end);
169 size_t fileDataSize = size_t(fs.tellg()) - headerSize;
170 CHECK(fileDataSize == tensorSize)
171 << "Error loading binary file '" << filename.data()
172 << "' with header size " << headerSize << " bytes and data size "
173 << fileDataSize << " bytes into "
174 << "tensor with size " << tensorSize << " bytes!";
175
176 // Load tensor data.
177 fs.seekg(headerSize, std::ios::beg);
178 fs.read(tensor.getUnsafePtr(), tensorSize);
179 fs.close();
180}
181
182void glow::dumpTensorToTextFile(Tensor &tensor, llvm::StringRef filename,
183 const TensorSerializationOptions &opts) {
184 std::ofstream fs;
185 fs.open(filename.data());
186 CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!";
187 CHECK(tensor.getUnsafePtr())
188 << "Tensor not initialized before dumping to text file!";
189 // Dump tensor type.
190 if (opts.withType) {
191 fs << tensor.getType().toString() << "\n";
192 }
193 // Dump tensor data.
194 switch (tensor.getElementType()) {
195 case ElemKind::FloatTy:
196 return dumpTensorToTextFileImpl<float>(tensor, filename, fs);
197 case ElemKind::Float16Ty:
198 return dumpTensorToTextFileImpl<float16_t>(tensor, filename, fs);
199 case ElemKind::BFloat16Ty:
200 return dumpTensorToTextFileImpl<bfloat16_t>(tensor, filename, fs);
201 case ElemKind::Int8QTy:
202 return dumpTensorToTextFileImpl<int8_t>(tensor, filename, fs);
203 case ElemKind::UInt8QTy:
204 return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs);
205 case ElemKind::Int16QTy:
206 return dumpTensorToTextFileImpl<int16_t>(tensor, filename, fs);
207 case ElemKind::Int32QTy:
208 return dumpTensorToTextFileImpl<int32_t>(tensor, filename, fs);
209 case ElemKind::Int32ITy:
210 return dumpTensorToTextFileImpl<int32_t>(tensor, filename, fs);
211 case ElemKind::Int64ITy:
212 return dumpTensorToTextFileImpl<int64_t>(tensor, filename, fs);
213 case ElemKind::UInt8FusedQTy:
214 return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs);
215 case ElemKind::UInt8FusedFP16QTy:
216 return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs);
217 case ElemKind::UInt4FusedFP16QTy:
218 return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs);
219 case ElemKind::UInt4FusedQTy:
220 return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs);
221 case ElemKind::BoolTy:
222 return dumpTensorToTextFileImpl<bool>(tensor, filename, fs);
223 default:
224 llvm_unreachable("Tensor type not supported for dumping to text file!");
225 }
226 fs.close();
227}
228
229void glow::loadTensorFromTextFile(Tensor &tensor, llvm::StringRef filename,
230 const TensorSerializationOptions &opts) {
231 std::ifstream fs;
232 fs.open(filename.data());
233 CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!";
234 // Load tensor type.
235 if (opts.withType) {
236 std::string typeStr;
237 CHECK(std::getline(fs, typeStr))
238 << "Error loading text file '" << filename.data()
239 << "'! Tensor type not found!";
240 tensor.reset(Type::fromString(typeStr));
241 } else {
242 CHECK(tensor.getUnsafePtr())
243 << "Tensor not initialized before loading from raw text file!";
244 }
245 // Load tensor data.
246 switch (tensor.getElementType()) {
247 case ElemKind::FloatTy:
248 return loadTensorFromTextFileImpl<float>(tensor, filename, fs);
249 case ElemKind::Float16Ty:
250 return loadTensorFromTextFileImpl<float16_t>(tensor, filename, fs);
251 case ElemKind::BFloat16Ty:
252 return loadTensorFromTextFileImpl<bfloat16_t>(tensor, filename, fs);
253 case ElemKind::Int8QTy:
254 return loadTensorFromTextFileImpl<int8_t>(tensor, filename, fs);
255 case ElemKind::UInt8QTy:
256 return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs);
257 case ElemKind::Int16QTy:
258 return loadTensorFromTextFileImpl<int16_t>(tensor, filename, fs);
259 case ElemKind::Int32QTy:
260 return loadTensorFromTextFileImpl<int32_t>(tensor, filename, fs);
261 case ElemKind::Int32ITy:
262 return loadTensorFromTextFileImpl<int32_t>(tensor, filename, fs);
263 case ElemKind::Int64ITy:
264 return loadTensorFromTextFileImpl<int64_t>(tensor, filename, fs);
265 case ElemKind::UInt8FusedQTy:
266 return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs);
267 case ElemKind::UInt8FusedFP16QTy:
268 return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs);
269 case ElemKind::UInt4FusedFP16QTy:
270 return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs);
271 case ElemKind::UInt4FusedQTy:
272 return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs);
273 case ElemKind::BoolTy:
274 return loadTensorFromTextFileImpl<bool>(tensor, filename, fs);
275 default:
276 llvm_unreachable("Tensor type not supported for loading from text file!");
277 }
278 fs.close();
279}
280
281#ifdef WITH_PNG
282
283void glow::loadInputImageFromFileWithType(
284 const llvm::ArrayRef<std::string> &filenames, Tensor *inputData,
285 ImageLayout imageLayout) {
286 DCHECK(!filenames.empty())
287 << "There must be at least one filename in filenames.";
288 assert((dim_t)filenames.size() == filenames.size());
289 dim_t numImages = filenames.size();
290
291 CHECK(inputTensorFileLoader_) << "tensor loader not assigned!";
292
293 // Read each tensor file into a vector of tensors.
294 std::vector<Tensor> data(numImages);
295 dim_t batchSize = 0;
296 for (dim_t n = 0; n < numImages; n++) {
297 inputTensorFileLoader_(data[n], filenames[n], imageLayout);
298 auto dims0 = data[0].dims();
299 auto dims = data[n].dims();
300 CHECK_EQ(dims0[1], dims[1]) << "Non batch dimensions must match";
301 CHECK_EQ(dims0[2], dims[2]) << "Non batch dimensions must match";
302 CHECK_EQ(dims0[3], dims[3]) << "Non batch dimensions must match";
303 batchSize += data[n].dims()[0];
304 }
305
306 // Input tensor size is known now.
307 inputData->reset(ElemKind::FloatTy, {batchSize, data[0].dims()[1],
308 data[0].dims()[2], data[0].dims()[3]});
309 auto IIDH = inputData->getHandle<>();
310 // Insert each loaded file (in data[] tensors) as the input tensor slices.
311 for (dim_t n = 0, e = data.size(); n < e; n++) {
312 Handle<float> H = data[n].getHandle<>();
313 IIDH.insertTensors(H, {n, 0, 0, 0});
314 }
315}
316
317/// Helper function for loadInputTensorFromFileWithType, to produce blob files.
318void glow::dumpInputTensorToFileWithType(
319 const llvm::ArrayRef<std::string> &filenames, const Tensor &T,
320 ImageLayout imageLayout) {
321 CHECK_EQ(filenames.size(), 1) << "Dumping support single file only";
322 const std::string &filename = filenames[0];
323 Tensor localTensor = T.clone();
324 // Convert to requested layout (tensor blob is in NCHW by default).
325 if (imageLayout == ImageLayout::NHWC) {
326 Tensor transposed;
327 localTensor.transpose(&transposed, NHWC2NCHW);
328 localTensor = std::move(transposed);
329 }
330 std::ofstream outfile(filename.c_str());
331 CHECK(outfile.is_open()) << "Error opening file '" << filename << "'!";
332 // write dimensions to 1st line.
333 for (dim_t i = 0; i < 4; i++) {
334 CHECK(outfile << localTensor.dims()[i] << " ")
335 << "Failed to write dimension " << i;
336 }
337 outfile << "\n";
338 // write tensor to 2nd line.
339 auto H = localTensor.getHandle<float>();
340 for (auto e : H) {
341 outfile << e << " ";
342 }
343}
344
345#endif // WITH_PNG
346