1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | /*! \brief The union of design space generators. */ |
25 | class ScheduleFnNode : public SpaceGeneratorNode { |
26 | public: |
27 | /*! \brief The random state. -1 means using random number. */ |
28 | TRandState rand_state_ = -1; |
29 | /*! \brief The schedule function. */ |
30 | runtime::PackedFunc schedule_fn_; |
31 | |
32 | void VisitAttrs(tvm::AttrVisitor* v) { |
33 | SpaceGeneratorNode::VisitAttrs(v); |
34 | // `schedule_fn_` is not visited. |
35 | } |
36 | |
37 | void InitializeWithTuneContext(const TuneContext& context) final { |
38 | SpaceGeneratorNode::InitializeWithTuneContext(context); |
39 | this->rand_state_ = ForkSeed(&context->rand_state); |
40 | } |
41 | |
42 | Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final { |
43 | tir::Schedule sch = tir::Schedule::Traced( |
44 | /*mod=*/mod, |
45 | /*rand_state=*/ForkSeed(&this->rand_state_), |
46 | /*debug_mode=*/0, |
47 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); |
48 | runtime::TVMRetValue rv; |
49 | rv = this->schedule_fn_(sch); |
50 | if (rv.type_code() == kTVMNullptr) { |
51 | return {sch}; |
52 | } |
53 | ObjectRef obj = rv; |
54 | if (const auto* sch = obj.as<tir::ScheduleNode>()) { |
55 | return {GetRef<tir::Schedule>(sch)}; |
56 | } |
57 | if (const auto* arr = obj.as<runtime::ArrayNode>()) { |
58 | Array<tir::Schedule> result; |
59 | result.reserve(arr->size()); |
60 | for (const ObjectRef& obj : *arr) { |
61 | if (const auto* sch = obj.as<tir::ScheduleNode>()) { |
62 | result.push_back(GetRef<tir::Schedule>(sch)); |
63 | } else { |
64 | LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " |
65 | "List[Schedule], but got: " |
66 | << obj->GetTypeKey(); |
67 | } |
68 | } |
69 | return result; |
70 | } |
71 | LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " |
72 | "List[Schedule], but got: " |
73 | << obj->GetTypeKey(); |
74 | throw; |
75 | } |
76 | |
77 | SpaceGenerator Clone() const final { |
78 | ObjectPtr<ScheduleFnNode> n = make_object<ScheduleFnNode>(*this); |
79 | CloneRules(this, n.get()); |
80 | return SpaceGenerator(n); |
81 | } |
82 | |
83 | static constexpr const char* _type_key = "meta_schedule.ScheduleFn" ; |
84 | TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); |
85 | }; |
86 | |
87 | SpaceGenerator SpaceGenerator::ScheduleFn(PackedFunc schedule_fn, |
88 | Optional<Array<ScheduleRule>> sch_rules, |
89 | Optional<Array<Postproc>> postprocs, |
90 | Optional<Map<Mutator, FloatImm>> mutator_probs) { |
91 | ObjectPtr<ScheduleFnNode> n = make_object<ScheduleFnNode>(); |
92 | n->sch_rules = std::move(sch_rules); |
93 | n->postprocs = std::move(postprocs); |
94 | n->mutator_probs = std::move(mutator_probs); |
95 | n->schedule_fn_ = std::move(schedule_fn); |
96 | return SpaceGenerator(n); |
97 | } |
98 | |
99 | TVM_REGISTER_NODE_TYPE(ScheduleFnNode); |
100 | TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn" ) |
101 | .set_body_typed(SpaceGenerator::ScheduleFn); |
102 | |
103 | } // namespace meta_schedule |
104 | } // namespace tvm |
105 | |