1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | class AddRFactorNode : public ScheduleRuleNode { |
25 | public: |
26 | // Inherited from ScheduleRuleNode |
27 | void InitializeWithTuneContext(const TuneContext& context) final { |
28 | ICHECK(context->target.defined()); |
29 | Target target = context->target.value(); |
30 | this->max_parallel_basic_ = GetTargetNumCores(target); |
31 | if (this->max_jobs_per_core != -1) { |
32 | this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core; |
33 | } |
34 | } |
35 | |
36 | // Inherited from ScheduleRuleNode |
37 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); |
38 | |
39 | // Inherited from ScheduleRuleNode |
40 | ScheduleRule Clone() const final { |
41 | ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>(*this); |
42 | return ScheduleRule(n); |
43 | } |
44 | |
45 | public: |
46 | /*! |
47 | * \brief The maximum number of jobs to be launched per core. |
48 | * It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`. |
49 | * Use -1 to disable parallelism. |
50 | */ |
51 | int max_jobs_per_core; |
52 | /*! \brief The maximum size of the innermost factor */ |
53 | int max_innermost_factor; |
54 | /*! \brief The number of uplimit of parallelism. */ |
55 | int max_parallel_extent_; |
56 | /*! \brief The number of cores. */ |
57 | int max_parallel_basic_; |
58 | |
59 | void VisitAttrs(tvm::AttrVisitor* v) { |
60 | v->Visit("max_jobs_per_core" , &max_jobs_per_core); |
61 | v->Visit("max_innermost_factor" , &max_innermost_factor); |
62 | // `max_parallel_extent_` is not visited |
63 | // `max_parallel_basic_` is not visited |
64 | } |
65 | |
66 | static constexpr const char* _type_key = "meta_schedule.AddRFactor" ; |
67 | TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode); |
68 | }; |
69 | |
70 | ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, |
71 | Optional<Integer> max_innermost_factor) { |
72 | ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>(); |
73 | n->max_jobs_per_core = max_jobs_per_core; |
74 | n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; |
75 | n->max_parallel_extent_ = -1; |
76 | n->max_parallel_basic_ = -1; |
77 | return ScheduleRule(n); |
78 | } |
79 | |
80 | Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { |
81 | tir::StmtSRef block_sref = sch->GetSRef(block_rv); |
82 | if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, |
83 | max_parallel_basic_)) { |
84 | return {sch}; |
85 | } |
86 | |
87 | // Make a copy of the original schedule. |
88 | tir::Schedule ori_sch = sch->Copy(); |
89 | ori_sch->Seed(sch->ForkSeed()); |
90 | |
91 | // Reorder the loop axes if reduction loops are not innermost. |
92 | // After the reordering, fuse all the reduction loops. |
93 | size_t num_spatial_loops; |
94 | tir::LoopRV fused_reduce_loop; |
95 | ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); |
96 | |
97 | // Split the fused reduction loop. |
98 | Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); |
99 | Array<tir::LoopRV> split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); |
100 | |
101 | Array<tir::Schedule> res; |
102 | for (const tir::LoopRV& split_loop : split_loops) { |
103 | tir::Schedule sch_tmp = sch->Copy(); |
104 | sch_tmp->Seed(sch->ForkSeed()); |
105 | try { |
106 | const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); |
107 | Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf); |
108 | ICHECK_GT(axes.size(), num_spatial_loops); |
109 | |
110 | // Annotate that the rfactor block, which is now the producer of the original block, needs to |
111 | // be considered by the rule Random-Compute-Location. |
112 | sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Integer(1)); |
113 | res.push_back(sch_tmp); |
114 | } catch (const tvm::runtime::Error& e) { |
115 | } |
116 | } |
117 | |
118 | res.push_back(ori_sch); |
119 | return res; |
120 | } |
121 | |
122 | TVM_REGISTER_NODE_TYPE(AddRFactorNode); |
123 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor" ) |
124 | .set_body_typed(ScheduleRule::AddRFactor); |
125 | |
126 | } // namespace meta_schedule |
127 | } // namespace tvm |
128 | |