1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Detail broadcast.
22 * \file topi/detail/broadcast.h
23 */
24#ifndef TVM_TOPI_DETAIL_BROADCAST_H_
25#define TVM_TOPI_DETAIL_BROADCAST_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/topi/detail/constant_utils.h>
29
30#include <algorithm>
31#include <deque>
32#include <string>
33
34namespace tvm {
35namespace topi {
36namespace detail {
37
38struct BroadcastHelper {
39 std::deque<tvm::PrimExpr> common_shape;
40 std::deque<tvm::tir::Var> all_vars;
41 std::deque<tvm::tir::Var> vars1;
42 std::deque<tvm::tir::Var> vars2;
43};
44
45static inline DataType CommonType(DataType type1, DataType type2) {
46 ICHECK(type1.is_scalar() && type2.is_scalar());
47 ICHECK(type1.code() == type2.code());
48 return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1);
49}
50
51inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
52 const tvm::Array<tvm::PrimExpr>& shape2) {
53 BroadcastHelper bh;
54 int s1_size = shape1.size();
55 int s2_size = shape2.size();
56 tvm::PrimExpr one(1);
57 int i;
58
59 auto cast_if_needed = [](DataType to_type, PrimExpr expr) {
60 return to_type != expr.dtype() ? cast(to_type, expr) : expr;
61 };
62
63 for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
64 // TODO(@icemelon9): Need to revisit this part
65 const IntImmNode* static_size1 = shape1[s1_size - i].as<IntImmNode>();
66 const IntImmNode* static_size2 = shape2[s2_size - i].as<IntImmNode>();
67 DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype());
68
69 bh.all_vars.push_front(tvm::tir::Var("dim", common_type));
70 if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
71 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
72 bh.vars1.push_front(bh.all_vars[0]);
73 bh.vars2.push_front(bh.all_vars[0]);
74 } else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
75 ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
76 bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
77 bh.vars2.push_front(bh.all_vars[0]);
78 } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
79 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
80 bh.vars1.push_front(bh.all_vars[0]);
81 } else if (!static_size1 && !static_size2) {
82 bh.common_shape.push_front(
83 cast_if_needed(common_type, max(shape1[s1_size - i], shape2[s2_size - i])));
84 bh.vars1.push_front(bh.all_vars[0]);
85 bh.vars2.push_front(bh.all_vars[0]);
86 } else if (!static_size1) {
87 bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
88 bh.vars2.push_front(bh.all_vars[0]);
89 bh.vars1.push_front(bh.all_vars[0]);
90 } else if (!static_size2) {
91 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
92 bh.vars1.push_front(bh.all_vars[0]);
93 bh.vars2.push_front(bh.all_vars[0]);
94 } else {
95 ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
96 << shape2[s2_size - i]
97 << " in: " << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end()) << " and "
98 << tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
99 }
100 }
101 // Remaining dimensions whether on shape1 or shape2 can always be completed
102 auto max_size = std::max(s1_size, s2_size);
103 auto& shape = (s1_size > s2_size) ? shape1 : shape2;
104 auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
105 for (; i <= max_size; ++i) {
106 bh.all_vars.push_front(tvm::tir::Var("v", shape[max_size - 1].dtype()));
107 bh.common_shape.push_front(shape[max_size - i]);
108 vars.push_front(bh.all_vars[0]);
109 }
110 return bh;
111}
112
113inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
114 const tvm::Array<tvm::tir::Var>& ovars, const tvm::te::Tensor& T,
115 const std::deque<tvm::tir::Var>& my_vars, const std::deque<tvm::tir::Var>& all_vars) {
116 tvm::Array<tvm::PrimExpr> ivars;
117 ICHECK_EQ(ovars.size(), all_vars.size());
118 // N^2, could use a map but NBD.
119 size_t expected_dims = T->shape.size();
120 for (size_t i = 0; i < ovars.size(); ++i) {
121 bool found = false;
122 for (size_t j = 0; j < my_vars.size(); ++j) {
123 if (all_vars[i].same_as(my_vars[j])) {
124 ivars.push_back(ovars[i]);
125 found = true;
126 break;
127 }
128 }
129 // Only inject 0 here if we have not yet reached the dimension of I
130 // (i.e. this must be a 1)
131 if (!found && (ovars.size() - i) <= expected_dims) {
132 ivars.push_back(tvm::tir::make_zero(ovars[i].dtype()));
133 }
134 }
135 ICHECK(expected_dims == ivars.size());
136 return ivars;
137}
138
139template <typename FBinaryExpr>
140inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A,
141 const tvm::te::Tensor& B, const std::string& name = "tensor",
142 const std::string& tag = "") {
143 auto bh = BroadcastShape(A->shape, B->shape);
144 auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
145 return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
146 B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
147 };
148 return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
149 l, name, tag);
150}
151
152} // namespace detail
153} // namespace topi
154} // namespace tvm
155
156#endif // TVM_TOPI_DETAIL_BROADCAST_H_
157