1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/meta_schedule/schedule/generic/winograd.h> |
20 | |
21 | #include "../../utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace meta_schedule { |
25 | |
26 | using namespace tvm::tir; |
27 | |
28 | static Array<tir::LoopRV> ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, |
29 | std::vector<int> tiled, std::vector<int> unrolled) { |
30 | using namespace tvm::tir; |
31 | ICHECK_EQ(tiled.size(), 2); |
32 | ICHECK_EQ(unrolled.size(), 4); |
33 | Array<ExprRV> factors{nullptr}; |
34 | Array<LoopRV> loops = sch->GetLoops(block); |
35 | ICHECK_EQ(loops.size(), 6); |
36 | |
37 | factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); |
38 | Array<LoopRV> t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); |
39 | ICHECK_EQ(t0.size(), 2); |
40 | |
41 | factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); |
42 | Array<LoopRV> t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); |
43 | ICHECK_EQ(t1.size(), 2); |
44 | |
45 | sch->Unroll(loops[unrolled[0]]); |
46 | sch->Unroll(loops[unrolled[1]]); |
47 | sch->Unroll(loops[unrolled[2]]); |
48 | sch->Unroll(loops[unrolled[3]]); |
49 | sch->Reorder({ |
50 | t0[0], |
51 | t1[0], |
52 | t0[1], |
53 | t1[1], |
54 | loops[unrolled[0]], |
55 | loops[unrolled[1]], |
56 | loops[unrolled[2]], |
57 | loops[unrolled[3]], |
58 | }); |
59 | return {t0[0], t1[0], t0[1], t1[1]}; |
60 | } |
61 | |
62 | TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack" ) |
63 | .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> { |
64 | BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); |
65 | BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); |
66 | ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); |
67 | sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), |
68 | /*preserve_unit_loops=*/true); |
69 | sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), |
70 | /*preserve_unit_loops=*/true); |
71 | return {sch}; |
72 | }); |
73 | |
74 | TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse" ) |
75 | .set_body_typed([](Schedule sch, BlockRV block) -> Array<Schedule> { |
76 | GetWinogradProducerAndInlineConst(sch, block); |
77 | ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); |
78 | return {sch}; |
79 | }); |
80 | |
81 | TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack" ) |
82 | .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> { |
83 | BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); |
84 | BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); |
85 | ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); |
86 | sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), |
87 | /*preserve_unit_loops=*/true); |
88 | sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), |
89 | /*preserve_unit_loops=*/true); |
90 | return {sch}; |
91 | }); |
92 | |
93 | TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse" ) |
94 | .set_body_typed([](Schedule sch, BlockRV block) -> Array<Schedule> { |
95 | GetWinogradProducerAndInlineConst(sch, block); |
96 | ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); |
97 | return {sch}; |
98 | }); |
99 | |
100 | } // namespace meta_schedule |
101 | } // namespace tvm |
102 | |