1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/loop_state.cc |
22 | * \brief An lightweight IR (intermediate representation) for loop structures. |
23 | * see auto_scheduler/loop_state.h for more explanation. |
24 | */ |
25 | |
26 | #include <tvm/auto_scheduler/compute_dag.h> |
27 | #include <tvm/auto_scheduler/loop_state.h> |
28 | #include <tvm/auto_scheduler/transform_step.h> |
29 | #include <tvm/runtime/registry.h> |
30 | #include <tvm/te/operation.h> |
31 | |
32 | #include <utility> |
33 | |
34 | #include "utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace auto_scheduler { |
38 | |
39 | TVM_REGISTER_OBJECT_TYPE(StepNode); |
40 | TVM_REGISTER_NODE_TYPE(StageNode); |
41 | TVM_REGISTER_NODE_TYPE(StateNode); |
42 | TVM_REGISTER_NODE_TYPE(IteratorNode); |
43 | |
44 | /********** Iterator **********/ |
45 | Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation, |
46 | const std::vector<Iterator>* orig_iters) { |
47 | auto node = make_object<IteratorNode>(); |
48 | node->name = std::move(name); |
49 | node->range = std::move(range); |
50 | node->iter_kind = iter_kind; |
51 | node->annotation = annotation; |
52 | if (orig_iters != nullptr) { |
53 | node->orig_iters = *orig_iters; |
54 | } |
55 | data_ = std::move(node); |
56 | } |
57 | |
58 | /********** Stage **********/ |
59 | Stage::Stage(te::Operation op) { |
60 | auto node = make_object<StageNode>(); |
61 | if (op->IsInstance<te::ComputeOpNode>()) { |
62 | node->op_type = StageKind::kCompute; |
63 | auto* pop = op.as<te::ComputeOpNode>(); |
64 | for (const auto& axis : pop->axis) { |
65 | node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, |
66 | IteratorKind::kSpatial, IteratorAnnotation::kNone)); |
67 | } |
68 | for (const auto& axis : pop->reduce_axis) { |
69 | node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, |
70 | IteratorKind::kReduction, IteratorAnnotation::kNone)); |
71 | } |
72 | } else if (op->IsInstance<te::PlaceholderOpNode>()) { |
73 | node->op_type = StageKind::kPlaceholder; |
74 | } else { |
75 | LOG(FATAL) << "Unsupported operator type" << op->_type_key; |
76 | } |
77 | |
78 | node->compute_at = ComputeAtKind::kRoot; |
79 | node->op = std::move(op); |
80 | node->attrs.auto_unroll_max_step = 0; |
81 | node->attrs.storage_offset = 0; |
82 | data_ = std::move(node); |
83 | } |
84 | |
85 | Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, |
86 | ComputeAtKind compute_at, StageAttributes attrs) { |
87 | auto node = make_object<StageNode>(); |
88 | node->op = std::move(op); |
89 | node->op_type = op_type; |
90 | node->iters = iters; |
91 | node->compute_at = compute_at; |
92 | node->attrs = attrs; |
93 | data_ = std::move(node); |
94 | } |
95 | |
96 | /********** AttachMap **********/ |
97 | void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { |
98 | AttachMapNode* pnode = CopyOnWrite(); |
99 | |
100 | // Delete the current entry of this stage |
101 | DeleteStageEntry(pnode, stage_id); |
102 | |
103 | // Store the new stage/iterator relations to map |
104 | IterKey iter_key(target_stage_id, target_iter_id); |
105 | pnode->stage_to_attach_iter[stage_id] = iter_key; |
106 | pnode->iter_to_attached_stages[iter_key].push_back(stage_id); |
107 | } |
108 | |
109 | void AttachMap::DeleteStage(int stage_id) { |
110 | AttachMapNode* pnode = CopyOnWrite(); |
111 | // Delete the original stage entry |
112 | DeleteStageEntry(pnode, stage_id); |
113 | } |
114 | |
115 | void AttachMap::UpdateIters(const std::vector<IterKey>& original_iters, |
116 | const std::vector<IterKey>& new_iters) { |
117 | ICHECK_EQ(original_iters.size(), new_iters.size()); |
118 | AttachMapNode* pnode = CopyOnWrite(); |
119 | std::unordered_map<IterKey, std::vector<StageKey>> new_iter_to_attached_stages; |
120 | for (size_t i = 0; i < original_iters.size(); ++i) { |
121 | auto entry = pnode->iter_to_attached_stages.find(original_iters[i]); |
122 | // We get <IterKey, std::vector<StageKey>> from this map |
123 | if (entry == pnode->iter_to_attached_stages.end()) { |
124 | // Skip if this iterator does not have any attach relations |
125 | continue; |
126 | } |
127 | |
128 | // Update the attaching target of an stage to the new iter in `stage_to_attach_iter` |
129 | for (const auto& s : entry->second) { |
130 | pnode->stage_to_attach_iter[s] = new_iters[i]; |
131 | } |
132 | |
133 | // Remove the original iterator relation from `iter_to_attached_stages` and add the new |
134 | // iterator to it |
135 | std::vector<int> attached_stages = std::move(entry->second); |
136 | pnode->iter_to_attached_stages.erase(entry); |
137 | new_iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); |
138 | } |
139 | |
140 | // Update new entries |
141 | for (auto& it : new_iter_to_attached_stages) { |
142 | pnode->iter_to_attached_stages[it.first] = std::move(it.second); |
143 | } |
144 | } |
145 | |
146 | void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { |
147 | auto old_entry = pnode->stage_to_attach_iter.find(stage_id); |
148 | // We get <StageKey, IterKey> from this map |
149 | if (old_entry != pnode->stage_to_attach_iter.end()) { |
150 | // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have |
151 | // any attached stage, delete this iterm too |
152 | auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); |
153 | // We get <IterKey, std::vector<StageKey>> from this map |
154 | FindAndDeleteItem(&entry2->second, stage_id); |
155 | if (entry2->second.size() == 0) { |
156 | pnode->iter_to_attached_stages.erase(entry2); |
157 | } |
158 | // Delete the stage in `stage_to_attach_iter` |
159 | pnode->stage_to_attach_iter.erase(old_entry); |
160 | } |
161 | } |
162 | |
163 | AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const { |
164 | AttachMap map = AttachMap(make_object<AttachMapNode>()); |
165 | auto pmap = map.CopyOnWrite(); |
166 | for (const auto& x : operator->()->stage_to_attach_iter) { |
167 | auto key = x.first; |
168 | if (key >= start_id) { |
169 | key += offset; |
170 | } |
171 | auto value = x.second; |
172 | if (value.first >= start_id) { |
173 | value.first += offset; |
174 | } |
175 | pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); |
176 | } |
177 | for (const auto& x : operator->()->iter_to_attached_stages) { |
178 | auto key = x.first; |
179 | if (key.first >= start_id) { |
180 | key.first += offset; |
181 | } |
182 | auto value = x.second; |
183 | for (auto& i : value) { |
184 | if (i >= start_id) { |
185 | i += offset; |
186 | } |
187 | } |
188 | pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); |
189 | } |
190 | return map; |
191 | } |
192 | |
193 | /********** State **********/ |
194 | State::State(const Array<te::Operation>& ops) { |
195 | auto node = make_object<StateNode>(); |
196 | for (const auto& op : ops) { |
197 | node->stages.push_back(Stage(op)); |
198 | } |
199 | node->attach_map = AttachMap(make_object<AttachMapNode>()); |
200 | node->concrete = true; |
201 | data_ = std::move(node); |
202 | } |
203 | |
204 | /********** Schedule primitives apis for state **********/ |
205 | Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { |
206 | const Stage& stage = operator->()->stages[stage_id]; |
207 | if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) { |
208 | LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, " |
209 | << "kThreadX, kThreadY, kBlockZ, kThreadZ" ; |
210 | } |
211 | AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); |
212 | CopyOnWrite()->transform_steps.push_back(step); |
213 | return step->ApplyToState(this); |
214 | } |
215 | |
216 | Iterator State::parallel(int stage_id, const Iterator& it) { |
217 | const Stage& stage = operator->()->stages[stage_id]; |
218 | AnnotationStep step = |
219 | AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel); |
220 | CopyOnWrite()->transform_steps.push_back(step); |
221 | return step->ApplyToState(this); |
222 | } |
223 | |
224 | Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { |
225 | const Stage& stage = operator->()->stages[stage_id]; |
226 | |
227 | // Don't unroll if the extent is larger than max_unroll |
228 | if (max_unroll != -1 && it->range.defined()) { |
229 | if (auto imm = it->range->extent.as<IntImmNode>()) { |
230 | if (imm->value > max_unroll) { |
231 | return it; |
232 | } |
233 | } |
234 | } |
235 | |
236 | AnnotationStep step = |
237 | AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); |
238 | CopyOnWrite()->transform_steps.push_back(step); |
239 | return step->ApplyToState(this); |
240 | } |
241 | |
242 | Iterator State::vectorize(int stage_id, const Iterator& it) { |
243 | const Stage& stage = operator->()->stages[stage_id]; |
244 | AnnotationStep step = |
245 | AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); |
246 | CopyOnWrite()->transform_steps.push_back(step); |
247 | return step->ApplyToState(this); |
248 | } |
249 | |
250 | Iterator State::fuse(int stage_id, const Array<Iterator>& iters) { |
251 | const Stage& stage = operator->()->stages[stage_id]; |
252 | Array<Integer> indices; |
253 | GetIndices(stage->iters, iters, &indices); |
254 | FuseStep step = FuseStep(stage_id, indices); |
255 | CopyOnWrite()->transform_steps.push_back(step); |
256 | return step->ApplyToState(this); |
257 | } |
258 | |
259 | void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) { |
260 | const Stage& stage = operator->()->stages[stage_id]; |
261 | PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); |
262 | CopyOnWrite()->transform_steps.push_back(step); |
263 | return step->ApplyToState(this); |
264 | } |
265 | |
266 | void State::reorder(int stage_id, const Array<Iterator>& order) { |
267 | const Stage& stage = operator->()->stages[stage_id]; |
268 | ICHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " |
269 | << "should be specified" ; |
270 | Array<Integer> after_ids; |
271 | GetIndices(stage->iters, order, &after_ids); |
272 | ReorderStep step = ReorderStep(stage_id, after_ids); |
273 | CopyOnWrite()->transform_steps.push_back(step); |
274 | step->ApplyToState(this); |
275 | } |
276 | |
277 | Array<Iterator> State::split(int stage_id, const Iterator& it, |
278 | const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
279 | const Stage& stage = operator->()->stages[stage_id]; |
280 | SplitStep step = |
281 | SplitStep(stage_id, GetIndex(stage->iters, it), |
282 | it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); |
283 | CopyOnWrite()->transform_steps.push_back(step); |
284 | return step->ApplyToState(this); |
285 | } |
286 | |
287 | Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id, |
288 | int n_split) { |
289 | const Stage& stage = operator->()->stages[stage_id]; |
290 | FollowSplitStep step = |
291 | FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); |
292 | CopyOnWrite()->transform_steps.push_back(step); |
293 | return step->ApplyToState(this); |
294 | } |
295 | |
296 | Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it, |
297 | const Array<Integer>& src_step_ids, int level, |
298 | bool factor_or_nparts) { |
299 | const Stage& stage = operator->()->stages[stage_id]; |
300 | FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), |
301 | src_step_ids, level, factor_or_nparts); |
302 | CopyOnWrite()->transform_steps.push_back(step); |
303 | return step->ApplyToState(this); |
304 | } |
305 | |
306 | void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { |
307 | const Stage& stage = operator->()->stages[stage_id]; |
308 | StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset); |
309 | CopyOnWrite()->transform_steps.push_back(step); |
310 | return step->ApplyToState(this); |
311 | } |
312 | |
313 | void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { |
314 | const Stage& target_stage = operator->()->stages[target_stage_id]; |
315 | ComputeAtStep step = |
316 | ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); |
317 | CopyOnWrite()->transform_steps.push_back(step); |
318 | step->ApplyToState(this); |
319 | } |
320 | |
321 | void State::compute_inline(int stage_id) { |
322 | ComputeInlineStep step = ComputeInlineStep(stage_id); |
323 | CopyOnWrite()->transform_steps.push_back(step); |
324 | step->ApplyToState(this); |
325 | } |
326 | |
327 | void State::compute_root(int stage_id) { |
328 | ComputeRootStep step = ComputeRootStep(stage_id); |
329 | CopyOnWrite()->transform_steps.push_back(step); |
330 | step->ApplyToState(this); |
331 | } |
332 | |
333 | int State::cache_read(int stage_id, const String& scope_name, |
334 | const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) { |
335 | CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids); |
336 | CopyOnWrite()->transform_steps.push_back(step); |
337 | return step->ApplyToState(this, dag); |
338 | } |
339 | |
340 | int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) { |
341 | CacheWriteStep step = CacheWriteStep(stage_id, scope_name); |
342 | CopyOnWrite()->transform_steps.push_back(step); |
343 | return step->ApplyToState(this, dag); |
344 | } |
345 | |
346 | int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) { |
347 | const Stage& stage = operator->()->stages[stage_id]; |
348 | RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id); |
349 | CopyOnWrite()->transform_steps.push_back(step); |
350 | return step->ApplyToState(this, dag); |
351 | } |
352 | |
353 | // Print stage to ostream |
354 | void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent, |
355 | bool delete_trivial_loop) { |
356 | const Stage& stage = state->stages[stage_id]; |
357 | |
358 | if (stage->attrs.auto_unroll_max_step != 0) { |
359 | for (size_t j = 0; j < base_indent; ++j) { |
360 | *os << " " ; |
361 | } |
362 | *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n" ; |
363 | } |
364 | if (stage->attrs.storage_offset != 0) { |
365 | for (size_t j = 0; j < base_indent; ++j) { |
366 | *os << " " ; |
367 | } |
368 | *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n" ; |
369 | } |
370 | |
371 | size_t indent = 0; |
372 | for (size_t i = 0; i < stage->iters.size(); ++i) { |
373 | const Iterator& iter = stage->iters[i]; |
374 | |
375 | if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { |
376 | for (size_t j = 0; j < base_indent + indent; ++j) { |
377 | *os << " " ; |
378 | } |
379 | *os << IteratorAnnotationString[static_cast<int>(iter->annotation)] << " " ; |
380 | if (iter->range.defined()) { |
381 | *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")" ; |
382 | } else { |
383 | *os << iter->name << " (None)" ; |
384 | } |
385 | *os << "\n" ; |
386 | |
387 | indent += 2; |
388 | } |
389 | |
390 | if (state.defined()) { |
391 | IterKey iter_key(stage_id, i); |
392 | auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); |
393 | if (pair != state->attach_map->iter_to_attached_stages.end()) { |
394 | // Print the attached stage |
395 | for (const auto& attach_stage_id : pair->second) { |
396 | PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); |
397 | } |
398 | } |
399 | } |
400 | } |
401 | |
402 | for (size_t j = 0; j < base_indent + indent; ++j) { |
403 | *os << " " ; |
404 | } |
405 | *os << stage->op->name << " = ...\n" ; |
406 | } |
407 | |
408 | // Print state to ostream |
409 | void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) { |
410 | // Gather placeholders |
411 | Array<String> placeholders; |
412 | for (const auto& stage : state->stages) { |
413 | if (stage->op_type == StageKind::kPlaceholder) { |
414 | placeholders.push_back(stage->op->name); |
415 | } |
416 | } |
417 | |
418 | *os << "Placeholder: " ; |
419 | for (size_t i = 0; i < placeholders.size(); ++i) { |
420 | *os << placeholders[i]; |
421 | if (i != placeholders.size() - 1) { |
422 | *os << ", " ; |
423 | } |
424 | } |
425 | *os << "\n" ; |
426 | |
427 | // Print all stages |
428 | for (size_t i = 0; i < state->stages.size(); ++i) { |
429 | const Stage& stage = state->stages[i]; |
430 | if (stage->op_type == StageKind::kPlaceholder) { |
431 | continue; |
432 | } else if (stage->op_type == StageKind::kCompute) { |
433 | if (stage->compute_at == ComputeAtKind::kRoot) { |
434 | PrintStage(os, i, state, 0, delete_trivial_loop); |
435 | } |
436 | } else { |
437 | LOG(FATAL) << "Invalid op type" ; |
438 | } |
439 | } |
440 | } |
441 | |
442 | String State::ToStr(bool delete_trivial_loop) const { |
443 | std::ostringstream os; |
444 | PrintState(&os, (*this), delete_trivial_loop); |
445 | return os.str(); |
446 | } |
447 | |
448 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
449 | .set_dispatch<StageNode>([](const ObjectRef& ref, ReprPrinter* p) { |
450 | const auto& stage = tvm::Downcast<Stage>(ref); |
451 | p->stream << stage->GetTypeKey() << "(" << stage.get() << ": " << stage->op->name << ")" ; |
452 | }); |
453 | |
454 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
455 | .set_dispatch<StateNode>([](const ObjectRef& ref, ReprPrinter* p) { |
456 | PrintState(&p->stream, tvm::Downcast<State>(ref), true); |
457 | }); |
458 | |
459 | /********** State interface API for ffi **********/ |
460 | TVM_REGISTER_GLOBAL("auto_scheduler.StateBind" ) |
461 | .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) { |
462 | const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type)); |
463 | return Array<ObjectRef>{state, res}; |
464 | }); |
465 | |
466 | TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel" ) |
467 | .set_body_typed([](State state, int stage_id, const Iterator& it) { |
468 | const auto& res = state.parallel(stage_id, it); |
469 | return Array<ObjectRef>{state, res}; |
470 | }); |
471 | |
472 | TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll" ) |
473 | .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) { |
474 | const auto& res = state.unroll(stage_id, it, max_unroll); |
475 | return Array<ObjectRef>{state, res}; |
476 | }); |
477 | |
478 | TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize" ) |
479 | .set_body_typed([](State state, int stage_id, const Iterator& it) { |
480 | const auto& res = state.vectorize(stage_id, it); |
481 | return Array<ObjectRef>{state, res}; |
482 | }); |
483 | |
484 | TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse" ) |
485 | .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) { |
486 | const auto& res = state.fuse(stage_id, iters); |
487 | return Array<ObjectRef>{state, res}; |
488 | }); |
489 | |
490 | TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma" ) |
491 | .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) { |
492 | state.pragma(stage_id, it, pragma_type); |
493 | return state; |
494 | }); |
495 | |
496 | TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder" ) |
497 | .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) { |
498 | state.reorder(stage_id, order); |
499 | return state; |
500 | }); |
501 | |
502 | TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit" ) |
503 | .set_body_typed([](State state, int stage_id, const Iterator& it, |
504 | const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
505 | const auto& res = state.split(stage_id, it, lengths, inner_to_outer); |
506 | return Array<ObjectRef>{state, res}; |
507 | }); |
508 | |
509 | TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit" ) |
510 | .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, |
511 | int n_split) { |
512 | const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); |
513 | return Array<ObjectRef>{state, Array<Iterator>(res)}; |
514 | }); |
515 | |
516 | TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit" ) |
517 | .set_body_typed([](State state, int stage_id, const Iterator& it, |
518 | const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) { |
519 | const auto& res = |
520 | state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts); |
521 | return Array<ObjectRef>{state, Array<Iterator>(res)}; |
522 | }); |
523 | |
524 | TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign" ) |
525 | .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) { |
526 | state.storage_align(stage_id, it, factor, offset); |
527 | return state; |
528 | }); |
529 | |
530 | TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt" ) |
531 | .set_body_typed([](State state, int stage_id, int target_stage_id, |
532 | const Iterator& target_iter) { |
533 | state.compute_at(stage_id, target_stage_id, target_iter); |
534 | return state; |
535 | }); |
536 | |
537 | TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline" ) |
538 | .set_body_typed([](State state, int stage_id) { |
539 | state.compute_inline(stage_id); |
540 | return state; |
541 | }); |
542 | |
543 | TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot" ) |
544 | .set_body_typed([](State state, int stage_id) { |
545 | state.compute_root(stage_id); |
546 | return state; |
547 | }); |
548 | |
549 | TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead" ) |
550 | .set_body_typed([](State state, int stage_id, const String& scope_name, |
551 | const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) { |
552 | int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag); |
553 | return Array<ObjectRef>{state, Integer(res)}; |
554 | }); |
555 | |
556 | TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite" ) |
557 | .set_body_typed([](State state, int stage_id, const String& scope_name, |
558 | const ComputeDAG& task_dag) { |
559 | int res = state.cache_write(stage_id, scope_name, task_dag); |
560 | return Array<ObjectRef>{state, Integer(res)}; |
561 | }); |
562 | |
563 | TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor" ) |
564 | .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id, |
565 | const ComputeDAG& dag) { |
566 | int res = state.rfactor(stage_id, it, factor_iter_id, dag); |
567 | return Array<ObjectRef>{state, Integer(res)}; |
568 | }); |
569 | |
570 | TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual" ).set_body_typed([](State state1, State state2) { |
571 | return std::equal_to<State>()(state1, state2); |
572 | }); |
573 | |
574 | } // namespace auto_scheduler |
575 | } // namespace tvm |
576 | |