1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/utils.h |
22 | * \brief Common utilities for search policies. |
23 | */ |
24 | |
25 | #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ |
26 | #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ |
27 | |
28 | #include <dmlc/common.h> |
29 | #include <tvm/auto_scheduler/loop_state.h> |
30 | #include <tvm/auto_scheduler/search_policy.h> |
31 | #include <tvm/ir/expr.h> |
32 | #include <tvm/te/operation.h> |
33 | |
34 | #include <algorithm> |
35 | #include <condition_variable> |
36 | #include <set> |
37 | #include <string> |
38 | #include <tuple> |
39 | #include <unordered_map> |
40 | #include <unordered_set> |
41 | #include <utility> |
42 | #include <vector> |
43 | |
44 | #include "../utils.h" |
45 | |
46 | namespace tvm { |
47 | namespace auto_scheduler { |
48 | |
49 | /*! \brief Return whether the search task is targeting a CPU. */ |
50 | inline bool IsCPUTask(const SearchTask& task) { |
51 | return (task)->target->GetTargetDeviceType() == kDLCPU; |
52 | } |
53 | |
54 | /*! \brief Return whether the search task is targeting a GPU. */ |
55 | inline bool IsGPUTask(const SearchTask& task) { |
56 | int device_type = (task)->target->GetTargetDeviceType(); |
57 | return device_type == kDLCUDA || device_type == kDLOpenCL || device_type == kDLVulkan || |
58 | device_type == kDLMetal || device_type == kDLROCM || device_type == kOpenGL; |
59 | } |
60 | |
61 | /*! \brief Return whether the search task is targeting a Hexagon. */ |
62 | inline bool IsHexagonTask(const SearchTask& task) { |
63 | return (task)->target->GetTargetDeviceType() == kDLHexagon; |
64 | } |
65 | |
66 | /*! \brief Return whether the search task is targeting a CUDA GPU. */ |
67 | inline bool IsCUDATask(const SearchTask& task) { |
68 | return (task)->target->GetTargetDeviceType() == kDLCUDA; |
69 | } |
70 | |
71 | /*! \brief Return whether the search task is targeting a OpenCL GPU. */ |
72 | inline bool IsOpenCLTask(const SearchTask& task) { |
73 | return (task)->target->GetTargetDeviceType() == kDLOpenCL; |
74 | } |
75 | |
76 | /*! \brief Argsort. Order: largest to smallest */ |
77 | template <typename T> |
78 | inline std::vector<int> Argsort(const std::vector<T>& scores) { |
79 | std::vector<int> index; |
80 | index.reserve(scores.size()); |
81 | for (size_t i = 0; i < scores.size(); ++i) { |
82 | index.push_back(i); |
83 | } |
84 | auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; }; |
85 | std::sort(index.begin(), index.end(), cmp); |
86 | return index; |
87 | } |
88 | |
89 | /*! \brief Convert operation to stage id. */ |
90 | inline int OperationToStage(const te::Operation& op, const State& state) { |
91 | for (size_t i = 0; i < state->stages.size(); ++i) { |
92 | if (op == state->stages[i]->op) { |
93 | return i; |
94 | } |
95 | } |
96 | LOG(FATAL) << "Cannot find op: " << op; |
97 | } |
98 | |
99 | /********** Get Parameters **********/ |
100 | |
101 | /*! \brief Get an integer from a tvm str Map. */ |
102 | inline int GetIntParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) { |
103 | ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; |
104 | auto pint = attr_dict[key].as<IntImmNode>(); |
105 | ICHECK(pint != nullptr); |
106 | return pint->value; |
107 | } |
108 | |
109 | /*! \brief Get a double from a tvm str Map. */ |
110 | inline double GetDoubleParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) { |
111 | ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; |
112 | auto pdouble = attr_dict[key].as<FloatImmNode>(); |
113 | ICHECK(pdouble != nullptr); |
114 | return pdouble->value; |
115 | } |
116 | |
117 | /*! \brief Get a string from a tvm str Map. */ |
118 | inline std::string GetStringParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) { |
119 | ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; |
120 | const auto& target = attr_dict[key]; |
121 | if (auto pstr = target.as<StringImmNode>()) { |
122 | return pstr->value; |
123 | } |
124 | auto pstr = target.as<StringObj>(); |
125 | ICHECK(pstr != nullptr); |
126 | return pstr->data; |
127 | } |
128 | |
129 | /*! \brief Get a iterator name set from a tvm str Map. */ |
130 | inline std::set<std::string> GetIterNameSetParam(const Map<String, ObjectRef>& attr_dict, |
131 | const std::string& key) { |
132 | std::set<std::string> ret; |
133 | ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; |
134 | auto names = attr_dict[key].as<ArrayNode>(); |
135 | ICHECK(names != nullptr); |
136 | for (const auto& name : *names) { |
137 | ret.insert(name.as<StringObj>()->data); |
138 | } |
139 | return ret; |
140 | } |
141 | |
142 | /********** Checks with ComputeDAG **********/ |
143 | |
144 | /*! \brief Return whether an op is strictly-inlineable. */ |
145 | inline bool IsStrictlyInlineable(const SearchTask& task, const State& state, int stage_id) { |
146 | if (state->current_compute_dag) { |
147 | return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsStrictlyInlineable( |
148 | state->stages[stage_id]->op); |
149 | } else { |
150 | return task->compute_dag->access_analyzer.IsStrictlyInlineable(state->stages[stage_id]->op); |
151 | } |
152 | } |
153 | |
154 | /*! \brief Return whether an op is an output op. */ |
155 | inline bool IsOutputOp(const SearchTask& task, const State& state, int stage_id) { |
156 | if (state->current_compute_dag) { |
157 | return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsOutput( |
158 | state->stages[stage_id]->op); |
159 | } else { |
160 | return task->compute_dag->access_analyzer.IsOutput(state->stages[stage_id]->op); |
161 | } |
162 | } |
163 | |
164 | /*! \brief Return whether an op needs multi level tiling. */ |
165 | inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, int stage_id) { |
166 | if (state->current_compute_dag) { |
167 | return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.NeedsMultiLevelTiling( |
168 | state->stages[stage_id]->op); |
169 | } else { |
170 | return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(state->stages[stage_id]->op); |
171 | } |
172 | } |
173 | |
174 | /*! \brief Get all consumers for a stage. This function propagates the relation for inlined ops. */ |
175 | inline std::set<int> GetConsumers(const SearchTask& task, const State& state, int stage_id) { |
176 | std::unordered_set<te::Operation, ObjectHash, ObjectEqual> consumers; |
177 | std::set<int> ret; |
178 | |
179 | if (state->current_compute_dag) { |
180 | consumers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetConsumers( |
181 | state, state->stages[stage_id]->op); |
182 | } else { |
183 | consumers = task->compute_dag->access_analyzer.GetConsumers(state, state->stages[stage_id]->op); |
184 | } |
185 | |
186 | for (const auto& op : consumers) { |
187 | ret.insert(OperationToStage(op, state)); |
188 | } |
189 | return ret; |
190 | } |
191 | |
192 | /*! \brief Check if a stage has single consumer or all of its consumers share a common root, return |
193 | * the target consumer root or -1. */ |
194 | inline int GetSingleConsumerId(const SearchTask& task, const State& state, int stage_id) { |
195 | const std::set<int>& consumers = GetConsumers(task, state, stage_id); |
196 | if (consumers.empty()) { |
197 | return -1; |
198 | } |
199 | |
200 | if (consumers.size() == 1) { |
201 | return *consumers.begin(); |
202 | } else { |
203 | // Check all consumers share a common root |
204 | int common_root_id = -1; |
205 | bool mismatch = false; |
206 | for (const auto& consumer_stage_id : consumers) { |
207 | int root_id = -1; |
208 | if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kRoot) { |
209 | root_id = consumer_stage_id; |
210 | } else if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kIter) { |
211 | root_id = state->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; |
212 | } else { |
213 | LOG(FATAL) << "Invalid case" ; |
214 | } |
215 | |
216 | if (common_root_id == -1) { |
217 | common_root_id = root_id; |
218 | } else { |
219 | if (common_root_id != root_id) { |
220 | mismatch = true; |
221 | break; |
222 | } |
223 | } |
224 | } |
225 | |
226 | return mismatch ? -1 : common_root_id; |
227 | } |
228 | } |
229 | |
230 | /*! \brief Get all producers for a stage. This function propagates the relation for inlined ops. */ |
231 | inline std::set<int> GetProducers(const SearchTask& task, const State& state, int stage_id) { |
232 | std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers; |
233 | std::set<int> ret; |
234 | |
235 | if (state->current_compute_dag) { |
236 | producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetProducers( |
237 | state, state->stages[stage_id]->op); |
238 | } else { |
239 | producers = task->compute_dag->access_analyzer.GetProducers(state, state->stages[stage_id]->op); |
240 | } |
241 | |
242 | for (const auto& op : producers) { |
243 | ret.insert(OperationToStage(op, state)); |
244 | } |
245 | return ret; |
246 | } |
247 | |
248 | /*! \brief Get all producers for a stage. This function DOES NOT propagates the relation for |
249 | * inlined ops. */ |
250 | inline std::set<int> GetDirectProducers(const SearchTask& task, const State& state, int stage_id) { |
251 | std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers; |
252 | std::set<int> ret; |
253 | |
254 | if (state->current_compute_dag) { |
255 | producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetDirectProducers( |
256 | state->stages[stage_id]->op); |
257 | } else { |
258 | producers = task->compute_dag->access_analyzer.GetDirectProducers(state->stages[stage_id]->op); |
259 | } |
260 | |
261 | for (const auto& op : producers) { |
262 | ret.insert(OperationToStage(op, state)); |
263 | } |
264 | return ret; |
265 | } |
266 | |
267 | /*! \brief Get the number of common outer iterators. This function propagates the relation for |
268 | * chains with multiple ops. */ |
269 | inline int GetNumCommonOuterIterator(const SearchTask& task, const State& state, int stage_id, |
270 | int target_stage_id) { |
271 | if (state->current_compute_dag) { |
272 | return state->current_compute_dag.as<ComputeDAGNode>() |
273 | ->access_analyzer.GetNumCommonOuterIterator(state->stages[stage_id]->op, |
274 | state->stages[target_stage_id]->op); |
275 | } else { |
276 | return task->compute_dag->access_analyzer.GetNumCommonOuterIterator( |
277 | state->stages[stage_id]->op, state->stages[target_stage_id]->op); |
278 | } |
279 | } |
280 | |
281 | /*! \brief Return whether two ops are elementwise-matched. */ |
282 | inline bool ElementwiseMatch(const SearchTask& task, const State& state, int stage_id, |
283 | int target_stage_id) { |
284 | const auto& op = state->stages[stage_id]->op; |
285 | const auto& target_op = state->stages[target_stage_id]->op; |
286 | if (state->current_compute_dag) { |
287 | return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.ElementWiseMatch( |
288 | op, target_op); |
289 | } else { |
290 | return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op); |
291 | } |
292 | } |
293 | |
294 | /********** Get informations from Stage/Iterator **********/ |
295 | |
296 | /*! \brief Return the extent of an iterator. */ |
297 | inline int64_t GetExtent(const Iterator& it) { |
298 | if (it->range.defined()) { |
299 | if (auto pint = it->range->extent.as<IntImmNode>()) { |
300 | return pint->value; |
301 | } |
302 | } |
303 | return -1; |
304 | } |
305 | |
306 | /*! \brief Compute the product of lengths of all space iters and all reduce iters, respectively. */ |
307 | inline std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const Stage& stage) { |
308 | int64_t cum_space_len = 1, cum_reduce_len = 1; |
309 | for (const auto& iter : stage->iters) { |
310 | if (iter->iter_kind == IteratorKind::kSpatial) { |
311 | cum_space_len *= GetExtent(iter); |
312 | } else if (iter->iter_kind == IteratorKind::kReduction) { |
313 | cum_reduce_len *= GetExtent(iter); |
314 | } |
315 | } |
316 | return std::make_pair(cum_space_len, cum_reduce_len); |
317 | } |
318 | |
319 | /*! \brief Return whether this stage needs rfactor. */ |
320 | inline bool NeedsRfactor(const SearchTask& task, const State& state, int stage_id) { |
321 | const auto& op = state->stages[stage_id]->op; |
322 | if (op->IsInstance<te::ComputeOpNode>()) { |
323 | // Compute the product of lengths of all space iters and all reduce iters |
324 | int cum_space_len, cum_reduce_len; |
325 | std::tie(cum_space_len, cum_reduce_len) = |
326 | GetCumulativeSpaceAndReductionLength(state->stages[stage_id]); |
327 | |
328 | if (NeedsMultilevelTiling(task, state, stage_id)) { |
329 | // Do not use rfactor if we have enough parallelism on space iters |
330 | if (cum_space_len > cum_reduce_len || cum_space_len > task->hardware_params->num_cores * 16) { |
331 | return false; |
332 | } else { |
333 | return true; |
334 | } |
335 | } else if (cum_reduce_len > 1) { |
336 | // Always try rfactor for reduction ops |
337 | return cum_reduce_len > task->hardware_params->num_cores; |
338 | } |
339 | } |
340 | |
341 | return false; |
342 | } |
343 | |
344 | /*! \brief Return whether the stage has reduce iterators. */ |
345 | inline bool HasReduceIter(const Stage& stage) { |
346 | for (const auto& iter : stage->iters) { |
347 | if (iter->iter_kind != IteratorKind::kSpatial) { |
348 | return true; |
349 | } |
350 | } |
351 | return false; |
352 | } |
353 | |
354 | /*! \brief Return whether the stage has specific annotated iterators. */ |
355 | inline bool HasAnnotatedIter(const Stage& stage, IteratorAnnotation type) { |
356 | for (const auto& iter : stage->iters) { |
357 | if (iter->annotation == type) { |
358 | return true; |
359 | } |
360 | } |
361 | return false; |
362 | } |
363 | |
364 | /*! \brief Return whether the stage has only one consumer and they are elementwise-matched. */ |
365 | inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const State& state, |
366 | int stage_id, int* target_stage_id = nullptr) { |
367 | // Temporal object to be used if the input pointer is nullptr |
368 | int temp_target_stage_id; |
369 | if (target_stage_id == nullptr) { |
370 | target_stage_id = &temp_target_stage_id; |
371 | } |
372 | const std::set<int>& consumers = GetConsumers(task, state, stage_id); |
373 | if (consumers.size() == 1) { |
374 | *target_stage_id = *consumers.begin(); |
375 | if (ElementwiseMatch(task, state, stage_id, *target_stage_id) && |
376 | (!(HasReduceIter(state->stages[stage_id]) && |
377 | HasReduceIter(state->stages[*target_stage_id]))) && |
378 | (!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared" ))) { |
379 | return true; |
380 | } |
381 | } |
382 | return false; |
383 | } |
384 | |
385 | /*! \brief Return whether the step changes the number of stages */ |
386 | inline bool IsStageNumberChangingStep(const Step& step) { |
387 | return step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>() || |
388 | step->IsInstance<RfactorStepNode>(); |
389 | } |
390 | |
391 | /*! \brief Return whether the state does cache_read for stage_id. */ |
392 | inline bool HasCacheReadStage(const State& s, int stage_id) { |
393 | for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { |
394 | if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) { |
395 | if (stage_id == ps->stage_id) { |
396 | return true; |
397 | } |
398 | } |
399 | |
400 | if (IsStageNumberChangingStep(s->transform_steps[i])) { |
401 | if (stage_id > s->transform_steps[i]->stage_id) { |
402 | stage_id--; |
403 | } |
404 | } |
405 | } |
406 | return false; |
407 | } |
408 | |
409 | /*! \brief Return whether the state does cache_write for stage_id. */ |
410 | inline bool HasCacheWriteStage(const State& s, int stage_id) { |
411 | for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { |
412 | if (auto ps = s->transform_steps[i].as<CacheWriteStepNode>()) { |
413 | if (stage_id == ps->stage_id) { |
414 | return true; |
415 | } |
416 | } |
417 | |
418 | if (IsStageNumberChangingStep(s->transform_steps[i])) { |
419 | if (stage_id > s->transform_steps[i]->stage_id) { |
420 | stage_id--; |
421 | } |
422 | } |
423 | } |
424 | return false; |
425 | } |
426 | |
427 | /*! \brief Return whether the state does rfactor for stage_id. */ |
428 | inline bool HasRfactorStage(const State& s, int stage_id) { |
429 | for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { |
430 | if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) { |
431 | if (stage_id == ps->stage_id) { |
432 | return true; |
433 | } |
434 | } |
435 | |
436 | if (IsStageNumberChangingStep(s->transform_steps[i])) { |
437 | if (stage_id > s->transform_steps[i]->stage_id) { |
438 | stage_id--; |
439 | } |
440 | } |
441 | } |
442 | return false; |
443 | } |
444 | |
445 | /*! \brief Return whether the stage does cross thread reduction. */ |
446 | inline bool HasCrossThreadReduction(const State& state, int stage_id) { |
447 | std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) { |
448 | for (const auto& iter : in_stage->iters) { |
449 | if (iter->annotation == IteratorAnnotation::kThreadX && |
450 | iter->iter_kind == IteratorKind::kReduction) { |
451 | return true; |
452 | } |
453 | } |
454 | return false; |
455 | }; |
456 | |
457 | // Check the stage itself |
458 | if (check_stage(state->stages[stage_id])) { |
459 | return true; |
460 | } |
461 | |
462 | // Check the attached stages |
463 | for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) { |
464 | const auto& res = |
465 | state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); |
466 | if (res != state->attach_map->iter_to_attached_stages.end()) { |
467 | for (int attached_stage_id : res->second) { |
468 | if (check_stage(state->stages[attached_stage_id])) { |
469 | return true; |
470 | } |
471 | } |
472 | } |
473 | } |
474 | |
475 | return false; |
476 | } |
477 | |
478 | /*! \brief Return whether the stage has been tiled already. */ |
479 | inline bool IsTiled(const Stage& stage) { |
480 | auto op = stage->op.as<te::ComputeOpNode>(); |
481 | ICHECK(op != nullptr); |
482 | return stage->iters.size() != op->axis.size() + op->reduce_axis.size(); |
483 | } |
484 | |
485 | /*! \brief Extract primitive iterators from a nested fused or splitted iterator's name. */ |
486 | inline void (const std::string& name, std::set<std::string>* rets) { |
487 | size_t last_pos = 0; |
488 | for (size_t i = 0; i < name.size(); ++i) { |
489 | if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split |
490 | if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { |
491 | rets->insert(name.substr(last_pos, i - last_pos)); |
492 | } |
493 | last_pos = i + 1; |
494 | } |
495 | } |
496 | |
497 | if (last_pos < name.size() && !isdigit(name[last_pos]) && name[last_pos] != '@' && |
498 | name[last_pos] != '.') { |
499 | rets->insert(name.substr(last_pos, name.size() - last_pos)); |
500 | } |
501 | } |
502 | |
503 | /*! \brief Get the last reduce iterator in the outermost reduce tile. */ |
504 | inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { |
505 | auto pop = stage->op.as<te::ComputeOpNode>(); |
506 | ICHECK(pop != nullptr); |
507 | std::set<std::string> original_names; |
508 | |
509 | const std::set<std::string>& no_split_at_inner_name_set = |
510 | stage->op->attrs.count(SearchPolicyKey::no_split_at_inner) |
511 | ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) |
512 | : std::set<std::string>(); |
513 | size_t reduce_axis_size = 0; |
514 | for (const auto axis : pop->reduce_axis) { |
515 | if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { |
516 | reduce_axis_size++; |
517 | } |
518 | } |
519 | if (reduce_axis_size) { |
520 | for (const auto& iter : stage->iters) { |
521 | if (iter->iter_kind == IteratorKind::kReduction) { |
522 | ExtractOriginalIterators(iter->name, &original_names); |
523 | if (original_names.size() == reduce_axis_size) { |
524 | return iter; |
525 | } |
526 | } |
527 | } |
528 | } else { |
529 | // Return the first reduce iterator |
530 | for (const auto& iter : stage->iters) { |
531 | if (iter->iter_kind == IteratorKind::kReduction) { |
532 | return iter; |
533 | } |
534 | } |
535 | } |
536 | |
537 | LOG(FATAL) << "Cannot find the iterator." ; |
538 | } |
539 | |
540 | /*! \brief Get the target stage id of a history step in the new state. |
541 | * We need this because the stage_id in the history may be stale due to later steps */ |
542 | inline int GetTargetStageIDInState(const State& s, int step_id) { |
543 | int stage_inc = 0; |
544 | |
545 | for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) { |
546 | if (IsStageNumberChangingStep(s->transform_steps[i])) { |
547 | if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc) |
548 | stage_inc++; |
549 | } |
550 | } |
551 | return s->transform_steps[step_id]->stage_id + stage_inc; |
552 | } |
553 | |
554 | /*! \brief Get all split steps for one stage. */ |
555 | inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) { |
556 | for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { |
557 | if (auto ps = s->transform_steps[i].as<SplitStepNode>()) { |
558 | if (stage_id == ps->stage_id) { |
559 | split_step_ids->push_back(i); |
560 | } |
561 | } |
562 | |
563 | if (IsStageNumberChangingStep(s->transform_steps[i])) { |
564 | if (stage_id > s->transform_steps[i]->stage_id) { |
565 | stage_id--; |
566 | } |
567 | } |
568 | } |
569 | } |
570 | |
571 | /*! \brief Fuse all reduction iterators. */ |
572 | inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter, |
573 | Array<Iterator>* space_iters, |
574 | Array<Iterator>* reduce_iters) { |
575 | space_iters->clear(); |
576 | reduce_iters->clear(); |
577 | |
578 | for (const auto& iter : state->stages[stage_id]->iters) { |
579 | if (iter->iter_kind == IteratorKind::kSpatial) { |
580 | space_iters->push_back(iter); |
581 | } else if (iter->iter_kind == IteratorKind::kReduction) { |
582 | reduce_iters->push_back(iter); |
583 | } |
584 | } |
585 | |
586 | ICHECK(!reduce_iters->empty()); |
587 | State tmp_s = state; |
588 | if (reduce_iters->size() > 1) { |
589 | *fused_iter = tmp_s.fuse(stage_id, *reduce_iters); |
590 | } else { |
591 | *fused_iter = (*reduce_iters)[0]; |
592 | } |
593 | return tmp_s; |
594 | } |
595 | |
596 | /*! \brief Fuse all outer level space iterators. */ |
597 | inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) { |
598 | std::vector<Iterator> to_fuse; |
599 | for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) { |
600 | const auto& it = state->stages[stage_id]->iters[iter_id]; |
601 | // Stop at reduce iterator or annotated iterator |
602 | if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) { |
603 | break; |
604 | } |
605 | // Stop at compute_at attach point |
606 | if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) { |
607 | break; |
608 | } |
609 | to_fuse.push_back(it); |
610 | } |
611 | |
612 | State tmp_s = state; |
613 | if (to_fuse.size() == 1) { |
614 | *fused_iter = to_fuse[0]; |
615 | } else { |
616 | *fused_iter = tmp_s.fuse(stage_id, to_fuse); |
617 | } |
618 | return tmp_s; |
619 | } |
620 | |
621 | /*! \brief Random sample states. */ |
622 | inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen, |
623 | size_t out_size) { |
624 | Array<State> out_states; |
625 | for (size_t i = 0; i < out_size; i++) { |
626 | out_states.push_back(in_states[(*random_gen)() % in_states.size()]); |
627 | } |
628 | return out_states; |
629 | } |
630 | |
631 | /*! \brief Compute prefix-sum probabiilty based on the given weights */ |
632 | inline void ComputePrefixSumProb(const std::vector<float>& weights, |
633 | std::vector<double>* prefix_sum_probs) { |
634 | // Compute selection probabilities. |
635 | float sum = 0.0; |
636 | prefix_sum_probs->resize(weights.size()); |
637 | for (size_t i = 0; i < weights.size(); ++i) { |
638 | sum += std::max(weights[i], 0.0f); |
639 | (*prefix_sum_probs)[i] = sum; |
640 | } |
641 | for (size_t i = 0; i < weights.size(); ++i) { |
642 | (*prefix_sum_probs)[i] /= sum; |
643 | } |
644 | } |
645 | |
646 | /*! \brief Random choose an index according to a prefix sum probability. */ |
647 | inline int RandomChoose(const std::vector<double>& prefix_sum_probs, std::mt19937* random_gen) { |
648 | std::uniform_real_distribution<> dis(0.0, 1.0); |
649 | double x = dis(*random_gen); |
650 | |
651 | ICHECK(!prefix_sum_probs.empty()); |
652 | |
653 | return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - |
654 | prefix_sum_probs.begin(); |
655 | } |
656 | |
657 | /*! \brief Print a title */ |
658 | inline void PrintTitle(const std::string& title, int verbose) { |
659 | StdCout(verbose) << Chars('-', 70) << "\n" |
660 | << Chars('-', 30) << " [ " << title << " ]\n" |
661 | << Chars('-', 70) << std::endl; |
662 | } |
663 | |
664 | /*! |
665 | * \brief Enumerate all possible factorization schemes for splitting an axes. |
666 | * \note This class will memorize the results for reuse. |
667 | */ |
668 | class SplitFactorizationMemo { |
669 | public: |
670 | using QueryKey = std::tuple<int, int, int>; |
671 | |
672 | const Array<Array<Integer>>& GetFactorizationSchemes(int extent, int n_lengths, |
673 | int max_innermost_factor); |
674 | const std::vector<int>& GetFactors(int n); |
675 | |
676 | private: |
677 | void DfsEnumerate(int now, int remaining_length, int max_innermost_factor); |
678 | |
679 | std::unordered_map<QueryKey, Array<Array<Integer>>> memory_; |
680 | |
681 | int n_lengths_; |
682 | Array<Integer> tmp_stack_; |
683 | Array<Array<Integer>>* results_; |
684 | std::unordered_map<int, std::vector<int>> factor_memory_; |
685 | }; |
686 | |
687 | /*! \brief Get the indexes of SplitStep that processes on spatial iterator. */ |
688 | Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id); |
689 | |
690 | /*! \brief Get the possible compute locations for a stage. */ |
691 | std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task, |
692 | const State& state, int stage_id); |
693 | |
694 | // Apply multi-level tiling structure according to a string format, |
695 | // where "S" stands a space level, "R" stands for a reduction level. |
696 | // For example, if the format is "SSRSRS", then we will |
697 | // use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3 |
698 | // For example, if apply "SSRSRS" to matrix multiplication, |
699 | // we have space iterators i and j, reduce iterator k. |
700 | // Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3 |
701 | State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, |
702 | std::vector<int>* spatial_split_step_ids = nullptr); |
703 | |
704 | // Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep |
705 | State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids, |
706 | int n_split); |
707 | |
708 | // Prune invalid states and return the results in-place. |
709 | void PruneInvalidState(const SearchTask& task, Array<State>* states); |
710 | |
711 | } // namespace auto_scheduler |
712 | } // namespace tvm |
713 | |
714 | #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ |
715 | |