1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Broadcast op constructions |
22 | * \file topi/broadcast.h |
23 | */ |
24 | #ifndef TVM_TOPI_BROADCAST_H_ |
25 | #define TVM_TOPI_BROADCAST_H_ |
26 | |
27 | #include <tvm/topi/detail/broadcast.h> |
28 | #include <tvm/topi/detail/constant_utils.h> |
29 | #include <tvm/topi/tags.h> |
30 | |
31 | #include <algorithm> |
32 | #include <string> |
33 | |
34 | namespace tvm { |
35 | namespace topi { |
36 | |
37 | /*! |
38 | * \brief Creates an operation that broadcasts a tensor into a compatible |
39 | * shape according to numpy's rules |
40 | * |
41 | * \param t The input tensor |
42 | * \param output_shape The target output shape, must be compatible |
43 | * \param name The name of the operation |
44 | * \param tag The tag to mark the operation |
45 | * |
46 | * \return A Tensor whose op member is a broadcast operation |
47 | */ |
48 | inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, |
49 | const tvm::Array<tvm::PrimExpr>& output_shape, |
50 | std::string name = "T_broadcast_to" , |
51 | std::string tag = kBroadcast) { |
52 | ICHECK_GE(output_shape.size(), t->shape.size()) |
53 | << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape |
54 | << "\nvs\ninput: " << t; |
55 | auto bh = detail::BroadcastShape(output_shape, t->shape); |
56 | ICHECK_EQ(output_shape.size(), bh.common_shape.size()); |
57 | Array<PrimExpr> oshape; |
58 | for (size_t i = 0; i < output_shape.size(); ++i) { |
59 | if (output_shape[i].as<tir::IntImmNode>() == nullptr) { |
60 | oshape.push_back(output_shape[i]); |
61 | } else { |
62 | ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); |
63 | oshape.push_back(bh.common_shape[i]); |
64 | } |
65 | } |
66 | auto l = [&](tvm::Array<tvm::tir::Var> ovars) { |
67 | return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); |
68 | }; |
69 | return tvm::te::compute(oshape, l, name, tag); |
70 | } |
71 | |
72 | #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ |
73 | inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ |
74 | inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ |
75 | std::string name = "T_" #Name, std::string tag = kBroadcast) { \ |
76 | auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ |
77 | return detail::WithBroadcast(l, A, B, name, tag); \ |
78 | } \ |
79 | inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ |
80 | std::string name = "T_" #Name, std::string tag = kElementWise) { \ |
81 | auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ |
82 | return tvm::te::compute( \ |
83 | A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \ |
84 | } \ |
85 | inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ |
86 | std::string name = "T_" #Name, std::string tag = kElementWise) { \ |
87 | auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ |
88 | return tvm::te::compute( \ |
89 | B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \ |
90 | } |
91 | |
92 | #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ |
93 | inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \ |
94 | return topi::OpName(A, B); \ |
95 | } \ |
96 | inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \ |
97 | return topi::OpName(A, B); \ |
98 | } \ |
99 | inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \ |
100 | return topi::OpName(A, B); \ |
101 | } |
102 | |
103 | /*! |
104 | * \fn logical_and |
105 | * \brief Compute A && B with auto-broadcasting. |
106 | * |
107 | * \param A The first tensor, or Expr |
108 | * \param B The second tensor, or Expr |
109 | * \param name The name of the operation |
110 | * \param tag The tag to mark the operation |
111 | * |
112 | * \return The result. |
113 | */ |
114 | TOPI_DEFINE_BCAST_OP(logical_and, { return a && b; }); |
115 | TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and); |
116 | |
117 | /*! |
118 | * \fn logical_or |
119 | * \brief Compute A || B with auto-broadcasting. |
120 | * |
121 | * \param A The first tensor, or Expr |
122 | * \param B The second tensor, or Expr |
123 | * \param name The name of the operation |
124 | * \param tag The tag to mark the operation |
125 | * |
126 | * \return The result. |
127 | */ |
128 | TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; }); |
129 | TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or); |
130 | |
131 | /*! |
132 | * \fn logical_xor |
133 | * \brief Compute A ^ B with auto-broadcasting. |
134 | * |
135 | * \param A The first tensor, or Expr |
136 | * \param B The second tensor, or Expr |
137 | * \param name The name of the operation |
138 | * \param tag The tag to mark the operation |
139 | * |
140 | * \return The result. |
141 | */ |
142 | TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; }); |
143 | |
144 | /*! |
145 | * \fn bitwise_and |
146 | * \brief Compute A & B with auto-broadcasting. |
147 | * |
148 | * \param A The first tensor, or Expr |
149 | * \param B The second tensor, or Expr |
150 | * \param name The name of the operation |
151 | * \param tag The tag to mark the operation |
152 | * |
153 | * \return The result. |
154 | */ |
155 | TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; }); |
156 | TOPI_DEFINE_OP_OVERLOAD(operator&, bitwise_and); |
157 | |
158 | /*! |
159 | * \fn bitwise_or |
160 | * \brief Compute A | B with auto-broadcasting. |
161 | * |
162 | * \param A The first tensor, or Expr |
163 | * \param B The second tensor, or Expr |
164 | * \param name The name of the operation |
165 | * \param tag The tag to mark the operation |
166 | * |
167 | * \return The result. |
168 | */ |
169 | TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; }); |
170 | TOPI_DEFINE_OP_OVERLOAD(operator|, bitwise_or); |
171 | |
172 | /*! |
173 | * \fn bitwise_xor |
174 | * \brief Compute A ^ B with auto-broadcasting. |
175 | * |
176 | * \param A The first tensor, or Expr |
177 | * \param B The second tensor, or Expr |
178 | * \param name The name of the operation |
179 | * \param tag The tag to mark the operation |
180 | * |
181 | * \return The result. |
182 | */ |
183 | TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; }); |
184 | TOPI_DEFINE_OP_OVERLOAD(operator^, bitwise_xor); |
185 | |
186 | /*! |
187 | * \fn add |
188 | * \brief Compute A + B with auto-broadcasting. |
189 | * |
190 | * \param A The first tensor, or Expr |
191 | * \param B The second tensor, or Expr |
192 | * \param name The name of the operation |
193 | * \param tag The tag to mark the operation |
194 | * |
195 | * \return The result. |
196 | */ |
197 | TOPI_DEFINE_BCAST_OP(add, { return a + b; }); |
198 | TOPI_DEFINE_OP_OVERLOAD(operator+, add); |
199 | |
200 | /*! |
201 | * \fn subtract |
202 | * \brief Compute A - B with auto-broadcasting. |
203 | * |
204 | * \param A The first tensor, or Expr |
205 | * \param B The second tensor, or Expr |
206 | * \param name The name of the operation |
207 | * \param tag The tag to mark the operation |
208 | * |
209 | * \return The result. |
210 | */ |
211 | TOPI_DEFINE_BCAST_OP(subtract, { return a - b; }); |
212 | TOPI_DEFINE_OP_OVERLOAD(operator-, subtract); |
213 | |
214 | /*! |
215 | * \fn multiply |
216 | * \brief Compute A * B with auto-broadcasting. |
217 | * |
218 | * \param A The first tensor, or Expr |
219 | * \param B The second tensor, or Expr |
220 | * \param name The name of the operation |
221 | * \param tag The tag to mark the operation |
222 | * |
223 | * \return The result. |
224 | */ |
225 | TOPI_DEFINE_BCAST_OP(multiply, { return a * b; }); |
226 | TOPI_DEFINE_OP_OVERLOAD(operator*, multiply); |
227 | |
228 | /*! |
229 | * \fn divide |
230 | * \brief Compute A / B with auto-broadcasting. |
231 | * |
232 | * \param A The first tensor, or Expr |
233 | * \param B The second tensor, or Expr |
234 | * \param name The name of the operation |
235 | * \param tag The tag to mark the operation |
236 | * |
237 | * \return The result. |
238 | */ |
239 | TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); }); |
240 | |
241 | /*! |
242 | * \fn floor divide |
243 | * \brief Compute floor(A / B) with auto-broadcasting. |
244 | * |
245 | * \param A The first tensor, or Expr |
246 | * \param B The second tensor, or Expr |
247 | * \param name The name of the operation |
248 | * \param tag The tag to mark the operation |
249 | * |
250 | * \return The result. |
251 | */ |
252 | TOPI_DEFINE_BCAST_OP(floor_divide, { |
253 | if (a.dtype().is_int() || a.dtype().is_uint()) { |
254 | return floordiv(a, b); |
255 | } else { |
256 | return floor(div(a, b)); |
257 | } |
258 | }); |
259 | |
260 | /*! |
261 | * \fn trunc divide |
262 | * \brief Compute trunc(A / B) with auto-broadcasting. |
263 | * |
264 | * \param A The first tensor, or Expr |
265 | * \param B The second tensor, or Expr |
266 | * \param name The name of the operation |
267 | * \param tag The tag to mark the operation |
268 | * |
269 | * \return The result. |
270 | */ |
271 | TOPI_DEFINE_BCAST_OP(trunc_divide, { |
272 | if (a.dtype().is_int() || a.dtype().is_uint()) { |
273 | return truncdiv(a, b); |
274 | } else { |
275 | return trunc(div(a, b)); |
276 | } |
277 | }); |
278 | |
279 | /*! |
280 | * \fn mod |
281 | * \brief Compute A % B with auto-broadcasting. |
282 | * |
283 | * \param A The first tensor, or Expr |
284 | * \param B The second tensor, or Expr |
285 | * \param name The name of the operation |
286 | * \param tag The tag to mark the operation |
287 | * |
288 | * \return The result. |
289 | */ |
290 | TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); }); |
291 | |
292 | /*! |
293 | * \fn floor mod |
294 | * \brief Compute A - floor_div(A, B) * B with auto-broadcasting. |
295 | * |
296 | * \param A The first tensor, or Expr |
297 | * \param B The second tensor, or Expr |
298 | * \param name The name of the operation |
299 | * \param tag The tag to mark the operation |
300 | * |
301 | * \return The result. |
302 | */ |
303 | TOPI_DEFINE_BCAST_OP(floor_mod, { |
304 | if (a.dtype().is_int() || a.dtype().is_uint()) { |
305 | return floormod(a, b); |
306 | } else { |
307 | return a - floor_divide(a, b) * b; |
308 | } |
309 | }); |
310 | |
311 | /*! |
312 | * \fn trunc mod |
313 | * \brief Compute A - trunc_div(A, B) * B with auto-broadcasting. |
314 | * |
315 | * \param A The first tensor, or Expr |
316 | * \param B The second tensor, or Expr |
317 | * \param name The name of the operation |
318 | * \param tag The tag to mark the operation |
319 | * |
320 | * \return The result. |
321 | */ |
322 | TOPI_DEFINE_BCAST_OP(trunc_mod, { |
323 | if (a.dtype().is_int() || a.dtype().is_uint()) { |
324 | return truncmod(a, b); |
325 | } else { |
326 | return a - trunc_divide(a, b) * b; |
327 | } |
328 | }); |
329 | |
330 | /*! |
331 | * \fn maximum |
332 | * \brief Compute maximum(A, B) with auto-broadcasting. |
333 | * |
334 | * \param A The first tensor, or Expr |
335 | * \param B The second tensor, or Expr |
336 | * \param name The name of the operation |
337 | * \param tag The tag to mark the operation |
338 | * |
339 | * \return The result. |
340 | */ |
341 | TOPI_DEFINE_BCAST_OP(maximum, { return tvm::max(a, b); }); |
342 | |
343 | /*! |
344 | * \fn minimum |
345 | * \brief Compute minimum(A, B) with auto-broadcasting. |
346 | * |
347 | * \param A The first tensor, or Expr |
348 | * \param B The second tensor, or Expr |
349 | * \param name The name of the operation |
350 | * \param tag The tag to mark the operation |
351 | * |
352 | * \return The result. |
353 | */ |
354 | TOPI_DEFINE_BCAST_OP(minimum, { return tvm::min(a, b); }); |
355 | |
356 | /*! |
357 | * \fn power |
358 | * \brief Compute power(A, B) with auto-broadcasting. |
359 | * |
360 | * \param A The first tensor, or Expr |
361 | * \param B The second tensor, or Expr |
362 | * \param name The name of the operation |
363 | * \param tag The tag to mark the operation |
364 | * |
365 | * \return The result. |
366 | */ |
367 | TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); }); |
368 | |
369 | /*! |
370 | * \fn left_shift |
371 | * \brief Compute A << B with auto-broadcasting. |
372 | * |
373 | * \param A The first tensor, or Expr |
374 | * \param B The second tensor, or Expr |
375 | * \param name The name of the operation |
376 | * \param tag The tag to mark the operation |
377 | * |
378 | * \return The result. |
379 | */ |
380 | TOPI_DEFINE_BCAST_OP(left_shift, { return a << b; }); |
381 | TOPI_DEFINE_OP_OVERLOAD(operator<<, left_shift); |
382 | |
383 | /*! |
384 | * \fn right_shift |
385 | * \brief Compute A >> B with auto-broadcasting. |
386 | * |
387 | * \param A The first tensor, or Expr |
388 | * \param B The second tensor, or Expr |
389 | * \param name The name of the operation |
390 | * \param tag The tag to mark the operation |
391 | * |
392 | * \return The result. |
393 | */ |
394 | TOPI_DEFINE_BCAST_OP(right_shift, { return a >> b; }); |
395 | TOPI_DEFINE_OP_OVERLOAD(operator>>, right_shift); |
396 | |
397 | /*! |
398 | * \fn greater |
399 | * \brief Compute (A > B) with auto-broadcasting. |
400 | * |
401 | * \param A The first tensor, or Expr |
402 | * \param B The second tensor, or Expr |
403 | * \param name The name of the operation |
404 | * \param tag The tag to mark the operation |
405 | * |
406 | * \return The result. |
407 | */ |
408 | TOPI_DEFINE_BCAST_OP(greater, { return (a > b); }); |
409 | |
410 | /*! |
411 | * \fn less |
412 | * \brief Compute (A < B) with auto-broadcasting. |
413 | * |
414 | * \param A The first tensor, or Expr |
415 | * \param B The second tensor, or Expr |
416 | * \param name The name of the operation |
417 | * \param tag The tag to mark the operation |
418 | * |
419 | * \return The result. |
420 | */ |
421 | TOPI_DEFINE_BCAST_OP(less, { return (a < b); }); |
422 | |
423 | /*! |
424 | * \fn equal |
425 | * \brief Compute (A == B) with auto-broadcasting. |
426 | * |
427 | * \param A The first tensor, or Expr |
428 | * \param B The second tensor, or Expr |
429 | * \param name The name of the operation |
430 | * \param tag The tag to mark the operation |
431 | * |
432 | * \return The result. |
433 | */ |
434 | TOPI_DEFINE_BCAST_OP(equal, { return (a == b); }); |
435 | |
436 | /*! |
437 | * \fn not_equal |
438 | * \brief Compute (A != B) with auto-broadcasting. |
439 | * |
440 | * \param A The first tensor, or Expr |
441 | * \param B The second tensor, or Expr |
442 | * \param name The name of the operation |
443 | * \param tag The tag to mark the operation |
444 | * |
445 | * \return The result. |
446 | */ |
447 | TOPI_DEFINE_BCAST_OP(not_equal, { return (a != b); }); |
448 | |
449 | /*! |
450 | * \fn greater_equal |
451 | * \brief Compute (A >= B) with auto-broadcasting. |
452 | * |
453 | * \param A The first tensor, or Expr |
454 | * \param B The second tensor, or Expr |
455 | * \param name The name of the operation |
456 | * \param tag The tag to mark the operation |
457 | * |
458 | * \return The result. |
459 | */ |
460 | TOPI_DEFINE_BCAST_OP(greater_equal, { return (a >= b); }); |
461 | |
462 | /*! |
463 | * \fn less_equal |
464 | * \brief Compute (A <= B) with auto-broadcasting. |
465 | * |
466 | * \param A The first tensor, or Expr |
467 | * \param B The second tensor, or Expr |
468 | * \param name The name of the operation |
469 | * \param tag The tag to mark the operation |
470 | * |
471 | * \return The result. |
472 | */ |
473 | TOPI_DEFINE_BCAST_OP(less_equal, { return (a <= b); }); |
474 | |
475 | } // namespace topi |
476 | } // namespace tvm |
477 | |
478 | #endif // TVM_TOPI_BROADCAST_H_ |
479 | |