1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef GLOW_TOOLS_LOADER_EXECUTOR_CORE_H |
18 | #define GLOW_TOOLS_LOADER_EXECUTOR_CORE_H |
19 | |
20 | #include "Loader.h" |
21 | #include "glow/Graph/Nodes.h" |
22 | |
23 | namespace glow { |
24 | class PostProcessOutputDataExtension { |
25 | public: |
26 | /// Called once per mini-batch after network is executed to post process |
27 | /// output(s). |
28 | virtual int |
29 | processOutputs(const llvm::StringMap<Placeholder *> &PHM, |
30 | PlaceholderBindings &bindings, |
31 | VecVecRef<std::string> inputImageBatchFilenames) = 0; |
32 | virtual ~PostProcessOutputDataExtension(){}; |
33 | }; |
34 | |
35 | using PostProcessExtFuncPtr = |
36 | std::function<std::unique_ptr<PostProcessOutputDataExtension>()>; |
37 | |
38 | class PreProcessInputDataExtension { |
39 | public: |
40 | /// Called once per batch after images are loaded in to Tensor. |
41 | virtual void processInputTensor(llvm::ArrayRef<Tensor *> inputImageData, |
42 | size_t startId, size_t endId, |
43 | size_t batchSz) = 0; |
44 | virtual ~PreProcessInputDataExtension(){}; |
45 | }; |
46 | |
47 | class Executor final { |
48 | public: |
49 | Executor(std::string appName, int argc, char **argv); |
50 | Executor() = delete; |
51 | Executor(const Executor &) = delete; |
52 | Executor &operator=(const Executor &) = delete; |
53 | |
54 | /// Registers a Loader Extension that will be invoked after model is |
55 | /// loaded. If multiple extensions are registered they will be executed in |
56 | /// order they were registered. |
57 | void registerLoaderExtension( |
58 | std::function<std::unique_ptr<LoaderExtension>()> func); |
59 | /// Registers an extension that will be invoked on Tensor containing current |
60 | /// batch of input data. If multiple extensions are registered they will be |
61 | /// executed in order they were registered. |
62 | /// A new instance of the extension will be created for each thread. |
63 | void registerInputDataPreProcessingExtension( |
64 | std::function<std::unique_ptr<PreProcessInputDataExtension>()> func); |
65 | /// Registers extension that will be invoked for each execution of the |
66 | /// network. If multiple extensions are registered they will be executed in |
67 | /// order they were registered. |
68 | /// A new instance of the extension will be created for each thread. |
69 | void registerPostProcessOutputExtension(PostProcessExtFuncPtr func); |
70 | /// This will parse command line, load, build and execute a network. |
71 | /// Returns /p 0 if no errors occured, others none zero value. |
72 | int executeNetwork(); |
73 | |
74 | private: |
75 | /// Iterates over lambda expressions and registers them with each instance of |
76 | /// a loader in main dispatch loop. |
77 | void addLoaderExtensions(Loader &ld); |
78 | |
79 | private: |
80 | std::vector<std::function<std::unique_ptr<PreProcessInputDataExtension>()>> |
81 | ppInputDataExtensions_; |
82 | std::vector<PostProcessExtFuncPtr> ppOutputDataExtensions_; |
83 | std::vector<std::function<std::unique_ptr<LoaderExtension>()>> |
84 | loaderextensions_; |
85 | std::string appName_; |
86 | }; |
87 | |
88 | } // namespace glow |
89 | #endif // GLOW_TOOLS_LOADER_EXECUTOR_CORE_H |
90 | |