1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/meta_schedule/schedule/generic/winograd.h>
20
21#include "../../utils.h"
22
23namespace tvm {
24namespace meta_schedule {
25
26using namespace tvm::tir;
27
28static Array<tir::LoopRV> ScheduleDataPack(tir::Schedule sch, tir::BlockRV block,
29 std::vector<int> tiled, std::vector<int> unrolled) {
30 using namespace tvm::tir;
31 ICHECK_EQ(tiled.size(), 2);
32 ICHECK_EQ(unrolled.size(), 4);
33 Array<ExprRV> factors{nullptr};
34 Array<LoopRV> loops = sch->GetLoops(block);
35 ICHECK_EQ(loops.size(), 6);
36
37 factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64);
38 Array<LoopRV> t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()});
39 ICHECK_EQ(t0.size(), 2);
40
41 factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64);
42 Array<LoopRV> t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()});
43 ICHECK_EQ(t1.size(), 2);
44
45 sch->Unroll(loops[unrolled[0]]);
46 sch->Unroll(loops[unrolled[1]]);
47 sch->Unroll(loops[unrolled[2]]);
48 sch->Unroll(loops[unrolled[3]]);
49 sch->Reorder({
50 t0[0],
51 t1[0],
52 t0[1],
53 t1[1],
54 loops[unrolled[0]],
55 loops[unrolled[1]],
56 loops[unrolled[2]],
57 loops[unrolled[3]],
58 });
59 return {t0[0], t1[0], t0[1], t1[1]};
60}
61
62TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack")
63 .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
64 BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack);
65 BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile);
66 ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5});
67 sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile),
68 /*preserve_unit_loops=*/true);
69 sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad),
70 /*preserve_unit_loops=*/true);
71 return {sch};
72 });
73
74TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse")
75 .set_body_typed([](Schedule sch, BlockRV block) -> Array<Schedule> {
76 GetWinogradProducerAndInlineConst(sch, block);
77 ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5});
78 return {sch};
79 });
80
81TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack")
82 .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
83 BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack);
84 BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile);
85 ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5});
86 sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile),
87 /*preserve_unit_loops=*/true);
88 sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad),
89 /*preserve_unit_loops=*/true);
90 return {sch};
91 });
92
93TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse")
94 .set_body_typed([](Schedule sch, BlockRV block) -> Array<Schedule> {
95 GetWinogradProducerAndInlineConst(sch, block);
96 ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5});
97 return {sch};
98 });
99
100} // namespace meta_schedule
101} // namespace tvm
102