1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file schedule_lang.cc |
22 | */ |
23 | #include <dmlc/thread_local.h> |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/te/operation.h> |
26 | #include <tvm/te/schedule.h> |
27 | |
28 | #include <algorithm> |
29 | #include <stack> |
30 | #include <unordered_set> |
31 | #include <vector> |
32 | |
33 | #include "graph.h" |
34 | |
35 | namespace tvm { |
36 | namespace te { |
37 | |
38 | // find first occurance location in leaf |
39 | template <typename T> |
40 | size_t FindNodeRef(ArrayNode* array_node, const T& v) { |
41 | const Object* n = v.get(); |
42 | for (size_t i = 0; i < array_node->size(); ++i) { |
43 | if (array_node->at(i).get() == n) return i; |
44 | } |
45 | return array_node->size(); |
46 | } |
47 | |
48 | size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) { |
49 | size_t pos = FindNodeRef(leaf_vars, v); |
50 | if (pos < leaf_vars->size()) return pos; |
51 | |
52 | if (FindNodeRef(all_vars, v) < all_vars->size()) { |
53 | LOG(FATAL) << "Operate on iter var " << v << "that has already been split" ; |
54 | } else { |
55 | LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule" ; |
56 | } |
57 | return 0; |
58 | } |
59 | |
60 | DataType MatchDataType(std::vector<DataType> dtypes) { |
61 | int max_bits = -1; |
62 | for (const auto& dtype : dtypes) { |
63 | ICHECK(dtype.is_int()); |
64 | ICHECK(dtype.is_scalar()); |
65 | max_bits = std::max(max_bits, dtype.bits()); |
66 | } |
67 | return DataType::Int(max_bits); |
68 | } |
69 | |
70 | void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, |
71 | IterVar* p_outer, IterVar* p_inner) { |
72 | // Check if split is valid. |
73 | ICHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || |
74 | parent->iter_type == kOrdered) |
75 | << "Cannot split on " << IterVarType2String(parent->iter_type); |
76 | IterVar outer = IterVar(Range(), parent->var.copy_with_suffix(".outer" ), parent->iter_type); |
77 | IterVar inner = IterVar(Range(), parent->var.copy_with_suffix(".inner" ), parent->iter_type); |
78 | *p_outer = outer; |
79 | *p_inner = inner; |
80 | // The splits |
81 | Array<IterVar>& all_vars = self->all_iter_vars; |
82 | Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
83 | size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); |
84 | self->relations.push_back(Split(parent, outer, inner, factor, nparts)); |
85 | // add vars to all vars |
86 | all_vars.push_back(outer); |
87 | all_vars.push_back(inner); |
88 | // replace the position. |
89 | leaf_vars.erase(leaf_vars.begin() + pos); |
90 | leaf_vars.insert(leaf_vars.begin() + pos, inner); |
91 | leaf_vars.insert(leaf_vars.begin() + pos, outer); |
92 | } |
93 | |
94 | Stage::Stage(Operation op) { |
95 | auto n = make_object<StageNode>(); |
96 | n->op = op; |
97 | n->origin_op = op; |
98 | n->all_iter_vars = op->root_iter_vars(); |
99 | // remove opaque var from leaf. |
100 | Array<IterVar> clean; |
101 | for (IterVar iv : n->all_iter_vars) { |
102 | if (iv->iter_type != kOpaque) clean.push_back(iv); |
103 | } |
104 | if (clean.size() == n->all_iter_vars.size()) { |
105 | n->leaf_iter_vars = n->all_iter_vars; |
106 | } else { |
107 | n->leaf_iter_vars = clean; |
108 | } |
109 | data_ = std::move(n); |
110 | } |
111 | |
112 | bool Stage::is_scheduled() const { |
113 | const StageNode* n = operator->(); |
114 | return !(n->relations.empty() && n->attach_type == kGroupRoot && |
115 | n->all_iter_vars.same_as(n->leaf_iter_vars)); |
116 | } |
117 | |
118 | Stage Stage::GetAttachSpec() const { |
119 | Stage attach_spec = *this; |
120 | while (attach_spec->attach_type == kGroupRoot && attach_spec->group.defined()) { |
121 | attach_spec = attach_spec->group; |
122 | } |
123 | return attach_spec; |
124 | } |
125 | |
126 | Stage& Stage::set_scope(std::string scope) { // NOLINT(*) |
127 | (*this)->scope = scope; |
128 | return *this; |
129 | } |
130 | |
131 | Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) |
132 | ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates" ; |
133 | // Group constraint checking. |
134 | Stage group = (*this)->group; |
135 | if (group.defined()) { |
136 | Stage pg = parent->group; |
137 | while (pg.defined() && !pg.same_as(group)) { |
138 | pg = pg->group; |
139 | } |
140 | ICHECK(pg.same_as(group)) << "Can only assign compute_at to stages within the same group" ; |
141 | } |
142 | |
143 | (*this)->attach_type = kScope; |
144 | (*this)->attach_ivar = scope; |
145 | (*this)->attach_stage = parent; |
146 | bool found = false; |
147 | for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { |
148 | if (scope == parent->leaf_iter_vars[i]) { |
149 | found = true; |
150 | break; |
151 | } |
152 | } |
153 | ICHECK(found) << "Cannot find the axis " << scope << " in parent's leaf_iter_vars" |
154 | << " parent=" << parent; |
155 | return *this; |
156 | } |
157 | |
158 | Stage& Stage::compute_inline() { // NOLINT(*) |
159 | ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates" ; |
160 | (*this)->attach_type = kInline; |
161 | return *this; |
162 | } |
163 | |
164 | Stage& Stage::compute_root() { // NOLINT(*) |
165 | ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates" ; |
166 | (*this)->attach_type = kGroupRoot; |
167 | return *this; |
168 | } |
169 | |
170 | Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) |
171 | StageNode* self = operator->(); |
172 | ICHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce) |
173 | << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread" ; |
174 | ICHECK(thread_ivar->iter_type == kThreadIndex) |
175 | << "Cannot rebase by " << IterVarType2String(ivar->iter_type) |
176 | << ", only thread axis is allowed so far" ; |
177 | ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
178 | ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
179 | FindLeafVar(all_vars, leaf_vars, ivar); |
180 | |
181 | auto it = self->iter_var_attrs.find(ivar); |
182 | ObjectPtr<IterVarAttrNode> n; |
183 | if (it != self->iter_var_attrs.end()) { |
184 | n = make_object<IterVarAttrNode>(*(*it).second.operator->()); |
185 | if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { |
186 | LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; |
187 | } |
188 | } else { |
189 | n = make_object<IterVarAttrNode>(); |
190 | } |
191 | n->bind_thread = thread_ivar; |
192 | self->iter_var_attrs.Set(ivar, IterVarAttr(n)); |
193 | return *this; |
194 | } |
195 | |
196 | Stage& Stage::env_threads(Array<IterVar> threads) { |
197 | StageNode* self = operator->(); |
198 | ICHECK(self->op.defined() && self->op.as<ScanOpNode>()) |
199 | << "env_threads is only valid for composite ops such as ScanOp" ; |
200 | ICHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads" ; |
201 | Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
202 | Array<IterVar>& all_vars = self->all_iter_vars; |
203 | std::vector<IterVar> temp; |
204 | for (IterVar iv : threads) { |
205 | temp.push_back(iv); |
206 | } |
207 | leaf_vars.insert(leaf_vars.begin(), temp.begin(), temp.end()); |
208 | all_vars.insert(all_vars.end(), temp.begin(), temp.end()); |
209 | self->env_threads = threads; |
210 | return *this; |
211 | } |
212 | |
213 | Stage& Stage::set_store_predicate(PrimExpr predicate) { |
214 | StageNode* self = operator->(); |
215 | self->store_predicate = predicate; |
216 | return *this; |
217 | } |
218 | |
219 | Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, |
220 | IterVar* p_inner) { // NOLINT(*) |
221 | SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); |
222 | return *this; |
223 | } |
224 | |
225 | Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, |
226 | IterVar* p_inner) { // NOLINT(*) |
227 | SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); |
228 | return *this; |
229 | } |
230 | |
231 | Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*) |
232 | StageNode* self = operator->(); |
233 | ICHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce || |
234 | outer->iter_type == kOrdered) |
235 | << "Cannot fuse " << IterVarType2String(outer->iter_type); |
236 | ICHECK(inner->iter_type == kDataPar || inner->iter_type == kCommReduce || |
237 | inner->iter_type == kOrdered) |
238 | << "Cannot fuse " << IterVarType2String(inner->iter_type); |
239 | |
240 | IterVarType iter_type = outer->iter_type; |
241 | if (inner->iter_type > iter_type) iter_type = inner->iter_type; |
242 | std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused" ; |
243 | DataType iter_dtype = MatchDataType({inner->var.dtype(), outer->var.dtype()}); |
244 | |
245 | IterVar fused = IterVar(Range(), Var(fused_name, iter_dtype), iter_type); |
246 | |
247 | Array<IterVar>& all_vars = self->all_iter_vars; |
248 | Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
249 | |
250 | size_t pos_inner = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), inner); |
251 | size_t pos_outer = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), outer); |
252 | if (pos_inner + 1 == pos_outer) { |
253 | std::swap(outer, inner); |
254 | std::swap(pos_inner, pos_outer); |
255 | } |
256 | ICHECK_EQ(pos_inner, pos_outer + 1) |
257 | << "Can only fuse iterations that are consecutive between each other" ; |
258 | self->relations.push_back(Fuse(outer, inner, fused)); |
259 | all_vars.push_back(fused); |
260 | leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1); |
261 | leaf_vars.insert(leaf_vars.begin() + pos_outer, fused); |
262 | *p_target = fused; |
263 | return *this; |
264 | } |
265 | |
266 | Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*) |
267 | if (axes.size() != 0) { |
268 | IterVar fused = axes[0]; |
269 | for (size_t i = 1; i < axes.size(); ++i) { |
270 | this->fuse(fused, axes[i], &fused); |
271 | } |
272 | *p_target = std::move(fused); |
273 | } else { |
274 | StageNode* self = operator->(); |
275 | // special handle fuse empty array. |
276 | // insert at the outer most loop |
277 | IterVar singleton = |
278 | IterVar(Range::FromMinExtent(0, 1), Var("singleton" , DataType::Int(32)), kDataPar); |
279 | self->relations.push_back(Singleton(singleton)); |
280 | Array<IterVar>& all_vars = self->all_iter_vars; |
281 | Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
282 | all_vars.push_back(singleton); |
283 | leaf_vars.insert(leaf_vars.begin(), singleton); |
284 | *p_target = singleton; |
285 | } |
286 | return *this; |
287 | } |
288 | |
289 | Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) |
290 | std::unordered_set<IterVar> seen_var; |
291 | StageNode* self = operator->(); |
292 | for (IterVar iv : order) { |
293 | ICHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce || |
294 | iv->iter_type == kThreadIndex) |
295 | << "Cannot reorder IterVar(" << IterVarType2String(iv->iter_type) << ")" ; |
296 | |
297 | ICHECK_EQ(seen_var.count(iv), 0) << "Same axis can not appear more than once " << iv; |
298 | seen_var.insert(iv); |
299 | } |
300 | ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
301 | ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
302 | std::vector<size_t> pos; |
303 | |
304 | for (size_t i = 0; i < order.size(); ++i) { |
305 | pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i])); |
306 | } |
307 | std::vector<ObjectRef> temp; |
308 | for (size_t i = 0; i < pos.size(); ++i) { |
309 | temp.emplace_back(leaf_vars->at(pos[i])); |
310 | } |
311 | std::sort(pos.begin(), pos.end()); |
312 | for (size_t i = 0; i < pos.size(); ++i) { |
313 | leaf_vars->SetItem(pos[i], temp[i]); |
314 | } |
315 | return *this; |
316 | } |
317 | |
318 | Stage& Stage::tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, |
319 | IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) { |
320 | split(x_parent, x_factor, p_x_outer, p_x_inner); |
321 | split(y_parent, y_factor, p_y_outer, p_y_inner); |
322 | reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); |
323 | return *this; |
324 | } |
325 | |
326 | template <typename FUpdate> |
327 | inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate, |
328 | bool need_leaf = true) { |
329 | if (need_leaf) { |
330 | ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
331 | ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
332 | FindLeafVar(all_vars, leaf_vars, var); |
333 | } |
334 | auto it = self->iter_var_attrs.find(var); |
335 | ObjectPtr<IterVarAttrNode> n; |
336 | if (it != self->iter_var_attrs.end()) { |
337 | n = make_object<IterVarAttrNode>(*(*it).second.operator->()); |
338 | } else { |
339 | n = make_object<IterVarAttrNode>(); |
340 | } |
341 | fupdate(n.get()); |
342 | self->iter_var_attrs.Set(var, IterVarAttr(n)); |
343 | } |
344 | |
345 | inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) { |
346 | UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { n->iter_type = iter_type; }); |
347 | } |
348 | |
349 | Stage& Stage::vectorize(IterVar var) { // NOLINT(*) |
350 | ICHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled || |
351 | var->iter_type == kVectorized || var->iter_type == kTensorized || |
352 | var->iter_type == kParallelized) |
353 | << "Cannot vectorize on " << IterVarType2String(var->iter_type); |
354 | SetAttrIterType(operator->(), var, kVectorized); |
355 | return *this; |
356 | } |
357 | |
358 | Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) |
359 | UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) { |
360 | n->iter_type = kTensorized; |
361 | n->tensor_intrin = f; |
362 | }); |
363 | return *this; |
364 | } |
365 | |
366 | Stage& Stage::unroll(IterVar var) { // NOLINT(*) |
367 | SetAttrIterType(operator->(), var, kUnrolled); |
368 | return *this; |
369 | } |
370 | |
371 | Stage& Stage::parallel(IterVar var) { // NOLINT(*) |
372 | SetAttrIterType(operator->(), var, kParallelized); |
373 | return *this; |
374 | } |
375 | |
376 | Stage& Stage::pragma(IterVar var, const std::string& pragma_type, |
377 | const PrimExpr& pragma_value) { // NOLINT(*) |
378 | if (pragma_type == "unroll" ) { |
379 | this->unroll(var); |
380 | } else if (pragma_type == "vectorize" ) { |
381 | this->vectorize(var); |
382 | } else { |
383 | UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { |
384 | n->pragma_keys.push_back(tir::StringImm(pragma_type)); |
385 | n->pragma_values.push_back(pragma_value); |
386 | }); |
387 | } |
388 | return *this; |
389 | } |
390 | |
391 | Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) { |
392 | StageNode* self = operator->(); |
393 | ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
394 | ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
395 | FindLeafVar(all_vars, leaf_vars, var); |
396 | auto it = self->iter_var_attrs.find(var); |
397 | ObjectPtr<IterVarAttrNode> n; |
398 | if (it != self->iter_var_attrs.end()) { |
399 | n = make_object<IterVarAttrNode>(*(*it).second.operator->()); |
400 | } else { |
401 | n = make_object<IterVarAttrNode>(); |
402 | } |
403 | n->prefetch_data.push_back(tensor); |
404 | n->prefetch_offset.push_back(offset); |
405 | self->iter_var_attrs.Set(var, IterVarAttr(n)); |
406 | return *this; |
407 | } |
408 | |
409 | Stage& Stage::storage_align(IterVar axis, int factor, int offset) { |
410 | StageNode* self = operator->(); |
411 | UpdateIterVarAttr( |
412 | self, axis, |
413 | [factor, offset](IterVarAttrNode* n) { |
414 | n->dim_align_factor = factor; |
415 | n->dim_align_offset = offset; |
416 | }, |
417 | false); |
418 | return *this; |
419 | } |
420 | |
421 | Stage& Stage::double_buffer() { |
422 | StageNode* self = operator->(); |
423 | ICHECK(!self->is_output) << "Cannot apply double buffer on output" ; |
424 | self->double_buffer = true; |
425 | return *this; |
426 | } |
427 | |
428 | Stage& Stage::rolling_buffer() { |
429 | StageNode* self = operator->(); |
430 | ICHECK(!self->is_output) << "Cannot apply rolling buffer on output" ; |
431 | self->rolling_buffer = true; |
432 | return *this; |
433 | } |
434 | Stage& Stage::transform_layout(const Array<Var>& initial_indices, |
435 | const Array<PrimExpr>& final_indices, |
436 | Array<IterVar>* out_iter_vars) { |
437 | StageNode* self = operator->(); |
438 | IndexMap map(initial_indices, final_indices); |
439 | self->layout_transforms.push_back(map); |
440 | |
441 | auto* compute = self->op.as<ComputeOpNode>(); |
442 | |
443 | // Can only rewrite the indices of compute op nodes. |
444 | if (!compute) { |
445 | return *this; |
446 | } |
447 | |
448 | CHECK_EQ(initial_indices.size(), compute->axis.size()) |
449 | << "Expected number of initial indices in transformation to match the dimension of " |
450 | << self->op->name; |
451 | |
452 | // Locate the IterVar objects for the data axes. |
453 | auto leaf_iter_range = [&]() -> std::pair<size_t, size_t> { |
454 | std::vector<size_t> leaf_var_indices; |
455 | for (const auto& axis : compute->axis) { |
456 | leaf_var_indices.push_back( |
457 | FindLeafVar(self->all_iter_vars.CopyOnWrite(), self->leaf_iter_vars.CopyOnWrite(), axis)); |
458 | } |
459 | auto minmax_element = std::minmax_element(leaf_var_indices.begin(), leaf_var_indices.end()); |
460 | return {*minmax_element.first, *minmax_element.second + 1}; |
461 | }(); |
462 | CHECK_EQ(leaf_iter_range.first + compute->axis.size(), leaf_iter_range.second) |
463 | << "Cannot transform indices if they have already been reordered" ; |
464 | |
465 | // Determine the updated ranges of iteration. |
466 | Array<Range> initial_ranges; |
467 | for (const auto& iter_var : compute->axis) { |
468 | initial_ranges.push_back(iter_var->dom); |
469 | } |
470 | Array<Range> final_ranges = map->MapRanges(initial_ranges); |
471 | |
472 | // Make IterVar objects to represent the new iterations. |
473 | auto inverse = map.Inverse(initial_ranges); |
474 | Array<IterVar> final_indices_iter; |
475 | ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size()); |
476 | for (size_t i = 0; i < inverse->initial_indices.size(); i++) { |
477 | final_indices_iter.push_back(IterVar(final_ranges[i], inverse->initial_indices[i], kDataPar)); |
478 | } |
479 | |
480 | // Append the new IterVar objects to all_iter_vars |
481 | for (const auto& iter_var : final_indices_iter) { |
482 | self->all_iter_vars.push_back(iter_var); |
483 | } |
484 | |
485 | // Replace the existing IterVar objects in leaf_iter_vars with the |
486 | // new IterVar objects. |
487 | self->leaf_iter_vars.erase(self->leaf_iter_vars.begin() + leaf_iter_range.first, |
488 | self->leaf_iter_vars.begin() + leaf_iter_range.second); |
489 | self->leaf_iter_vars.insert(self->leaf_iter_vars.begin() + leaf_iter_range.first, |
490 | final_indices_iter.begin(), final_indices_iter.end()); |
491 | |
492 | // Define a relationship for each new axis |
493 | self->relations.push_back(Transform(compute->axis, final_indices_iter, map, inverse)); |
494 | |
495 | // Return the iteration variables as an output. |
496 | if (out_iter_vars) { |
497 | *out_iter_vars = final_indices_iter; |
498 | } |
499 | |
500 | return *this; |
501 | } |
502 | |
503 | Stage& Stage::set_axis_separators(const Array<IntImm>& axis_separators) { |
504 | StageNode* self = operator->(); |
505 | self->axis_separators = axis_separators; |
506 | return *this; |
507 | } |
508 | |
509 | Stage CopyStage(const Stage& s) { |
510 | ObjectPtr<StageNode> n = make_object<StageNode>(*s.operator->()); |
511 | return Stage(n); |
512 | } |
513 | |
514 | Schedule Schedule::copy() const { |
515 | // map of stages. |
516 | const ScheduleNode* self = operator->(); |
517 | std::unordered_map<Stage, Stage, ObjectPtrHash, ObjectPtrEqual> smap; |
518 | ObjectPtr<ScheduleNode> n = make_object<ScheduleNode>(); |
519 | n->outputs = self->outputs; |
520 | // Copy the stages. |
521 | for (Stage s : self->stages) { |
522 | Stage scopy = CopyStage(s); |
523 | smap[s] = scopy; |
524 | n->stages.push_back(scopy); |
525 | } |
526 | for (Stage g : self->groups) { |
527 | Stage gcopy = CopyStage(g); |
528 | smap[g] = gcopy; |
529 | n->groups.push_back(gcopy); |
530 | } |
531 | // Remaps the reference relations. |
532 | for (auto kv : self->stage_map) { |
533 | n->stage_map.Set(kv.first, smap.at(kv.second)); |
534 | } |
535 | for (Stage s : n->stages) { |
536 | if (s->attach_stage.defined()) { |
537 | ICHECK(smap.find(s->attach_stage) != smap.end()) |
538 | << s->attach_stage << " not found in " << (*this); |
539 | s->attach_stage = smap.at(s->attach_stage); |
540 | } |
541 | if (s->group.defined()) { |
542 | ICHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); |
543 | s->group = smap.at(s->group); |
544 | } |
545 | } |
546 | for (Stage s : n->groups) { |
547 | if (s->attach_stage.defined()) { |
548 | ICHECK(smap.find(s->attach_stage) != smap.end()) |
549 | << s->attach_stage << " not found in " << (*this); |
550 | s->attach_stage = smap.at(s->attach_stage); |
551 | } |
552 | if (s->group.defined()) { |
553 | ICHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); |
554 | s->group = smap.at(s->group); |
555 | } |
556 | } |
557 | return Schedule(n); |
558 | } |
559 | |
560 | Stage Schedule::operator[](const Operation& op) { |
561 | auto it = (*this)->stage_map.find(op); |
562 | ICHECK(it != (*this)->stage_map.end()) |
563 | << "Cannot find Stage for operator " << op << " in the schedule" ; |
564 | return (*it).second; |
565 | } |
566 | |
567 | Stage LeastCommonAncestor(Stage g1, Stage g2) { |
568 | if (!g1.defined()) return g1; |
569 | if (!g2.defined()) return g2; |
570 | if (g1.same_as(g2)) return g1; |
571 | Stage g = g1; |
572 | while (g.defined()) { |
573 | if (g.same_as(g2)) return g2; |
574 | g = g->group; |
575 | } |
576 | g = g2; |
577 | while (g.defined()) { |
578 | if (g.same_as(g1)) return g1; |
579 | g = g->group; |
580 | } |
581 | return g; |
582 | } |
583 | |
584 | Array<Tensor> RemapTensor(ScheduleNode* self, const Array<Tensor>& arr) { |
585 | self->InitCache(); |
586 | const auto& op2stage_cache = self->op2stage_cache_; |
587 | Array<Tensor> ret; |
588 | for (Tensor t : arr) { |
589 | if (!op2stage_cache.count(t->op.get())) { |
590 | ICHECK(self->stage_map.count(t->op)) << "Given tensor is not in the schedule plan" ; |
591 | t = self->stage_map[t->op]->op.output(t->value_index); |
592 | } |
593 | ret.push_back(t); |
594 | } |
595 | return ret; |
596 | } |
597 | |
598 | // Group the schedule stages. |
599 | Stage Schedule::create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs, |
600 | bool include_inputs) { |
601 | ScheduleNode* self = operator->(); |
602 | self->InitCache(); |
603 | const auto& op2stage_cache = self->op2stage_cache_; |
604 | // Get the ops. |
605 | Array<Operation> ops = |
606 | te::GetSubGraph(RemapTensor(self, outputs), RemapTensor(self, inputs), include_inputs); |
607 | // local counter entry |
608 | // Automatically initialize to 0 during creation. |
609 | struct Entry { |
610 | int count{0}; |
611 | }; |
612 | // Map of group->touched counter |
613 | std::unordered_map<Stage, Entry, ObjectPtrHash, ObjectPtrEqual> counter; |
614 | // The parent group; |
615 | Stage parent_group; |
616 | // Detect common parent and child. |
617 | for (size_t i = 0; i < ops.size(); ++i) { |
618 | Operation op = ops[i]; |
619 | auto it = op2stage_cache.find(op.get()); |
620 | ICHECK(it != op2stage_cache.end()); |
621 | Stage op_group = it->second->group; |
622 | if (i == 0) { |
623 | parent_group = op_group; |
624 | } else { |
625 | parent_group = LeastCommonAncestor(parent_group, op_group); |
626 | } |
627 | if (op_group.defined()) { |
628 | ++counter[op_group].count; |
629 | } |
630 | } |
631 | // Create the new group stage. |
632 | Stage gstage(make_object<StageNode>()); |
633 | gstage->group = parent_group; |
634 | if (parent_group.defined()) { |
635 | ++parent_group->num_child_stages; |
636 | } |
637 | // Propagate the counter statistics from by checking if subgroup |
638 | // Is full and propagate. |
639 | std::vector<Stage> stack; |
640 | for (auto& kv : counter) { |
641 | if (!kv.first.same_as(parent_group)) { |
642 | if (kv.first->num_child_stages == kv.second.count) { |
643 | stack.push_back(kv.first); |
644 | } |
645 | } |
646 | } |
647 | while (!stack.empty()) { |
648 | Stage g = stack.back(); |
649 | stack.pop_back(); |
650 | if (g->group.defined() && !g->group.same_as(parent_group)) { |
651 | Entry& e = counter[g->group]; |
652 | ++e.count; |
653 | if (e.count == g->group->num_child_stages) { |
654 | stack.push_back(g->group); |
655 | } |
656 | } |
657 | } |
658 | // Verification and remappig the subgroups. |
659 | for (auto& kv : counter) { |
660 | if (kv.first.same_as(parent_group)) continue; |
661 | ICHECK_EQ(kv.first->num_child_stages, kv.second.count) |
662 | << "Trying to group region that intersect with an already existed group" ; |
663 | if (kv.first->group.same_as(parent_group)) { |
664 | Stage s = kv.first; |
665 | s->group = gstage; |
666 | ++gstage->num_child_stages; |
667 | if (parent_group.defined()) { |
668 | --parent_group->num_child_stages; |
669 | } |
670 | } |
671 | } |
672 | // Remap the group of op stages. |
673 | for (Operation op : ops) { |
674 | auto it = op2stage_cache.find(op.get()); |
675 | ICHECK(it != op2stage_cache.end()); |
676 | Stage s = it->second; |
677 | if (s->group.same_as(parent_group)) { |
678 | s->group = gstage; |
679 | ++gstage->num_child_stages; |
680 | if (parent_group.defined()) { |
681 | --parent_group->num_child_stages; |
682 | } |
683 | } |
684 | } |
685 | // Correct the attach to keep everything in group. |
686 | for (Operation op : ops) { |
687 | auto it = op2stage_cache.find(op.get()); |
688 | ICHECK(it != op2stage_cache.end()); |
689 | Stage s = it->second; |
690 | if (s->attach_type == kScope) { |
691 | Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage); |
692 | if (!cg.same_as(gstage)) { |
693 | LOG(WARNING) << "group invalidates some previous compute_at relation " |
694 | << " and keeps things to be computed inside the group" ; |
695 | s.compute_root(); |
696 | } |
697 | } |
698 | } |
699 | |
700 | self->groups.push_back(gstage); |
701 | return gstage; |
702 | } |
703 | |
704 | void ScheduleNode::InvalidateCache() { op2stage_cache_.clear(); } |
705 | |
706 | void ScheduleNode::InitCache() { |
707 | if (op2stage_cache_.size() == stages.size()) return; |
708 | InvalidateCache(); |
709 | for (Stage s : stages) { |
710 | if (s->op.defined()) { |
711 | op2stage_cache_[s->op.get()] = s; |
712 | } |
713 | } |
714 | ICHECK_EQ(op2stage_cache_.size(), stages.size()); |
715 | } |
716 | |
717 | bool ScheduleNode::Contain(const Operation& op) const { |
718 | return stage_map.find(op) != stage_map.end(); |
719 | } |
720 | |
721 | Schedule::Schedule(Array<Operation> ops) { |
722 | auto n = make_object<ScheduleNode>(); |
723 | data_ = n; |
724 | n->outputs = ops; |
725 | auto g = te::CreateReadGraph(n->outputs); |
726 | Array<Operation> post_order = te::PostDFSOrder(n->outputs, g); |
727 | // output set. |
728 | std::unordered_set<Operation> output_set; |
729 | for (Operation x : ops) { |
730 | output_set.insert(x); |
731 | } |
732 | for (Operation op : post_order) { |
733 | Stage stage(op); |
734 | stage->is_output = output_set.count(op) != 0; |
735 | n->stages.push_back(stage); |
736 | n->stage_map.Set(op, stage); |
737 | // mark scan updates. |
738 | if (const ScanOpNode* scan = op.as<ScanOpNode>()) { |
739 | Array<Tensor> inputs; |
740 | for (Tensor t : scan->state_placeholder) { |
741 | inputs.push_back(t); |
742 | } |
743 | for (Tensor t : scan->inputs) { |
744 | inputs.push_back(t); |
745 | } |
746 | // Create the scan group. |
747 | Stage scan_group = this->create_group(scan->update, inputs, false); |
748 | scan_group->attach_type = kScanUpdate; |
749 | scan_group->attach_stage = stage; |
750 | |
751 | for (size_t i = 0; i < scan->update.size(); ++i) { |
752 | Stage s = n->stage_map[scan->update[i]->op]; |
753 | ICHECK(scan_group.same_as(s->group)); |
754 | } |
755 | } |
756 | } |
757 | } |
758 | |
759 | Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { |
760 | auto n = make_object<SplitNode>(); |
761 | n->parent = parent; |
762 | n->outer = outer; |
763 | n->inner = inner; |
764 | n->factor = factor; |
765 | n->nparts = nparts; |
766 | data_ = std::move(n); |
767 | } |
768 | |
769 | Fuse::Fuse(IterVar outer, IterVar inner, IterVar fused) { |
770 | auto n = make_object<FuseNode>(); |
771 | n->outer = outer; |
772 | n->inner = inner; |
773 | n->fused = fused; |
774 | data_ = std::move(n); |
775 | } |
776 | |
777 | Rebase::Rebase(IterVar parent, IterVar rebased) { |
778 | auto n = make_object<RebaseNode>(); |
779 | n->parent = parent; |
780 | n->rebased = rebased; |
781 | data_ = std::move(n); |
782 | } |
783 | |
784 | Singleton::Singleton(IterVar iter) { |
785 | auto n = make_object<SingletonNode>(); |
786 | n->iter = iter; |
787 | data_ = std::move(n); |
788 | } |
789 | |
790 | Transform::Transform(Array<IterVar> original_variables, Array<IterVar> transformed_variables, |
791 | IndexMap forward_transformation, IndexMap inverse_transformation) { |
792 | auto n = make_object<TransformNode>(); |
793 | n->original_variables = original_variables; |
794 | n->transformed_variables = transformed_variables; |
795 | n->forward_transformation = forward_transformation; |
796 | n->inverse_transformation = inverse_transformation; |
797 | data_ = std::move(n); |
798 | } |
799 | |
800 | SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) { |
801 | ObjectPtr<SpecializedConditionNode> n = make_object<SpecializedConditionNode>(); |
802 | n->clauses = std::move(conditions); |
803 | data_ = std::move(n); |
804 | } |
805 | |
806 | /*! \brief Entry to hold the SpecializedCondition context stack. */ |
807 | struct TVMSpecializationThreadLocalEntry { |
808 | /*! \brief The current specialized condition */ |
809 | std::stack<SpecializedCondition> condition_stack; |
810 | }; |
811 | |
812 | /*! \brief Thread local store to hold the Target context stack. */ |
813 | typedef dmlc::ThreadLocalStore<TVMSpecializationThreadLocalEntry> TVMSpecializationThreadLocalStore; |
814 | |
815 | void SpecializedCondition::EnterWithScope() { |
816 | TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); |
817 | entry->condition_stack.push(*this); |
818 | } |
819 | |
820 | void SpecializedCondition::ExitWithScope() { |
821 | TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); |
822 | ICHECK(!entry->condition_stack.empty()); |
823 | ICHECK(entry->condition_stack.top().same_as(*this)); |
824 | entry->condition_stack.pop(); |
825 | } |
826 | |
827 | SpecializedCondition SpecializedCondition::Current() { |
828 | TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); |
829 | SpecializedCondition cond; |
830 | if (entry->condition_stack.size() > 0) { |
831 | cond = entry->condition_stack.top(); |
832 | } |
833 | return cond; |
834 | } |
835 | |
836 | class SpecializedCondition::Internal { |
837 | public: |
838 | static void EnterScope(SpecializedCondition cond) { cond.EnterWithScope(); } |
839 | |
840 | static void ExitScope(SpecializedCondition cond) { cond.ExitWithScope(); } |
841 | }; |
842 | |
843 | TVM_REGISTER_NODE_TYPE(StageNode); |
844 | TVM_REGISTER_NODE_TYPE(IterVarAttrNode); |
845 | TVM_REGISTER_NODE_TYPE(SplitNode); |
846 | TVM_REGISTER_NODE_TYPE(FuseNode); |
847 | TVM_REGISTER_NODE_TYPE(RebaseNode); |
848 | TVM_REGISTER_NODE_TYPE(SingletonNode); |
849 | TVM_REGISTER_NODE_TYPE(ScheduleNode); |
850 | TVM_REGISTER_NODE_TYPE(SpecializedConditionNode); |
851 | |
852 | // Printer |
853 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
854 | .set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) { |
855 | auto* op = static_cast<const StageNode*>(node.get()); |
856 | if (op->op.defined()) { |
857 | p->stream << "stage(" << op->origin_op->name << ", " << op->op << ")" ; |
858 | } else { |
859 | p->stream << "group-stage(" << op << ")" ; |
860 | } |
861 | }) |
862 | .set_dispatch<IterVarAttrNode>([](const ObjectRef& node, ReprPrinter* p) { |
863 | auto* op = static_cast<const IterVarAttrNode*>(node.get()); |
864 | p->stream << IterVarType2String(op->iter_type); |
865 | }) |
866 | .set_dispatch<SplitNode>([](const ObjectRef& node, ReprPrinter* p) { |
867 | auto* op = static_cast<const SplitNode*>(node.get()); |
868 | p->stream << "split(parent=" ; |
869 | p->Print(op->parent); |
870 | p->stream << ", outer=" ; |
871 | p->Print(op->outer); |
872 | p->stream << ", inner=" ; |
873 | p->Print(op->inner); |
874 | if (op->factor.defined()) { |
875 | p->stream << ", factor=" ; |
876 | p->Print(op->factor); |
877 | } else { |
878 | p->stream << ", nparts=" ; |
879 | p->Print(op->nparts); |
880 | } |
881 | p->stream << ')'; |
882 | }) |
883 | .set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) { |
884 | auto* op = static_cast<const FuseNode*>(node.get()); |
885 | p->stream << "fuse(" ; |
886 | p->stream << "outer=" ; |
887 | p->Print(op->outer); |
888 | p->stream << ", inner=" ; |
889 | p->Print(op->inner); |
890 | p->stream << ", fused=" ; |
891 | p->Print(op->fused); |
892 | p->stream << ')'; |
893 | }) |
894 | .set_dispatch<RebaseNode>([](const ObjectRef& node, ReprPrinter* p) { |
895 | auto* op = static_cast<const RebaseNode*>(node.get()); |
896 | p->stream << "rebase(" ; |
897 | p->stream << "parent=" ; |
898 | p->Print(op->parent); |
899 | p->stream << ", rebased=" ; |
900 | p->Print(op->rebased); |
901 | p->stream << ')'; |
902 | }) |
903 | .set_dispatch<SingletonNode>([](const ObjectRef& node, ReprPrinter* p) { |
904 | auto* op = static_cast<const SingletonNode*>(node.get()); |
905 | p->stream << "singleton(" ; |
906 | p->Print(op->iter); |
907 | p->stream << ')'; |
908 | }) |
909 | .set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) { |
910 | auto* op = static_cast<const ScheduleNode*>(node.get()); |
911 | p->stream << "schedule(" << op << ")" ; |
912 | }) |
913 | .set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) { |
914 | auto* op = static_cast<const SpecializedConditionNode*>(node.get()); |
915 | p->stream << "specialized_condition(" ; |
916 | p->Print(op->clauses); |
917 | p->stream << ')'; |
918 | }); |
919 | |
920 | TVM_REGISTER_GLOBAL("te.CreateSchedule" ).set_body_typed(create_schedule); |
921 | |
922 | TVM_REGISTER_GLOBAL("te.StageSetScope" ).set_body_method(&Stage::set_scope); |
923 | |
924 | TVM_REGISTER_GLOBAL("te.StageBind" ).set_body_method(&Stage::bind); |
925 | |
926 | TVM_REGISTER_GLOBAL("te.StageSplitByFactor" ) |
927 | .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { |
928 | IterVar outer, inner; |
929 | stage.split(parent, factor, &outer, &inner); |
930 | return Array<IterVar>({outer, inner}); |
931 | }); |
932 | |
933 | TVM_REGISTER_GLOBAL("te.StageSplitByNParts" ) |
934 | .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { |
935 | IterVar outer, inner; |
936 | stage.split_by_nparts(parent, nparts, &outer, &inner); |
937 | return Array<IterVar>({outer, inner}); |
938 | }); |
939 | |
940 | TVM_REGISTER_GLOBAL("te.StageFuse" ).set_body_typed([](Stage stage, Array<IterVar> axes) { |
941 | IterVar fused; |
942 | stage.fuse(axes, &fused); |
943 | return fused; |
944 | }); |
945 | |
946 | TVM_REGISTER_GLOBAL("te.StageComputeAt" ).set_body_method(&Stage::compute_at); |
947 | |
948 | TVM_REGISTER_GLOBAL("te.StageComputeInline" ).set_body_method(&Stage::compute_inline); |
949 | |
950 | TVM_REGISTER_GLOBAL("te.StageComputeRoot" ).set_body_method(&Stage::compute_root); |
951 | |
952 | TVM_REGISTER_GLOBAL("te.StageReorder" ).set_body_method(&Stage::reorder); |
953 | |
954 | TVM_REGISTER_GLOBAL("te.StageTile" ) |
955 | .set_body_typed([](Stage stage, IterVar x_parent, IterVar y_parent, PrimExpr x_factor, |
956 | PrimExpr y_factor) { |
957 | IterVar x_outer, y_outer, x_inner, y_inner; |
958 | stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner); |
959 | return Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); |
960 | }); |
961 | |
962 | TVM_REGISTER_GLOBAL("te.StageEnvThreads" ).set_body_method(&Stage::env_threads); |
963 | |
964 | TVM_REGISTER_GLOBAL("te.StageSetStorePredicate" ).set_body_method(&Stage::set_store_predicate); |
965 | |
966 | TVM_REGISTER_GLOBAL("te.StageUnroll" ).set_body_method(&Stage::unroll); |
967 | |
968 | TVM_REGISTER_GLOBAL("te.StageVectorize" ).set_body_method(&Stage::vectorize); |
969 | |
970 | TVM_REGISTER_GLOBAL("te.StageTensorize" ).set_body_method(&Stage::tensorize); |
971 | |
972 | TVM_REGISTER_GLOBAL("te.StageParallel" ).set_body_method(&Stage::parallel); |
973 | |
974 | TVM_REGISTER_GLOBAL("te.StagePragma" ).set_body_method(&Stage::pragma); |
975 | |
976 | TVM_REGISTER_GLOBAL("te.StagePrefetch" ).set_body_method(&Stage::prefetch); |
977 | |
978 | TVM_REGISTER_GLOBAL("te.StageStorageAlign" ).set_body_method(&Stage::storage_align); |
979 | |
980 | TVM_REGISTER_GLOBAL("te.StageDoubleBuffer" ).set_body_method(&Stage::double_buffer); |
981 | |
982 | TVM_REGISTER_GLOBAL("te.StageRollingBuffer" ).set_body_method(&Stage::rolling_buffer); |
983 | |
984 | TVM_REGISTER_GLOBAL("te.StageTransformLayout" ) |
985 | .set_body_typed([](Stage stage, const Array<Var>& initial_indices, |
986 | const Array<PrimExpr>& final_indices) { |
987 | Array<IterVar> new_iter_vars; |
988 | stage.transform_layout(initial_indices, final_indices, &new_iter_vars); |
989 | return new_iter_vars; |
990 | }); |
991 | |
992 | TVM_REGISTER_GLOBAL("te.StageSetAxisSeparators" ).set_body_method(&Stage::set_axis_separators); |
993 | |
994 | TVM_REGISTER_GLOBAL("te.ScheduleNormalize" ).set_body_method(&Schedule::normalize); |
995 | |
996 | TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup" ).set_body_method(&Schedule::create_group); |
997 | |
998 | TVM_REGISTER_GLOBAL("te.ScheduleCacheRead" ).set_body_method(&Schedule::cache_read); |
999 | |
1000 | TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
1001 | if (args[1].IsObjectRef<Tensor>()) { |
1002 | *ret = args[0].operator Schedule().cache_write(args[1].operator Tensor(), args[2]); |
1003 | } else { |
1004 | *ret = args[0].operator Schedule().cache_write(args[1].operator Array<Tensor>(), args[2]); |
1005 | } |
1006 | }); |
1007 | |
1008 | TVM_REGISTER_GLOBAL("te.ScheduleRFactor" ).set_body_method(&Schedule::rfactor); |
1009 | |
1010 | TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition" ).set_body_typed([](Array<PrimExpr> condition) { |
1011 | return SpecializedCondition(condition); |
1012 | }); |
1013 | |
1014 | TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
1015 | *ret = SpecializedCondition::Current(); |
1016 | }); |
1017 | |
1018 | TVM_REGISTER_GLOBAL("te.EnterSpecializationScope" ) |
1019 | .set_body_typed(SpecializedCondition::Internal::EnterScope); |
1020 | |
1021 | TVM_REGISTER_GLOBAL("te.ExitSpecializationScope" ) |
1022 | .set_body_typed(SpecializedCondition::Internal::ExitScope); |
1023 | |
1024 | } // namespace te |
1025 | } // namespace tvm |
1026 | |