1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuda/pooling.h
22 * \brief CUDA schedule for pooling operations
23 */
24#ifndef TVM_TOPI_CUDA_POOLING_H_
25#define TVM_TOPI_CUDA_POOLING_H_
26
27#include <tvm/target/generic_func.h>
28#include <tvm/te/operation.h>
29#include <tvm/te/schedule_pass.h>
30#include <tvm/topi/detail/array_utils.h>
31#include <tvm/topi/detail/fuse.h>
32#include <tvm/topi/tags.h>
33
34namespace tvm {
35namespace topi {
36
37using namespace tvm::te;
38
39namespace cuda {
40
41/*!
42 * \brief Create a CUDA schedule for pool
43 *
44 * \param target The target to generate a schedule for.
45 * \param outs The output tensors.
46 *
47 * \return A schedule for the given ops.
48 */
49inline Schedule schedule_pool(const Target& target, const Array<Tensor>& outs) {
50 Array<Operation> out_ops;
51 for (auto t : outs) {
52 out_ops.push_back(t->op);
53 }
54 auto s = create_schedule(out_ops);
55
56 auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) {
57 if (padded_input->op->IsInstance<ComputeOpNode>()) {
58 s[padded_input].compute_inline();
59 }
60 int num_thread = target->GetAttr<Integer>("max_num_threads").value().IntValue();
61 Tensor out;
62 Tensor OL;
63 if (detail::contains(s->outputs, pool->op)) {
64 out = pool;
65 OL = s.cache_write(pool, "local");
66 } else {
67 out = outs[0]->op.output(0);
68 s[pool].set_scope("local");
69 }
70 auto fused = detail::Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis);
71 IterVar bx, tx;
72 s[out].split(fused, num_thread, &bx, &tx);
73 s[out].bind(bx, tvm::te::thread_axis(Range(), "blockIdx.x"));
74 s[out].bind(tx, tvm::te::thread_axis(Range(), "threadIdx.x"));
75 if (detail::contains(s->outputs, pool->op)) {
76 s[OL].compute_at(s[out], tx);
77 } else {
78 s[pool].compute_at(s[out], tx);
79 }
80 };
81
82 std::function<void(Operation)> traverse;
83 traverse = [&](const Operation& op) {
84 // Inline all one-to-one-mapping operators except the last stage (output)
85 if (is_broadcast(op->tag)) {
86 if (!detail::contains(s->outputs, op)) {
87 s[op].compute_inline();
88 }
89 for (auto tensor : op->InputTensors()) {
90 if (tensor->op->InputTensors().size() > 0) {
91 traverse(tensor->op);
92 }
93 }
94 } else if (op->tag.rfind("pool", 0) == 0) {
95 // If tag starts with pool
96 auto padded_input = op->InputTensors()[0];
97 auto pool = op.output(0);
98 _schedule(padded_input, pool);
99 } else {
100 LOG(ERROR) << "Unsupported operator " << op->tag;
101 }
102 };
103
104 traverse(outs[0]->op);
105 return s;
106}
107
108/*!
109 * \brief Create a CUDA schedule for global_pool
110 *
111 * \param target The target to generate a schedule for.
112 * \param outs The output tensors.
113 *
114 * \return A schedule for the given ops.
115 */
116inline Schedule schedule_global_pool(const Target& target, const Array<Tensor>& outs) {
117 Array<Operation> out_ops;
118 for (auto t : outs) {
119 out_ops.push_back(t->op);
120 }
121 auto s = create_schedule(out_ops);
122
123 auto _schedule = [&](const Tensor& pool) {
124 auto num_thread = 8;
125 auto block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
126 auto block_y = tvm::te::thread_axis(Range(), "blockIdx.y");
127 auto thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
128 auto thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y");
129 Tensor out;
130 Tensor OL;
131 if (detail::contains(s->outputs, pool->op)) {
132 out = pool;
133 OL = s.cache_write(pool, "local");
134 } else {
135 out = outs[0]->op.output(0);
136 s[pool].set_scope("local");
137 }
138
139 auto i = s[out]->op.as<ComputeOpNode>()->axis[0];
140 auto c = s[out]->op.as<ComputeOpNode>()->axis[1];
141
142 IterVar by, ty;
143 s[out].split(i, num_thread, &by, &ty);
144 IterVar bx, tx;
145 s[out].split(c, num_thread, &bx, &tx);
146 s[out].reorder({by, bx, ty, tx});
147 s[out].bind(ty, thread_y);
148 s[out].bind(tx, thread_x);
149 s[out].bind(by, block_y);
150 s[out].bind(bx, block_x);
151
152 if (detail::contains(s->outputs, pool->op)) {
153 s[OL].compute_at(s[out], tx);
154 } else {
155 s[pool].compute_at(s[out], tx);
156 }
157 };
158
159 std::function<void(Operation)> traverse;
160 traverse = [&](const Operation& op) {
161 // Inline all one-to-one-mapping operators except the last stage (output)
162 if (is_broadcast(op->tag)) {
163 if (!detail::contains(s->outputs, op)) {
164 s[op].compute_inline();
165 }
166 for (auto tensor : op->InputTensors()) {
167 if (tensor->op->InputTensors().size() > 0) {
168 traverse(tensor->op);
169 }
170 }
171 } else if (op->tag.rfind("global_pool", 0) == 0) {
172 // If tag starts with global_pool
173 auto pool = op.output(0);
174 _schedule(pool);
175 } else {
176 LOG(ERROR) << "Unsupported operator " << op->tag;
177 }
178 };
179
180 traverse(outs[0]->op);
181 return s;
182}
183
184} // namespace cuda
185} // namespace topi
186} // namespace tvm
187#endif // TVM_TOPI_CUDA_POOLING_H_
188