1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24class AddRFactorNode : public ScheduleRuleNode {
25 public:
26 // Inherited from ScheduleRuleNode
27 void InitializeWithTuneContext(const TuneContext& context) final {
28 ICHECK(context->target.defined());
29 Target target = context->target.value();
30 this->max_parallel_basic_ = GetTargetNumCores(target);
31 if (this->max_jobs_per_core != -1) {
32 this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core;
33 }
34 }
35
36 // Inherited from ScheduleRuleNode
37 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv);
38
39 // Inherited from ScheduleRuleNode
40 ScheduleRule Clone() const final {
41 ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>(*this);
42 return ScheduleRule(n);
43 }
44
45 public:
46 /*!
47 * \brief The maximum number of jobs to be launched per core.
48 * It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`.
49 * Use -1 to disable parallelism.
50 */
51 int max_jobs_per_core;
52 /*! \brief The maximum size of the innermost factor */
53 int max_innermost_factor;
54 /*! \brief The number of uplimit of parallelism. */
55 int max_parallel_extent_;
56 /*! \brief The number of cores. */
57 int max_parallel_basic_;
58
59 void VisitAttrs(tvm::AttrVisitor* v) {
60 v->Visit("max_jobs_per_core", &max_jobs_per_core);
61 v->Visit("max_innermost_factor", &max_innermost_factor);
62 // `max_parallel_extent_` is not visited
63 // `max_parallel_basic_` is not visited
64 }
65
66 static constexpr const char* _type_key = "meta_schedule.AddRFactor";
67 TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode);
68};
69
70ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
71 Optional<Integer> max_innermost_factor) {
72 ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>();
73 n->max_jobs_per_core = max_jobs_per_core;
74 n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
75 n->max_parallel_extent_ = -1;
76 n->max_parallel_basic_ = -1;
77 return ScheduleRule(n);
78}
79
80Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
81 tir::StmtSRef block_sref = sch->GetSRef(block_rv);
82 if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_,
83 max_parallel_basic_)) {
84 return {sch};
85 }
86
87 // Make a copy of the original schedule.
88 tir::Schedule ori_sch = sch->Copy();
89 ori_sch->Seed(sch->ForkSeed());
90
91 // Reorder the loop axes if reduction loops are not innermost.
92 // After the reordering, fuse all the reduction loops.
93 size_t num_spatial_loops;
94 tir::LoopRV fused_reduce_loop;
95 ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops);
96
97 // Split the fused reduction loop.
98 Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
99 Array<tir::LoopRV> split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});
100
101 Array<tir::Schedule> res;
102 for (const tir::LoopRV& split_loop : split_loops) {
103 tir::Schedule sch_tmp = sch->Copy();
104 sch_tmp->Seed(sch->ForkSeed());
105 try {
106 const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops);
107 Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
108 ICHECK_GT(axes.size(), num_spatial_loops);
109
110 // Annotate that the rfactor block, which is now the producer of the original block, needs to
111 // be considered by the rule Random-Compute-Location.
112 sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Integer(1));
113 res.push_back(sch_tmp);
114 } catch (const tvm::runtime::Error& e) {
115 }
116 }
117
118 res.push_back(ori_sch);
119 return res;
120}
121
122TVM_REGISTER_NODE_TYPE(AddRFactorNode);
123TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor")
124 .set_body_typed(ScheduleRule::AddRFactor);
125
126} // namespace meta_schedule
127} // namespace tvm
128