1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Defines an implementation of Module-based Model Runtime Interface that works with |
22 | * Ahead-of-Time compilation. |
23 | * \file aot_executor.h |
24 | */ |
25 | #ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ |
26 | #define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ |
27 | |
28 | #include <tvm/runtime/metadata.h> |
29 | #include <tvm/runtime/module.h> |
30 | #include <tvm/runtime/object.h> |
31 | #include <tvm/runtime/packed_func.h> |
32 | |
33 | #include <string> |
34 | #include <vector> |
35 | |
36 | namespace tvm { |
37 | namespace runtime { |
38 | |
39 | class TVM_DLL AotExecutor : public ModuleNode { |
40 | public: |
41 | /*! |
42 | * \brief Implements member function lookup for this Module for the frontend. |
43 | * \param name The name of the function. |
44 | * \param sptr_to_self The pointer to the module node. |
45 | * \return The corresponding member function. |
46 | */ |
47 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override; |
48 | |
49 | /*! |
50 | * \return The type key of the executor. |
51 | */ |
52 | const char* type_key() const final { return "AotExecutor" ; } |
53 | |
54 | void Run(); |
55 | |
56 | /*! |
57 | * \brief Initialize the AOT executor with metadata, runtime::Module, and device. |
58 | * \param module The module containing the compiled functions for the host |
59 | * processor. |
60 | * \param devs A 1-element vector. The Device which AOT compute will run on. Currently, only |
61 | * Device(kDLCPU, 0) is supported. |
62 | */ |
63 | AotExecutor(tvm::runtime::Module module, const std::vector<Device>& devs); |
64 | |
65 | /*! |
66 | * \brief Get the input index given the name of input. |
67 | * \param name The name of the input. |
68 | * \return The index of input. |
69 | */ |
70 | int GetInputIndex(const std::string& name); |
71 | |
72 | /*! |
73 | * \brief Get the output index given the name of output. |
74 | * \param name The name of the output. |
75 | * \return The index of output. |
76 | */ |
77 | int GetOutputIndex(const std::string& name); |
78 | |
79 | /*! |
80 | * \brief set index-th input to the graph. |
81 | * \param index The input index. |
82 | * \param data_in The input data. |
83 | */ |
84 | void SetInput(int index, DLTensor* data_in); |
85 | /*! |
86 | * \brief set index-th input to the graph without copying the data |
87 | * \param index The input index. |
88 | * \param data_ref The input data that is referred. |
89 | */ |
90 | void SetInputZeroCopy(int index, DLTensor* data_ref); |
91 | /*! |
92 | * \brief set index-th output to the graph without copying the data. |
93 | * \param index The output index. |
94 | * \param data_ref The output data that is referred. |
95 | */ |
96 | void SetOutputZeroCopy(int index, DLTensor* data_ref); |
97 | /*! |
98 | * \brief Get the number of outputs |
99 | * |
100 | * \return The number of outputs from graph. |
101 | */ |
102 | int NumOutputs() const; |
103 | /*! |
104 | * \brief Get the number of inputs |
105 | * |
106 | * \return The number of inputs to the graph. |
107 | */ |
108 | int NumInputs() const; |
109 | /*! |
110 | * \brief Return NDArray for given input index. |
111 | * \param index The input index. |
112 | * |
113 | * \return NDArray corresponding to given input node index. |
114 | */ |
115 | NDArray GetInput(int index) const; |
116 | /*! |
117 | * \brief Return NDArray for given output index. |
118 | * \param index The output index. |
119 | * |
120 | * \return NDArray corresponding to given output node index. |
121 | */ |
122 | NDArray GetOutput(int index) const; |
123 | /*! |
124 | * \brief Copy index-th output to data_out. |
125 | * \param index The output index. |
126 | * \param data_out the output data. |
127 | */ |
128 | void CopyOutputTo(int index, DLTensor* data_out); |
129 | |
130 | private: |
131 | /*! \brief Metadata provided to the runtime from the compiler. */ |
132 | metadata::Metadata metadata_; |
133 | |
134 | /*! \brief Runtime module which contains the AOT top-level function. */ |
135 | Module module_; |
136 | |
137 | /*! \brief The devices which should be used to execute the computations. */ |
138 | std::vector<Device> devices_; |
139 | |
140 | /*! \brief Holds one NDArray per function argument in the same order. */ |
141 | std::vector<NDArray> args_; |
142 | }; |
143 | |
144 | } // namespace runtime |
145 | } // namespace tvm |
146 | |
147 | #endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ |
148 | |