1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) {
25 StmtSRef block_sref = sch->GetSRef(block_rv);
26 return block_sref->parent == nullptr;
27}
28
29bool CheckSpatialPrimFunc(const Schedule& sch, const BlockRV& root_block_rv) {
30 return IsSpatialPrimFunc(
31 GetRef<PrimFunc>(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr)));
32}
33
34} // namespace tir
35} // namespace tvm
36
37namespace tvm {
38namespace meta_schedule {
39
40class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode {
41 public:
42 // Inherited from ScheduleRuleNode
43 void InitializeWithTuneContext(const TuneContext& context) final {
44 ICHECK(context->target.defined());
45 if (this->max_jobs_per_core != -1) {
46 Target target = context->target.value();
47 this->max_parallel_extent_ = GetTargetNumCores(target) * max_jobs_per_core;
48 }
49 }
50
51 // Inherited from ScheduleRuleNode
52 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) {
53 // Currently only mark the root block with annotations.
54 if (!tir::IsRootBlock(sch, root_rv)) {
55 return {sch};
56 }
57
58 // Parallelization
59 if (max_jobs_per_core != -1) {
60 sch->Annotate(root_rv, tir::attr::meta_schedule_parallel,
61 Integer(this->max_parallel_extent_));
62 }
63 // Vectorization
64 if (max_vectorize_extent != -1) {
65 sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent));
66 }
67 // Unroll
68 if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) {
69 int n = unroll_max_steps.size();
70 double prob = 1.0 / n;
71 Array<FloatImm> probs(n, FloatImm(DataType::Float(64), prob));
72 PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs);
73 if (unroll_explicit) {
74 sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step);
75 } else {
76 sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_implicit, max_step);
77 }
78 }
79 return {sch};
80 }
81
82 // Inherited from ScheduleRuleNode
83 ScheduleRule Clone() const final {
84 ObjectPtr<ParallelizeVectorizeUnrollNode> n =
85 make_object<ParallelizeVectorizeUnrollNode>(*this);
86 return ScheduleRule(n);
87 }
88
89 public:
90 /*!
91 * \brief The maximum number of jobs to be launched per CPU core. It sets the
92 * upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
93 * parallelism.
94 */
95 int64_t max_jobs_per_core;
96 /*!
97 * \brief The maximum extent to be vectorized.
98 * It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization.
99 */
100 int max_vectorize_extent;
101 /*!
102 * \brief The options of the maximum number of unroll steps to be done.
103 * Use an empty array to disable unroll.
104 */
105 Array<Integer> unroll_max_steps;
106 /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */
107 bool unroll_explicit;
108 /*! \brief The number of maximum available jobs in CPU. */
109 int64_t max_parallel_extent_;
110
111 void VisitAttrs(tvm::AttrVisitor* v) {
112 v->Visit("max_jobs_per_core", &max_jobs_per_core);
113 v->Visit("max_vectorize_extent", &max_vectorize_extent);
114 v->Visit("unroll_max_steps", &unroll_max_steps);
115 v->Visit("unroll_explicit", &unroll_explicit);
116 // `max_parallel_extent_` is not visited
117 }
118
119 static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll";
120 TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode);
121};
122
123ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core,
124 int max_vectorize_extent,
125 Array<Integer> unroll_max_steps,
126 bool unroll_explicit) {
127 ObjectPtr<ParallelizeVectorizeUnrollNode> n = make_object<ParallelizeVectorizeUnrollNode>();
128 n->max_jobs_per_core = max_jobs_per_core;
129 n->max_vectorize_extent = max_vectorize_extent;
130 n->unroll_max_steps = unroll_max_steps;
131 n->unroll_explicit = unroll_explicit;
132 n->max_parallel_extent_ = -1;
133 return ScheduleRule(n);
134}
135
136TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode);
137TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll")
138 .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll);
139
140} // namespace meta_schedule
141} // namespace tvm
142