1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/empty_policy.cc |
22 | * \brief A simple example of the search policy which always returns the initial naive schedule |
23 | * (state). |
24 | */ |
25 | |
26 | #include "empty_policy.h" |
27 | |
28 | #include <tvm/auto_scheduler/measure.h> |
29 | #include <tvm/runtime/registry.h> |
30 | |
31 | #include <utility> |
32 | |
33 | #include "utils.h" |
34 | |
35 | namespace tvm { |
36 | namespace auto_scheduler { |
37 | |
38 | TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); |
39 | |
40 | EmptyPolicy::EmptyPolicy(SearchTask task, Optional<Array<SearchCallback>> init_search_callbacks) { |
41 | auto node = make_object<EmptyPolicyNode>(); |
42 | node->search_task = task; |
43 | |
44 | // Run init_search_callbacks before the search process |
45 | // This Interface is usually used to set some init status |
46 | if (init_search_callbacks) { |
47 | node->RunCallbacks(init_search_callbacks.value()); |
48 | } |
49 | |
50 | data_ = std::move(node); |
51 | } |
52 | |
53 | State EmptyPolicyNode::Search(int num_measure_trials, int early_stopping, |
54 | int num_measures_per_round, ProgramMeasurer measurer) { |
55 | // Basic design principe: `SearchOneRound()` several times to get candidate states, |
56 | // measure them and return the best one |
57 | // Measure is disabled if num_measure_trials <= 1 |
58 | if (num_measure_trials <= 1) { |
59 | const auto& res = SearchOneRound(); |
60 | ICHECK_GT(res.size(), 0); |
61 | |
62 | return res[0]; |
63 | } else { |
64 | Array<MeasureInput> inputs; |
65 | Array<MeasureResult> results; |
66 | |
67 | measurer->Reset(); |
68 | int ct = 0; |
69 | // In each round, we call SearchOneRound to get several candidate states, |
70 | // then use ProgramMeasurer to measure their performance. |
71 | while (ct < num_measure_trials) { |
72 | const auto& res = SearchOneRound(); |
73 | ct += res.size(); |
74 | // Build MeasureInputs for measuring |
75 | inputs.clear(); |
76 | for (const auto& state : res) { |
77 | inputs.push_back(MeasureInput(search_task, state)); |
78 | } |
79 | // Perform measurement. |
80 | // ProgramMeasurer will record the state with best performance during measure process |
81 | results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs); |
82 | } |
83 | |
84 | // Return a state with best measured performance |
85 | return measurer->best_state[search_task->workload_key]; |
86 | } |
87 | } |
88 | |
89 | std::pair<Array<MeasureInput>, Array<MeasureResult>> EmptyPolicyNode::ContinueSearchOneRound( |
90 | int num_measure, ProgramMeasurer measurer) { |
91 | Array<State> best_states; |
92 | Array<MeasureInput> inputs; |
93 | Array<MeasureResult> results; |
94 | |
95 | // Search one round to get promising states |
96 | PrintTitle("Search" , verbose); |
97 | best_states = SearchOneRound(); |
98 | |
99 | // Measure these states |
100 | PrintTitle("Measure" , verbose); |
101 | for (const auto& state : best_states) { |
102 | inputs.push_back(MeasureInput(search_task, state)); |
103 | } |
104 | results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs); |
105 | |
106 | return std::make_pair(std::move(inputs), std::move(results)); |
107 | } |
108 | |
109 | // As an example policy, EmptyPolicy always returns a init state |
110 | Array<State> EmptyPolicyNode::SearchOneRound() { |
111 | Array<State> res; |
112 | |
113 | // Simply return the initial naive schedule (state). |
114 | res.push_back(search_task->compute_dag->init_state); |
115 | |
116 | return res; |
117 | } |
118 | |
119 | TVM_REGISTER_GLOBAL("auto_scheduler.EmptyPolicy" ) |
120 | .set_body_typed([](SearchTask task, Optional<Array<SearchCallback>> init_search_callbacks) { |
121 | return EmptyPolicy(task, init_search_callbacks); |
122 | }); |
123 | |
124 | } // namespace auto_scheduler |
125 | } // namespace tvm |
126 | |