1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief The union of design space generators. */
25class ScheduleFnNode : public SpaceGeneratorNode {
26 public:
27 /*! \brief The random state. -1 means using random number. */
28 TRandState rand_state_ = -1;
29 /*! \brief The schedule function. */
30 runtime::PackedFunc schedule_fn_;
31
32 void VisitAttrs(tvm::AttrVisitor* v) {
33 SpaceGeneratorNode::VisitAttrs(v);
34 // `schedule_fn_` is not visited.
35 }
36
37 void InitializeWithTuneContext(const TuneContext& context) final {
38 SpaceGeneratorNode::InitializeWithTuneContext(context);
39 this->rand_state_ = ForkSeed(&context->rand_state);
40 }
41
42 Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
43 tir::Schedule sch = tir::Schedule::Traced(
44 /*mod=*/mod,
45 /*rand_state=*/ForkSeed(&this->rand_state_),
46 /*debug_mode=*/0,
47 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
48 runtime::TVMRetValue rv;
49 rv = this->schedule_fn_(sch);
50 if (rv.type_code() == kTVMNullptr) {
51 return {sch};
52 }
53 ObjectRef obj = rv;
54 if (const auto* sch = obj.as<tir::ScheduleNode>()) {
55 return {GetRef<tir::Schedule>(sch)};
56 }
57 if (const auto* arr = obj.as<runtime::ArrayNode>()) {
58 Array<tir::Schedule> result;
59 result.reserve(arr->size());
60 for (const ObjectRef& obj : *arr) {
61 if (const auto* sch = obj.as<tir::ScheduleNode>()) {
62 result.push_back(GetRef<tir::Schedule>(sch));
63 } else {
64 LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or "
65 "List[Schedule], but got: "
66 << obj->GetTypeKey();
67 }
68 }
69 return result;
70 }
71 LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or "
72 "List[Schedule], but got: "
73 << obj->GetTypeKey();
74 throw;
75 }
76
77 SpaceGenerator Clone() const final {
78 ObjectPtr<ScheduleFnNode> n = make_object<ScheduleFnNode>(*this);
79 CloneRules(this, n.get());
80 return SpaceGenerator(n);
81 }
82
83 static constexpr const char* _type_key = "meta_schedule.ScheduleFn";
84 TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode);
85};
86
87SpaceGenerator SpaceGenerator::ScheduleFn(PackedFunc schedule_fn,
88 Optional<Array<ScheduleRule>> sch_rules,
89 Optional<Array<Postproc>> postprocs,
90 Optional<Map<Mutator, FloatImm>> mutator_probs) {
91 ObjectPtr<ScheduleFnNode> n = make_object<ScheduleFnNode>();
92 n->sch_rules = std::move(sch_rules);
93 n->postprocs = std::move(postprocs);
94 n->mutator_probs = std::move(mutator_probs);
95 n->schedule_fn_ = std::move(schedule_fn);
96 return SpaceGenerator(n);
97}
98
99TVM_REGISTER_NODE_TYPE(ScheduleFnNode);
100TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn")
101 .set_body_typed(SpaceGenerator::ScheduleFn);
102
103} // namespace meta_schedule
104} // namespace tvm
105