1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/meta_schedule/schedule/cuda/thread_bind.h>
20#include <tvm/meta_schedule/schedule/generic/winograd.h>
21
22#include <vector>
23
24#include "../../utils.h"
25
26namespace tvm {
27namespace meta_schedule {
28
29using namespace tvm::tir;
30
31static Array<tir::LoopRV> ScheduleDataPack(tir::Schedule sch, tir::BlockRV block,
32 std::vector<int> tiled, std::vector<int> unrolled) {
33 // This method is used for NHWC layout only. Will likely be refactored into a more schedule
34 using namespace tvm::tir;
35 ICHECK_EQ(tiled.size(), 2);
36 ICHECK_EQ(unrolled.size(), 4);
37 Array<ExprRV> factors{nullptr};
38 Array<LoopRV> loops = sch->GetLoops(block);
39 ICHECK_EQ(loops.size(), 6);
40
41 factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64);
42 Array<LoopRV> t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()});
43 ICHECK_EQ(t0.size(), 2);
44
45 factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64);
46 Array<LoopRV> t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()});
47 ICHECK_EQ(t1.size(), 2);
48
49 sch->Unroll(loops[unrolled[0]]);
50 sch->Unroll(loops[unrolled[1]]);
51 sch->Unroll(loops[unrolled[2]]);
52 sch->Unroll(loops[unrolled[3]]);
53 sch->Reorder({
54 t0[0],
55 t1[0],
56 t0[1],
57 t1[1],
58 loops[unrolled[0]],
59 loops[unrolled[1]],
60 loops[unrolled[2]],
61 loops[unrolled[3]],
62 });
63 return {t0[0], t1[0], t0[1], t1[1]};
64}
65
66TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack")
67 .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
68 BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack);
69 BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile);
70 Array<LoopRV> loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5});
71 {
72 BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local");
73 sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true);
74 }
75 {
76 sch->ComputeAt(input_tile, /*loop_rv=*/loops.back(), /*preserve_unit_loops=*/true);
77 sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local");
78 sch->ComputeInline(data_pad);
79 }
80 {
81 int64_t max_threadblocks = 256;
82 int64_t max_threads_per_block = 1024;
83 Array<LoopRV> loops = sch->GetLoops(data_pack);
84 ICHECK_EQ(loops.size(), 8);
85 BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks,
86 max_threads_per_block);
87 }
88 return {sch};
89 });
90
91TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse")
92 .set_body_typed([](Schedule sch, BlockRV inverse) -> Array<Schedule> {
93 GetWinogradProducerAndInlineConst(sch, inverse);
94 ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5});
95 int64_t max_threadblocks = 256;
96 int64_t max_threads_per_block = 1024;
97 Array<LoopRV> loops = sch->GetLoops(inverse);
98 ICHECK_EQ(loops.size(), 8);
99 BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks,
100 max_threads_per_block);
101 return {sch};
102 });
103
104TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack")
105 .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
106 int64_t max_threadblocks = 256;
107 int64_t max_threads_per_block = 1024;
108 BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack);
109 BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile);
110 LoopRV outer{nullptr};
111 {
112 Array<LoopRV> loops = sch->GetLoops(data_pack);
113 ICHECK_EQ(loops.size(), 6);
114 sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]});
115 sch->Unroll(loops[0]);
116 sch->Unroll(loops[1]);
117 sch->Unroll(loops[4]);
118 sch->Unroll(loops[5]);
119 outer = BindSpatialLoop(sch, sch->Fuse({loops[2], loops[3]}), max_threadblocks,
120 max_threads_per_block, /*get_factor=*/nullptr)
121 .back();
122 }
123 {
124 BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local");
125 sch->ReverseComputeAt(data_pack_local, outer, /*preserve_unit_loops=*/true);
126 }
127 {
128 sch->ComputeAt(input_tile, /*loop_rv=*/outer, /*preserve_unit_loops=*/true);
129 sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local");
130 sch->ComputeInline(data_pad);
131 }
132 return {sch};
133 });
134
135TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse")
136 .set_body_typed([](Schedule sch, BlockRV inverse) -> Array<Schedule> {
137 GetWinogradProducerAndInlineConst(sch, inverse);
138 // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha]
139 int64_t tile_size = Downcast<IntImm>(sch->Get(inverse)->writes[0]->buffer->shape[2])->value;
140 LoopRV outer{nullptr};
141 {
142 BlockRV output = sch->GetConsumers(inverse)[0];
143 Array<LoopRV> nchw = sch->GetLoops(output);
144 ICHECK_EQ(nchw.size(), 4);
145 Array<LoopRV> hs = sch->Split(nchw[2], {NullOpt, Integer(tile_size)});
146 Array<LoopRV> ws = sch->Split(nchw[3], {NullOpt, Integer(tile_size)});
147 sch->Reorder({hs[0], ws[0], hs[1], ws[1]});
148 outer = ws[0];
149 }
150 {
151 sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true);
152 sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local");
153 Array<LoopRV> loops = sch->GetLoops(inverse);
154 ICHECK_EQ(loops.size(), 10);
155 sch->Unroll(loops[6]);
156 sch->Unroll(loops[7]);
157 sch->Unroll(loops[8]);
158 sch->Unroll(loops[9]);
159 }
160 return {sch};
161 });
162
163} // namespace meta_schedule
164} // namespace tvm
165