1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/driver/driver_api.h
22 * \brief Compiler driver APIs to drive the compilation.
23 *
24 * This module provides end-to-end utils to drive the compilation process.
25 * We adopt the term "compiler driver" in common compiler infrastructures.
26 * Note that a compiler driver is different from "runtime drivers".
27 * Most of runtime related code are defined in the runtime folder instead.
28 */
29#ifndef TVM_DRIVER_DRIVER_API_H_
30#define TVM_DRIVER_DRIVER_API_H_
31
32#include <tvm/ir/global_var_supply.h>
33#include <tvm/ir/module.h>
34#include <tvm/ir/transform.h>
35#include <tvm/runtime/packed_func.h>
36#include <tvm/support/with.h>
37#include <tvm/target/target.h>
38#include <tvm/te/schedule_pass.h>
39#include <tvm/tir/function.h>
40
41#include <string>
42#include <unordered_map>
43#include <unordered_set>
44#include <utility>
45#include <vector>
46
47namespace tvm {
48using tvm::transform::Pass;
49
50/*!
51 * \brief Configures and returns the composite Pass for the fused module (pre split) that contains
52 * device and host code.
53 * \param mixed_mod The original mixed module.
54 * \param target The device Target.
55 * \return The composite Pass for the fused module.
56// */
57TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);
58
59/*!
60 * \brief Configures and returns the composite Pass for the device Target after device/host from
61 * mixed module.
62 * \param mixed_mod The optimized mixed module.
63 * \param target The device Target.
64 * \return The composite Pass for the device module.
65 */
66TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target);
67
68/*!
69 * \brief Configures and returns the composite Pass for the host Target after device/host from mixed
70 * module.
71 * \param mixed_mod The optimized mixed module.
72 * \param target_host The host Target.
73 * \return The composite Pass for the host module.
74 */
75TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host);
76
77/*!
78 * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList)
79 * \param mod The IRmodule to lower
80 * \param simple_mode Disables the loop partition pass. Defaults to false.
81 * \return The result module.
82 */
83TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
84
85/*!
86 * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list
87 * defined in CreatePassList)
88 * \param func The PrimFunc to lower
89 * \param name The name of the lowered function.
90 * \param simple_mode Disables the loop partition pass. Defaults to false.
91 * \return The result module.
92 */
93TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
94 bool simple_mode = false);
95
96/*!
97 * \brief Build an IRModule given a TE schedule, args and binds. This function also applies
98 * the lowering passes defined in CreatePassList.
99 * \param sch The TE schedule to lower.
100 * \param args The arguments to the function.
101 * \param name The name of the lowered function.
102 * \param binds Buffer assignments.
103 * \param global_var_supply The GlobalVarSupply to be used in the module.
104 * \param simple_mode Disables the loop partition pass. Defaults to false.
105 * \return The result module.
106 */
107
108TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
109 const std::string& name,
110 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
111 GlobalVarSupply global_var_supply, bool simple_mode = false);
112
113/*!
114 * \brief Build an IRModule given a TE schedule, args and binds. This function also applies
115 * the lowering passes defined in CreatePassList.
116 * \param sch The TE schedule to lower.
117 * \param args The arguments to the function (Array of Tensor, Buffer and Vars)
118 * \param name The name of the lowered function.
119 * \param binds Buffer assignments.
120 * \param global_var_supply The GlobalVarSupply to be used in the module.
121 * \param simple_mode Disables the loop partition pass. Defaults to false.
122 * \return The result module.
123 */
124TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
125 const std::string& name,
126 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
127 GlobalVarSupply global_var_supply, bool simple_mode = false);
128
129/*!
130 * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want
131 * to apply lowering passes as well, use LowerSchedule.
132 * \param sch The schedule
133 * \param args The arguments to the function.
134 * \param name The name of the lowered function.
135 * \param binds Buffer assignments.
136 * \param global_var_supply The GlobalVarSupply to be used in the module and when creating
137 * GlobalVars.
138 * \return The result module.
139 */
140IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
141 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
142 GlobalVarSupply global_var_supply);
143/*!
144 * \brief Build a device and host module for a specific target from an IRModule.
145 * \param funcs The functions to be built.
146 * \param target The target device to build for.
147 * \param target_host The target for building host code. To use the default, pass Target()
148 * \return The built module.
149 */
150TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target,
151 const Target& target_host);
152
153/*!
154 * \brief Build a device and host module for a specific target from a map
155 * contains target to IRModule. This function is used
156 * for heterogeneous build.
157 * \param input The map contains target to an IRModule.
158 * \param target_host The target for building host code. To use the default,
159 * pass Target().
160 * \return The built module that contains code for different processors.
161 */
162TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host);
163
164/*!
165 * \brief Build a device and host module for a specific target from a map
166 * contains target to IRModule. This function is used
167 * for heterogeneous build.
168 * \param input The map contains target string to an IRModule.
169 * \param target_host The target for building host code. To use the default,
170 * pass Target().
171 * \return The built module that contains code for different processors.
172 */
173TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host);
174} // namespace tvm
175
176#endif // TVM_DRIVER_DRIVER_API_H_
177