1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/utils.h
22 * \brief Common utilities for search policies.
23 */
24
25#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
26#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
27
28#include <dmlc/common.h>
29#include <tvm/auto_scheduler/loop_state.h>
30#include <tvm/auto_scheduler/search_policy.h>
31#include <tvm/ir/expr.h>
32#include <tvm/te/operation.h>
33
34#include <algorithm>
35#include <condition_variable>
36#include <set>
37#include <string>
38#include <tuple>
39#include <unordered_map>
40#include <unordered_set>
41#include <utility>
42#include <vector>
43
44#include "../utils.h"
45
46namespace tvm {
47namespace auto_scheduler {
48
49/*! \brief Return whether the search task is targeting a CPU. */
50inline bool IsCPUTask(const SearchTask& task) {
51 return (task)->target->GetTargetDeviceType() == kDLCPU;
52}
53
54/*! \brief Return whether the search task is targeting a GPU. */
55inline bool IsGPUTask(const SearchTask& task) {
56 int device_type = (task)->target->GetTargetDeviceType();
57 return device_type == kDLCUDA || device_type == kDLOpenCL || device_type == kDLVulkan ||
58 device_type == kDLMetal || device_type == kDLROCM || device_type == kOpenGL;
59}
60
61/*! \brief Return whether the search task is targeting a Hexagon. */
62inline bool IsHexagonTask(const SearchTask& task) {
63 return (task)->target->GetTargetDeviceType() == kDLHexagon;
64}
65
66/*! \brief Return whether the search task is targeting a CUDA GPU. */
67inline bool IsCUDATask(const SearchTask& task) {
68 return (task)->target->GetTargetDeviceType() == kDLCUDA;
69}
70
71/*! \brief Return whether the search task is targeting a OpenCL GPU. */
72inline bool IsOpenCLTask(const SearchTask& task) {
73 return (task)->target->GetTargetDeviceType() == kDLOpenCL;
74}
75
76/*! \brief Argsort. Order: largest to smallest */
77template <typename T>
78inline std::vector<int> Argsort(const std::vector<T>& scores) {
79 std::vector<int> index;
80 index.reserve(scores.size());
81 for (size_t i = 0; i < scores.size(); ++i) {
82 index.push_back(i);
83 }
84 auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; };
85 std::sort(index.begin(), index.end(), cmp);
86 return index;
87}
88
89/*! \brief Convert operation to stage id. */
90inline int OperationToStage(const te::Operation& op, const State& state) {
91 for (size_t i = 0; i < state->stages.size(); ++i) {
92 if (op == state->stages[i]->op) {
93 return i;
94 }
95 }
96 LOG(FATAL) << "Cannot find op: " << op;
97}
98
99/********** Get Parameters **********/
100
101/*! \brief Get an integer from a tvm str Map. */
102inline int GetIntParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
103 ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
104 auto pint = attr_dict[key].as<IntImmNode>();
105 ICHECK(pint != nullptr);
106 return pint->value;
107}
108
109/*! \brief Get a double from a tvm str Map. */
110inline double GetDoubleParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
111 ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
112 auto pdouble = attr_dict[key].as<FloatImmNode>();
113 ICHECK(pdouble != nullptr);
114 return pdouble->value;
115}
116
117/*! \brief Get a string from a tvm str Map. */
118inline std::string GetStringParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
119 ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
120 const auto& target = attr_dict[key];
121 if (auto pstr = target.as<StringImmNode>()) {
122 return pstr->value;
123 }
124 auto pstr = target.as<StringObj>();
125 ICHECK(pstr != nullptr);
126 return pstr->data;
127}
128
129/*! \brief Get a iterator name set from a tvm str Map. */
130inline std::set<std::string> GetIterNameSetParam(const Map<String, ObjectRef>& attr_dict,
131 const std::string& key) {
132 std::set<std::string> ret;
133 ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
134 auto names = attr_dict[key].as<ArrayNode>();
135 ICHECK(names != nullptr);
136 for (const auto& name : *names) {
137 ret.insert(name.as<StringObj>()->data);
138 }
139 return ret;
140}
141
142/********** Checks with ComputeDAG **********/
143
144/*! \brief Return whether an op is strictly-inlineable. */
145inline bool IsStrictlyInlineable(const SearchTask& task, const State& state, int stage_id) {
146 if (state->current_compute_dag) {
147 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsStrictlyInlineable(
148 state->stages[stage_id]->op);
149 } else {
150 return task->compute_dag->access_analyzer.IsStrictlyInlineable(state->stages[stage_id]->op);
151 }
152}
153
154/*! \brief Return whether an op is an output op. */
155inline bool IsOutputOp(const SearchTask& task, const State& state, int stage_id) {
156 if (state->current_compute_dag) {
157 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsOutput(
158 state->stages[stage_id]->op);
159 } else {
160 return task->compute_dag->access_analyzer.IsOutput(state->stages[stage_id]->op);
161 }
162}
163
164/*! \brief Return whether an op needs multi level tiling. */
165inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, int stage_id) {
166 if (state->current_compute_dag) {
167 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.NeedsMultiLevelTiling(
168 state->stages[stage_id]->op);
169 } else {
170 return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(state->stages[stage_id]->op);
171 }
172}
173
174/*! \brief Get all consumers for a stage. This function propagates the relation for inlined ops. */
175inline std::set<int> GetConsumers(const SearchTask& task, const State& state, int stage_id) {
176 std::unordered_set<te::Operation, ObjectHash, ObjectEqual> consumers;
177 std::set<int> ret;
178
179 if (state->current_compute_dag) {
180 consumers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetConsumers(
181 state, state->stages[stage_id]->op);
182 } else {
183 consumers = task->compute_dag->access_analyzer.GetConsumers(state, state->stages[stage_id]->op);
184 }
185
186 for (const auto& op : consumers) {
187 ret.insert(OperationToStage(op, state));
188 }
189 return ret;
190}
191
192/*! \brief Check if a stage has single consumer or all of its consumers share a common root, return
193 * the target consumer root or -1. */
194inline int GetSingleConsumerId(const SearchTask& task, const State& state, int stage_id) {
195 const std::set<int>& consumers = GetConsumers(task, state, stage_id);
196 if (consumers.empty()) {
197 return -1;
198 }
199
200 if (consumers.size() == 1) {
201 return *consumers.begin();
202 } else {
203 // Check all consumers share a common root
204 int common_root_id = -1;
205 bool mismatch = false;
206 for (const auto& consumer_stage_id : consumers) {
207 int root_id = -1;
208 if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kRoot) {
209 root_id = consumer_stage_id;
210 } else if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kIter) {
211 root_id = state->attach_map->stage_to_attach_iter.at(consumer_stage_id).first;
212 } else {
213 LOG(FATAL) << "Invalid case";
214 }
215
216 if (common_root_id == -1) {
217 common_root_id = root_id;
218 } else {
219 if (common_root_id != root_id) {
220 mismatch = true;
221 break;
222 }
223 }
224 }
225
226 return mismatch ? -1 : common_root_id;
227 }
228}
229
230/*! \brief Get all producers for a stage. This function propagates the relation for inlined ops. */
231inline std::set<int> GetProducers(const SearchTask& task, const State& state, int stage_id) {
232 std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
233 std::set<int> ret;
234
235 if (state->current_compute_dag) {
236 producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetProducers(
237 state, state->stages[stage_id]->op);
238 } else {
239 producers = task->compute_dag->access_analyzer.GetProducers(state, state->stages[stage_id]->op);
240 }
241
242 for (const auto& op : producers) {
243 ret.insert(OperationToStage(op, state));
244 }
245 return ret;
246}
247
248/*! \brief Get all producers for a stage. This function DOES NOT propagates the relation for
249 * inlined ops. */
250inline std::set<int> GetDirectProducers(const SearchTask& task, const State& state, int stage_id) {
251 std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
252 std::set<int> ret;
253
254 if (state->current_compute_dag) {
255 producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetDirectProducers(
256 state->stages[stage_id]->op);
257 } else {
258 producers = task->compute_dag->access_analyzer.GetDirectProducers(state->stages[stage_id]->op);
259 }
260
261 for (const auto& op : producers) {
262 ret.insert(OperationToStage(op, state));
263 }
264 return ret;
265}
266
267/*! \brief Get the number of common outer iterators. This function propagates the relation for
268 * chains with multiple ops. */
269inline int GetNumCommonOuterIterator(const SearchTask& task, const State& state, int stage_id,
270 int target_stage_id) {
271 if (state->current_compute_dag) {
272 return state->current_compute_dag.as<ComputeDAGNode>()
273 ->access_analyzer.GetNumCommonOuterIterator(state->stages[stage_id]->op,
274 state->stages[target_stage_id]->op);
275 } else {
276 return task->compute_dag->access_analyzer.GetNumCommonOuterIterator(
277 state->stages[stage_id]->op, state->stages[target_stage_id]->op);
278 }
279}
280
281/*! \brief Return whether two ops are elementwise-matched. */
282inline bool ElementwiseMatch(const SearchTask& task, const State& state, int stage_id,
283 int target_stage_id) {
284 const auto& op = state->stages[stage_id]->op;
285 const auto& target_op = state->stages[target_stage_id]->op;
286 if (state->current_compute_dag) {
287 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.ElementWiseMatch(
288 op, target_op);
289 } else {
290 return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op);
291 }
292}
293
294/********** Get informations from Stage/Iterator **********/
295
296/*! \brief Return the extent of an iterator. */
297inline int64_t GetExtent(const Iterator& it) {
298 if (it->range.defined()) {
299 if (auto pint = it->range->extent.as<IntImmNode>()) {
300 return pint->value;
301 }
302 }
303 return -1;
304}
305
306/*! \brief Compute the product of lengths of all space iters and all reduce iters, respectively. */
307inline std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const Stage& stage) {
308 int64_t cum_space_len = 1, cum_reduce_len = 1;
309 for (const auto& iter : stage->iters) {
310 if (iter->iter_kind == IteratorKind::kSpatial) {
311 cum_space_len *= GetExtent(iter);
312 } else if (iter->iter_kind == IteratorKind::kReduction) {
313 cum_reduce_len *= GetExtent(iter);
314 }
315 }
316 return std::make_pair(cum_space_len, cum_reduce_len);
317}
318
319/*! \brief Return whether this stage needs rfactor. */
320inline bool NeedsRfactor(const SearchTask& task, const State& state, int stage_id) {
321 const auto& op = state->stages[stage_id]->op;
322 if (op->IsInstance<te::ComputeOpNode>()) {
323 // Compute the product of lengths of all space iters and all reduce iters
324 int cum_space_len, cum_reduce_len;
325 std::tie(cum_space_len, cum_reduce_len) =
326 GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
327
328 if (NeedsMultilevelTiling(task, state, stage_id)) {
329 // Do not use rfactor if we have enough parallelism on space iters
330 if (cum_space_len > cum_reduce_len || cum_space_len > task->hardware_params->num_cores * 16) {
331 return false;
332 } else {
333 return true;
334 }
335 } else if (cum_reduce_len > 1) {
336 // Always try rfactor for reduction ops
337 return cum_reduce_len > task->hardware_params->num_cores;
338 }
339 }
340
341 return false;
342}
343
344/*! \brief Return whether the stage has reduce iterators. */
345inline bool HasReduceIter(const Stage& stage) {
346 for (const auto& iter : stage->iters) {
347 if (iter->iter_kind != IteratorKind::kSpatial) {
348 return true;
349 }
350 }
351 return false;
352}
353
354/*! \brief Return whether the stage has specific annotated iterators. */
355inline bool HasAnnotatedIter(const Stage& stage, IteratorAnnotation type) {
356 for (const auto& iter : stage->iters) {
357 if (iter->annotation == type) {
358 return true;
359 }
360 }
361 return false;
362}
363
364/*! \brief Return whether the stage has only one consumer and they are elementwise-matched. */
365inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const State& state,
366 int stage_id, int* target_stage_id = nullptr) {
367 // Temporal object to be used if the input pointer is nullptr
368 int temp_target_stage_id;
369 if (target_stage_id == nullptr) {
370 target_stage_id = &temp_target_stage_id;
371 }
372 const std::set<int>& consumers = GetConsumers(task, state, stage_id);
373 if (consumers.size() == 1) {
374 *target_stage_id = *consumers.begin();
375 if (ElementwiseMatch(task, state, stage_id, *target_stage_id) &&
376 (!(HasReduceIter(state->stages[stage_id]) &&
377 HasReduceIter(state->stages[*target_stage_id]))) &&
378 (!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared"))) {
379 return true;
380 }
381 }
382 return false;
383}
384
385/*! \brief Return whether the step changes the number of stages */
386inline bool IsStageNumberChangingStep(const Step& step) {
387 return step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>() ||
388 step->IsInstance<RfactorStepNode>();
389}
390
391/*! \brief Return whether the state does cache_read for stage_id. */
392inline bool HasCacheReadStage(const State& s, int stage_id) {
393 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
394 if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) {
395 if (stage_id == ps->stage_id) {
396 return true;
397 }
398 }
399
400 if (IsStageNumberChangingStep(s->transform_steps[i])) {
401 if (stage_id > s->transform_steps[i]->stage_id) {
402 stage_id--;
403 }
404 }
405 }
406 return false;
407}
408
409/*! \brief Return whether the state does cache_write for stage_id. */
410inline bool HasCacheWriteStage(const State& s, int stage_id) {
411 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
412 if (auto ps = s->transform_steps[i].as<CacheWriteStepNode>()) {
413 if (stage_id == ps->stage_id) {
414 return true;
415 }
416 }
417
418 if (IsStageNumberChangingStep(s->transform_steps[i])) {
419 if (stage_id > s->transform_steps[i]->stage_id) {
420 stage_id--;
421 }
422 }
423 }
424 return false;
425}
426
427/*! \brief Return whether the state does rfactor for stage_id. */
428inline bool HasRfactorStage(const State& s, int stage_id) {
429 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
430 if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) {
431 if (stage_id == ps->stage_id) {
432 return true;
433 }
434 }
435
436 if (IsStageNumberChangingStep(s->transform_steps[i])) {
437 if (stage_id > s->transform_steps[i]->stage_id) {
438 stage_id--;
439 }
440 }
441 }
442 return false;
443}
444
445/*! \brief Return whether the stage does cross thread reduction. */
446inline bool HasCrossThreadReduction(const State& state, int stage_id) {
447 std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) {
448 for (const auto& iter : in_stage->iters) {
449 if (iter->annotation == IteratorAnnotation::kThreadX &&
450 iter->iter_kind == IteratorKind::kReduction) {
451 return true;
452 }
453 }
454 return false;
455 };
456
457 // Check the stage itself
458 if (check_stage(state->stages[stage_id])) {
459 return true;
460 }
461
462 // Check the attached stages
463 for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) {
464 const auto& res =
465 state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
466 if (res != state->attach_map->iter_to_attached_stages.end()) {
467 for (int attached_stage_id : res->second) {
468 if (check_stage(state->stages[attached_stage_id])) {
469 return true;
470 }
471 }
472 }
473 }
474
475 return false;
476}
477
478/*! \brief Return whether the stage has been tiled already. */
479inline bool IsTiled(const Stage& stage) {
480 auto op = stage->op.as<te::ComputeOpNode>();
481 ICHECK(op != nullptr);
482 return stage->iters.size() != op->axis.size() + op->reduce_axis.size();
483}
484
485/*! \brief Extract primitive iterators from a nested fused or splitted iterator's name. */
486inline void ExtractOriginalIterators(const std::string& name, std::set<std::string>* rets) {
487 size_t last_pos = 0;
488 for (size_t i = 0; i < name.size(); ++i) {
489 if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split
490 if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') {
491 rets->insert(name.substr(last_pos, i - last_pos));
492 }
493 last_pos = i + 1;
494 }
495 }
496
497 if (last_pos < name.size() && !isdigit(name[last_pos]) && name[last_pos] != '@' &&
498 name[last_pos] != '.') {
499 rets->insert(name.substr(last_pos, name.size() - last_pos));
500 }
501}
502
503/*! \brief Get the last reduce iterator in the outermost reduce tile. */
504inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
505 auto pop = stage->op.as<te::ComputeOpNode>();
506 ICHECK(pop != nullptr);
507 std::set<std::string> original_names;
508
509 const std::set<std::string>& no_split_at_inner_name_set =
510 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
511 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
512 : std::set<std::string>();
513 size_t reduce_axis_size = 0;
514 for (const auto axis : pop->reduce_axis) {
515 if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
516 reduce_axis_size++;
517 }
518 }
519 if (reduce_axis_size) {
520 for (const auto& iter : stage->iters) {
521 if (iter->iter_kind == IteratorKind::kReduction) {
522 ExtractOriginalIterators(iter->name, &original_names);
523 if (original_names.size() == reduce_axis_size) {
524 return iter;
525 }
526 }
527 }
528 } else {
529 // Return the first reduce iterator
530 for (const auto& iter : stage->iters) {
531 if (iter->iter_kind == IteratorKind::kReduction) {
532 return iter;
533 }
534 }
535 }
536
537 LOG(FATAL) << "Cannot find the iterator.";
538}
539
540/*! \brief Get the target stage id of a history step in the new state.
541 * We need this because the stage_id in the history may be stale due to later steps */
542inline int GetTargetStageIDInState(const State& s, int step_id) {
543 int stage_inc = 0;
544
545 for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) {
546 if (IsStageNumberChangingStep(s->transform_steps[i])) {
547 if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc)
548 stage_inc++;
549 }
550 }
551 return s->transform_steps[step_id]->stage_id + stage_inc;
552}
553
554/*! \brief Get all split steps for one stage. */
555inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) {
556 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
557 if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
558 if (stage_id == ps->stage_id) {
559 split_step_ids->push_back(i);
560 }
561 }
562
563 if (IsStageNumberChangingStep(s->transform_steps[i])) {
564 if (stage_id > s->transform_steps[i]->stage_id) {
565 stage_id--;
566 }
567 }
568 }
569}
570
571/*! \brief Fuse all reduction iterators. */
572inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter,
573 Array<Iterator>* space_iters,
574 Array<Iterator>* reduce_iters) {
575 space_iters->clear();
576 reduce_iters->clear();
577
578 for (const auto& iter : state->stages[stage_id]->iters) {
579 if (iter->iter_kind == IteratorKind::kSpatial) {
580 space_iters->push_back(iter);
581 } else if (iter->iter_kind == IteratorKind::kReduction) {
582 reduce_iters->push_back(iter);
583 }
584 }
585
586 ICHECK(!reduce_iters->empty());
587 State tmp_s = state;
588 if (reduce_iters->size() > 1) {
589 *fused_iter = tmp_s.fuse(stage_id, *reduce_iters);
590 } else {
591 *fused_iter = (*reduce_iters)[0];
592 }
593 return tmp_s;
594}
595
596/*! \brief Fuse all outer level space iterators. */
597inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) {
598 std::vector<Iterator> to_fuse;
599 for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) {
600 const auto& it = state->stages[stage_id]->iters[iter_id];
601 // Stop at reduce iterator or annotated iterator
602 if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
603 break;
604 }
605 // Stop at compute_at attach point
606 if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) {
607 break;
608 }
609 to_fuse.push_back(it);
610 }
611
612 State tmp_s = state;
613 if (to_fuse.size() == 1) {
614 *fused_iter = to_fuse[0];
615 } else {
616 *fused_iter = tmp_s.fuse(stage_id, to_fuse);
617 }
618 return tmp_s;
619}
620
621/*! \brief Random sample states. */
622inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen,
623 size_t out_size) {
624 Array<State> out_states;
625 for (size_t i = 0; i < out_size; i++) {
626 out_states.push_back(in_states[(*random_gen)() % in_states.size()]);
627 }
628 return out_states;
629}
630
631/*! \brief Compute prefix-sum probabiilty based on the given weights */
632inline void ComputePrefixSumProb(const std::vector<float>& weights,
633 std::vector<double>* prefix_sum_probs) {
634 // Compute selection probabilities.
635 float sum = 0.0;
636 prefix_sum_probs->resize(weights.size());
637 for (size_t i = 0; i < weights.size(); ++i) {
638 sum += std::max(weights[i], 0.0f);
639 (*prefix_sum_probs)[i] = sum;
640 }
641 for (size_t i = 0; i < weights.size(); ++i) {
642 (*prefix_sum_probs)[i] /= sum;
643 }
644}
645
646/*! \brief Random choose an index according to a prefix sum probability. */
647inline int RandomChoose(const std::vector<double>& prefix_sum_probs, std::mt19937* random_gen) {
648 std::uniform_real_distribution<> dis(0.0, 1.0);
649 double x = dis(*random_gen);
650
651 ICHECK(!prefix_sum_probs.empty());
652
653 return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) -
654 prefix_sum_probs.begin();
655}
656
657/*! \brief Print a title */
658inline void PrintTitle(const std::string& title, int verbose) {
659 StdCout(verbose) << Chars('-', 70) << "\n"
660 << Chars('-', 30) << " [ " << title << " ]\n"
661 << Chars('-', 70) << std::endl;
662}
663
664/*!
665 * \brief Enumerate all possible factorization schemes for splitting an axes.
666 * \note This class will memorize the results for reuse.
667 */
668class SplitFactorizationMemo {
669 public:
670 using QueryKey = std::tuple<int, int, int>;
671
672 const Array<Array<Integer>>& GetFactorizationSchemes(int extent, int n_lengths,
673 int max_innermost_factor);
674 const std::vector<int>& GetFactors(int n);
675
676 private:
677 void DfsEnumerate(int now, int remaining_length, int max_innermost_factor);
678
679 std::unordered_map<QueryKey, Array<Array<Integer>>> memory_;
680
681 int n_lengths_;
682 Array<Integer> tmp_stack_;
683 Array<Array<Integer>>* results_;
684 std::unordered_map<int, std::vector<int>> factor_memory_;
685};
686
687/*! \brief Get the indexes of SplitStep that processes on spatial iterator. */
688Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);
689
690/*! \brief Get the possible compute locations for a stage. */
691std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
692 const State& state, int stage_id);
693
694// Apply multi-level tiling structure according to a string format,
695// where "S" stands a space level, "R" stands for a reduction level.
696// For example, if the format is "SSRSRS", then we will
697// use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3
698// For example, if apply "SSRSRS" to matrix multiplication,
699// we have space iterators i and j, reduce iterator k.
700// Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3
701State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
702 std::vector<int>* spatial_split_step_ids = nullptr);
703
704// Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep
705State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
706 int n_split);
707
708// Prune invalid states and return the results in-place.
709void PruneInvalidState(const SearchTask& task, Array<State>* states);
710
711} // namespace auto_scheduler
712} // namespace tvm
713
714#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
715