1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "Loader.h" |
18 | |
19 | #include <fstream> |
20 | #include <iostream> |
21 | #include <sstream> |
22 | |
23 | #include "ExecutorCore.h" |
24 | #include "ExecutorCoreHelperFunctions.h" |
25 | |
26 | using namespace glow; |
27 | |
28 | /// Dumps txt files for each output and for each file. |
29 | static int |
30 | processAndPrintResults(const llvm::StringMap<Placeholder *> &PHM, |
31 | PlaceholderBindings &bindings, |
32 | llvm::ArrayRef<std::string> inputImageBatchFilenames) { |
33 | // Print out the object detection results. |
34 | std::vector<Tensor *> vecOutTensors; |
35 | std::vector<std::string> VecOutNames; |
36 | for (auto const &OutEntry : PHM) { |
37 | VecOutNames.push_back(OutEntry.getKey().str()); |
38 | Placeholder *SMPH = OutEntry.getValue(); |
39 | Tensor *SMT = bindings.get(SMPH); |
40 | vecOutTensors.push_back(SMT); |
41 | } |
42 | for (unsigned i = 0; i < inputImageBatchFilenames.size(); i++) { |
43 | llvm::outs() << "Input File " << inputImageBatchFilenames[i] << ":\n" ; |
44 | for (size_t k = 0; k < VecOutNames.size(); k++) { |
45 | std::error_code EC; |
46 | int idx = inputImageBatchFilenames[i].find_last_of("\\/" ); |
47 | int idx1 = inputImageBatchFilenames[i].find_last_of("." ); |
48 | std::string basename = |
49 | inputImageBatchFilenames[i].substr(idx + 1, idx1 - idx - 1); |
50 | |
51 | // Dump all the output tensors of input file with name <inimage> to |
52 | // files with name <inimage>_<tensorname>.txt in the working |
53 | // directory. |
54 | std::replace_if( |
55 | VecOutNames[k].begin(), VecOutNames[k].end(), |
56 | [](char c) { return !isalnum(c); }, '_'); |
57 | std::string filename = basename + "_" + VecOutNames[k] + ".txt" ; |
58 | llvm::raw_fd_ostream fd(filename, EC); |
59 | if (EC) { |
60 | llvm::outs() << "Error opening file " << filename; |
61 | llvm::outs().flush(); |
62 | fd.close(); |
63 | return 1; |
64 | } |
65 | |
66 | auto getDumpTensor = [](Tensor *t, int slice) { |
67 | assert(t->dims().size() > 1 && "Tensor dims should be > 2" ); |
68 | switch (t->getElementType()) { |
69 | case ElemKind::Int64ITy: |
70 | return t->getHandle<int64_t>().extractSlice(slice); |
71 | case ElemKind::Int32ITy: |
72 | return t->getHandle<int32_t>().extractSlice(slice); |
73 | default: |
74 | return t->getHandle<float>().extractSlice(slice); |
75 | } |
76 | }; |
77 | Tensor tensor = getDumpTensor(vecOutTensors[k], i); |
78 | tensor.dump(fd, tensor.size()); |
79 | |
80 | llvm::outs() << "\t" << filename << ":" << tensor.size() << "\n" ; |
81 | fd.close(); |
82 | } |
83 | } |
84 | return 0; |
85 | } |
86 | |
87 | /// Given the output PlaceHolder StringMap \p PHM, outputs results in to text |
88 | /// files for each output. |
89 | class ObjectDetectionProcessResult : public PostProcessOutputDataExtension { |
90 | public: |
91 | int processOutputs(const llvm::StringMap<Placeholder *> &PHM, |
92 | PlaceholderBindings &bindings, |
93 | VecVecRef<std::string> imageList) override { |
94 | processAndPrintResults(PHM, bindings, imageList[0]); |
95 | return 0; |
96 | } |
97 | }; |
98 | |
99 | int main(int argc, char **argv) { |
100 | glow::Executor core("ObjectDetector" , argc, argv); |
101 | |
102 | auto printResultCreator = |
103 | []() -> std::unique_ptr<PostProcessOutputDataExtension> { |
104 | return std::make_unique<ObjectDetectionProcessResult>(); |
105 | }; |
106 | core.registerPostProcessOutputExtension(printResultCreator); |
107 | |
108 | int numErrors = core.executeNetwork(); |
109 | return numErrors; |
110 | } |
111 | |