1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) { |
25 | StmtSRef block_sref = sch->GetSRef(block_rv); |
26 | return block_sref->parent == nullptr; |
27 | } |
28 | |
29 | bool CheckSpatialPrimFunc(const Schedule& sch, const BlockRV& root_block_rv) { |
30 | return IsSpatialPrimFunc( |
31 | GetRef<PrimFunc>(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); |
32 | } |
33 | |
34 | } // namespace tir |
35 | } // namespace tvm |
36 | |
37 | namespace tvm { |
38 | namespace meta_schedule { |
39 | |
40 | class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { |
41 | public: |
42 | // Inherited from ScheduleRuleNode |
43 | void InitializeWithTuneContext(const TuneContext& context) final { |
44 | ICHECK(context->target.defined()); |
45 | if (this->max_jobs_per_core != -1) { |
46 | Target target = context->target.value(); |
47 | this->max_parallel_extent_ = GetTargetNumCores(target) * max_jobs_per_core; |
48 | } |
49 | } |
50 | |
51 | // Inherited from ScheduleRuleNode |
52 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { |
53 | // Currently only mark the root block with annotations. |
54 | if (!tir::IsRootBlock(sch, root_rv)) { |
55 | return {sch}; |
56 | } |
57 | |
58 | // Parallelization |
59 | if (max_jobs_per_core != -1) { |
60 | sch->Annotate(root_rv, tir::attr::meta_schedule_parallel, |
61 | Integer(this->max_parallel_extent_)); |
62 | } |
63 | // Vectorization |
64 | if (max_vectorize_extent != -1) { |
65 | sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent)); |
66 | } |
67 | // Unroll |
68 | if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { |
69 | int n = unroll_max_steps.size(); |
70 | double prob = 1.0 / n; |
71 | Array<FloatImm> probs(n, FloatImm(DataType::Float(64), prob)); |
72 | PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); |
73 | if (unroll_explicit) { |
74 | sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); |
75 | } else { |
76 | sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_implicit, max_step); |
77 | } |
78 | } |
79 | return {sch}; |
80 | } |
81 | |
82 | // Inherited from ScheduleRuleNode |
83 | ScheduleRule Clone() const final { |
84 | ObjectPtr<ParallelizeVectorizeUnrollNode> n = |
85 | make_object<ParallelizeVectorizeUnrollNode>(*this); |
86 | return ScheduleRule(n); |
87 | } |
88 | |
89 | public: |
90 | /*! |
91 | * \brief The maximum number of jobs to be launched per CPU core. It sets the |
92 | * upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable |
93 | * parallelism. |
94 | */ |
95 | int64_t max_jobs_per_core; |
96 | /*! |
97 | * \brief The maximum extent to be vectorized. |
98 | * It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization. |
99 | */ |
100 | int max_vectorize_extent; |
101 | /*! |
102 | * \brief The options of the maximum number of unroll steps to be done. |
103 | * Use an empty array to disable unroll. |
104 | */ |
105 | Array<Integer> unroll_max_steps; |
106 | /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */ |
107 | bool unroll_explicit; |
108 | /*! \brief The number of maximum available jobs in CPU. */ |
109 | int64_t max_parallel_extent_; |
110 | |
111 | void VisitAttrs(tvm::AttrVisitor* v) { |
112 | v->Visit("max_jobs_per_core" , &max_jobs_per_core); |
113 | v->Visit("max_vectorize_extent" , &max_vectorize_extent); |
114 | v->Visit("unroll_max_steps" , &unroll_max_steps); |
115 | v->Visit("unroll_explicit" , &unroll_explicit); |
116 | // `max_parallel_extent_` is not visited |
117 | } |
118 | |
119 | static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll" ; |
120 | TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode); |
121 | }; |
122 | |
123 | ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, |
124 | int max_vectorize_extent, |
125 | Array<Integer> unroll_max_steps, |
126 | bool unroll_explicit) { |
127 | ObjectPtr<ParallelizeVectorizeUnrollNode> n = make_object<ParallelizeVectorizeUnrollNode>(); |
128 | n->max_jobs_per_core = max_jobs_per_core; |
129 | n->max_vectorize_extent = max_vectorize_extent; |
130 | n->unroll_max_steps = unroll_max_steps; |
131 | n->unroll_explicit = unroll_explicit; |
132 | n->max_parallel_extent_ = -1; |
133 | return ScheduleRule(n); |
134 | } |
135 | |
136 | TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode); |
137 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll" ) |
138 | .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll); |
139 | |
140 | } // namespace meta_schedule |
141 | } // namespace tvm |
142 | |