1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Base/Image.h"
18#include "glow/Base/Tensor.h"
19#include "glow/Support/Support.h"
20
21#include "llvm/Support/CommandLine.h"
22
23#ifndef WITH_PNG
24#error "Using Glow's PNG library requires installing libpng"
25#endif
26
27using namespace glow;
28
29#include <png.h>
30
31namespace glow {
32
33llvm::cl::OptionCategory imageCat("Image Processing Options");
34
35/// Range of the input image pixels. It is set by each image loader.
36/// Note, at the moment PNG/PPM loader are not setting this one as
37/// set the default to be U8.
38/// This is made an user option just so that user can give different range
39/// to their input data. For example, float input tensors need to have
40/// the range applied. Then, the option allows for flexibilty of using
41/// e.g. 16/32/64-bit files that actually represent simple U8 images.
42std::vector<ImgDataRange> imageDataRangeOpt;
43static llvm::cl::list<ImgDataRange, std::vector<ImgDataRange>> ImgDataRangeF(
44 "input-values-range", llvm::cl::CommaSeparated,
45 llvm::cl::desc("Specify the input data values range."),
46 llvm::cl::cat(imageCat), llvm::cl::location(imageDataRangeOpt),
47 llvm::cl::values(
48 clEnumValN(ImgDataRange::U8, "U8", "Values range: 0 and 255"),
49 clEnumValN(ImgDataRange::S8, "S8", "Values range: -128 and 127"),
50 clEnumValN(ImgDataRange::U16, "U16", "Values range: 0 and 65535"),
51 clEnumValN(ImgDataRange::S16, "S16",
52 "Values range: -32768 and 32767")));
53
54/// Normalization mode for each input. By default, image mode is not assigned
55/// and is in so-called pass-through mode; it takes input image pixels values
56/// range, thus pixels are not modified.
57std::vector<ImageNormalizationMode> imageNormMode;
58static llvm::cl::list<ImageNormalizationMode,
59 std::vector<ImageNormalizationMode>>
60 imageNormModeF("image-mode", llvm::cl::CommaSeparated,
61 llvm::cl::desc("Specify the image mode:"),
62 llvm::cl::cat(imageCat), llvm::cl::location(imageNormMode),
63 llvm::cl::values(
64 clEnumValN(ImageNormalizationMode::kneg1to1, "neg1to1",
65 "Values are in the range: -1 and 1"),
66 clEnumValN(ImageNormalizationMode::k0to1, "0to1",
67 "Values are in the range: 0 and 1"),
68 clEnumValN(ImageNormalizationMode::k0to255, "0to255",
69 "Values are in the range: 0 and 255"),
70 clEnumValN(ImageNormalizationMode::kneg128to127,
71 "neg128to127",
72 "Values are in the range: -128 .. 127"),
73 clEnumValN(ImageNormalizationMode::U16, "U16",
74 "Values are in the range: 0 and 65535"),
75 clEnumValN(ImageNormalizationMode::S16, "S16",
76 "Values are in the range: -32768 .. 32768")));
77static llvm::cl::alias imageNormModeA("i",
78 llvm::cl::desc("Alias for -image-mode"),
79 llvm::cl::aliasopt(imageNormModeF),
80 llvm::cl::cat(imageCat));
81
82std::vector<ImageChannelOrder> imageChannelOrderOpt;
83static llvm::cl::list<ImageChannelOrder, std::vector<ImageChannelOrder>>
84 imageChannelOrderF(
85 "image-channel-order", llvm::cl::CommaSeparated,
86 llvm::cl::desc("Specify the image channel order"), llvm::cl::ZeroOrMore,
87 llvm::cl::cat(imageCat), llvm::cl::location(imageChannelOrderOpt),
88 llvm::cl::values(clEnumValN(ImageChannelOrder::BGR, "BGR", "Use BGR"),
89 clEnumValN(ImageChannelOrder::RGB, "RGB", "Use RGB")));
90
91std::vector<ImageLayout> imageLayoutOpt;
92static llvm::cl::list<ImageLayout, std::vector<ImageLayout>> imageLayoutOptF(
93 "image-layout", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
94 llvm::cl::location(imageLayoutOpt), llvm::cl::desc(".\n"),
95 llvm::cl::values(
96 clEnumValN(ImageLayout::Unspecified, "NonImage",
97 "Use NonImage image layout"),
98 clEnumValN(ImageLayout::NCHW, "NCHW", "Use NCHW image layout"),
99 clEnumValN(ImageLayout::NHWC, "NHWC", "Use NHWC image layout")),
100 llvm::cl::cat(imageCat));
101static llvm::cl::alias
102 imageLayoutOptA("l", llvm::cl::desc("Alias for -image-layout"),
103 llvm::cl::aliasopt(imageLayoutOptF),
104 llvm::cl::cat(imageCat));
105
106std::vector<ImageLayout> inputLayoutOpt;
107static llvm::cl::list<ImageLayout, std::vector<ImageLayout>> inputLayoutF1(
108 "input-layout", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
109 llvm::cl::location(inputLayoutOpt), llvm::cl::desc(".\n"),
110 llvm::cl::values(
111 clEnumValN(ImageLayout::Unspecified, "AsIs",
112 "Use -image-layout setting for this input."),
113 clEnumValN(ImageLayout::NCHW, "NCHW", "Use NCHW image layout"),
114 clEnumValN(ImageLayout::NHWC, "NHWC", "Use NHWC image layout")),
115 llvm::cl::cat(imageCat));
116
117bool useImagenetNormalization;
118static llvm::cl::opt<bool, true> useImagenetNormalizationF(
119 "use-imagenet-normalization", llvm::cl::ZeroOrMore,
120 llvm::cl::location(useImagenetNormalization),
121 llvm::cl::desc("Use Imagenet Normalization. This works in combination "
122 "with the Image Mode normalization."),
123 llvm::cl::cat(imageCat));
124
125// LLVM command line made parser subclasing final in 3.7 yet the only cmd
126// line manual still refers to the old data. Also, the change was not clear
127// why it's made. Assigning callbacks is not possible, and subclassing
128// basic_parser is open to future errors. Thus, relying in LLVM parser is
129// minimized - we will just obtain strings and process options.
130
131VecVec<float> meanValuesOpt;
132static std::string meanValues_;
133static llvm::cl::opt<std::string, true> meanValuesF(
134 "mean",
135 llvm::cl::desc("Mean values m1,m2,m3..."
136 "Count must be equal to number of input channels."
137 "Order of values must match specified image channel order."),
138 llvm::cl::location(meanValues_), llvm::cl::value_desc("string"),
139 llvm::cl::cat(imageCat));
140
141VecVec<float> stddevValuesOpt;
142static std::string stddevValues_;
143static llvm::cl::opt<std::string, true> stddevValuesF(
144 "stddev",
145 llvm::cl::desc("Standard deviation values s1,s2,s3..."
146 "Count must be equal to number of input channels."
147 "Order of values must match specified image channel order."),
148 llvm::cl::location(stddevValues_), llvm::cl::value_desc("string"),
149 llvm::cl::cat(imageCat));
150
151} // namespace glow
152
153/// Some global options are set from functions that can be called from
154/// multiple threads. Lock the access while setting them.
155std::mutex globalOpts;
156
157// Process list of lists command line in the following format:
158/// All elements in a list are comma separated. Lists are double-colon
159/// separated. For example, "-option=1,2,3:4,5,6" defines two lists each with 3
160/// elements. Final destination for the processed command line string \p cmdStr
161/// is double vector \p outVec.
162template <typename T>
163static void processListOfListsCmdOption(size_t numInputs, std::string &cmdStr,
164 VecVec<T> &outVec) {
165 std::vector<std::string> strVec;
166 std::vector<T> typeVec;
167 if (cmdStr.empty()) {
168 outVec.resize(numInputs);
169 return;
170 }
171 outVec.clear();
172 std::stringstream ss(cmdStr);
173 while (ss) {
174 T elem;
175 char sep;
176 ss >> elem >> sep;
177 typeVec.push_back(elem);
178 if (sep == ':') {
179 outVec.push_back(typeVec);
180 typeVec.clear();
181 } else {
182 CHECK_EQ(sep, ',') << "Expected either ',' or ':' as separator";
183 }
184 }
185 if (!typeVec.empty()) {
186 outVec.push_back(typeVec);
187 }
188}
189
190void glow::initImageCmdArgVars() {
191 // clear external storage for all the variables.
192 globalOpts.lock();
193 imageDataRangeOpt.clear();
194 imageNormMode.clear();
195 imageChannelOrderOpt.clear();
196 imageLayoutOpt.clear();
197 inputLayoutOpt.clear();
198 meanValuesOpt.clear();
199 meanValues_.clear();
200 stddevValuesOpt.clear();
201 stddevValues_.clear();
202 globalOpts.unlock();
203}
204
205/// Processes special command line options for Image module.
206void glow::processImageCmdArgVars(size_t numInputs) {
207 globalOpts.lock();
208 processListOfListsCmdOption(numInputs, meanValues_, meanValuesOpt);
209 processListOfListsCmdOption(numInputs, stddevValues_, stddevValuesOpt);
210
211 // Default for pixel values range is U8.
212 if (imageDataRangeOpt.empty()) {
213 for (size_t i = 0, e = numInputs; i < e; i++) {
214 imageDataRangeOpt.push_back(ImgDataRange::U8);
215 }
216 }
217
218 // Default for image normalization is pass-through (keep pixel value the
219 // same).
220 if (imageNormMode.empty()) {
221 for (size_t i = 0, e = numInputs; i < e; i++) {
222 imageNormMode.push_back(ImageNormalizationMode::PassThrough);
223 }
224 }
225 // Default for image layout is NCHW.
226 if (imageLayoutOpt.empty()) {
227 for (size_t i = 0, e = numInputs; i < e; i++) {
228 imageLayoutOpt.push_back(ImageLayout::NCHW);
229 }
230 }
231 // If input-layout is empty just copy image-layout to it.
232 // If input-layout is not empty, and one of the values is "AsIs", copy
233 // the corresponding image-layout value to it.
234 if (inputLayoutOpt.empty()) {
235 inputLayoutOpt = imageLayoutOpt;
236 } else {
237 CHECK_EQ(inputLayoutOpt.size(), imageLayoutOpt.size())
238 << "Expecting the same number of values in -image-layout and "
239 "-input-layout";
240 for (size_t i = 0, e = inputLayoutOpt.size(); i < e; i++) {
241 if (inputLayoutOpt[i] == ImageLayout::Unspecified) {
242 inputLayoutOpt[i] = imageLayoutOpt[i];
243 }
244 }
245 }
246 // Default for channel order is BGR.
247 if (imageChannelOrderOpt.empty()) {
248 for (size_t i = 0, e = numInputs; i < e; i++) {
249 imageChannelOrderOpt.push_back(ImageChannelOrder::BGR);
250 }
251 }
252
253 CHECK_EQ(numInputs, imageDataRangeOpt.size())
254 << "Number of -image-range values must match number of inputs";
255 CHECK_EQ(numInputs, imageNormMode.size())
256 << "Number of -image-mode values must match number of inputs";
257 CHECK_EQ(numInputs, imageLayoutOpt.size())
258 << "Number of -image-layout values must match number of inputs";
259 CHECK_EQ(numInputs, imageChannelOrderOpt.size())
260 << "Number of -image-channel-order values must match number of inputs";
261 CHECK_EQ(numInputs, meanValuesOpt.size())
262 << "Number of -mean values must match number of inputs";
263 CHECK_EQ(numInputs, stddevValuesOpt.size())
264 << "Number of -stddev values must match number of inputs";
265 CHECK_EQ(numInputs, inputLayoutOpt.size())
266 << "Number of -input-mode values must match number of inputs";
267 globalOpts.unlock();
268}
269
270float glow::getPixelValMin(ImgDataRange range) {
271 switch (range) {
272 case (ImgDataRange::S8):
273 return -128.;
274 case (ImgDataRange::S16):
275 return -32768.;
276 default:
277 return 0.;
278 }
279}
280
281float glow::getPixelValMax(ImgDataRange range) {
282 switch (range) {
283 case (ImgDataRange::S8):
284 return 127.;
285 case (ImgDataRange::S16):
286 return 32767.;
287 case (ImgDataRange::U8):
288 return 255.;
289 case (ImgDataRange::U16):
290 return 65535.;
291 default:
292 LOG(FATAL) << "Error";
293 }
294}
295
296/// Convert the normalization to numeric floating poing ranges.
297std::pair<float, float> glow::normModeToRange(ImageNormalizationMode mode,
298 ImgDataRange range) {
299 switch (mode) {
300 case ImageNormalizationMode::PassThrough:
301 return {getPixelValMin(range), getPixelValMax(range)};
302 case ImageNormalizationMode::kneg1to1:
303 return {-1., 1.};
304 case ImageNormalizationMode::k0to1:
305 return {0., 1.0};
306 case ImageNormalizationMode::k0to255:
307 return {0., 255.0};
308 case ImageNormalizationMode::kneg128to127:
309 return {-128., 127.};
310 case ImageNormalizationMode::S16:
311 return {-32768., 32767.};
312 case ImageNormalizationMode::U16:
313 return {0., 65535.};
314 default:
315 LOG(FATAL) << "Image format not defined.";
316 }
317 return {0, 0};
318}
319
320/// Returns whether string \p hdr is recognized as PNG.
321static bool isPngHdrSignature(uint8_t *header) {
322 return png_sig_cmp(header, 0, 8) == 0;
323}
324
325/// Returns whether file \p filename is in png format.
326bool glow::isPngFormat(const std::string &filename) {
327 // open file and test for it being a png.
328 FILE *fp = fopen(filename.c_str(), "rb");
329 CHECK(fp) << "Can't open image file with name: " << filename;
330
331 unsigned char header[8];
332 size_t fread_ret = fread(header, 1, 8, fp);
333 fclose(fp);
334 CHECK_EQ(fread_ret, 8) << "fread failed for file: " << filename;
335 return isPngHdrSignature(header);
336}
337
338bool glow::isPpmFormat(const std::string &filename) {
339 // Open file and test for it being a PPM.
340 char magic[2];
341 FILE *fp = fopen(filename.c_str(), "rb");
342 CHECK(fp) << "Can't open image file with name: " << filename;
343 CHECK_EQ(fread(magic, sizeof(char), sizeof(magic), fp), sizeof(magic))
344 << "Failed to read magic number from file: " << filename;
345 fclose(fp);
346 return magic[0] == 'P' && (magic[1] == '5' || magic[1] == '6');
347}
348
349std::tuple<dim_t, dim_t, bool> glow::getPngInfo(const char *filename) {
350 // open file and test for it being a png.
351 FILE *fp = fopen(filename, "rb");
352 CHECK(fp) << "Can't open image file with name: " << filename;
353
354 unsigned char header[8];
355 size_t fread_ret = fread(header, 1, 8, fp);
356 CHECK_EQ(fread_ret, 8) << "fread failed for file: " << filename;
357 CHECK(isPngHdrSignature(header)) << "Invalid image file signature.";
358
359 // Initialize stuff.
360 png_structp png_ptr =
361 png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
362 CHECK(png_ptr) << "Image initialization failed.";
363
364 png_infop info_ptr = png_create_info_struct(png_ptr);
365 CHECK(info_ptr) << "Could not get png info.";
366
367 int sjmpGetPtr = setjmp(png_jmpbuf(png_ptr));
368 CHECK(!sjmpGetPtr) << "Failed getting png_ptr.";
369
370 png_init_io(png_ptr, fp);
371 png_set_sig_bytes(png_ptr, 8);
372 png_read_info(png_ptr, info_ptr);
373
374 size_t height = png_get_image_height(png_ptr, info_ptr);
375 size_t width = png_get_image_width(png_ptr, info_ptr);
376 png_byte color_type = png_get_color_type(png_ptr, info_ptr);
377
378 const bool isGray = color_type == PNG_COLOR_TYPE_GRAY;
379
380 png_destroy_read_struct(&png_ptr, &info_ptr, (png_infopp)NULL);
381 fclose(fp);
382
383 return std::make_tuple(height, width, isGray);
384}
385
386/// \returns floating-point range for input image based on specified options.
387std::pair<float, float> glow::getImageRange(size_t idx) {
388
389 // Take into account mean and stddev.
390 auto mean = getImageMean(idx);
391 auto stddev = getImageStdDev(idx);
392 const size_t N = mean.size();
393 std::vector<float> minVals(N), maxVals(N);
394 CHECK_EQ(stddev.size(), N) << "Number of mean and stddev values mismatch";
395 for (size_t i = 0; i < N; i++) {
396 minVals[i] = (getPixelValMin(imageDataRangeOpt[i]) - mean[i]) / stddev[i];
397 maxVals[i] = (getPixelValMax(imageDataRangeOpt[i]) - mean[i]) / stddev[i];
398 }
399 float min = *std::min_element(minVals.begin(), minVals.end());
400 float max = *std::max_element(maxVals.begin(), maxVals.end());
401
402 // Take into account normalization mode.
403 auto range = normModeToRange(imageNormMode[idx], imageDataRangeOpt[idx]);
404 float scale =
405 ((range.second - range.first) / getPixelValMax(imageDataRangeOpt[idx]));
406 float bias = range.first;
407 min = min * scale + bias;
408 max = max * scale + bias;
409
410 return {min, max};
411}
412
413/// \returns mean for input image based on specified options.
414llvm::ArrayRef<float> glow::getImageMean(size_t idx, size_t numChannels) {
415 CHECK(!meanValuesOpt.empty()) << "Internal Error: mean values not set.";
416 CHECK_LE(idx, meanValuesOpt.size())
417 << "Request for non-existent mean values.";
418 if (!meanValuesOpt[idx].empty()) {
419 CHECK(!useImagenetNormalization)
420 << "-mean and -use-imagenet-normalization cannot be used together.";
421 if (numChannels) {
422 CHECK_EQ(meanValuesOpt[idx].size(), numChannels)
423 << "Number of mean values != input channels";
424 }
425 return meanValuesOpt[idx];
426 } else if (useImagenetNormalization) {
427 return llvm::ArrayRef<float>(imagenetNormMean, 3);
428 } else {
429 return zeroMean;
430 }
431}
432
433/// \returns stddev for input image based on specified options.
434llvm::ArrayRef<float> glow::getImageStdDev(size_t idx, size_t numChannels) {
435 CHECK(!stddevValuesOpt.empty()) << "Internal Error: mean stddev not set.";
436 CHECK_LE(idx, stddevValuesOpt.size())
437 << "Request for non-existent stddev values.";
438 if (!stddevValuesOpt[idx].empty()) {
439 CHECK(!useImagenetNormalization)
440 << "-stddev and -use-imagenet-normalization cannot be used together.";
441 if (numChannels) {
442 CHECK_EQ(stddevValuesOpt[idx].size(), numChannels)
443 << "Number of stddev values != input channels";
444 }
445 return stddevValuesOpt[idx];
446 } else if (useImagenetNormalization) {
447 return llvm::ArrayRef<float>(imagenetNormStd, 3);
448 } else {
449 return oneStd;
450 }
451}
452
453static void skipSpacePPM(FILE *fp) {
454 int c = getc(fp);
455 while (c != EOF) {
456 if (c == '#') {
457 // Skip comment line.
458 do {
459 c = getc(fp);
460 } while (c != EOF && c != '\n');
461 } else if (!isspace(c)) {
462 ungetc(c, fp);
463 break;
464 }
465 c = getc(fp);
466 }
467}
468
469std::tuple<dim_t, dim_t, bool> glow::getPpmInfo(FILE *fp,
470 const char *filename) {
471 // Open file and test for it being a PPM.
472 char magic[2];
473 CHECK_EQ(fread(magic, sizeof(char), sizeof(magic), fp), sizeof(magic))
474 << "Failed to read magic number from file: " << filename;
475 CHECK(magic[0] == 'P' && (magic[1] == '5' || magic[1] == '6'))
476 << filename << " is not a PPM image";
477
478 // Gray-scale or color is determined by magic number.
479 bool isGray = magic[1] == '5';
480
481 // Read dimensions and color depth.
482 int32_t height, width, depth;
483 skipSpacePPM(fp);
484 CHECK_EQ(fscanf(fp, "%d", &width), 1)
485 << "Can't read width from: " << filename;
486 skipSpacePPM(fp);
487 CHECK_EQ(fscanf(fp, "%d", &height), 1)
488 << "Can't read height from: " << filename;
489 skipSpacePPM(fp);
490 CHECK_EQ(fscanf(fp, "%d", &depth), 1)
491 << "Can't read color depth from: " << filename;
492 CHECK_EQ(depth, 255) << "Unsupported color depth " << depth
493 << " in file: " << filename;
494
495 return std::make_tuple(dim_t(height), dim_t(width), isGray);
496}
497
498std::tuple<dim_t, dim_t, bool> glow::getPpmInfo(const char *filename) {
499 bool isGray;
500 dim_t width, height;
501 FILE *fp = fopen(filename, "rb");
502 CHECK(fp) << "Can't open image file with name: " << filename;
503 std::tie(height, width, isGray) = getPpmInfo(fp, filename);
504 fclose(fp);
505 return std::make_tuple(height, width, isGray);
506}
507
508bool glow::readPngImage(Tensor *T, const char *filename,
509 std::pair<float, float> range,
510 llvm::ArrayRef<float> mean,
511 llvm::ArrayRef<float> stddev) {
512 unsigned char header[8];
513 // open file and test for it being a png.
514 FILE *fp = fopen(filename, "rb");
515 // Can't open the file.
516 if (!fp) {
517 return true;
518 }
519
520 // Validate signature.
521 size_t fread_ret = fread(header, 1, 8, fp);
522 if (fread_ret != 8) {
523 fclose(fp);
524 return true;
525 }
526 if (png_sig_cmp(header, 0, 8)) {
527 fclose(fp);
528 return true;
529 }
530
531 // Initialize stuff.
532 png_structp png_ptr =
533 png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
534 if (!png_ptr) {
535 fclose(fp);
536 return true;
537 }
538
539 png_infop info_ptr = png_create_info_struct(png_ptr);
540 if (!info_ptr) {
541 png_destroy_read_struct(&png_ptr, (png_infopp)NULL, (png_infopp)NULL);
542 fclose(fp);
543 return true;
544 }
545
546 if (setjmp(png_jmpbuf(png_ptr))) {
547 png_destroy_read_struct(&png_ptr, &info_ptr, (png_infopp)NULL);
548 fclose(fp);
549 return true;
550 }
551
552 png_init_io(png_ptr, fp);
553 png_set_sig_bytes(png_ptr, 8);
554 png_read_info(png_ptr, info_ptr);
555
556 dim_t width = png_get_image_width(png_ptr, info_ptr);
557 dim_t height = png_get_image_height(png_ptr, info_ptr);
558 int color_type = png_get_color_type(png_ptr, info_ptr);
559 int bit_depth = png_get_bit_depth(png_ptr, info_ptr);
560
561 const bool isGray = color_type == PNG_COLOR_TYPE_GRAY;
562 const dim_t numChannels = isGray ? 1 : 3;
563
564 (void)bit_depth;
565 DCHECK_EQ(bit_depth, 8) << "Invalid image";
566 DCHECK((color_type == PNG_COLOR_TYPE_RGB_ALPHA ||
567 color_type == PNG_COLOR_TYPE_RGB || isGray))
568 << "Invalid image";
569 bool hasAlpha = (color_type == PNG_COLOR_TYPE_RGB_ALPHA);
570
571 int number_of_passes = png_set_interlace_handling(png_ptr);
572 (void)number_of_passes;
573 DCHECK_EQ(number_of_passes, 1) << "Invalid image";
574
575 png_read_update_info(png_ptr, info_ptr);
576
577 // Error during image read.
578 if (setjmp(png_jmpbuf(png_ptr))) {
579 png_destroy_read_struct(&png_ptr, &info_ptr, (png_infopp)NULL);
580 fclose(fp);
581 return true;
582 }
583
584 auto *row_pointers = (png_bytep *)malloc(sizeof(png_bytep) * height);
585 for (dim_t y = 0; y < height; y++) {
586 row_pointers[y] = (png_byte *)malloc(png_get_rowbytes(png_ptr, info_ptr));
587 }
588
589 png_read_image(png_ptr, row_pointers);
590 png_read_end(png_ptr, info_ptr);
591
592 T->reset(ElemKind::FloatTy, {height, width, numChannels});
593 auto H = T->getHandle<>();
594
595 float scale = ((range.second - range.first) / 255.0);
596 float bias = range.first;
597
598 for (dim_t row_n = 0; row_n < height; row_n++) {
599 png_byte *row = row_pointers[row_n];
600 for (dim_t col_n = 0; col_n < width; col_n++) {
601 png_byte *ptr =
602 &(row[col_n * (hasAlpha ? (numChannels + 1) : numChannels)]);
603 for (dim_t i = 0; i < numChannels; i++) {
604 float val = float(ptr[i]);
605 val = (val - mean[i]) / stddev[i];
606 H.at({row_n, col_n, i}) = val * scale + bias;
607 }
608 }
609 }
610
611 for (dim_t y = 0; y < height; y++) {
612 free(row_pointers[y]);
613 }
614 free(row_pointers);
615 png_destroy_read_struct(&png_ptr, &info_ptr, (png_infopp)NULL);
616 fclose(fp);
617
618 return false;
619}
620
621bool glow::readPpmImage(Tensor *T, const char *filename,
622 std::pair<float, float> range,
623 llvm::ArrayRef<float> mean,
624 llvm::ArrayRef<float> stddev) {
625 bool isGray;
626 dim_t height, width;
627 FILE *fp = fopen(filename, "rb");
628 if (!fp) {
629 return true;
630 }
631
632 // Get PPM info.
633 std::tie(height, width, isGray) = getPpmInfo(fp, filename);
634 const dim_t numChannels = isGray ? 1 : 3;
635 T->reset(ElemKind::FloatTy, {height, width, numChannels});
636
637 // Skip a single byte of space.
638 fgetc(fp);
639
640 // Read pixels and do pre-processing.
641 auto H = T->getHandle<>();
642 float scale = ((range.second - range.first) / 255.0);
643 float bias = range.first;
644 unsigned char *buf =
645 (unsigned char *)malloc(width * numChannels * sizeof(unsigned char));
646 for (dim_t h = 0; h < height; h++) {
647 if (fread(buf, width * numChannels, 1, fp) != 1) {
648 free(buf);
649 fclose(fp);
650 return true;
651 }
652 for (dim_t w = 0; w < width; w++) {
653 for (dim_t c = 0; c < numChannels; c++) {
654 float val = float(buf[w * numChannels + c]);
655 H.at({h, w, c}) = ((val - mean[c]) / stddev[c]) * scale + bias;
656 }
657 }
658 }
659
660 free(buf);
661 fclose(fp);
662 return false;
663}
664
665bool glow::writePngImage(Tensor *T, const char *filename,
666 std::pair<float, float> range,
667 llvm::ArrayRef<float> mean,
668 llvm::ArrayRef<float> stddev) {
669 /* create file */
670 FILE *fp = fopen(filename, "wb");
671 if (!fp) {
672 return true;
673 }
674
675 /* initialize stuff */
676 png_structp png_ptr =
677 png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
678
679 if (!png_ptr) {
680 return true;
681 }
682
683 png_infop info_ptr = png_create_info_struct(png_ptr);
684 if (!info_ptr) {
685 return true;
686 }
687
688 if (setjmp(png_jmpbuf(png_ptr))) {
689 return true;
690 }
691
692 png_init_io(png_ptr, fp);
693
694 if (setjmp(png_jmpbuf(png_ptr))) {
695 return true;
696 }
697
698 auto H = T->getHandle<>();
699
700 auto odim = H.dims();
701 constexpr size_t numChannels = 3;
702 DCHECK_EQ(odim[2], numChannels)
703 << "Currently only supports saving RGB images without alpha.";
704
705 size_t width = odim[0];
706 size_t height = odim[1];
707 int color_type = PNG_COLOR_TYPE_RGB_ALPHA;
708 int bit_depth = 8;
709
710 png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, color_type,
711 PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_BASE,
712 PNG_FILTER_TYPE_BASE);
713
714 png_write_info(png_ptr, info_ptr);
715
716 if (setjmp(png_jmpbuf(png_ptr))) {
717 return true;
718 }
719
720 auto *row_pointers = (png_bytep *)malloc(sizeof(png_bytep) * height);
721 for (size_t y = 0; y < height; y++) {
722 row_pointers[y] = (png_byte *)malloc(png_get_rowbytes(png_ptr, info_ptr));
723 }
724
725 float scale = ((range.second - range.first) / 255.0);
726 float bias = range.first;
727
728 for (dim_t y = 0; y < height; y++) {
729 png_byte *row = row_pointers[y];
730 for (dim_t x = 0; x < width; x++) {
731 png_byte *ptr = &(row[x * 4]);
732 for (dim_t i = 0; i < numChannels; i++) {
733 float val = (H.at({y, x, i}) - bias) / scale;
734 val = (val * stddev[i]) + mean[i];
735 ptr[i] = val;
736 }
737 ptr[3] = 0xff;
738 }
739 }
740
741 png_write_image(png_ptr, row_pointers);
742
743 if (setjmp(png_jmpbuf(png_ptr))) {
744 return true;
745 }
746
747 png_write_end(png_ptr, nullptr);
748
749 /* cleanup heap allocation */
750 for (size_t y = 0; y < height; y++) {
751 free(row_pointers[y]);
752 }
753 free(row_pointers);
754 png_destroy_write_struct(&png_ptr, &info_ptr);
755 fclose(fp);
756 return false;
757}
758
759Tensor glow::readPngPpmImageAndPreprocess(llvm::StringRef filename,
760 ImageNormalizationMode imageNormMode,
761 ImageChannelOrder imageChannelOrder,
762 ImageLayout imageLayout,
763 llvm::ArrayRef<float> mean,
764 llvm::ArrayRef<float> stddev) {
765 Tensor imageData;
766 readPngPpmImageAndPreprocess(imageData, filename, imageNormMode,
767 imageChannelOrder, imageLayout, mean, stddev);
768 return imageData;
769}
770
771void glow::readPngPpmImageAndPreprocess(Tensor &imageData,
772 llvm::StringRef filename,
773 ImageNormalizationMode imageNormMode,
774 ImageChannelOrder imageChannelOrder,
775 ImageLayout imageLayout,
776 llvm::ArrayRef<float> mean,
777 llvm::ArrayRef<float> stddev) {
778
779 // PNG images are RGB, so shuffle mean and stddev values to be in RGB order
780 // as well, prior applying them to input image.
781 std::vector<float> meanRGB(mean);
782 std::vector<float> stddevRGB(stddev);
783 if (imageChannelOrder == ImageChannelOrder::BGR) {
784 std::reverse(meanRGB.begin(), meanRGB.end());
785 std::reverse(stddevRGB.begin(), stddevRGB.end());
786 }
787
788 bool isPNG = isPngFormat(filename.data());
789 CHECK(isPNG || isPpmFormat(filename.data())) << "Unrecognized image format";
790 auto range = normModeToRange(imageNormMode, ImgDataRange::U8);
791 bool loadSuccess = isPNG ? !readPngImage(&imageData, filename.data(), range,
792 meanRGB, stddevRGB)
793 : !readPpmImage(&imageData, filename.data(), range,
794 meanRGB, stddevRGB);
795 CHECK(loadSuccess) << "Error reading input image from file: "
796 << filename.str();
797 dim_t imgHeight = imageData.dims()[0];
798 dim_t imgWidth = imageData.dims()[1];
799 dim_t numChannels = imageData.dims()[2];
800
801 // PNG/PPM images are NHWC and RGB. Convert if needed.
802 // Convert to requested channel ordering.
803 if (imageChannelOrder == ImageChannelOrder::BGR) {
804 Tensor swizzled(imageData.getType());
805 auto IH = imageData.getHandle();
806 auto SH = swizzled.getHandle();
807 for (unsigned z = 0; z < numChannels; z++) {
808 for (unsigned y = 0; y < imgHeight; y++) {
809 for (unsigned x = 0; x < imgWidth; x++) {
810 SH.at({y, x, numChannels - 1 - z}) = IH.at({y, x, z});
811 }
812 }
813 }
814 imageData = std::move(swizzled);
815 }
816 // Convert to requested layout.
817 if (imageLayout == ImageLayout::NCHW) {
818 Tensor transposed;
819 imageData.transpose(&transposed, {2u, 0u, 1u});
820 imageData = std::move(transposed);
821 }
822}
823
824/// Entry point for the PNG/PPM images loader.
825void glow::readPngPpmImagesAndPreprocess(
826 Tensor &inputImageData, const llvm::ArrayRef<std::string> &filenames,
827 ImageNormalizationMode imageNormMode, ImageChannelOrder imageChannelOrder,
828 ImageLayout imageLayout, llvm::ArrayRef<float> meanRef,
829 llvm::ArrayRef<float> stddevRef) {
830 DCHECK(!filenames.empty())
831 << "There must be at least one filename in filenames.";
832 DCHECK_EQ((dim_t)filenames.size(), filenames.size());
833 dim_t numImages = filenames.size();
834
835 // Get image dimensions and check if grayscale or color.
836 dim_t imgHeight;
837 dim_t imgWidth;
838 bool isGray;
839 bool isPNG = isPngFormat(filenames[0]);
840 CHECK(isPNG || isPpmFormat(filenames[0])) << "Unrecognized image format";
841 std::tie(imgHeight, imgWidth, isGray) =
842 isPNG ? getPngInfo(filenames[0].c_str())
843 : getPpmInfo(filenames[0].c_str());
844 const dim_t numChannels = isGray ? 1 : 3;
845
846 // Assign mean and stddev for input normalization.
847 llvm::ArrayRef<float> mean;
848 llvm::ArrayRef<float> stddev;
849 if (!meanRef.empty()) {
850 CHECK_EQ(meanRef.size(), numChannels)
851 << "Number of mean values != input channels";
852 CHECK(!useImagenetNormalization)
853 << "-mean and -use-imagenet-normalization cannot be used together.";
854 mean = meanRef;
855 } else if (useImagenetNormalization) {
856 mean = imagenetNormMean;
857 } else {
858 mean = zeroMean;
859 }
860
861 if (!stddevRef.empty()) {
862 CHECK_EQ(stddevRef.size(), numChannels)
863 << "Number of stddev values != input channels";
864 CHECK(!useImagenetNormalization)
865 << "-stddev and -use-imagenet-normalization cannot be used together.";
866 stddev = stddevRef;
867 } else if (useImagenetNormalization) {
868 stddev = imagenetNormStd;
869 } else {
870 stddev = oneStd;
871 }
872
873 // Allocate a tensor for the batch.
874 ShapeVector batchDims;
875 switch (imageLayout) {
876 case ImageLayout::NCHW:
877 batchDims = {numImages, numChannels, imgHeight, imgWidth};
878 break;
879 case ImageLayout::NHWC:
880 batchDims = {numImages, imgHeight, imgWidth, numChannels};
881 break;
882 default:
883 LOG(FATAL) << "Unexpected layout\n";
884 }
885 inputImageData.reset(ElemKind::FloatTy, batchDims);
886 auto IIDH = inputImageData.getHandle<>();
887
888 // Read images into local tensors and add to batch.
889 for (size_t n = 0; n < filenames.size(); n++) {
890 Tensor localCopy;
891 readPngPpmImageAndPreprocess(localCopy, filenames[n], imageNormMode,
892 imageChannelOrder, imageLayout, mean, stddev);
893 DCHECK(std::equal(localCopy.dims().begin(), localCopy.dims().end(),
894 inputImageData.dims().begin() + 1))
895 << "All images must have the same dimensions";
896 IIDH.insertSlice(localCopy, n);
897 }
898}
899
900/// Dispatching loading to the format handlers.
901void glow::loadImagesAndPreprocess(
902 VecVecRef<std::string> filenamesList,
903 llvm::ArrayRef<Tensor *> inputImageDataList,
904 llvm::ArrayRef<ImageNormalizationMode> normMode,
905 llvm::ArrayRef<ImageChannelOrder> channelOrder,
906 llvm::ArrayRef<ImageLayout> imageLayout,
907 llvm::ArrayRef<ImageLayout> inputLayout, VecVecRef<float> mean,
908 VecVecRef<float> stddev) {
909
910 globalOpts.lock();
911 if (normMode.size()) {
912 imageNormMode = normMode;
913 }
914 if (channelOrder.size()) {
915 imageChannelOrderOpt = channelOrder;
916 }
917 if (imageLayout.size()) {
918 imageLayoutOpt = imageLayout;
919 }
920 if (inputLayout.size()) {
921 inputLayoutOpt = inputLayout;
922 }
923 if (stddev.size()) {
924 stddevValuesOpt = stddev;
925 }
926 if (mean.size()) {
927 meanValuesOpt = mean;
928 }
929 globalOpts.unlock();
930
931 CHECK(!filenamesList.empty())
932 << "There must be at least one list in filenames.";
933
934 CHECK_EQ(filenamesList.size(), inputImageDataList.size())
935 << "Number of image and tensor lists must match.";
936
937 processImageCmdArgVars(inputImageDataList.size());
938
939 for (size_t i = 0; i < filenamesList.size(); i++) {
940 // Get list of files for an input.
941 auto filenames = filenamesList[i];
942
943 // Get tensor to load for that one selected input.
944 auto inputImageData = inputImageDataList[i];
945 // All files for an input must be of the same type, thus will just check
946 // the first one.
947 if (isPngFormat(filenames[0]) || isPpmFormat(filenames[0])) {
948 readPngPpmImagesAndPreprocess(
949 *inputImageData, filenames, imageNormMode[i], imageChannelOrderOpt[i],
950 imageLayoutOpt[i], meanValuesOpt[i], stddevValuesOpt[i]);
951 } else if (isNumpyNpyFormat(filenames[0])) {
952 loadNumpyImagesAndPreprocess(filenames, *inputImageData, imageNormMode[i],
953 imageChannelOrderOpt[i], imageLayoutOpt[i],
954 inputLayoutOpt[i], meanValuesOpt[i],
955 stddevValuesOpt[i], imageDataRangeOpt[i]);
956 } else {
957 LOG(FATAL) << "Input file format is not recognized: \n" << filenames[0];
958 }
959 }
960}
961