1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_ONNXIFI_INLINEONNXIFI_H |
17 | #define GLOW_ONNXIFI_INLINEONNXIFI_H |
18 | |
19 | #include "Base.h" |
20 | |
21 | #include "llvm/ADT/SmallVector.h" |
22 | |
23 | namespace glow { |
24 | namespace onnxifi { |
25 | |
26 | /// Onnxifi Graph whose run method just executes the underlying function on the |
27 | /// same thread that calls its setIOAndRun function. |
28 | class InlineGraph : public Graph { |
29 | public: |
30 | InlineGraph(BackendPtr backendPtr, QuantizationMode quantizationMode) |
31 | : Graph(backendPtr), quantizationMode_(quantizationMode) {} |
32 | |
33 | /// Init Glow graph based on the ONNX model \p onnxModel and |
34 | /// static trained weights \p weightDescriptors. Weights can be read in later |
35 | /// by a \p deferedBlobReader. \p loadingGlowAOT specifies if the model has |
36 | /// already been AOT optimized via Glow. |
37 | onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize, |
38 | uint32_t weightCount, |
39 | const onnxTensorDescriptorV1 *weightDescriptors, |
40 | uint32_t maxSeqLength, void *deferedBlobReader, |
41 | bool loadingGlowAOT) override; |
42 | |
43 | onnxStatus run(std::unique_ptr<ExecutionContext> ctx, EventPtr outputEvent, |
44 | onnxTraceEventList *traceEvents) override; |
45 | |
46 | private: |
47 | ExecutionEngine executionEngine_; |
48 | Function *function_; |
49 | QuantizationMode quantizationMode_; |
50 | |
51 | /// A map between quantization profiling names of NodeValues that were lowered |
52 | /// from each other. Maps to a set of NodeValues that were replaced by the |
53 | /// NodeValue (key) that replaced them. |
54 | LoweredInfoMap loweredMap_; |
55 | |
56 | /// Hash of the model, used to find profiling data. |
57 | llvm::SmallString<32> modelHash_; |
58 | }; |
59 | |
60 | } // namespace onnxifi |
61 | } // namespace glow |
62 | |
63 | #endif // GLOW_ONNXIFI_INLINEONNXIFI_H |
64 | |