1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "../../tir/schedule/analysis.h"
21#include "../../tir/schedule/transform.h"
22#include "../utils.h"
23#include "multi_level_tiling.h"
24
25namespace tvm {
26namespace meta_schedule {
27
28using tir::BlockRV;
29using tir::LoopRV;
30using tir::Schedule;
31
32/*!
33 * \brief Extension of MultiLevelTiling for backends with wide vectors.
34 * The loop over the innermost spatial axis of the output buffer is always vectorized with the
35 * maximum vector length.
36 */
37class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode {
38 public:
39 size_t vector_length_in_bits;
40
41 static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector";
42 TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode);
43
44 protected:
45 ScheduleRule Clone() const final {
46 ObjectPtr<MultiLevelTilingWideVectorNode> n =
47 make_object<MultiLevelTilingWideVectorNode>(*this);
48 return ScheduleRule(n);
49 }
50
51 Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const;
52};
53
54Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv,
55 LoopRV loop_rv, int n_tiles) const {
56 const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
57 const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
58 const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
59 const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref);
60 ICHECK(block_node && block_node->writes.size() == 1);
61
62 const auto out_dtype = block_node->writes[0]->buffer->dtype;
63 const int vec_len = vector_length_in_bits / out_dtype.bits();
64
65 // Determine if this loop is over the innermost axis of the output buffer.
66 // In the example below, we look for a loop whose loop var is bound to the axis co.
67
68 // for (i0, 0, 1) {
69 // for (i1, 0, 56) {
70 // for (i2, 0, 56) {
71 // for (i3, 0, 64) {
72 // for (i4, 0, 3) {
73 // for (i5, 0, 3) {
74 // for (i6, 0, 64) {
75 // block conv2d_nhwc(...) {
76 // ...
77 // bind(co, i3)
78 // ...
79 // writes([conv2d_nhwc[n, h, w, co]])
80 // ...
81 // conv2d_nhwc[n, h, w, co] = ...
82 // }
83 const size_t innermost_axis = block_node->writes[0]->region.size() - 1;
84 const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis];
85
86 if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) {
87 // If this is not the innermost spatial loop, split the loop in the normal way.
88 return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles);
89 } else {
90 // We split the innermost spatial loop in a way that always uses the maximum vector length.
91 const int64_t* extent_int = tir::GetLoopIntExtent(loop);
92 if (extent_int && *extent_int > vec_len) {
93 Array<tir::LoopRV> inner_splits = sch->Split(/*loop=*/loop_rv,
94 /*factors=*/{NullOpt, PrimExpr(vec_len)});
95 Array<tir::ExprRV> outer_factors = sch->SamplePerfectTile(
96 /*loop=*/inner_splits[0],
97 /*n=*/n_tiles - 1,
98 /*max_innermost_factor=*/max_innermost_factor);
99 Array<tir::LoopRV> outer_splits = sch->Split(
100 /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()});
101 outer_splits.push_back(inner_splits[1]);
102 return outer_splits;
103 } else {
104 Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
105 factors.push_back(loop->extent);
106 return sch->Split(/*loop=*/loop_rv,
107 /*factors=*/{factors.begin(), factors.end()});
108 }
109 }
110}
111
112ScheduleRule ScheduleRule::MultiLevelTilingWideVector(
113 String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
114 Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
115 auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>(
116 structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write);
117 node->vector_length_in_bits = vector_length_in_bits->value;
118 return ScheduleRule(node);
119}
120
121TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode);
122TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector")
123 .set_body_typed(ScheduleRule::MultiLevelTilingWideVector);
124
125} // namespace meta_schedule
126} // namespace tvm
127