1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24void PyCostModelNode::Load(const String& path) {
25 ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!";
26 f_load(path);
27}
28
29void PyCostModelNode::Save(const String& path) {
30 ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!";
31 f_save(path);
32}
33
34void PyCostModelNode::Update(const TuneContext& context, const Array<MeasureCandidate>& candidates,
35 const Array<RunnerResult>& results) {
36 ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!";
37 f_update(context, candidates, results);
38}
39
40std::vector<double> PyCostModelNode::Predict(const TuneContext& context,
41 const Array<MeasureCandidate>& candidates) {
42 ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!";
43 std::vector<double> result(candidates.size(), 0.0);
44 f_predict(context, candidates, result.data());
45 return result;
46}
47
48CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, //
49 PyCostModelNode::FSave f_save, //
50 PyCostModelNode::FUpdate f_update, //
51 PyCostModelNode::FPredict f_predict, //
52 PyCostModelNode::FAsString f_as_string) {
53 ObjectPtr<PyCostModelNode> n = make_object<PyCostModelNode>();
54 n->f_load = std::move(f_load);
55 n->f_save = std::move(f_save);
56 n->f_update = std::move(f_update);
57 n->f_predict = std::move(f_predict);
58 n->f_as_string = std::move(f_as_string);
59 return CostModel(n);
60}
61
62TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
63 .set_dispatch<PyCostModelNode>([](const ObjectRef& n, ReprPrinter* p) {
64 const auto* self = n.as<PyCostModelNode>();
65 ICHECK(self);
66 PyCostModelNode::FAsString f_as_string = (*self).f_as_string;
67 ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!";
68 p->stream << f_as_string();
69 });
70
71TVM_REGISTER_OBJECT_TYPE(CostModelNode);
72TVM_REGISTER_NODE_TYPE(PyCostModelNode);
73
74TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method<CostModel>(&CostModelNode::Load);
75TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method<CostModel>(&CostModelNode::Save);
76TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate")
77 .set_body_method<CostModel>(&CostModelNode::Update);
78TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict")
79 .set_body_typed([](CostModel model, //
80 const TuneContext& context, //
81 Array<MeasureCandidate> candidates, //
82 void* p_addr) -> void {
83 std::vector<double> result = model->Predict(context, candidates);
84 std::copy(result.begin(), result.end(), static_cast<double*>(p_addr));
85 });
86TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel);
87
88} // namespace meta_schedule
89} // namespace tvm
90