1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief The gradient based task scheduler. */
25class GradientBasedNode final : public TaskSchedulerNode {
26 public:
27 double alpha;
28 int window_size;
29 support::LinearCongruentialEngine::TRandState rand_state;
30
31 int round_robin_rounds_;
32 std::vector<std::vector<double>> best_latency_history_;
33
34 void VisitAttrs(tvm::AttrVisitor* v) {
35 TaskSchedulerNode::VisitAttrs(v);
36 v->Visit("alpha", &alpha);
37 v->Visit("window_size", &window_size);
38 // `rand_state` is not visited.
39 // `num_rounds_already_` is not visited.
40 // `best_latency_history_` is not visited.
41 }
42
43 static constexpr const char* _type_key = "meta_schedule.GradientBased";
44 TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode);
45
46 public:
47 void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global,
48 int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
49 Array<MeasureCallback> measure_callbacks, Optional<Database> database,
50 Optional<CostModel> cost_model) final {
51 int n_tasks = tasks.size();
52 round_robin_rounds_ = 0;
53 best_latency_history_.resize(n_tasks, std::vector<double>());
54 TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task,
55 num_trials_per_iter, builder, runner, measure_callbacks, database,
56 cost_model);
57 }
58
59 int NextTaskId() final {
60 int n_tasks = this->tasks_.size();
61 // Step 1. Check if it's in round robin mode.
62 if (round_robin_rounds_ == 0) {
63 TVM_PY_LOG_CLEAR_SCREEN(this->logger);
64 this->PrintTuningStatistics();
65 }
66 if (round_robin_rounds_ < n_tasks) {
67 return round_robin_rounds_++;
68 }
69 if (round_robin_rounds_ == n_tasks) {
70 for (int i = 0; i < n_tasks; ++i) {
71 if (this->tasks_[i]->runner_futures.defined()) {
72 this->JoinRunningTask(i);
73 }
74 }
75 ++round_robin_rounds_;
76 }
77 // Step 2. Collect the tasks that are not terminated yet
78 std::vector<int> tasks_alive;
79 {
80 tasks_alive.reserve(n_tasks);
81 for (int i = 0; i < n_tasks; ++i) {
82 this->TouchTask(i);
83 if (!this->tasks_[i]->is_terminated) {
84 tasks_alive.push_back(i);
85 }
86 }
87 if (tasks_alive.empty()) {
88 return -1;
89 }
90 }
91 // Step 3. Calculate the gradient of each task alive
92 std::vector<double> grad;
93 grad.reserve(n_tasks);
94 for (int task_id : tasks_alive) {
95 const std::vector<double>& best_latency = this->best_latency_history_.at(task_id);
96 int n = best_latency.size();
97 double task_weight = this->tasks_[task_id]->task_weight;
98 int w = this->window_size;
99 if (n > 0 && best_latency[n - 1] < 1e9) {
100 double best = best_latency[n - 1];
101 double g1 = (n >= 1 + w) ? (best_latency[n - 1 - w] - best) / w : 0.0;
102 double g2 = best / n;
103 double g = alpha * g1 + (1 - alpha) * g2;
104 grad.push_back(g * task_weight);
105 } else {
106 // If the best time cost is unavailable, it means some task is not valid. Skip it.
107 grad.push_back(-1e9);
108 }
109 }
110 // Step 4. Select the task with the largest gradient
111 auto max_grad = std::max_element(grad.begin(), grad.end());
112 auto min_grad = std::min_element(grad.begin(), grad.end());
113 int task_id = -1;
114 if (*max_grad == *min_grad) {
115 task_id = tasks_alive[tir::SampleInt(&this->rand_state, 0, tasks_alive.size())];
116 } else {
117 task_id = tasks_alive[std::distance(grad.begin(), max_grad)];
118 }
119 if (this->tasks_[task_id]->runner_futures.defined()) {
120 JoinRunningTask(task_id);
121 }
122 return task_id;
123 }
124
125 Array<RunnerResult> JoinRunningTask(int task_id) final {
126 Array<RunnerResult> results = TaskSchedulerNode::JoinRunningTask(task_id);
127 TaskRecordNode* task = this->tasks_[task_id].get();
128 if (task->latency_ms.size() > 0) {
129 this->best_latency_history_.at(task_id).push_back(
130 *std::min_element(task->latency_ms.begin(), //
131 task->latency_ms.end()));
132 }
133 return results;
134 }
135};
136
137TaskScheduler TaskScheduler::GradientBased(PackedFunc logger, double alpha, int window_size,
138 support::LinearCongruentialEngine::TRandState seed) {
139 ObjectPtr<GradientBasedNode> n = make_object<GradientBasedNode>();
140 n->logger = logger;
141 n->alpha = alpha;
142 n->window_size = window_size;
143 n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(seed);
144 return TaskScheduler(n);
145}
146
147TVM_REGISTER_NODE_TYPE(GradientBasedNode);
148TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased")
149 .set_body_typed(TaskScheduler::GradientBased);
150
151} // namespace meta_schedule
152} // namespace tvm
153