1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24class ApplyCustomRuleNode : public ScheduleRuleNode {
25 public:
26 // Inherited from ScheduleRuleNode
27 void InitializeWithTuneContext(const TuneContext& context) final {
28 CHECK(context->target.defined()) << "ValueError: Target is not defined in the tune context.";
29 this->target_ = context->target;
30 }
31
32 static std::string GetCustomRuleName(const std::string& name, const std::string& key) {
33 return "meta_schedule." + key + "." + name;
34 }
35
36 // Inherited from ScheduleRuleNode
37 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
38 CHECK(this->target_.defined())
39 << "ValueError: ApplyCustomRule is not initialized with TuneContext that has a Target.";
40 Array<String> keys = this->target_.value()->keys;
41 if (Optional<String> ann = tir::GetAnn<String>(sch->GetSRef(block_rv), "schedule_rule")) {
42 if (ann.value() != "None") {
43 for (const String& key : keys) {
44 if (const runtime::PackedFunc* custom_schedule_fn =
45 runtime::Registry::Get(GetCustomRuleName(ann.value(), key))) {
46 Array<tir::Schedule> result = ((*custom_schedule_fn)(sch, block_rv));
47 return result;
48 }
49 }
50 std::ostringstream os;
51 os << "Unknown schedule rule \"" << ann.value() << "\" for target keys \"" << keys
52 << "\". Checked PackedFuncs:";
53 for (const String& key : keys) {
54 os << "\n " << GetCustomRuleName(ann.value(), key);
55 }
56 LOG(WARNING) << os.str();
57 sch->Unannotate(block_rv, "schedule_rule");
58 }
59 }
60 return {sch};
61 }
62
63 // Inherited from ScheduleRuleNode
64 ScheduleRule Clone() const final {
65 ObjectPtr<ApplyCustomRuleNode> n = make_object<ApplyCustomRuleNode>(*this);
66 n->target_ = target_;
67 return ScheduleRule(n);
68 }
69
70 public:
71 Optional<Target> target_ = NullOpt;
72
73 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target_", &target_); }
74
75 static constexpr const char* _type_key = "meta_schedule.ApplyCustomRule";
76 TVM_DECLARE_FINAL_OBJECT_INFO(ApplyCustomRuleNode, ScheduleRuleNode);
77};
78
79ScheduleRule ScheduleRule::ApplyCustomRule() {
80 ObjectPtr<ApplyCustomRuleNode> n = make_object<ApplyCustomRuleNode>();
81 return ScheduleRule(n);
82}
83
84bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) {
85 return rule->IsInstance<ApplyCustomRuleNode>();
86}
87
88TVM_REGISTER_NODE_TYPE(ApplyCustomRuleNode);
89TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule")
90 .set_body_typed(ScheduleRule::ApplyCustomRule);
91
92} // namespace meta_schedule
93} // namespace tvm
94