1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Broadcast op constructions
22 * \file topi/broadcast.h
23 */
24#ifndef TVM_TOPI_BROADCAST_H_
25#define TVM_TOPI_BROADCAST_H_
26
27#include <tvm/topi/detail/broadcast.h>
28#include <tvm/topi/detail/constant_utils.h>
29#include <tvm/topi/tags.h>
30
31#include <algorithm>
32#include <string>
33
34namespace tvm {
35namespace topi {
36
37/*!
38 * \brief Creates an operation that broadcasts a tensor into a compatible
39 * shape according to numpy's rules
40 *
41 * \param t The input tensor
42 * \param output_shape The target output shape, must be compatible
43 * \param name The name of the operation
44 * \param tag The tag to mark the operation
45 *
46 * \return A Tensor whose op member is a broadcast operation
47 */
48inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
49 const tvm::Array<tvm::PrimExpr>& output_shape,
50 std::string name = "T_broadcast_to",
51 std::string tag = kBroadcast) {
52 ICHECK_GE(output_shape.size(), t->shape.size())
53 << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape
54 << "\nvs\ninput: " << t;
55 auto bh = detail::BroadcastShape(output_shape, t->shape);
56 ICHECK_EQ(output_shape.size(), bh.common_shape.size());
57 Array<PrimExpr> oshape;
58 for (size_t i = 0; i < output_shape.size(); ++i) {
59 if (output_shape[i].as<tir::IntImmNode>() == nullptr) {
60 oshape.push_back(output_shape[i]);
61 } else {
62 ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
63 oshape.push_back(bh.common_shape[i]);
64 }
65 }
66 auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
67 return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
68 };
69 return tvm::te::compute(oshape, l, name, tag);
70}
71
72#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
73 inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
74 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
75 std::string name = "T_" #Name, std::string tag = kBroadcast) { \
76 auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
77 return detail::WithBroadcast(l, A, B, name, tag); \
78 } \
79 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \
80 std::string name = "T_" #Name, std::string tag = kElementWise) { \
81 auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
82 return tvm::te::compute( \
83 A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \
84 } \
85 inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \
86 std::string name = "T_" #Name, std::string tag = kElementWise) { \
87 auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
88 return tvm::te::compute( \
89 B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \
90 }
91
92#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \
93 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \
94 return topi::OpName(A, B); \
95 } \
96 inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \
97 return topi::OpName(A, B); \
98 } \
99 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \
100 return topi::OpName(A, B); \
101 }
102
103/*!
104 * \fn logical_and
105 * \brief Compute A && B with auto-broadcasting.
106 *
107 * \param A The first tensor, or Expr
108 * \param B The second tensor, or Expr
109 * \param name The name of the operation
110 * \param tag The tag to mark the operation
111 *
112 * \return The result.
113 */
114TOPI_DEFINE_BCAST_OP(logical_and, { return a && b; });
115TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and);
116
117/*!
118 * \fn logical_or
119 * \brief Compute A || B with auto-broadcasting.
120 *
121 * \param A The first tensor, or Expr
122 * \param B The second tensor, or Expr
123 * \param name The name of the operation
124 * \param tag The tag to mark the operation
125 *
126 * \return The result.
127 */
128TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
129TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);
130
131/*!
132 * \fn logical_xor
133 * \brief Compute A ^ B with auto-broadcasting.
134 *
135 * \param A The first tensor, or Expr
136 * \param B The second tensor, or Expr
137 * \param name The name of the operation
138 * \param tag The tag to mark the operation
139 *
140 * \return The result.
141 */
142TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; });
143
144/*!
145 * \fn bitwise_and
146 * \brief Compute A & B with auto-broadcasting.
147 *
148 * \param A The first tensor, or Expr
149 * \param B The second tensor, or Expr
150 * \param name The name of the operation
151 * \param tag The tag to mark the operation
152 *
153 * \return The result.
154 */
155TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; });
156TOPI_DEFINE_OP_OVERLOAD(operator&, bitwise_and);
157
158/*!
159 * \fn bitwise_or
160 * \brief Compute A | B with auto-broadcasting.
161 *
162 * \param A The first tensor, or Expr
163 * \param B The second tensor, or Expr
164 * \param name The name of the operation
165 * \param tag The tag to mark the operation
166 *
167 * \return The result.
168 */
169TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; });
170TOPI_DEFINE_OP_OVERLOAD(operator|, bitwise_or);
171
172/*!
173 * \fn bitwise_xor
174 * \brief Compute A ^ B with auto-broadcasting.
175 *
176 * \param A The first tensor, or Expr
177 * \param B The second tensor, or Expr
178 * \param name The name of the operation
179 * \param tag The tag to mark the operation
180 *
181 * \return The result.
182 */
183TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; });
184TOPI_DEFINE_OP_OVERLOAD(operator^, bitwise_xor);
185
186/*!
187 * \fn add
188 * \brief Compute A + B with auto-broadcasting.
189 *
190 * \param A The first tensor, or Expr
191 * \param B The second tensor, or Expr
192 * \param name The name of the operation
193 * \param tag The tag to mark the operation
194 *
195 * \return The result.
196 */
197TOPI_DEFINE_BCAST_OP(add, { return a + b; });
198TOPI_DEFINE_OP_OVERLOAD(operator+, add);
199
200/*!
201 * \fn subtract
202 * \brief Compute A - B with auto-broadcasting.
203 *
204 * \param A The first tensor, or Expr
205 * \param B The second tensor, or Expr
206 * \param name The name of the operation
207 * \param tag The tag to mark the operation
208 *
209 * \return The result.
210 */
211TOPI_DEFINE_BCAST_OP(subtract, { return a - b; });
212TOPI_DEFINE_OP_OVERLOAD(operator-, subtract);
213
214/*!
215 * \fn multiply
216 * \brief Compute A * B with auto-broadcasting.
217 *
218 * \param A The first tensor, or Expr
219 * \param B The second tensor, or Expr
220 * \param name The name of the operation
221 * \param tag The tag to mark the operation
222 *
223 * \return The result.
224 */
225TOPI_DEFINE_BCAST_OP(multiply, { return a * b; });
226TOPI_DEFINE_OP_OVERLOAD(operator*, multiply);
227
228/*!
229 * \fn divide
230 * \brief Compute A / B with auto-broadcasting.
231 *
232 * \param A The first tensor, or Expr
233 * \param B The second tensor, or Expr
234 * \param name The name of the operation
235 * \param tag The tag to mark the operation
236 *
237 * \return The result.
238 */
239TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); });
240
241/*!
242 * \fn floor divide
243 * \brief Compute floor(A / B) with auto-broadcasting.
244 *
245 * \param A The first tensor, or Expr
246 * \param B The second tensor, or Expr
247 * \param name The name of the operation
248 * \param tag The tag to mark the operation
249 *
250 * \return The result.
251 */
252TOPI_DEFINE_BCAST_OP(floor_divide, {
253 if (a.dtype().is_int() || a.dtype().is_uint()) {
254 return floordiv(a, b);
255 } else {
256 return floor(div(a, b));
257 }
258});
259
260/*!
261 * \fn trunc divide
262 * \brief Compute trunc(A / B) with auto-broadcasting.
263 *
264 * \param A The first tensor, or Expr
265 * \param B The second tensor, or Expr
266 * \param name The name of the operation
267 * \param tag The tag to mark the operation
268 *
269 * \return The result.
270 */
271TOPI_DEFINE_BCAST_OP(trunc_divide, {
272 if (a.dtype().is_int() || a.dtype().is_uint()) {
273 return truncdiv(a, b);
274 } else {
275 return trunc(div(a, b));
276 }
277});
278
279/*!
280 * \fn mod
281 * \brief Compute A % B with auto-broadcasting.
282 *
283 * \param A The first tensor, or Expr
284 * \param B The second tensor, or Expr
285 * \param name The name of the operation
286 * \param tag The tag to mark the operation
287 *
288 * \return The result.
289 */
290TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); });
291
292/*!
293 * \fn floor mod
294 * \brief Compute A - floor_div(A, B) * B with auto-broadcasting.
295 *
296 * \param A The first tensor, or Expr
297 * \param B The second tensor, or Expr
298 * \param name The name of the operation
299 * \param tag The tag to mark the operation
300 *
301 * \return The result.
302 */
303TOPI_DEFINE_BCAST_OP(floor_mod, {
304 if (a.dtype().is_int() || a.dtype().is_uint()) {
305 return floormod(a, b);
306 } else {
307 return a - floor_divide(a, b) * b;
308 }
309});
310
311/*!
312 * \fn trunc mod
313 * \brief Compute A - trunc_div(A, B) * B with auto-broadcasting.
314 *
315 * \param A The first tensor, or Expr
316 * \param B The second tensor, or Expr
317 * \param name The name of the operation
318 * \param tag The tag to mark the operation
319 *
320 * \return The result.
321 */
322TOPI_DEFINE_BCAST_OP(trunc_mod, {
323 if (a.dtype().is_int() || a.dtype().is_uint()) {
324 return truncmod(a, b);
325 } else {
326 return a - trunc_divide(a, b) * b;
327 }
328});
329
330/*!
331 * \fn maximum
332 * \brief Compute maximum(A, B) with auto-broadcasting.
333 *
334 * \param A The first tensor, or Expr
335 * \param B The second tensor, or Expr
336 * \param name The name of the operation
337 * \param tag The tag to mark the operation
338 *
339 * \return The result.
340 */
341TOPI_DEFINE_BCAST_OP(maximum, { return tvm::max(a, b); });
342
343/*!
344 * \fn minimum
345 * \brief Compute minimum(A, B) with auto-broadcasting.
346 *
347 * \param A The first tensor, or Expr
348 * \param B The second tensor, or Expr
349 * \param name The name of the operation
350 * \param tag The tag to mark the operation
351 *
352 * \return The result.
353 */
354TOPI_DEFINE_BCAST_OP(minimum, { return tvm::min(a, b); });
355
356/*!
357 * \fn power
358 * \brief Compute power(A, B) with auto-broadcasting.
359 *
360 * \param A The first tensor, or Expr
361 * \param B The second tensor, or Expr
362 * \param name The name of the operation
363 * \param tag The tag to mark the operation
364 *
365 * \return The result.
366 */
367TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); });
368
369/*!
370 * \fn left_shift
371 * \brief Compute A << B with auto-broadcasting.
372 *
373 * \param A The first tensor, or Expr
374 * \param B The second tensor, or Expr
375 * \param name The name of the operation
376 * \param tag The tag to mark the operation
377 *
378 * \return The result.
379 */
380TOPI_DEFINE_BCAST_OP(left_shift, { return a << b; });
381TOPI_DEFINE_OP_OVERLOAD(operator<<, left_shift);
382
383/*!
384 * \fn right_shift
385 * \brief Compute A >> B with auto-broadcasting.
386 *
387 * \param A The first tensor, or Expr
388 * \param B The second tensor, or Expr
389 * \param name The name of the operation
390 * \param tag The tag to mark the operation
391 *
392 * \return The result.
393 */
394TOPI_DEFINE_BCAST_OP(right_shift, { return a >> b; });
395TOPI_DEFINE_OP_OVERLOAD(operator>>, right_shift);
396
397/*!
398 * \fn greater
399 * \brief Compute (A > B) with auto-broadcasting.
400 *
401 * \param A The first tensor, or Expr
402 * \param B The second tensor, or Expr
403 * \param name The name of the operation
404 * \param tag The tag to mark the operation
405 *
406 * \return The result.
407 */
408TOPI_DEFINE_BCAST_OP(greater, { return (a > b); });
409
410/*!
411 * \fn less
412 * \brief Compute (A < B) with auto-broadcasting.
413 *
414 * \param A The first tensor, or Expr
415 * \param B The second tensor, or Expr
416 * \param name The name of the operation
417 * \param tag The tag to mark the operation
418 *
419 * \return The result.
420 */
421TOPI_DEFINE_BCAST_OP(less, { return (a < b); });
422
423/*!
424 * \fn equal
425 * \brief Compute (A == B) with auto-broadcasting.
426 *
427 * \param A The first tensor, or Expr
428 * \param B The second tensor, or Expr
429 * \param name The name of the operation
430 * \param tag The tag to mark the operation
431 *
432 * \return The result.
433 */
434TOPI_DEFINE_BCAST_OP(equal, { return (a == b); });
435
436/*!
437 * \fn not_equal
438 * \brief Compute (A != B) with auto-broadcasting.
439 *
440 * \param A The first tensor, or Expr
441 * \param B The second tensor, or Expr
442 * \param name The name of the operation
443 * \param tag The tag to mark the operation
444 *
445 * \return The result.
446 */
447TOPI_DEFINE_BCAST_OP(not_equal, { return (a != b); });
448
449/*!
450 * \fn greater_equal
451 * \brief Compute (A >= B) with auto-broadcasting.
452 *
453 * \param A The first tensor, or Expr
454 * \param B The second tensor, or Expr
455 * \param name The name of the operation
456 * \param tag The tag to mark the operation
457 *
458 * \return The result.
459 */
460TOPI_DEFINE_BCAST_OP(greater_equal, { return (a >= b); });
461
462/*!
463 * \fn less_equal
464 * \brief Compute (A <= B) with auto-broadcasting.
465 *
466 * \param A The first tensor, or Expr
467 * \param B The second tensor, or Expr
468 * \param name The name of the operation
469 * \param tag The tag to mark the operation
470 *
471 * \return The result.
472 */
473TOPI_DEFINE_BCAST_OP(less_equal, { return (a <= b); });
474
475} // namespace topi
476} // namespace tvm
477
478#endif // TVM_TOPI_BROADCAST_H_
479