1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/meta_schedule/schedule/cuda/thread_bind.h> |
20 | #include <tvm/meta_schedule/schedule/generic/winograd.h> |
21 | |
22 | #include <vector> |
23 | |
24 | #include "../../utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace meta_schedule { |
28 | |
29 | using namespace tvm::tir; |
30 | |
31 | static Array<tir::LoopRV> ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, |
32 | std::vector<int> tiled, std::vector<int> unrolled) { |
33 | // This method is used for NHWC layout only. Will likely be refactored into a more schedule |
34 | using namespace tvm::tir; |
35 | ICHECK_EQ(tiled.size(), 2); |
36 | ICHECK_EQ(unrolled.size(), 4); |
37 | Array<ExprRV> factors{nullptr}; |
38 | Array<LoopRV> loops = sch->GetLoops(block); |
39 | ICHECK_EQ(loops.size(), 6); |
40 | |
41 | factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); |
42 | Array<LoopRV> t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); |
43 | ICHECK_EQ(t0.size(), 2); |
44 | |
45 | factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); |
46 | Array<LoopRV> t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); |
47 | ICHECK_EQ(t1.size(), 2); |
48 | |
49 | sch->Unroll(loops[unrolled[0]]); |
50 | sch->Unroll(loops[unrolled[1]]); |
51 | sch->Unroll(loops[unrolled[2]]); |
52 | sch->Unroll(loops[unrolled[3]]); |
53 | sch->Reorder({ |
54 | t0[0], |
55 | t1[0], |
56 | t0[1], |
57 | t1[1], |
58 | loops[unrolled[0]], |
59 | loops[unrolled[1]], |
60 | loops[unrolled[2]], |
61 | loops[unrolled[3]], |
62 | }); |
63 | return {t0[0], t1[0], t0[1], t1[1]}; |
64 | } |
65 | |
66 | TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack" ) |
67 | .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> { |
68 | BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); |
69 | BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); |
70 | Array<LoopRV> loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); |
71 | { |
72 | BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local" ); |
73 | sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true); |
74 | } |
75 | { |
76 | sch->ComputeAt(input_tile, /*loop_rv=*/loops.back(), /*preserve_unit_loops=*/true); |
77 | sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local" ); |
78 | sch->ComputeInline(data_pad); |
79 | } |
80 | { |
81 | int64_t max_threadblocks = 256; |
82 | int64_t max_threads_per_block = 1024; |
83 | Array<LoopRV> loops = sch->GetLoops(data_pack); |
84 | ICHECK_EQ(loops.size(), 8); |
85 | BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, |
86 | max_threads_per_block); |
87 | } |
88 | return {sch}; |
89 | }); |
90 | |
91 | TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse" ) |
92 | .set_body_typed([](Schedule sch, BlockRV inverse) -> Array<Schedule> { |
93 | GetWinogradProducerAndInlineConst(sch, inverse); |
94 | ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); |
95 | int64_t max_threadblocks = 256; |
96 | int64_t max_threads_per_block = 1024; |
97 | Array<LoopRV> loops = sch->GetLoops(inverse); |
98 | ICHECK_EQ(loops.size(), 8); |
99 | BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, |
100 | max_threads_per_block); |
101 | return {sch}; |
102 | }); |
103 | |
104 | TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack" ) |
105 | .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> { |
106 | int64_t max_threadblocks = 256; |
107 | int64_t max_threads_per_block = 1024; |
108 | BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); |
109 | BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); |
110 | LoopRV outer{nullptr}; |
111 | { |
112 | Array<LoopRV> loops = sch->GetLoops(data_pack); |
113 | ICHECK_EQ(loops.size(), 6); |
114 | sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]}); |
115 | sch->Unroll(loops[0]); |
116 | sch->Unroll(loops[1]); |
117 | sch->Unroll(loops[4]); |
118 | sch->Unroll(loops[5]); |
119 | outer = BindSpatialLoop(sch, sch->Fuse({loops[2], loops[3]}), max_threadblocks, |
120 | max_threads_per_block, /*get_factor=*/nullptr) |
121 | .back(); |
122 | } |
123 | { |
124 | BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local" ); |
125 | sch->ReverseComputeAt(data_pack_local, outer, /*preserve_unit_loops=*/true); |
126 | } |
127 | { |
128 | sch->ComputeAt(input_tile, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); |
129 | sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local" ); |
130 | sch->ComputeInline(data_pad); |
131 | } |
132 | return {sch}; |
133 | }); |
134 | |
135 | TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse" ) |
136 | .set_body_typed([](Schedule sch, BlockRV inverse) -> Array<Schedule> { |
137 | GetWinogradProducerAndInlineConst(sch, inverse); |
138 | // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] |
139 | int64_t tile_size = Downcast<IntImm>(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; |
140 | LoopRV outer{nullptr}; |
141 | { |
142 | BlockRV output = sch->GetConsumers(inverse)[0]; |
143 | Array<LoopRV> nchw = sch->GetLoops(output); |
144 | ICHECK_EQ(nchw.size(), 4); |
145 | Array<LoopRV> hs = sch->Split(nchw[2], {NullOpt, Integer(tile_size)}); |
146 | Array<LoopRV> ws = sch->Split(nchw[3], {NullOpt, Integer(tile_size)}); |
147 | sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); |
148 | outer = ws[0]; |
149 | } |
150 | { |
151 | sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); |
152 | sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local" ); |
153 | Array<LoopRV> loops = sch->GetLoops(inverse); |
154 | ICHECK_EQ(loops.size(), 10); |
155 | sch->Unroll(loops[6]); |
156 | sch->Unroll(loops[7]); |
157 | sch->Unroll(loops[8]); |
158 | sch->Unroll(loops[9]); |
159 | } |
160 | return {sch}; |
161 | }); |
162 | |
163 | } // namespace meta_schedule |
164 | } // namespace tvm |
165 | |