1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/cost_model.cc
22 * \brief Cost models that estimate the performance of programs
23 */
24
25#include <tvm/auto_scheduler/cost_model.h>
26
27namespace tvm {
28namespace auto_scheduler {
29
30TVM_REGISTER_OBJECT_TYPE(CostModelNode);
31TVM_REGISTER_OBJECT_TYPE(RandomModelNode);
32TVM_REGISTER_OBJECT_TYPE(PythonBasedModelNode);
33
34RandomModel::RandomModel() {
35 ObjectPtr<RandomModelNode> node = make_object<RandomModelNode>();
36 const auto* f = runtime::Registry::Get("auto_scheduler.cost_model.random_fill_float");
37 ICHECK(f != nullptr);
38 node->random_number_func = reinterpret_cast<const TypedPackedFunc<void(size_t, void*)>*>(f);
39 data_ = std::move(node);
40}
41
42void RandomModelNode::Update(const Array<MeasureInput>& inputs,
43 const Array<MeasureResult>& results) {}
44
45void RandomModelNode::Predict(const SearchTask& task, const Array<State>& states,
46 std::vector<float>* scores) {
47 scores->resize(states.size());
48 (*random_number_func)(states.size(), static_cast<void*>(scores->data()));
49}
50
51PythonBasedModel::PythonBasedModel(PackedFunc update_func, PackedFunc predict_func,
52 PackedFunc predict_stage_func) {
53 auto node = make_object<PythonBasedModelNode>();
54 node->update_func = std::move(update_func);
55 node->predict_func = std::move(predict_func);
56 node->predict_stage_func = std::move(predict_stage_func);
57 data_ = std::move(node);
58}
59
60void PythonBasedModelNode::Update(const Array<MeasureInput>& inputs,
61 const Array<MeasureResult>& results) {
62 update_func(inputs, results);
63}
64
65void PythonBasedModelNode::Predict(const SearchTask& task, const Array<State>& states,
66 std::vector<float>* scores) {
67 scores->resize(states.size());
68 predict_func(task, states, static_cast<void*>(scores->data()));
69}
70
71void PythonBasedModelNode::PredictStages(const SearchTask& task, const Array<State>& states,
72 std::vector<float>* state_scores,
73 std::vector<std::vector<float>>* stage_scores) {
74 size_t n_states = states.size();
75 size_t n_stages = task->compute_dag->init_state->stages.size();
76 std::vector<float> flatten_scores;
77 // Allocate sufficient spaces.
78 flatten_scores.resize(n_states * n_stages * 2);
79 predict_stage_func(task, states, static_cast<void*>(flatten_scores.data()));
80
81 /* For faster data copy between c++ and python, the python part returns scores in a
82 * single flatten array using a packed format. The c++ part then unpacks the flatten array.
83 *
84 * The packed format is:
85 * {
86 * float scores[N]; // scores[i] is the score for states[i].
87 * int n_stage_0; // the number of stages in states[0]
88 * float stage_scores_0[[n_stage_0] // the scores for all stages in states[0]
89 * int n_stage_1; // the number of stages in states[1]
90 * float stage_scores_1[n_stage_1]; // the scores for all stages in states[1]
91 * ...
92 * int n_stage_i; // the number of stages in states[i]
93 * float stage_scores_1[n_stage_i]; // the scores for all stages in states[i]
94 * ... // until i == N - 1
95 * }
96 * To implement this format, we also store int as float, so we can store all numbers
97 * into a single float array.
98 */
99
100 // Unpack flatten scores.
101 state_scores->clear();
102 stage_scores->clear();
103
104 // Score of each states.
105 for (size_t i = 0; i < n_states; ++i) {
106 state_scores->push_back(flatten_scores[i]);
107 }
108
109 // Score of each stage in each states.
110 size_t idx = n_states;
111 for (size_t i = 0; i < n_states; ++i) {
112 ICHECK_LE(idx, flatten_scores.size());
113
114 // Number of scored stages of this state.
115 int s_length = static_cast<int>(flatten_scores[idx++]);
116
117 if (s_length > 0) {
118 std::vector<float> scores;
119 int offset = 0;
120
121 if ((*state_scores)[i] > -INFINITY) {
122 // If the score is valid. Copy scored stages and assign 0 to placeholder
123 // and inlined stages. If the score is 0, meaning this state failed to
124 // be lowered. Just bypass to update offset.
125 for (const Stage& stage : states[i]->stages) {
126 if (stage->op_type == StageKind::kPlaceholder) {
127 scores.push_back(0);
128 continue;
129 }
130 if (stage->compute_at == ComputeAtKind::kInlined) {
131 scores.push_back(0);
132 continue;
133 }
134 scores.push_back(flatten_scores[idx + offset]);
135 offset++;
136 }
137 ICHECK_EQ(offset, s_length);
138 stage_scores->push_back(std::move(scores));
139 }
140 idx += s_length;
141 } else {
142 // Cost model does not provide any stage score details.
143 stage_scores->push_back({});
144 }
145 }
146}
147
148TVM_REGISTER_GLOBAL("auto_scheduler.RandomModel").set_body_typed([]() { return RandomModel(); });
149
150TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedModel")
151 .set_body_typed([](PackedFunc update_func, PackedFunc predict_func,
152 PackedFunc predict_stage_func) {
153 return PythonBasedModel(update_func, predict_func, predict_stage_func);
154 });
155
156TVM_REGISTER_GLOBAL("auto_scheduler.CostModelUpdate")
157 .set_body_typed([](CostModel model, Array<MeasureInput> inputs, Array<MeasureResult> results) {
158 model->Update(inputs, results);
159 });
160
161TVM_REGISTER_GLOBAL("auto_scheduler.CostModelPredict")
162 .set_body_typed([](CostModel model, SearchTask task, Array<State> states) {
163 std::vector<float> scores;
164 model->Predict(task, states, &scores);
165 Array<FloatImm> ret;
166 for (auto x : scores) {
167 ret.push_back(FloatImm(DataType::Float(32), x));
168 }
169 return ret;
170 });
171
172} // namespace auto_scheduler
173} // namespace tvm
174