1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Registration of TVM schedules
22 * \file schedule.cc
23 */
24
25#include <tvm/ir/expr.h>
26#include <tvm/runtime/module.h>
27#include <tvm/runtime/packed_func.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/target/generic_func.h>
30#include <tvm/topi/cuda/dense.h>
31#include <tvm/topi/cuda/injective.h>
32#include <tvm/topi/cuda/pooling.h>
33#include <tvm/topi/cuda/reduction.h>
34#include <tvm/topi/cuda/softmax.h>
35#include <tvm/topi/detail/tensor_utils.h>
36#include <tvm/topi/generic/default.h>
37#include <tvm/topi/generic/extern.h>
38#include <tvm/topi/generic/injective.h>
39#include <tvm/topi/rocm/dense.h>
40#include <tvm/topi/rocm/injective.h>
41#include <tvm/topi/rocm/pooling.h>
42#include <tvm/topi/rocm/reduction.h>
43#include <tvm/topi/rocm/softmax.h>
44#include <tvm/topi/x86/bnn.h>
45#include <tvm/topi/x86/default.h>
46#include <tvm/topi/x86/injective.h>
47
48namespace tvm {
49namespace topi {
50
51using namespace tvm;
52using namespace tvm::runtime;
53
54TVM_REGISTER_GLOBAL("topi.TEST_create_target").set_body([](TVMArgs args, TVMRetValue* rv) {
55 *rv = tvm::Target(args[0].operator String());
56});
57
58/* Generic schedules */
59TVM_REGISTER_GLOBAL("topi.generic.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
60 if (args[2]) {
61 *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
62 } else {
63 *rv = topi::generic::default_schedule(args[0], args[1]);
64 }
65});
66
67TVM_REGISTER_GLOBAL("topi.generic.schedule_extern").set_body([](TVMArgs args, TVMRetValue* rv) {
68 *rv = topi::generic::schedule_extern(args[0], args[1]);
69});
70
71TVM_REGISTER_GLOBAL("topi.generic.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
72 *rv = topi::generic::schedule_injective(args[0], args[1]);
73});
74
75TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
76 .set_body([](TVMArgs args, TVMRetValue* rv) {
77 *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
78 });
79
80/* x86 schedules */
81TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) {
82 *rv = topi::x86::schedule_binarize_pack(args[0], args[1]);
83});
84
85TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
86 *rv = topi::x86::schedule_binary_dense(args[0], args[1]);
87});
88
89TVM_REGISTER_GLOBAL("topi.x86.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
90 if (args[2]) {
91 *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
92 } else {
93 *rv = topi::x86::default_schedule(args[0], args[1]);
94 }
95});
96
97TVM_REGISTER_GLOBAL("topi.x86.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
98 *rv = topi::x86::schedule_injective(args[0], args[1]);
99});
100
101TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
102 .set_body([](TVMArgs args, TVMRetValue* rv) {
103 *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
104 });
105
106/* ROCm schedules */
107TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
108 *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]);
109});
110
111TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
112 *rv = topi::rocm::schedule_dense(args[0], args[1]);
113});
114
115TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
116 *rv = topi::rocm::schedule_injective(args[0], args[1]);
117});
118
119TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
120 .set_body([](TVMArgs args, TVMRetValue* rv) {
121 *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
122 });
123
124TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
125 *rv = topi::rocm::schedule_pool(args[0], args[1]);
126});
127
128TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
129 *rv = topi::rocm::schedule_global_pool(args[0], args[1]);
130});
131
132TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
133 *rv = topi::rocm::schedule_reduce(args[0], args[1]);
134});
135
136TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
137 *rv = topi::rocm::schedule_softmax(args[0], args[1]);
138});
139
140/* CUDA schedules */
141TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
142 *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
143});
144
145TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
146 *rv = topi::cuda::schedule_dense(args[0], args[1]);
147});
148
149TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
150 *rv = topi::cuda::schedule_injective(args[0], args[1]);
151});
152
153TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
154 .set_body([](TVMArgs args, TVMRetValue* rv) {
155 *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
156 });
157
158TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
159 *rv = topi::cuda::schedule_pool(args[0], args[1]);
160});
161
162TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
163 *rv = topi::cuda::schedule_global_pool(args[0], args[1]);
164});
165
166TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
167 *rv = topi::cuda::schedule_reduce(args[0], args[1]);
168});
169
170TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
171 *rv = topi::cuda::schedule_softmax(args[0], args[1]);
172});
173
174/* Utility functions */
175TVM_REGISTER_GLOBAL("topi.utils.is_empty_shape").set_body([](TVMArgs args, TVMRetValue* rv) {
176 *rv = topi::detail::is_empty_shape(args[0]);
177});
178
179TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw").set_body([](TVMArgs args, TVMRetValue* rv) {
180 *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
181});
182
183TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) {
184 *rv = detail::bilinear_sample_nhwc(args[0], args[1], args[2], args[3]);
185});
186
187/*! \brief Builder function for instantiating schedules. */
188using FTVMScheduleBuilder = std::function<tvm::te::Schedule(
189 const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
190
191/*!
192 * \brief Helper function for registering generic functions matching the
193 * FTVMScheduleBuilder signature. The schedule builder function is wrapped
194 * with a PackedFunc suitable for passing to a tvm::GenericFunc.
195 *
196 * \param builder The schedule builder to wrap.
197 *
198 * \return The wrapped schedule builder
199 */
200inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
201 return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
202 auto target = Target::Current(false);
203 Array<Tensor> outs;
204 ObjectRef argNodeRef = args[0];
205 if (argNodeRef->type_index() == outs->type_index()) {
206 outs = args[0];
207 } else {
208 outs = Array<Tensor>{args[0]};
209 }
210
211 *ret = builder(target, outs);
212 });
213}
214
215TVM_REGISTER_GENERIC_FUNC(schedule_injective)
216 .set_default(WrapSchedule(topi::generic::schedule_injective))
217 .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_injective))
218 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_injective));
219
220TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
221 .set_default(WrapSchedule(topi::generic::default_schedule))
222 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
223 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_softmax));
224
225TVM_REGISTER_GENERIC_FUNC(schedule_dense)
226 .set_default(WrapSchedule(topi::generic::default_schedule))
227 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_dense))
228 .register_func({"rocm"}, WrapSchedule(topi::rocm::schedule_dense));
229
230TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
231 .set_default(WrapSchedule(topi::generic::default_schedule));
232
233TVM_REGISTER_GENERIC_FUNC(schedule_batch_norm)
234 .set_default(WrapSchedule(topi::generic::default_schedule));
235
236TVM_REGISTER_GENERIC_FUNC(schedule_pool)
237 .set_default(WrapSchedule(topi::generic::default_schedule))
238 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
239 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_pool));
240
241TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
242 .set_default(WrapSchedule(topi::generic::default_schedule))
243 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
244 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_global_pool));
245
246TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
247 .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
248 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule_auto_inline))
249 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_reduce));
250
251TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
252 .set_default(WrapSchedule(topi::generic::default_schedule))
253 .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binarize_pack));
254
255TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
256 .set_default(WrapSchedule(topi::generic::default_schedule))
257 .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binary_dense));
258
259/*! \brief Builder function for instantiating schedules from existing schedules. */
260using FTVMScheduleFromExistingBuilder =
261 std::function<tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>;
262
263/*!
264 * \brief Helper function for registering generic functions matching the
265 * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
266 * with a PackedFunc suitable for passing to a tvm::GenericFunc.
267 *
268 * \param builder The schedule builder to wrap.
269 *
270 * \return The wrapped schedule builder
271 */
272inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
273 return PackedFunc(
274 [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); });
275}
276
277TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
278 .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
279 .register_func({"cpu"}, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
280 .register_func({"cuda", "gpu"},
281 WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
282
283/*! \brief Builder function for instantiating dense ops. */
284using FTVMDenseOpBuilder = std::function<tvm::te::Tensor(
285 const Target& target, const tvm::te::Tensor& data, const tvm::te::Tensor& weight,
286 const tvm::te::Tensor& bias, const DataType& out_dtype)>;
287
288/*!
289 * \brief Helper function for registering dense ops matching the
290 * FTVMDenseOpBuilder signature. The op builder function is wrapped
291 * with a PackedFunc suitable for passing to a tvm::GenericFunc.
292 *
293 * \param builder The op builder to wrap.
294 *
295 * \return The wrapped op builder
296 */
297inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
298 return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
299 auto target = Target::Current(false);
300 Tensor data = args[0];
301 Tensor weight = args[1];
302 Tensor bias = args[2];
303 DataType out_dtype = args[3];
304
305 *ret = builder(target, data, weight, bias, out_dtype);
306 });
307}
308
309TVM_REGISTER_GENERIC_FUNC(dense)
310 .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data,
311 const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
312 const DataType& out_dtype) {
313 return topi::nn::dense(data, weight, bias, out_dtype);
314 }))
315 .register_func({"cuda", "gpu"}, WrapDenseOp(topi::cuda::dense_cuda))
316 .register_func({"rocm"}, WrapDenseOp(topi::rocm::dense_rocm));
317
318} // namespace topi
319} // namespace tvm
320