1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Base/TensorSerialization.h" |
18 | #ifdef WITH_PNG |
19 | #include "glow/Base/Image.h" |
20 | #endif |
21 | #include "glow/Graph/Graph.h" |
22 | |
23 | #include "llvm/Support/CommandLine.h" |
24 | #include <fstream> |
25 | |
26 | using namespace glow; |
27 | |
28 | namespace glow { |
29 | |
30 | /// Helper method to dump the tensor content into a text file. |
31 | template <class ElemTy> |
32 | static void dumpTensorToTextFileImpl(Tensor &tensor, llvm::StringRef filename, |
33 | std::ofstream &fs) { |
34 | Handle<ElemTy> handle = tensor.getHandle<ElemTy>(); |
35 | for (dim_t idx = 0, e = handle.actualSize(); idx < e; idx++) { |
36 | fs << (double)handle.raw(idx) << ", " ; |
37 | } |
38 | } |
39 | |
40 | /// Helper method to load the tensor content from a text file. |
41 | template <class ElemTy> |
42 | static void loadTensorFromTextFileImpl(Tensor &tensor, llvm::StringRef filename, |
43 | std::ifstream &fs) { |
44 | Handle<ElemTy> handle = tensor.getHandle<ElemTy>(); |
45 | char ch; |
46 | double val; |
47 | for (dim_t idx = 0, e = handle.actualSize(); idx < e; idx++) { |
48 | // Load tensor value. |
49 | CHECK(fs >> val) << "Error loading text file '" << filename.data() |
50 | << "'! Only " << idx |
51 | << " values were given for loading a tensor " |
52 | << "with " << e << " elements!" ; |
53 | handle.raw(idx) = val; |
54 | // Check delimiter. |
55 | CHECK(fs >> ch) << "Error loading text file '" << filename.data() |
56 | << "'! Delimiter character ',' not found!" ; |
57 | if (idx < e - 1) { |
58 | CHECK(ch == ',') |
59 | << "Error loading text file '" << filename.data() |
60 | << "'! Delimiter character is expected to be ',' but character '" |
61 | << ch << "' was found!" ; |
62 | } |
63 | } |
64 | CHECK(!(fs >> val)) << "Error loading text file '" << filename.data() |
65 | << "'! Too many values given for loading a tensor with " |
66 | << handle.actualSize() << " elements!" ; |
67 | } |
68 | |
69 | #ifdef WITH_PNG |
70 | |
71 | /// Helper method to load tensor files into the model input tensor. |
72 | static void loadTensorFromFileWithType(Tensor &T, llvm::StringRef filename, |
73 | ImageLayout imageLayout) { |
74 | std::ifstream infile(filename.str().c_str()); |
75 | CHECK(infile.is_open()) << "Error opening file '" << filename.data() << "'!" ; |
76 | std::string line; |
77 | ShapeVector dims; |
78 | |
79 | CHECK(std::getline(infile, line)) << "Failed to read 1st line" ; |
80 | std::stringstream ss(line); |
81 | for (dim_t i = 0; i < 4; i++) { |
82 | int val; |
83 | CHECK(ss >> val) << "Failed to read dimension " << i; |
84 | dims.push_back(val); |
85 | } |
86 | T.reset(ElemKind::FloatTy, dims); |
87 | // Now read the tensor. |
88 | CHECK(std::getline(infile, line)) << "Failed to read 2nd line" ; |
89 | auto H = T.getHandle<>(); |
90 | std::stringstream ss2(line); |
91 | for (dim_t i = 0, e = H.size(); i < e; i++) { |
92 | float val; |
93 | CHECK(ss2 >> val) << "Error loading file " << filename.data() |
94 | << " @ element " << i; |
95 | H.raw(i) = val; |
96 | } |
97 | // Convert to requested layout (tensor blob is in NCHW by default). |
98 | if (imageLayout == ImageLayout::NHWC) { |
99 | Tensor transposed; |
100 | T.transpose(&transposed, NCHW2NHWC); |
101 | T = std::move(transposed); |
102 | } |
103 | } |
104 | |
105 | /// Set default tensor loader. |
106 | static InputTensorFileLoaderFn inputTensorFileLoader_ = |
107 | loadTensorFromFileWithType; |
108 | |
109 | #endif // WITH_PNG |
110 | |
111 | } // namespace glow |
112 | |
113 | #ifdef WITH_PNG |
114 | |
115 | void glow::registerInputTensorFileLoader(InputTensorFileLoaderFn loader) { |
116 | inputTensorFileLoader_ = loader; |
117 | } |
118 | |
119 | #endif // WITH_PNG |
120 | |
121 | void glow::dumpTensorToBinaryFile(const Tensor &tensor, |
122 | llvm::StringRef filename, |
123 | const TensorSerializationOptions &opts) { |
124 | std::ofstream fs; |
125 | fs.open(filename.data(), std::ios::binary); |
126 | CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!" ; |
127 | dumpTensorToBinaryFile(tensor, fs, opts); |
128 | fs.close(); |
129 | } |
130 | |
131 | void glow::dumpTensorToBinaryFile(const Tensor &tensor, std::ofstream &fs, |
132 | const TensorSerializationOptions &opts) { |
133 | CHECK(tensor.getUnsafePtr()) |
134 | << "Tensor not initialized before dumping to binary file!" ; |
135 | // Dump tensor type. |
136 | if (opts.withType) { |
137 | std::string typeStr = tensor.getType().toString(); |
138 | fs.write(typeStr.c_str(), typeStr.size()); |
139 | } |
140 | // Dump tensor data. |
141 | fs.write(tensor.getUnsafePtr(), tensor.getSizeInBytes()); |
142 | } |
143 | |
144 | void glow::loadTensorFromBinaryFile(Tensor &tensor, llvm::StringRef filename, |
145 | const TensorSerializationOptions &opts) { |
146 | std::ifstream fs; |
147 | fs.open(filename.data(), std::ios::binary); |
148 | CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!" ; |
149 | // Load tensor type. |
150 | size_t = 0; |
151 | if (opts.withType) { |
152 | std::string typeStr; |
153 | char ch; |
154 | do { |
155 | CHECK(fs.read(&ch, 1)) |
156 | << "Error loading binary file '" << filename.data() |
157 | << "'! Tensor type delimiter character '>' not found!" ; |
158 | typeStr += ch; |
159 | } while (ch != '>'); |
160 | tensor.reset(Type::fromString(typeStr)); |
161 | headerSize = typeStr.size(); |
162 | } else { |
163 | CHECK(tensor.getUnsafePtr()) |
164 | << "Tensor not initialized before loading from raw binary file!" ; |
165 | } |
166 | // Verify file data size matches tensor size in bytes. |
167 | size_t tensorSize = tensor.getSizeInBytes(); |
168 | fs.seekg(0, std::ios::end); |
169 | size_t fileDataSize = size_t(fs.tellg()) - headerSize; |
170 | CHECK(fileDataSize == tensorSize) |
171 | << "Error loading binary file '" << filename.data() |
172 | << "' with header size " << headerSize << " bytes and data size " |
173 | << fileDataSize << " bytes into " |
174 | << "tensor with size " << tensorSize << " bytes!" ; |
175 | |
176 | // Load tensor data. |
177 | fs.seekg(headerSize, std::ios::beg); |
178 | fs.read(tensor.getUnsafePtr(), tensorSize); |
179 | fs.close(); |
180 | } |
181 | |
182 | void glow::dumpTensorToTextFile(Tensor &tensor, llvm::StringRef filename, |
183 | const TensorSerializationOptions &opts) { |
184 | std::ofstream fs; |
185 | fs.open(filename.data()); |
186 | CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!" ; |
187 | CHECK(tensor.getUnsafePtr()) |
188 | << "Tensor not initialized before dumping to text file!" ; |
189 | // Dump tensor type. |
190 | if (opts.withType) { |
191 | fs << tensor.getType().toString() << "\n" ; |
192 | } |
193 | // Dump tensor data. |
194 | switch (tensor.getElementType()) { |
195 | case ElemKind::FloatTy: |
196 | return dumpTensorToTextFileImpl<float>(tensor, filename, fs); |
197 | case ElemKind::Float16Ty: |
198 | return dumpTensorToTextFileImpl<float16_t>(tensor, filename, fs); |
199 | case ElemKind::BFloat16Ty: |
200 | return dumpTensorToTextFileImpl<bfloat16_t>(tensor, filename, fs); |
201 | case ElemKind::Int8QTy: |
202 | return dumpTensorToTextFileImpl<int8_t>(tensor, filename, fs); |
203 | case ElemKind::UInt8QTy: |
204 | return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs); |
205 | case ElemKind::Int16QTy: |
206 | return dumpTensorToTextFileImpl<int16_t>(tensor, filename, fs); |
207 | case ElemKind::Int32QTy: |
208 | return dumpTensorToTextFileImpl<int32_t>(tensor, filename, fs); |
209 | case ElemKind::Int32ITy: |
210 | return dumpTensorToTextFileImpl<int32_t>(tensor, filename, fs); |
211 | case ElemKind::Int64ITy: |
212 | return dumpTensorToTextFileImpl<int64_t>(tensor, filename, fs); |
213 | case ElemKind::UInt8FusedQTy: |
214 | return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs); |
215 | case ElemKind::UInt8FusedFP16QTy: |
216 | return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs); |
217 | case ElemKind::UInt4FusedFP16QTy: |
218 | return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs); |
219 | case ElemKind::UInt4FusedQTy: |
220 | return dumpTensorToTextFileImpl<uint8_t>(tensor, filename, fs); |
221 | case ElemKind::BoolTy: |
222 | return dumpTensorToTextFileImpl<bool>(tensor, filename, fs); |
223 | default: |
224 | llvm_unreachable("Tensor type not supported for dumping to text file!" ); |
225 | } |
226 | fs.close(); |
227 | } |
228 | |
229 | void glow::loadTensorFromTextFile(Tensor &tensor, llvm::StringRef filename, |
230 | const TensorSerializationOptions &opts) { |
231 | std::ifstream fs; |
232 | fs.open(filename.data()); |
233 | CHECK(fs.is_open()) << "Error opening file '" << filename.data() << "'!" ; |
234 | // Load tensor type. |
235 | if (opts.withType) { |
236 | std::string typeStr; |
237 | CHECK(std::getline(fs, typeStr)) |
238 | << "Error loading text file '" << filename.data() |
239 | << "'! Tensor type not found!" ; |
240 | tensor.reset(Type::fromString(typeStr)); |
241 | } else { |
242 | CHECK(tensor.getUnsafePtr()) |
243 | << "Tensor not initialized before loading from raw text file!" ; |
244 | } |
245 | // Load tensor data. |
246 | switch (tensor.getElementType()) { |
247 | case ElemKind::FloatTy: |
248 | return loadTensorFromTextFileImpl<float>(tensor, filename, fs); |
249 | case ElemKind::Float16Ty: |
250 | return loadTensorFromTextFileImpl<float16_t>(tensor, filename, fs); |
251 | case ElemKind::BFloat16Ty: |
252 | return loadTensorFromTextFileImpl<bfloat16_t>(tensor, filename, fs); |
253 | case ElemKind::Int8QTy: |
254 | return loadTensorFromTextFileImpl<int8_t>(tensor, filename, fs); |
255 | case ElemKind::UInt8QTy: |
256 | return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs); |
257 | case ElemKind::Int16QTy: |
258 | return loadTensorFromTextFileImpl<int16_t>(tensor, filename, fs); |
259 | case ElemKind::Int32QTy: |
260 | return loadTensorFromTextFileImpl<int32_t>(tensor, filename, fs); |
261 | case ElemKind::Int32ITy: |
262 | return loadTensorFromTextFileImpl<int32_t>(tensor, filename, fs); |
263 | case ElemKind::Int64ITy: |
264 | return loadTensorFromTextFileImpl<int64_t>(tensor, filename, fs); |
265 | case ElemKind::UInt8FusedQTy: |
266 | return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs); |
267 | case ElemKind::UInt8FusedFP16QTy: |
268 | return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs); |
269 | case ElemKind::UInt4FusedFP16QTy: |
270 | return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs); |
271 | case ElemKind::UInt4FusedQTy: |
272 | return loadTensorFromTextFileImpl<uint8_t>(tensor, filename, fs); |
273 | case ElemKind::BoolTy: |
274 | return loadTensorFromTextFileImpl<bool>(tensor, filename, fs); |
275 | default: |
276 | llvm_unreachable("Tensor type not supported for loading from text file!" ); |
277 | } |
278 | fs.close(); |
279 | } |
280 | |
281 | #ifdef WITH_PNG |
282 | |
283 | void glow::loadInputImageFromFileWithType( |
284 | const llvm::ArrayRef<std::string> &filenames, Tensor *inputData, |
285 | ImageLayout imageLayout) { |
286 | DCHECK(!filenames.empty()) |
287 | << "There must be at least one filename in filenames." ; |
288 | assert((dim_t)filenames.size() == filenames.size()); |
289 | dim_t numImages = filenames.size(); |
290 | |
291 | CHECK(inputTensorFileLoader_) << "tensor loader not assigned!" ; |
292 | |
293 | // Read each tensor file into a vector of tensors. |
294 | std::vector<Tensor> data(numImages); |
295 | dim_t batchSize = 0; |
296 | for (dim_t n = 0; n < numImages; n++) { |
297 | inputTensorFileLoader_(data[n], filenames[n], imageLayout); |
298 | auto dims0 = data[0].dims(); |
299 | auto dims = data[n].dims(); |
300 | CHECK_EQ(dims0[1], dims[1]) << "Non batch dimensions must match" ; |
301 | CHECK_EQ(dims0[2], dims[2]) << "Non batch dimensions must match" ; |
302 | CHECK_EQ(dims0[3], dims[3]) << "Non batch dimensions must match" ; |
303 | batchSize += data[n].dims()[0]; |
304 | } |
305 | |
306 | // Input tensor size is known now. |
307 | inputData->reset(ElemKind::FloatTy, {batchSize, data[0].dims()[1], |
308 | data[0].dims()[2], data[0].dims()[3]}); |
309 | auto IIDH = inputData->getHandle<>(); |
310 | // Insert each loaded file (in data[] tensors) as the input tensor slices. |
311 | for (dim_t n = 0, e = data.size(); n < e; n++) { |
312 | Handle<float> H = data[n].getHandle<>(); |
313 | IIDH.insertTensors(H, {n, 0, 0, 0}); |
314 | } |
315 | } |
316 | |
317 | /// Helper function for loadInputTensorFromFileWithType, to produce blob files. |
318 | void glow::dumpInputTensorToFileWithType( |
319 | const llvm::ArrayRef<std::string> &filenames, const Tensor &T, |
320 | ImageLayout imageLayout) { |
321 | CHECK_EQ(filenames.size(), 1) << "Dumping support single file only" ; |
322 | const std::string &filename = filenames[0]; |
323 | Tensor localTensor = T.clone(); |
324 | // Convert to requested layout (tensor blob is in NCHW by default). |
325 | if (imageLayout == ImageLayout::NHWC) { |
326 | Tensor transposed; |
327 | localTensor.transpose(&transposed, NHWC2NCHW); |
328 | localTensor = std::move(transposed); |
329 | } |
330 | std::ofstream outfile(filename.c_str()); |
331 | CHECK(outfile.is_open()) << "Error opening file '" << filename << "'!" ; |
332 | // write dimensions to 1st line. |
333 | for (dim_t i = 0; i < 4; i++) { |
334 | CHECK(outfile << localTensor.dims()[i] << " " ) |
335 | << "Failed to write dimension " << i; |
336 | } |
337 | outfile << "\n" ; |
338 | // write tensor to 2nd line. |
339 | auto H = localTensor.getHandle<float>(); |
340 | for (auto e : H) { |
341 | outfile << e << " " ; |
342 | } |
343 | } |
344 | |
345 | #endif // WITH_PNG |
346 | |