1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | /*! \brief The gradient based task scheduler. */ |
25 | class GradientBasedNode final : public TaskSchedulerNode { |
26 | public: |
27 | double alpha; |
28 | int window_size; |
29 | support::LinearCongruentialEngine::TRandState rand_state; |
30 | |
31 | int round_robin_rounds_; |
32 | std::vector<std::vector<double>> best_latency_history_; |
33 | |
34 | void VisitAttrs(tvm::AttrVisitor* v) { |
35 | TaskSchedulerNode::VisitAttrs(v); |
36 | v->Visit("alpha" , &alpha); |
37 | v->Visit("window_size" , &window_size); |
38 | // `rand_state` is not visited. |
39 | // `num_rounds_already_` is not visited. |
40 | // `best_latency_history_` is not visited. |
41 | } |
42 | |
43 | static constexpr const char* _type_key = "meta_schedule.GradientBased" ; |
44 | TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); |
45 | |
46 | public: |
47 | void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global, |
48 | int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, |
49 | Array<MeasureCallback> measure_callbacks, Optional<Database> database, |
50 | Optional<CostModel> cost_model) final { |
51 | int n_tasks = tasks.size(); |
52 | round_robin_rounds_ = 0; |
53 | best_latency_history_.resize(n_tasks, std::vector<double>()); |
54 | TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, |
55 | num_trials_per_iter, builder, runner, measure_callbacks, database, |
56 | cost_model); |
57 | } |
58 | |
59 | int NextTaskId() final { |
60 | int n_tasks = this->tasks_.size(); |
61 | // Step 1. Check if it's in round robin mode. |
62 | if (round_robin_rounds_ == 0) { |
63 | TVM_PY_LOG_CLEAR_SCREEN(this->logger); |
64 | this->PrintTuningStatistics(); |
65 | } |
66 | if (round_robin_rounds_ < n_tasks) { |
67 | return round_robin_rounds_++; |
68 | } |
69 | if (round_robin_rounds_ == n_tasks) { |
70 | for (int i = 0; i < n_tasks; ++i) { |
71 | if (this->tasks_[i]->runner_futures.defined()) { |
72 | this->JoinRunningTask(i); |
73 | } |
74 | } |
75 | ++round_robin_rounds_; |
76 | } |
77 | // Step 2. Collect the tasks that are not terminated yet |
78 | std::vector<int> tasks_alive; |
79 | { |
80 | tasks_alive.reserve(n_tasks); |
81 | for (int i = 0; i < n_tasks; ++i) { |
82 | this->TouchTask(i); |
83 | if (!this->tasks_[i]->is_terminated) { |
84 | tasks_alive.push_back(i); |
85 | } |
86 | } |
87 | if (tasks_alive.empty()) { |
88 | return -1; |
89 | } |
90 | } |
91 | // Step 3. Calculate the gradient of each task alive |
92 | std::vector<double> grad; |
93 | grad.reserve(n_tasks); |
94 | for (int task_id : tasks_alive) { |
95 | const std::vector<double>& best_latency = this->best_latency_history_.at(task_id); |
96 | int n = best_latency.size(); |
97 | double task_weight = this->tasks_[task_id]->task_weight; |
98 | int w = this->window_size; |
99 | if (n > 0 && best_latency[n - 1] < 1e9) { |
100 | double best = best_latency[n - 1]; |
101 | double g1 = (n >= 1 + w) ? (best_latency[n - 1 - w] - best) / w : 0.0; |
102 | double g2 = best / n; |
103 | double g = alpha * g1 + (1 - alpha) * g2; |
104 | grad.push_back(g * task_weight); |
105 | } else { |
106 | // If the best time cost is unavailable, it means some task is not valid. Skip it. |
107 | grad.push_back(-1e9); |
108 | } |
109 | } |
110 | // Step 4. Select the task with the largest gradient |
111 | auto max_grad = std::max_element(grad.begin(), grad.end()); |
112 | auto min_grad = std::min_element(grad.begin(), grad.end()); |
113 | int task_id = -1; |
114 | if (*max_grad == *min_grad) { |
115 | task_id = tasks_alive[tir::SampleInt(&this->rand_state, 0, tasks_alive.size())]; |
116 | } else { |
117 | task_id = tasks_alive[std::distance(grad.begin(), max_grad)]; |
118 | } |
119 | if (this->tasks_[task_id]->runner_futures.defined()) { |
120 | JoinRunningTask(task_id); |
121 | } |
122 | return task_id; |
123 | } |
124 | |
125 | Array<RunnerResult> JoinRunningTask(int task_id) final { |
126 | Array<RunnerResult> results = TaskSchedulerNode::JoinRunningTask(task_id); |
127 | TaskRecordNode* task = this->tasks_[task_id].get(); |
128 | if (task->latency_ms.size() > 0) { |
129 | this->best_latency_history_.at(task_id).push_back( |
130 | *std::min_element(task->latency_ms.begin(), // |
131 | task->latency_ms.end())); |
132 | } |
133 | return results; |
134 | } |
135 | }; |
136 | |
137 | TaskScheduler TaskScheduler::GradientBased(PackedFunc logger, double alpha, int window_size, |
138 | support::LinearCongruentialEngine::TRandState seed) { |
139 | ObjectPtr<GradientBasedNode> n = make_object<GradientBasedNode>(); |
140 | n->logger = logger; |
141 | n->alpha = alpha; |
142 | n->window_size = window_size; |
143 | n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(seed); |
144 | return TaskScheduler(n); |
145 | } |
146 | |
147 | TVM_REGISTER_NODE_TYPE(GradientBasedNode); |
148 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased" ) |
149 | .set_body_typed(TaskScheduler::GradientBased); |
150 | |
151 | } // namespace meta_schedule |
152 | } // namespace tvm |
153 | |