1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/empty_policy.cc
22 * \brief A simple example of the search policy which always returns the initial naive schedule
23 * (state).
24 */
25
26#include "empty_policy.h"
27
28#include <tvm/auto_scheduler/measure.h>
29#include <tvm/runtime/registry.h>
30
31#include <utility>
32
33#include "utils.h"
34
35namespace tvm {
36namespace auto_scheduler {
37
38TVM_REGISTER_NODE_TYPE(EmptyPolicyNode);
39
40EmptyPolicy::EmptyPolicy(SearchTask task, Optional<Array<SearchCallback>> init_search_callbacks) {
41 auto node = make_object<EmptyPolicyNode>();
42 node->search_task = task;
43
44 // Run init_search_callbacks before the search process
45 // This Interface is usually used to set some init status
46 if (init_search_callbacks) {
47 node->RunCallbacks(init_search_callbacks.value());
48 }
49
50 data_ = std::move(node);
51}
52
53State EmptyPolicyNode::Search(int num_measure_trials, int early_stopping,
54 int num_measures_per_round, ProgramMeasurer measurer) {
55 // Basic design principe: `SearchOneRound()` several times to get candidate states,
56 // measure them and return the best one
57 // Measure is disabled if num_measure_trials <= 1
58 if (num_measure_trials <= 1) {
59 const auto& res = SearchOneRound();
60 ICHECK_GT(res.size(), 0);
61
62 return res[0];
63 } else {
64 Array<MeasureInput> inputs;
65 Array<MeasureResult> results;
66
67 measurer->Reset();
68 int ct = 0;
69 // In each round, we call SearchOneRound to get several candidate states,
70 // then use ProgramMeasurer to measure their performance.
71 while (ct < num_measure_trials) {
72 const auto& res = SearchOneRound();
73 ct += res.size();
74 // Build MeasureInputs for measuring
75 inputs.clear();
76 for (const auto& state : res) {
77 inputs.push_back(MeasureInput(search_task, state));
78 }
79 // Perform measurement.
80 // ProgramMeasurer will record the state with best performance during measure process
81 results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
82 }
83
84 // Return a state with best measured performance
85 return measurer->best_state[search_task->workload_key];
86 }
87}
88
89std::pair<Array<MeasureInput>, Array<MeasureResult>> EmptyPolicyNode::ContinueSearchOneRound(
90 int num_measure, ProgramMeasurer measurer) {
91 Array<State> best_states;
92 Array<MeasureInput> inputs;
93 Array<MeasureResult> results;
94
95 // Search one round to get promising states
96 PrintTitle("Search", verbose);
97 best_states = SearchOneRound();
98
99 // Measure these states
100 PrintTitle("Measure", verbose);
101 for (const auto& state : best_states) {
102 inputs.push_back(MeasureInput(search_task, state));
103 }
104 results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
105
106 return std::make_pair(std::move(inputs), std::move(results));
107}
108
109// As an example policy, EmptyPolicy always returns a init state
110Array<State> EmptyPolicyNode::SearchOneRound() {
111 Array<State> res;
112
113 // Simply return the initial naive schedule (state).
114 res.push_back(search_task->compute_dag->init_state);
115
116 return res;
117}
118
119TVM_REGISTER_GLOBAL("auto_scheduler.EmptyPolicy")
120 .set_body_typed([](SearchTask task, Optional<Array<SearchCallback>> init_search_callbacks) {
121 return EmptyPolicy(task, init_search_callbacks);
122 });
123
124} // namespace auto_scheduler
125} // namespace tvm
126