1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | void PyCostModelNode::Load(const String& path) { |
25 | ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!" ; |
26 | f_load(path); |
27 | } |
28 | |
29 | void PyCostModelNode::Save(const String& path) { |
30 | ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!" ; |
31 | f_save(path); |
32 | } |
33 | |
34 | void PyCostModelNode::Update(const TuneContext& context, const Array<MeasureCandidate>& candidates, |
35 | const Array<RunnerResult>& results) { |
36 | ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!" ; |
37 | f_update(context, candidates, results); |
38 | } |
39 | |
40 | std::vector<double> PyCostModelNode::Predict(const TuneContext& context, |
41 | const Array<MeasureCandidate>& candidates) { |
42 | ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!" ; |
43 | std::vector<double> result(candidates.size(), 0.0); |
44 | f_predict(context, candidates, result.data()); |
45 | return result; |
46 | } |
47 | |
48 | CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // |
49 | PyCostModelNode::FSave f_save, // |
50 | PyCostModelNode::FUpdate f_update, // |
51 | PyCostModelNode::FPredict f_predict, // |
52 | PyCostModelNode::FAsString f_as_string) { |
53 | ObjectPtr<PyCostModelNode> n = make_object<PyCostModelNode>(); |
54 | n->f_load = std::move(f_load); |
55 | n->f_save = std::move(f_save); |
56 | n->f_update = std::move(f_update); |
57 | n->f_predict = std::move(f_predict); |
58 | n->f_as_string = std::move(f_as_string); |
59 | return CostModel(n); |
60 | } |
61 | |
62 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
63 | .set_dispatch<PyCostModelNode>([](const ObjectRef& n, ReprPrinter* p) { |
64 | const auto* self = n.as<PyCostModelNode>(); |
65 | ICHECK(self); |
66 | PyCostModelNode::FAsString f_as_string = (*self).f_as_string; |
67 | ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!" ; |
68 | p->stream << f_as_string(); |
69 | }); |
70 | |
71 | TVM_REGISTER_OBJECT_TYPE(CostModelNode); |
72 | TVM_REGISTER_NODE_TYPE(PyCostModelNode); |
73 | |
74 | TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad" ).set_body_method<CostModel>(&CostModelNode::Load); |
75 | TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave" ).set_body_method<CostModel>(&CostModelNode::Save); |
76 | TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate" ) |
77 | .set_body_method<CostModel>(&CostModelNode::Update); |
78 | TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict" ) |
79 | .set_body_typed([](CostModel model, // |
80 | const TuneContext& context, // |
81 | Array<MeasureCandidate> candidates, // |
82 | void* p_addr) -> void { |
83 | std::vector<double> result = model->Predict(context, candidates); |
84 | std::copy(result.begin(), result.end(), static_cast<double*>(p_addr)); |
85 | }); |
86 | TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel" ).set_body_typed(CostModel::PyCostModel); |
87 | |
88 | } // namespace meta_schedule |
89 | } // namespace tvm |
90 | |