1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Logics related to cross thread reduction, used by ComputeOpNode. |
22 | * \file cross_thread_reduction.cc |
23 | */ |
24 | #include <tvm/tir/builtin.h> |
25 | |
26 | #include "compute_op.h" |
27 | #include "op_utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace te { |
31 | using namespace tir; |
32 | |
33 | // |
34 | // Cross thread reduction transformation. |
35 | // |
36 | // The input loop nest in generic form (single reduction/thread case) |
37 | // |
38 | // let m be the reduction extent |
39 | // let N be the thread extent |
40 | // let input_pred be the predicate on the reduction |
41 | // |
42 | // B[..] = 0 |
43 | // for (tid, 0, N) |
44 | // for (i, 0, floordiv(m+N-1, N)) |
45 | // if (i + tid * floordiv(m+N-1, N) < m) |
46 | // if (input_pred) |
47 | // B[..] = op(B[..], A[i + tid * floordiv(m+N-1,N)]) |
48 | // |
49 | // The threaded reduction looks like |
50 | // |
51 | // (1) normal reductions (leaves) |
52 | // for (i, 0, floordiv(m+N-1, N)) |
53 | // if (i + tid * floordiv(m+N-1, N) < m) |
54 | // if (input_pred) |
55 | // B_temp[0] = op(B_temp[0], A[i + tid * floordiv(m+N-1,N)]) |
56 | // |
57 | // (2) threaded reduction does not require predicates as an identity |
58 | // element will be filled if out of bounds. |
59 | // |
60 | // tvm_thread_allreduce(size, B_temp, (bool)1, tid) |
61 | // |
62 | // The last step is to write the final reduction variable, |
63 | // which should be predicated by the existing input_pred if any |
64 | // The consequence is that input_pred should be independent of |
65 | // the reduction axis. Otherwise, we need to separate it into |
66 | // dependent part and independent one. |
67 | // |
68 | // (3) write back |
69 | // if (input_pred) |
70 | // B[..] = B_temp[0] |
71 | // |
72 | // In summary, we are going to need two predicates |
73 | // |
74 | // * the original input_pred from reduction itself |
75 | // |
76 | // * the normal reduction axis predicate |
77 | // normal_pred = (i + tid * floordiv(m+N-1,N)) < m |
78 | // this predicate depends on the normal reduction variable. |
79 | // |
80 | // input_pred will be applied to both normal reduction and |
81 | // the writeback step. |
82 | // |
83 | Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, |
84 | const std::unordered_map<IterVar, Range>& dom_map, |
85 | bool debug_keep_trivial_loop) { |
86 | Array<PrimExpr> args; |
87 | for (IterVar iv : self->axis) { |
88 | args.push_back(iv->var); |
89 | } |
90 | std::unordered_map<IterVar, PrimExpr> value_map; |
91 | auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, |
92 | debug_keep_trivial_loop); |
93 | |
94 | size_t size = self->body.size(); |
95 | ICHECK_GT(size, 0); |
96 | std::vector<const ReduceNode*> reduces(size); |
97 | for (size_t i = 0; i < size; ++i) { |
98 | const ReduceNode* reduce = self->body[i].as<ReduceNode>(); |
99 | ICHECK(reduce); |
100 | ICHECK(reduce->init.empty()) |
101 | << "Cannot perform cross_thread_reduction for reductions with init" ; |
102 | reduces[i] = reduce; |
103 | } |
104 | |
105 | // This computes the bound checking predicates in normal reduction. |
106 | auto normal_preds = |
107 | MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>()); |
108 | |
109 | // normal_pred = input_pred && normal_pred |
110 | PrimExpr input_pred = reduces[0]->condition; |
111 | normal_preds.push_back(input_pred); |
112 | normal_preds.erase(std::remove_if(normal_preds.begin(), normal_preds.end(), |
113 | [](const PrimExpr& e) { return !e.defined(); }), |
114 | normal_preds.end()); |
115 | |
116 | std::vector<std::vector<Stmt>> common, normal_red; |
117 | for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) { |
118 | IterVar iv = stage->leaf_iter_vars[i]; |
119 | IterVarAttr attr; |
120 | auto it = stage->iter_var_attrs.find(iv); |
121 | if (it != stage->iter_var_attrs.end()) { |
122 | attr = (*it).second; |
123 | } |
124 | if (iv->iter_type == kCommReduce) { |
125 | if (attr.defined() && attr->bind_thread.defined()) { |
126 | common.emplace_back(nest[i + 1]); |
127 | } else { |
128 | normal_red.emplace_back(nest[i + 1]); |
129 | } |
130 | } else { |
131 | common.emplace_back(nest[i + 1]); |
132 | } |
133 | } |
134 | |
135 | // If we load from and then store into the same res_handles in the thread_allreduce intrinsic, |
136 | // something goes wrong, so we use an extra variable here for normal reduction. |
137 | std::vector<Buffer> normal_res_buffers; |
138 | std::vector<Stmt> normal_init, normal_update; |
139 | if (!normal_red.empty()) { |
140 | normal_res_buffers.reserve(size); |
141 | normal_init.reserve(size); |
142 | normal_update.resize(size); |
143 | const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>(); |
144 | ICHECK(combiner); |
145 | Array<PrimExpr> lhs; |
146 | for (size_t i = 0; i < size; ++i) { |
147 | normal_res_buffers.push_back( |
148 | decl_buffer({1}, reduces[i]->dtype, "normal_reduce_temp" + std::to_string(i), "local" )); |
149 | lhs.push_back(BufferLoad(normal_res_buffers[i], {0})); |
150 | } |
151 | Array<PrimExpr> init_value = combiner->identity_element; |
152 | Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source); |
153 | for (size_t i = 0; i < size; ++i) { |
154 | normal_init.emplace_back(BufferStore(normal_res_buffers[i], init_value[i], {0})); |
155 | normal_update.emplace_back(BufferStore(normal_res_buffers[i], update_value[i], {0})); |
156 | } |
157 | } |
158 | |
159 | Array<PrimExpr> freduce_args; |
160 | freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size))); |
161 | for (size_t i = 0; i < size; ++i) { |
162 | if (!normal_red.empty()) { |
163 | freduce_args.push_back(BufferLoad(normal_res_buffers[i], {0})); |
164 | } else { |
165 | freduce_args.push_back(reduces[0]->source[i]); |
166 | } |
167 | } |
168 | |
169 | // No constraints on the thread reduction step. It may have redundent |
170 | // computation for rare cases. TODO(tvm-team): revisit this. |
171 | freduce_args.push_back(const_true(1)); |
172 | std::vector<Buffer> res_buffers(size); |
173 | for (size_t idx = 0; idx < size; ++idx) { |
174 | res_buffers[idx] = |
175 | decl_buffer({1}, reduces[idx]->dtype, "reduce_temp" + std::to_string(idx), "local" ); |
176 | // Make a BufferLoad object so that we can pass the entire Buffer |
177 | // object through to LowerThreadAllreduce. The index here is |
178 | // unused. |
179 | PrimExpr dummy_load = BufferLoad(res_buffers[idx], {0}); |
180 | freduce_args.push_back(dummy_load); |
181 | } |
182 | |
183 | for (IterVar iv : stage->leaf_iter_vars) { |
184 | if (iv->iter_type == kCommReduce) { |
185 | auto it = stage->iter_var_attrs.find(iv); |
186 | if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { |
187 | IterVar tv = (*it).second->bind_thread; |
188 | freduce_args.push_back(tv->var); |
189 | } |
190 | } |
191 | } |
192 | |
193 | // Checks for the thread. |
194 | std::vector<PrimExpr> output_preds; |
195 | if (stage->store_predicate.defined()) { |
196 | output_preds.emplace_back(stage->store_predicate); |
197 | } |
198 | |
199 | // Apply the existing input predicate if any. |
200 | output_preds.push_back(input_pred); |
201 | |
202 | Stmt reduce_body = |
203 | Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args)); |
204 | reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, |
205 | make_zero(DataType::Handle()), reduce_body); |
206 | |
207 | if (!normal_red.empty()) { |
208 | Stmt init_body = SeqStmt::Flatten(normal_init); |
209 | Stmt update_body = SeqStmt::Flatten(normal_update); |
210 | update_body = MergeNest(MakeIfNest(normal_preds), update_body); |
211 | update_body = MergeNest(normal_red, update_body); |
212 | reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body); |
213 | } |
214 | |
215 | std::vector<Stmt> assigns(size); |
216 | for (size_t idx = 0; idx < size; ++idx) { |
217 | assigns[idx] = ProducerStore(stage->op.output(idx), BufferLoad(res_buffers[idx], {0}), args); |
218 | } |
219 | Stmt assign_body = SeqStmt::Flatten(assigns); |
220 | assign_body = MergeNest(MakeIfNest(output_preds), assign_body); |
221 | Stmt body = SeqStmt::Flatten(reduce_body, assign_body); |
222 | for (size_t idx = size; idx != 0; --idx) { |
223 | const auto& res_buffer = res_buffers[idx - 1]; |
224 | body = Allocate(res_buffer->data, res_buffer->dtype, res_buffer->shape, const_true(), body); |
225 | if (!normal_red.empty()) { |
226 | const auto& normal_res_buffer = normal_res_buffers[idx - 1]; |
227 | body = Allocate(normal_res_buffer->data, normal_res_buffer->dtype, normal_res_buffer->shape, |
228 | const_true(), body); |
229 | } |
230 | } |
231 | body = Substitute(body, value_map); |
232 | return MergeNest(common, body); |
233 | } |
234 | } // namespace te |
235 | } // namespace tvm |
236 | |