1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/driver/driver_api.h |
22 | * \brief Compiler driver APIs to drive the compilation. |
23 | * |
24 | * This module provides end-to-end utils to drive the compilation process. |
25 | * We adopt the term "compiler driver" in common compiler infrastructures. |
26 | * Note that a compiler driver is different from "runtime drivers". |
27 | * Most of runtime related code are defined in the runtime folder instead. |
28 | */ |
29 | #ifndef TVM_DRIVER_DRIVER_API_H_ |
30 | #define TVM_DRIVER_DRIVER_API_H_ |
31 | |
32 | #include <tvm/ir/global_var_supply.h> |
33 | #include <tvm/ir/module.h> |
34 | #include <tvm/ir/transform.h> |
35 | #include <tvm/runtime/packed_func.h> |
36 | #include <tvm/support/with.h> |
37 | #include <tvm/target/target.h> |
38 | #include <tvm/te/schedule_pass.h> |
39 | #include <tvm/tir/function.h> |
40 | |
41 | #include <string> |
42 | #include <unordered_map> |
43 | #include <unordered_set> |
44 | #include <utility> |
45 | #include <vector> |
46 | |
47 | namespace tvm { |
48 | using tvm::transform::Pass; |
49 | |
50 | /*! |
51 | * \brief Configures and returns the composite Pass for the fused module (pre split) that contains |
52 | * device and host code. |
53 | * \param mixed_mod The original mixed module. |
54 | * \param target The device Target. |
55 | * \return The composite Pass for the fused module. |
56 | // */ |
57 | TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); |
58 | |
59 | /*! |
60 | * \brief Configures and returns the composite Pass for the device Target after device/host from |
61 | * mixed module. |
62 | * \param mixed_mod The optimized mixed module. |
63 | * \param target The device Target. |
64 | * \return The composite Pass for the device module. |
65 | */ |
66 | TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target); |
67 | |
68 | /*! |
69 | * \brief Configures and returns the composite Pass for the host Target after device/host from mixed |
70 | * module. |
71 | * \param mixed_mod The optimized mixed module. |
72 | * \param target_host The host Target. |
73 | * \return The composite Pass for the host module. |
74 | */ |
75 | TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host); |
76 | |
77 | /*! |
78 | * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) |
79 | * \param mod The IRmodule to lower |
80 | * \param simple_mode Disables the loop partition pass. Defaults to false. |
81 | * \return The result module. |
82 | */ |
83 | TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); |
84 | |
85 | /*! |
86 | * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list |
87 | * defined in CreatePassList) |
88 | * \param func The PrimFunc to lower |
89 | * \param name The name of the lowered function. |
90 | * \param simple_mode Disables the loop partition pass. Defaults to false. |
91 | * \return The result module. |
92 | */ |
93 | TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, |
94 | bool simple_mode = false); |
95 | |
96 | /*! |
97 | * \brief Build an IRModule given a TE schedule, args and binds. This function also applies |
98 | * the lowering passes defined in CreatePassList. |
99 | * \param sch The TE schedule to lower. |
100 | * \param args The arguments to the function. |
101 | * \param name The name of the lowered function. |
102 | * \param binds Buffer assignments. |
103 | * \param global_var_supply The GlobalVarSupply to be used in the module. |
104 | * \param simple_mode Disables the loop partition pass. Defaults to false. |
105 | * \return The result module. |
106 | */ |
107 | |
108 | TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, |
109 | const std::string& name, |
110 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
111 | GlobalVarSupply global_var_supply, bool simple_mode = false); |
112 | |
113 | /*! |
114 | * \brief Build an IRModule given a TE schedule, args and binds. This function also applies |
115 | * the lowering passes defined in CreatePassList. |
116 | * \param sch The TE schedule to lower. |
117 | * \param args The arguments to the function (Array of Tensor, Buffer and Vars) |
118 | * \param name The name of the lowered function. |
119 | * \param binds Buffer assignments. |
120 | * \param global_var_supply The GlobalVarSupply to be used in the module. |
121 | * \param simple_mode Disables the loop partition pass. Defaults to false. |
122 | * \return The result module. |
123 | */ |
124 | TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, |
125 | const std::string& name, |
126 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
127 | GlobalVarSupply global_var_supply, bool simple_mode = false); |
128 | |
129 | /*! |
130 | * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want |
131 | * to apply lowering passes as well, use LowerSchedule. |
132 | * \param sch The schedule |
133 | * \param args The arguments to the function. |
134 | * \param name The name of the lowered function. |
135 | * \param binds Buffer assignments. |
136 | * \param global_var_supply The GlobalVarSupply to be used in the module and when creating |
137 | * GlobalVars. |
138 | * \return The result module. |
139 | */ |
140 | IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name, |
141 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
142 | GlobalVarSupply global_var_supply); |
143 | /*! |
144 | * \brief Build a device and host module for a specific target from an IRModule. |
145 | * \param funcs The functions to be built. |
146 | * \param target The target device to build for. |
147 | * \param target_host The target for building host code. To use the default, pass Target() |
148 | * \return The built module. |
149 | */ |
150 | TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target, |
151 | const Target& target_host); |
152 | |
153 | /*! |
154 | * \brief Build a device and host module for a specific target from a map |
155 | * contains target to IRModule. This function is used |
156 | * for heterogeneous build. |
157 | * \param input The map contains target to an IRModule. |
158 | * \param target_host The target for building host code. To use the default, |
159 | * pass Target(). |
160 | * \return The built module that contains code for different processors. |
161 | */ |
162 | TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host); |
163 | |
164 | /*! |
165 | * \brief Build a device and host module for a specific target from a map |
166 | * contains target to IRModule. This function is used |
167 | * for heterogeneous build. |
168 | * \param input The map contains target string to an IRModule. |
169 | * \param target_host The target for building host code. To use the default, |
170 | * pass Target(). |
171 | * \return The built module that contains code for different processors. |
172 | */ |
173 | TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host); |
174 | } // namespace tvm |
175 | |
176 | #endif // TVM_DRIVER_DRIVER_API_H_ |
177 | |