1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "../../tir/schedule/analysis.h" |
21 | #include "../../tir/schedule/transform.h" |
22 | #include "../utils.h" |
23 | #include "multi_level_tiling.h" |
24 | |
25 | namespace tvm { |
26 | namespace meta_schedule { |
27 | |
28 | using tir::BlockRV; |
29 | using tir::LoopRV; |
30 | using tir::Schedule; |
31 | |
32 | /*! |
33 | * \brief Extension of MultiLevelTiling for backends with wide vectors. |
34 | * The loop over the innermost spatial axis of the output buffer is always vectorized with the |
35 | * maximum vector length. |
36 | */ |
37 | class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { |
38 | public: |
39 | size_t vector_length_in_bits; |
40 | |
41 | static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector" ; |
42 | TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode); |
43 | |
44 | protected: |
45 | ScheduleRule Clone() const final { |
46 | ObjectPtr<MultiLevelTilingWideVectorNode> n = |
47 | make_object<MultiLevelTilingWideVectorNode>(*this); |
48 | return ScheduleRule(n); |
49 | } |
50 | |
51 | Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; |
52 | }; |
53 | |
54 | Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, |
55 | LoopRV loop_rv, int n_tiles) const { |
56 | const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); |
57 | const tir::StmtSRef block_sref = sch->GetSRef(block_rv); |
58 | const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>(); |
59 | const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref); |
60 | ICHECK(block_node && block_node->writes.size() == 1); |
61 | |
62 | const auto out_dtype = block_node->writes[0]->buffer->dtype; |
63 | const int vec_len = vector_length_in_bits / out_dtype.bits(); |
64 | |
65 | // Determine if this loop is over the innermost axis of the output buffer. |
66 | // In the example below, we look for a loop whose loop var is bound to the axis co. |
67 | |
68 | // for (i0, 0, 1) { |
69 | // for (i1, 0, 56) { |
70 | // for (i2, 0, 56) { |
71 | // for (i3, 0, 64) { |
72 | // for (i4, 0, 3) { |
73 | // for (i5, 0, 3) { |
74 | // for (i6, 0, 64) { |
75 | // block conv2d_nhwc(...) { |
76 | // ... |
77 | // bind(co, i3) |
78 | // ... |
79 | // writes([conv2d_nhwc[n, h, w, co]]) |
80 | // ... |
81 | // conv2d_nhwc[n, h, w, co] = ... |
82 | // } |
83 | const size_t innermost_axis = block_node->writes[0]->region.size() - 1; |
84 | const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis]; |
85 | |
86 | if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) { |
87 | // If this is not the innermost spatial loop, split the loop in the normal way. |
88 | return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles); |
89 | } else { |
90 | // We split the innermost spatial loop in a way that always uses the maximum vector length. |
91 | const int64_t* extent_int = tir::GetLoopIntExtent(loop); |
92 | if (extent_int && *extent_int > vec_len) { |
93 | Array<tir::LoopRV> inner_splits = sch->Split(/*loop=*/loop_rv, |
94 | /*factors=*/{NullOpt, PrimExpr(vec_len)}); |
95 | Array<tir::ExprRV> outer_factors = sch->SamplePerfectTile( |
96 | /*loop=*/inner_splits[0], |
97 | /*n=*/n_tiles - 1, |
98 | /*max_innermost_factor=*/max_innermost_factor); |
99 | Array<tir::LoopRV> outer_splits = sch->Split( |
100 | /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); |
101 | outer_splits.push_back(inner_splits[1]); |
102 | return outer_splits; |
103 | } else { |
104 | Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1)); |
105 | factors.push_back(loop->extent); |
106 | return sch->Split(/*loop=*/loop_rv, |
107 | /*factors=*/{factors.begin(), factors.end()}); |
108 | } |
109 | } |
110 | } |
111 | |
112 | ScheduleRule ScheduleRule::MultiLevelTilingWideVector( |
113 | String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor, |
114 | Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) { |
115 | auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>( |
116 | structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write); |
117 | node->vector_length_in_bits = vector_length_in_bits->value; |
118 | return ScheduleRule(node); |
119 | } |
120 | |
121 | TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode); |
122 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector" ) |
123 | .set_body_typed(ScheduleRule::MultiLevelTilingWideVector); |
124 | |
125 | } // namespace meta_schedule |
126 | } // namespace tvm |
127 | |