1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Logics related to cross thread reduction, used by ComputeOpNode.
22 * \file cross_thread_reduction.cc
23 */
24#include <tvm/tir/builtin.h>
25
26#include "compute_op.h"
27#include "op_utils.h"
28
29namespace tvm {
30namespace te {
31using namespace tir;
32
33//
34// Cross thread reduction transformation.
35//
36// The input loop nest in generic form (single reduction/thread case)
37//
38// let m be the reduction extent
39// let N be the thread extent
40// let input_pred be the predicate on the reduction
41//
42// B[..] = 0
43// for (tid, 0, N)
44// for (i, 0, floordiv(m+N-1, N))
45// if (i + tid * floordiv(m+N-1, N) < m)
46// if (input_pred)
47// B[..] = op(B[..], A[i + tid * floordiv(m+N-1,N)])
48//
49// The threaded reduction looks like
50//
51// (1) normal reductions (leaves)
52// for (i, 0, floordiv(m+N-1, N))
53// if (i + tid * floordiv(m+N-1, N) < m)
54// if (input_pred)
55// B_temp[0] = op(B_temp[0], A[i + tid * floordiv(m+N-1,N)])
56//
57// (2) threaded reduction does not require predicates as an identity
58// element will be filled if out of bounds.
59//
60// tvm_thread_allreduce(size, B_temp, (bool)1, tid)
61//
62// The last step is to write the final reduction variable,
63// which should be predicated by the existing input_pred if any
64// The consequence is that input_pred should be independent of
65// the reduction axis. Otherwise, we need to separate it into
66// dependent part and independent one.
67//
68// (3) write back
69// if (input_pred)
70// B[..] = B_temp[0]
71//
72// In summary, we are going to need two predicates
73//
74// * the original input_pred from reduction itself
75//
76// * the normal reduction axis predicate
77// normal_pred = (i + tid * floordiv(m+N-1,N)) < m
78// this predicate depends on the normal reduction variable.
79//
80// input_pred will be applied to both normal reduction and
81// the writeback step.
82//
83Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
84 const std::unordered_map<IterVar, Range>& dom_map,
85 bool debug_keep_trivial_loop) {
86 Array<PrimExpr> args;
87 for (IterVar iv : self->axis) {
88 args.push_back(iv->var);
89 }
90 std::unordered_map<IterVar, PrimExpr> value_map;
91 auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map,
92 debug_keep_trivial_loop);
93
94 size_t size = self->body.size();
95 ICHECK_GT(size, 0);
96 std::vector<const ReduceNode*> reduces(size);
97 for (size_t i = 0; i < size; ++i) {
98 const ReduceNode* reduce = self->body[i].as<ReduceNode>();
99 ICHECK(reduce);
100 ICHECK(reduce->init.empty())
101 << "Cannot perform cross_thread_reduction for reductions with init";
102 reduces[i] = reduce;
103 }
104
105 // This computes the bound checking predicates in normal reduction.
106 auto normal_preds =
107 MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>());
108
109 // normal_pred = input_pred && normal_pred
110 PrimExpr input_pred = reduces[0]->condition;
111 normal_preds.push_back(input_pred);
112 normal_preds.erase(std::remove_if(normal_preds.begin(), normal_preds.end(),
113 [](const PrimExpr& e) { return !e.defined(); }),
114 normal_preds.end());
115
116 std::vector<std::vector<Stmt>> common, normal_red;
117 for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) {
118 IterVar iv = stage->leaf_iter_vars[i];
119 IterVarAttr attr;
120 auto it = stage->iter_var_attrs.find(iv);
121 if (it != stage->iter_var_attrs.end()) {
122 attr = (*it).second;
123 }
124 if (iv->iter_type == kCommReduce) {
125 if (attr.defined() && attr->bind_thread.defined()) {
126 common.emplace_back(nest[i + 1]);
127 } else {
128 normal_red.emplace_back(nest[i + 1]);
129 }
130 } else {
131 common.emplace_back(nest[i + 1]);
132 }
133 }
134
135 // If we load from and then store into the same res_handles in the thread_allreduce intrinsic,
136 // something goes wrong, so we use an extra variable here for normal reduction.
137 std::vector<Buffer> normal_res_buffers;
138 std::vector<Stmt> normal_init, normal_update;
139 if (!normal_red.empty()) {
140 normal_res_buffers.reserve(size);
141 normal_init.reserve(size);
142 normal_update.resize(size);
143 const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>();
144 ICHECK(combiner);
145 Array<PrimExpr> lhs;
146 for (size_t i = 0; i < size; ++i) {
147 normal_res_buffers.push_back(
148 decl_buffer({1}, reduces[i]->dtype, "normal_reduce_temp" + std::to_string(i), "local"));
149 lhs.push_back(BufferLoad(normal_res_buffers[i], {0}));
150 }
151 Array<PrimExpr> init_value = combiner->identity_element;
152 Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source);
153 for (size_t i = 0; i < size; ++i) {
154 normal_init.emplace_back(BufferStore(normal_res_buffers[i], init_value[i], {0}));
155 normal_update.emplace_back(BufferStore(normal_res_buffers[i], update_value[i], {0}));
156 }
157 }
158
159 Array<PrimExpr> freduce_args;
160 freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
161 for (size_t i = 0; i < size; ++i) {
162 if (!normal_red.empty()) {
163 freduce_args.push_back(BufferLoad(normal_res_buffers[i], {0}));
164 } else {
165 freduce_args.push_back(reduces[0]->source[i]);
166 }
167 }
168
169 // No constraints on the thread reduction step. It may have redundent
170 // computation for rare cases. TODO(tvm-team): revisit this.
171 freduce_args.push_back(const_true(1));
172 std::vector<Buffer> res_buffers(size);
173 for (size_t idx = 0; idx < size; ++idx) {
174 res_buffers[idx] =
175 decl_buffer({1}, reduces[idx]->dtype, "reduce_temp" + std::to_string(idx), "local");
176 // Make a BufferLoad object so that we can pass the entire Buffer
177 // object through to LowerThreadAllreduce. The index here is
178 // unused.
179 PrimExpr dummy_load = BufferLoad(res_buffers[idx], {0});
180 freduce_args.push_back(dummy_load);
181 }
182
183 for (IterVar iv : stage->leaf_iter_vars) {
184 if (iv->iter_type == kCommReduce) {
185 auto it = stage->iter_var_attrs.find(iv);
186 if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
187 IterVar tv = (*it).second->bind_thread;
188 freduce_args.push_back(tv->var);
189 }
190 }
191 }
192
193 // Checks for the thread.
194 std::vector<PrimExpr> output_preds;
195 if (stage->store_predicate.defined()) {
196 output_preds.emplace_back(stage->store_predicate);
197 }
198
199 // Apply the existing input predicate if any.
200 output_preds.push_back(input_pred);
201
202 Stmt reduce_body =
203 Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args));
204 reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope,
205 make_zero(DataType::Handle()), reduce_body);
206
207 if (!normal_red.empty()) {
208 Stmt init_body = SeqStmt::Flatten(normal_init);
209 Stmt update_body = SeqStmt::Flatten(normal_update);
210 update_body = MergeNest(MakeIfNest(normal_preds), update_body);
211 update_body = MergeNest(normal_red, update_body);
212 reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body);
213 }
214
215 std::vector<Stmt> assigns(size);
216 for (size_t idx = 0; idx < size; ++idx) {
217 assigns[idx] = ProducerStore(stage->op.output(idx), BufferLoad(res_buffers[idx], {0}), args);
218 }
219 Stmt assign_body = SeqStmt::Flatten(assigns);
220 assign_body = MergeNest(MakeIfNest(output_preds), assign_body);
221 Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
222 for (size_t idx = size; idx != 0; --idx) {
223 const auto& res_buffer = res_buffers[idx - 1];
224 body = Allocate(res_buffer->data, res_buffer->dtype, res_buffer->shape, const_true(), body);
225 if (!normal_red.empty()) {
226 const auto& normal_res_buffer = normal_res_buffers[idx - 1];
227 body = Allocate(normal_res_buffer->data, normal_res_buffer->dtype, normal_res_buffer->shape,
228 const_true(), body);
229 }
230 }
231 body = Substitute(body, value_map);
232 return MergeNest(common, body);
233}
234} // namespace te
235} // namespace tvm
236