1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Defines an implementation of Module-based Model Runtime Interface that works with
22 * Ahead-of-Time compilation.
23 * \file aot_executor.h
24 */
25#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_
26#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_
27
28#include <tvm/runtime/metadata.h>
29#include <tvm/runtime/module.h>
30#include <tvm/runtime/object.h>
31#include <tvm/runtime/packed_func.h>
32
33#include <string>
34#include <vector>
35
36namespace tvm {
37namespace runtime {
38
39class TVM_DLL AotExecutor : public ModuleNode {
40 public:
41 /*!
42 * \brief Implements member function lookup for this Module for the frontend.
43 * \param name The name of the function.
44 * \param sptr_to_self The pointer to the module node.
45 * \return The corresponding member function.
46 */
47 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override;
48
49 /*!
50 * \return The type key of the executor.
51 */
52 const char* type_key() const final { return "AotExecutor"; }
53
54 void Run();
55
56 /*!
57 * \brief Initialize the AOT executor with metadata, runtime::Module, and device.
58 * \param module The module containing the compiled functions for the host
59 * processor.
60 * \param devs A 1-element vector. The Device which AOT compute will run on. Currently, only
61 * Device(kDLCPU, 0) is supported.
62 */
63 AotExecutor(tvm::runtime::Module module, const std::vector<Device>& devs);
64
65 /*!
66 * \brief Get the input index given the name of input.
67 * \param name The name of the input.
68 * \return The index of input.
69 */
70 int GetInputIndex(const std::string& name);
71
72 /*!
73 * \brief Get the output index given the name of output.
74 * \param name The name of the output.
75 * \return The index of output.
76 */
77 int GetOutputIndex(const std::string& name);
78
79 /*!
80 * \brief set index-th input to the graph.
81 * \param index The input index.
82 * \param data_in The input data.
83 */
84 void SetInput(int index, DLTensor* data_in);
85 /*!
86 * \brief set index-th input to the graph without copying the data
87 * \param index The input index.
88 * \param data_ref The input data that is referred.
89 */
90 void SetInputZeroCopy(int index, DLTensor* data_ref);
91 /*!
92 * \brief set index-th output to the graph without copying the data.
93 * \param index The output index.
94 * \param data_ref The output data that is referred.
95 */
96 void SetOutputZeroCopy(int index, DLTensor* data_ref);
97 /*!
98 * \brief Get the number of outputs
99 *
100 * \return The number of outputs from graph.
101 */
102 int NumOutputs() const;
103 /*!
104 * \brief Get the number of inputs
105 *
106 * \return The number of inputs to the graph.
107 */
108 int NumInputs() const;
109 /*!
110 * \brief Return NDArray for given input index.
111 * \param index The input index.
112 *
113 * \return NDArray corresponding to given input node index.
114 */
115 NDArray GetInput(int index) const;
116 /*!
117 * \brief Return NDArray for given output index.
118 * \param index The output index.
119 *
120 * \return NDArray corresponding to given output node index.
121 */
122 NDArray GetOutput(int index) const;
123 /*!
124 * \brief Copy index-th output to data_out.
125 * \param index The output index.
126 * \param data_out the output data.
127 */
128 void CopyOutputTo(int index, DLTensor* data_out);
129
130 private:
131 /*! \brief Metadata provided to the runtime from the compiler. */
132 metadata::Metadata metadata_;
133
134 /*! \brief Runtime module which contains the AOT top-level function. */
135 Module module_;
136
137 /*! \brief The devices which should be used to execute the computations. */
138 std::vector<Device> devices_;
139
140 /*! \brief Holds one NDArray per function argument in the same order. */
141 std::vector<NDArray> args_;
142};
143
144} // namespace runtime
145} // namespace tvm
146
147#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_
148