1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/search_policy.cc |
22 | * \brief The base class of search policies. |
23 | */ |
24 | |
25 | #include <tvm/auto_scheduler/measure_record.h> |
26 | #include <tvm/auto_scheduler/search_policy.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | #include "utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace auto_scheduler { |
33 | |
34 | TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); |
35 | TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); |
36 | TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode); |
37 | |
38 | void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) { |
39 | RecordReader reader = RecordReader(log_file); |
40 | const auto& res = reader->ReadLines(-1); |
41 | size_t log_size = res.first.size(); |
42 | ICHECK_EQ(log_size, res.second.size()); |
43 | if (log_size) { |
44 | Array<State> measured_states; |
45 | std::vector<float> measured_throughputs; |
46 | for (size_t i = 0; i < log_size; i++) { |
47 | const auto& inp = res.first[i]; |
48 | if (inp->task->workload_key == search_task->workload_key && |
49 | inp->task->target->kind->name.compare(search_task->target->kind->name) == 0) { |
50 | State state = search_task->compute_dag->init_state; |
51 | auto pstate = state.CopyOnWrite(); |
52 | pstate->transform_steps = inp->state->transform_steps; |
53 | for (const auto& step : pstate->transform_steps) { |
54 | StepApplyToState(step, &state, search_task->compute_dag); |
55 | } |
56 | measured_states.push_back(std::move(state)); |
57 | measured_throughputs.push_back( |
58 | res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); |
59 | } |
60 | } |
61 | // We can assume the recorded states will all be valid after infer bound |
62 | measured_states = search_task->compute_dag.InferBound(measured_states); |
63 | for (size_t i = 0; i < measured_states.size(); i++) { |
64 | auto& state = measured_states[i]; |
65 | const auto& state_str = state.ToStr(); |
66 | if (!measured_states_set_.count(state_str)) { |
67 | measured_states_set_.insert(state_str); |
68 | if (measured_throughputs[i] != 0.0) { |
69 | measured_states_vector_.emplace_back(std::move(state)); |
70 | measured_states_throughputs_.emplace_back(measured_throughputs[i]); |
71 | } |
72 | } |
73 | } |
74 | |
75 | StdCout(verbose) << "SearchPolicy: Loaded " << measured_states_set_.size() |
76 | << " measurement records from " << log_file << " for " |
77 | << search_task->workload_key << std::endl; |
78 | } else { |
79 | StdCout(verbose) << "SearchPolicy: No measurement records found in " << log_file << " for " |
80 | << search_task->workload_key << std::endl; |
81 | } |
82 | } |
83 | |
84 | void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) { |
85 | for (const auto& callback : callbacks) { |
86 | callback->Callback(this); |
87 | } |
88 | } |
89 | |
90 | PreloadMeasuredStates::PreloadMeasuredStates(String filename) { |
91 | auto node = make_object<PreloadMeasuredStatesNode>(); |
92 | node->filename = std::move(filename); |
93 | data_ = std::move(node); |
94 | } |
95 | |
96 | void PreloadMeasuredStatesNode::Callback(SearchPolicyNode* policy) { |
97 | policy->PreloadMeasuredStates(filename); |
98 | } |
99 | |
100 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks" ) |
101 | .set_body_typed([](SearchPolicy policy, Optional<Array<SearchCallback>> callbacks) { |
102 | if (callbacks) { |
103 | policy->RunCallbacks(callbacks.value()); |
104 | } |
105 | }); |
106 | |
107 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyContinueSearchOneRound" ) |
108 | .set_body_typed([](SearchPolicy policy, int num_measure, ProgramMeasurer measurer) { |
109 | auto [inputs, results] = policy->ContinueSearchOneRound(num_measure, measurer); |
110 | return Array<ObjectRef>{inputs, results}; |
111 | }); |
112 | |
113 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose" ) |
114 | .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; }); |
115 | |
116 | TVM_REGISTER_GLOBAL("auto_scheduler.PreloadMeasuredStates" ).set_body_typed([](String filename) { |
117 | return PreloadMeasuredStates(filename); |
118 | }); |
119 | |
120 | } // namespace auto_scheduler |
121 | } // namespace tvm |
122 | |