1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Registration of TVM schedules |
22 | * \file schedule.cc |
23 | */ |
24 | |
25 | #include <tvm/ir/expr.h> |
26 | #include <tvm/runtime/module.h> |
27 | #include <tvm/runtime/packed_func.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/target/generic_func.h> |
30 | #include <tvm/topi/cuda/dense.h> |
31 | #include <tvm/topi/cuda/injective.h> |
32 | #include <tvm/topi/cuda/pooling.h> |
33 | #include <tvm/topi/cuda/reduction.h> |
34 | #include <tvm/topi/cuda/softmax.h> |
35 | #include <tvm/topi/detail/tensor_utils.h> |
36 | #include <tvm/topi/generic/default.h> |
37 | #include <tvm/topi/generic/extern.h> |
38 | #include <tvm/topi/generic/injective.h> |
39 | #include <tvm/topi/rocm/dense.h> |
40 | #include <tvm/topi/rocm/injective.h> |
41 | #include <tvm/topi/rocm/pooling.h> |
42 | #include <tvm/topi/rocm/reduction.h> |
43 | #include <tvm/topi/rocm/softmax.h> |
44 | #include <tvm/topi/x86/bnn.h> |
45 | #include <tvm/topi/x86/default.h> |
46 | #include <tvm/topi/x86/injective.h> |
47 | |
48 | namespace tvm { |
49 | namespace topi { |
50 | |
51 | using namespace tvm; |
52 | using namespace tvm::runtime; |
53 | |
54 | TVM_REGISTER_GLOBAL("topi.TEST_create_target" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
55 | *rv = tvm::Target(args[0].operator String()); |
56 | }); |
57 | |
58 | /* Generic schedules */ |
59 | TVM_REGISTER_GLOBAL("topi.generic.default_schedule" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
60 | if (args[2]) { |
61 | *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]); |
62 | } else { |
63 | *rv = topi::generic::default_schedule(args[0], args[1]); |
64 | } |
65 | }); |
66 | |
67 | TVM_REGISTER_GLOBAL("topi.generic.schedule_extern" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
68 | *rv = topi::generic::schedule_extern(args[0], args[1]); |
69 | }); |
70 | |
71 | TVM_REGISTER_GLOBAL("topi.generic.schedule_injective" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
72 | *rv = topi::generic::schedule_injective(args[0], args[1]); |
73 | }); |
74 | |
75 | TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing" ) |
76 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
77 | *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); |
78 | }); |
79 | |
80 | /* x86 schedules */ |
81 | TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
82 | *rv = topi::x86::schedule_binarize_pack(args[0], args[1]); |
83 | }); |
84 | |
85 | TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
86 | *rv = topi::x86::schedule_binary_dense(args[0], args[1]); |
87 | }); |
88 | |
89 | TVM_REGISTER_GLOBAL("topi.x86.default_schedule" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
90 | if (args[2]) { |
91 | *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]); |
92 | } else { |
93 | *rv = topi::x86::default_schedule(args[0], args[1]); |
94 | } |
95 | }); |
96 | |
97 | TVM_REGISTER_GLOBAL("topi.x86.schedule_injective" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
98 | *rv = topi::x86::schedule_injective(args[0], args[1]); |
99 | }); |
100 | |
101 | TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing" ) |
102 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
103 | *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); |
104 | }); |
105 | |
106 | /* ROCm schedules */ |
107 | TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
108 | *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]); |
109 | }); |
110 | |
111 | TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
112 | *rv = topi::rocm::schedule_dense(args[0], args[1]); |
113 | }); |
114 | |
115 | TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
116 | *rv = topi::rocm::schedule_injective(args[0], args[1]); |
117 | }); |
118 | |
119 | TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing" ) |
120 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
121 | *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); |
122 | }); |
123 | |
124 | TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
125 | *rv = topi::rocm::schedule_pool(args[0], args[1]); |
126 | }); |
127 | |
128 | TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
129 | *rv = topi::rocm::schedule_global_pool(args[0], args[1]); |
130 | }); |
131 | |
132 | TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
133 | *rv = topi::rocm::schedule_reduce(args[0], args[1]); |
134 | }); |
135 | |
136 | TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
137 | *rv = topi::rocm::schedule_softmax(args[0], args[1]); |
138 | }); |
139 | |
140 | /* CUDA schedules */ |
141 | TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
142 | *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]); |
143 | }); |
144 | |
145 | TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
146 | *rv = topi::cuda::schedule_dense(args[0], args[1]); |
147 | }); |
148 | |
149 | TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
150 | *rv = topi::cuda::schedule_injective(args[0], args[1]); |
151 | }); |
152 | |
153 | TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing" ) |
154 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
155 | *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); |
156 | }); |
157 | |
158 | TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
159 | *rv = topi::cuda::schedule_pool(args[0], args[1]); |
160 | }); |
161 | |
162 | TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
163 | *rv = topi::cuda::schedule_global_pool(args[0], args[1]); |
164 | }); |
165 | |
166 | TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
167 | *rv = topi::cuda::schedule_reduce(args[0], args[1]); |
168 | }); |
169 | |
170 | TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
171 | *rv = topi::cuda::schedule_softmax(args[0], args[1]); |
172 | }); |
173 | |
174 | /* Utility functions */ |
175 | TVM_REGISTER_GLOBAL("topi.utils.is_empty_shape" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
176 | *rv = topi::detail::is_empty_shape(args[0]); |
177 | }); |
178 | |
179 | TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
180 | *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]); |
181 | }); |
182 | |
183 | TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
184 | *rv = detail::bilinear_sample_nhwc(args[0], args[1], args[2], args[3]); |
185 | }); |
186 | |
187 | /*! \brief Builder function for instantiating schedules. */ |
188 | using FTVMScheduleBuilder = std::function<tvm::te::Schedule( |
189 | const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>; |
190 | |
191 | /*! |
192 | * \brief Helper function for registering generic functions matching the |
193 | * FTVMScheduleBuilder signature. The schedule builder function is wrapped |
194 | * with a PackedFunc suitable for passing to a tvm::GenericFunc. |
195 | * |
196 | * \param builder The schedule builder to wrap. |
197 | * |
198 | * \return The wrapped schedule builder |
199 | */ |
200 | inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { |
201 | return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { |
202 | auto target = Target::Current(false); |
203 | Array<Tensor> outs; |
204 | ObjectRef argNodeRef = args[0]; |
205 | if (argNodeRef->type_index() == outs->type_index()) { |
206 | outs = args[0]; |
207 | } else { |
208 | outs = Array<Tensor>{args[0]}; |
209 | } |
210 | |
211 | *ret = builder(target, outs); |
212 | }); |
213 | } |
214 | |
215 | TVM_REGISTER_GENERIC_FUNC(schedule_injective) |
216 | .set_default(WrapSchedule(topi::generic::schedule_injective)) |
217 | .register_func({"cpu" }, WrapSchedule(topi::x86::schedule_injective)) |
218 | .register_func({"cuda" , "gpu" }, WrapSchedule(topi::cuda::schedule_injective)); |
219 | |
220 | TVM_REGISTER_GENERIC_FUNC(schedule_softmax) |
221 | .set_default(WrapSchedule(topi::generic::default_schedule)) |
222 | .register_func({"cpu" }, WrapSchedule(topi::x86::default_schedule)) |
223 | .register_func({"cuda" , "gpu" }, WrapSchedule(topi::cuda::schedule_softmax)); |
224 | |
225 | TVM_REGISTER_GENERIC_FUNC(schedule_dense) |
226 | .set_default(WrapSchedule(topi::generic::default_schedule)) |
227 | .register_func({"cuda" , "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) |
228 | .register_func({"rocm" }, WrapSchedule(topi::rocm::schedule_dense)); |
229 | |
230 | TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) |
231 | .set_default(WrapSchedule(topi::generic::default_schedule)); |
232 | |
233 | TVM_REGISTER_GENERIC_FUNC(schedule_batch_norm) |
234 | .set_default(WrapSchedule(topi::generic::default_schedule)); |
235 | |
236 | TVM_REGISTER_GENERIC_FUNC(schedule_pool) |
237 | .set_default(WrapSchedule(topi::generic::default_schedule)) |
238 | .register_func({"cpu" }, WrapSchedule(topi::x86::default_schedule)) |
239 | .register_func({"cuda" , "gpu" }, WrapSchedule(topi::cuda::schedule_pool)); |
240 | |
241 | TVM_REGISTER_GENERIC_FUNC(schedule_global_pool) |
242 | .set_default(WrapSchedule(topi::generic::default_schedule)) |
243 | .register_func({"cpu" }, WrapSchedule(topi::x86::default_schedule)) |
244 | .register_func({"cuda" , "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool)); |
245 | |
246 | TVM_REGISTER_GENERIC_FUNC(schedule_reduce) |
247 | .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) |
248 | .register_func({"cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline)) |
249 | .register_func({"cuda" , "gpu" }, WrapSchedule(topi::cuda::schedule_reduce)); |
250 | |
251 | TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack) |
252 | .set_default(WrapSchedule(topi::generic::default_schedule)) |
253 | .register_func({"cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack)); |
254 | |
255 | TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) |
256 | .set_default(WrapSchedule(topi::generic::default_schedule)) |
257 | .register_func({"cpu" }, WrapSchedule(topi::x86::schedule_binary_dense)); |
258 | |
259 | /*! \brief Builder function for instantiating schedules from existing schedules. */ |
260 | using FTVMScheduleFromExistingBuilder = |
261 | std::function<tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>; |
262 | |
263 | /*! |
264 | * \brief Helper function for registering generic functions matching the |
265 | * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped |
266 | * with a PackedFunc suitable for passing to a tvm::GenericFunc. |
267 | * |
268 | * \param builder The schedule builder to wrap. |
269 | * |
270 | * \return The wrapped schedule builder |
271 | */ |
272 | inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) { |
273 | return PackedFunc( |
274 | [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); }); |
275 | } |
276 | |
277 | TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) |
278 | .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) |
279 | .register_func({"cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) |
280 | .register_func({"cuda" , "gpu" }, |
281 | WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing)); |
282 | |
283 | /*! \brief Builder function for instantiating dense ops. */ |
284 | using FTVMDenseOpBuilder = std::function<tvm::te::Tensor( |
285 | const Target& target, const tvm::te::Tensor& data, const tvm::te::Tensor& weight, |
286 | const tvm::te::Tensor& bias, const DataType& out_dtype)>; |
287 | |
288 | /*! |
289 | * \brief Helper function for registering dense ops matching the |
290 | * FTVMDenseOpBuilder signature. The op builder function is wrapped |
291 | * with a PackedFunc suitable for passing to a tvm::GenericFunc. |
292 | * |
293 | * \param builder The op builder to wrap. |
294 | * |
295 | * \return The wrapped op builder |
296 | */ |
297 | inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { |
298 | return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { |
299 | auto target = Target::Current(false); |
300 | Tensor data = args[0]; |
301 | Tensor weight = args[1]; |
302 | Tensor bias = args[2]; |
303 | DataType out_dtype = args[3]; |
304 | |
305 | *ret = builder(target, data, weight, bias, out_dtype); |
306 | }); |
307 | } |
308 | |
309 | TVM_REGISTER_GENERIC_FUNC(dense) |
310 | .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data, |
311 | const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, |
312 | const DataType& out_dtype) { |
313 | return topi::nn::dense(data, weight, bias, out_dtype); |
314 | })) |
315 | .register_func({"cuda" , "gpu" }, WrapDenseOp(topi::cuda::dense_cuda)) |
316 | .register_func({"rocm" }, WrapDenseOp(topi::rocm::dense_rocm)); |
317 | |
318 | } // namespace topi |
319 | } // namespace tvm |
320 | |