1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Loader.h"
18
19#include <fstream>
20#include <iostream>
21#include <sstream>
22
23#include "ExecutorCore.h"
24#include "ExecutorCoreHelperFunctions.h"
25
26using namespace glow;
27
28/// Dumps txt files for each output and for each file.
29static int
30processAndPrintResults(const llvm::StringMap<Placeholder *> &PHM,
31 PlaceholderBindings &bindings,
32 llvm::ArrayRef<std::string> inputImageBatchFilenames) {
33 // Print out the object detection results.
34 std::vector<Tensor *> vecOutTensors;
35 std::vector<std::string> VecOutNames;
36 for (auto const &OutEntry : PHM) {
37 VecOutNames.push_back(OutEntry.getKey().str());
38 Placeholder *SMPH = OutEntry.getValue();
39 Tensor *SMT = bindings.get(SMPH);
40 vecOutTensors.push_back(SMT);
41 }
42 for (unsigned i = 0; i < inputImageBatchFilenames.size(); i++) {
43 llvm::outs() << "Input File " << inputImageBatchFilenames[i] << ":\n";
44 for (size_t k = 0; k < VecOutNames.size(); k++) {
45 std::error_code EC;
46 int idx = inputImageBatchFilenames[i].find_last_of("\\/");
47 int idx1 = inputImageBatchFilenames[i].find_last_of(".");
48 std::string basename =
49 inputImageBatchFilenames[i].substr(idx + 1, idx1 - idx - 1);
50
51 // Dump all the output tensors of input file with name <inimage> to
52 // files with name <inimage>_<tensorname>.txt in the working
53 // directory.
54 std::replace_if(
55 VecOutNames[k].begin(), VecOutNames[k].end(),
56 [](char c) { return !isalnum(c); }, '_');
57 std::string filename = basename + "_" + VecOutNames[k] + ".txt";
58 llvm::raw_fd_ostream fd(filename, EC);
59 if (EC) {
60 llvm::outs() << "Error opening file " << filename;
61 llvm::outs().flush();
62 fd.close();
63 return 1;
64 }
65
66 auto getDumpTensor = [](Tensor *t, int slice) {
67 assert(t->dims().size() > 1 && "Tensor dims should be > 2");
68 switch (t->getElementType()) {
69 case ElemKind::Int64ITy:
70 return t->getHandle<int64_t>().extractSlice(slice);
71 case ElemKind::Int32ITy:
72 return t->getHandle<int32_t>().extractSlice(slice);
73 default:
74 return t->getHandle<float>().extractSlice(slice);
75 }
76 };
77 Tensor tensor = getDumpTensor(vecOutTensors[k], i);
78 tensor.dump(fd, tensor.size());
79
80 llvm::outs() << "\t" << filename << ":" << tensor.size() << "\n";
81 fd.close();
82 }
83 }
84 return 0;
85}
86
87/// Given the output PlaceHolder StringMap \p PHM, outputs results in to text
88/// files for each output.
89class ObjectDetectionProcessResult : public PostProcessOutputDataExtension {
90public:
91 int processOutputs(const llvm::StringMap<Placeholder *> &PHM,
92 PlaceholderBindings &bindings,
93 VecVecRef<std::string> imageList) override {
94 processAndPrintResults(PHM, bindings, imageList[0]);
95 return 0;
96 }
97};
98
99int main(int argc, char **argv) {
100 glow::Executor core("ObjectDetector", argc, argv);
101
102 auto printResultCreator =
103 []() -> std::unique_ptr<PostProcessOutputDataExtension> {
104 return std::make_unique<ObjectDetectionProcessResult>();
105 };
106 core.registerPostProcessOutputExtension(printResultCreator);
107
108 int numErrors = core.executeNetwork();
109 return numErrors;
110}
111