1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/search_policy.cc
22 * \brief The base class of search policies.
23 */
24
25#include <tvm/auto_scheduler/measure_record.h>
26#include <tvm/auto_scheduler/search_policy.h>
27#include <tvm/runtime/registry.h>
28
29#include "utils.h"
30
31namespace tvm {
32namespace auto_scheduler {
33
34TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode);
35TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode);
36TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode);
37
38void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) {
39 RecordReader reader = RecordReader(log_file);
40 const auto& res = reader->ReadLines(-1);
41 size_t log_size = res.first.size();
42 ICHECK_EQ(log_size, res.second.size());
43 if (log_size) {
44 Array<State> measured_states;
45 std::vector<float> measured_throughputs;
46 for (size_t i = 0; i < log_size; i++) {
47 const auto& inp = res.first[i];
48 if (inp->task->workload_key == search_task->workload_key &&
49 inp->task->target->kind->name.compare(search_task->target->kind->name) == 0) {
50 State state = search_task->compute_dag->init_state;
51 auto pstate = state.CopyOnWrite();
52 pstate->transform_steps = inp->state->transform_steps;
53 for (const auto& step : pstate->transform_steps) {
54 StepApplyToState(step, &state, search_task->compute_dag);
55 }
56 measured_states.push_back(std::move(state));
57 measured_throughputs.push_back(
58 res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0);
59 }
60 }
61 // We can assume the recorded states will all be valid after infer bound
62 measured_states = search_task->compute_dag.InferBound(measured_states);
63 for (size_t i = 0; i < measured_states.size(); i++) {
64 auto& state = measured_states[i];
65 const auto& state_str = state.ToStr();
66 if (!measured_states_set_.count(state_str)) {
67 measured_states_set_.insert(state_str);
68 if (measured_throughputs[i] != 0.0) {
69 measured_states_vector_.emplace_back(std::move(state));
70 measured_states_throughputs_.emplace_back(measured_throughputs[i]);
71 }
72 }
73 }
74
75 StdCout(verbose) << "SearchPolicy: Loaded " << measured_states_set_.size()
76 << " measurement records from " << log_file << " for "
77 << search_task->workload_key << std::endl;
78 } else {
79 StdCout(verbose) << "SearchPolicy: No measurement records found in " << log_file << " for "
80 << search_task->workload_key << std::endl;
81 }
82}
83
84void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) {
85 for (const auto& callback : callbacks) {
86 callback->Callback(this);
87 }
88}
89
90PreloadMeasuredStates::PreloadMeasuredStates(String filename) {
91 auto node = make_object<PreloadMeasuredStatesNode>();
92 node->filename = std::move(filename);
93 data_ = std::move(node);
94}
95
96void PreloadMeasuredStatesNode::Callback(SearchPolicyNode* policy) {
97 policy->PreloadMeasuredStates(filename);
98}
99
100TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")
101 .set_body_typed([](SearchPolicy policy, Optional<Array<SearchCallback>> callbacks) {
102 if (callbacks) {
103 policy->RunCallbacks(callbacks.value());
104 }
105 });
106
107TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyContinueSearchOneRound")
108 .set_body_typed([](SearchPolicy policy, int num_measure, ProgramMeasurer measurer) {
109 auto [inputs, results] = policy->ContinueSearchOneRound(num_measure, measurer);
110 return Array<ObjectRef>{inputs, results};
111 });
112
113TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose")
114 .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; });
115
116TVM_REGISTER_GLOBAL("auto_scheduler.PreloadMeasuredStates").set_body_typed([](String filename) {
117 return PreloadMeasuredStates(filename);
118});
119
120} // namespace auto_scheduler
121} // namespace tvm
122