1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief The round-robin style task scheduler. */
25class RoundRobinNode final : public TaskSchedulerNode {
26 public:
27 /*! \brief The current task id processed. */
28 int task_id = -1;
29
30 void VisitAttrs(tvm::AttrVisitor* v) {
31 TaskSchedulerNode::VisitAttrs(v);
32 v->Visit("task_id", &task_id);
33 }
34
35 static constexpr const char* _type_key = "meta_schedule.RoundRobin";
36 TVM_DECLARE_FINAL_OBJECT_INFO(RoundRobinNode, TaskSchedulerNode);
37
38 protected:
39 int NextTaskId() final {
40 int n_tasks = this->tasks_.size();
41 for (int i = 0; i < n_tasks; ++i) {
42 this->TouchTask(i);
43 }
44 for (int i = 0; i < n_tasks; ++i) {
45 task_id = (task_id + 1) % n_tasks;
46 TaskRecordNode* task = this->tasks_[task_id].get();
47 if (!task->is_terminated) {
48 if (task->runner_futures.defined()) {
49 JoinRunningTask(task_id);
50 }
51 return task_id;
52 }
53 }
54 return -1;
55 }
56};
57
58TaskScheduler TaskScheduler::RoundRobin(PackedFunc logger) {
59 ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
60 n->logger = logger;
61 n->task_id = -1;
62 return TaskScheduler(n);
63}
64
65TVM_REGISTER_NODE_TYPE(RoundRobinNode);
66TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin")
67 .set_body_typed(TaskScheduler::RoundRobin);
68
69} // namespace meta_schedule
70} // namespace tvm
71