1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file cuda/pooling.h |
22 | * \brief CUDA schedule for pooling operations |
23 | */ |
24 | #ifndef TVM_TOPI_CUDA_POOLING_H_ |
25 | #define TVM_TOPI_CUDA_POOLING_H_ |
26 | |
27 | #include <tvm/target/generic_func.h> |
28 | #include <tvm/te/operation.h> |
29 | #include <tvm/te/schedule_pass.h> |
30 | #include <tvm/topi/detail/array_utils.h> |
31 | #include <tvm/topi/detail/fuse.h> |
32 | #include <tvm/topi/tags.h> |
33 | |
34 | namespace tvm { |
35 | namespace topi { |
36 | |
37 | using namespace tvm::te; |
38 | |
39 | namespace cuda { |
40 | |
41 | /*! |
42 | * \brief Create a CUDA schedule for pool |
43 | * |
44 | * \param target The target to generate a schedule for. |
45 | * \param outs The output tensors. |
46 | * |
47 | * \return A schedule for the given ops. |
48 | */ |
49 | inline Schedule schedule_pool(const Target& target, const Array<Tensor>& outs) { |
50 | Array<Operation> out_ops; |
51 | for (auto t : outs) { |
52 | out_ops.push_back(t->op); |
53 | } |
54 | auto s = create_schedule(out_ops); |
55 | |
56 | auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) { |
57 | if (padded_input->op->IsInstance<ComputeOpNode>()) { |
58 | s[padded_input].compute_inline(); |
59 | } |
60 | int num_thread = target->GetAttr<Integer>("max_num_threads" ).value().IntValue(); |
61 | Tensor out; |
62 | Tensor OL; |
63 | if (detail::contains(s->outputs, pool->op)) { |
64 | out = pool; |
65 | OL = s.cache_write(pool, "local" ); |
66 | } else { |
67 | out = outs[0]->op.output(0); |
68 | s[pool].set_scope("local" ); |
69 | } |
70 | auto fused = detail::Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis); |
71 | IterVar bx, tx; |
72 | s[out].split(fused, num_thread, &bx, &tx); |
73 | s[out].bind(bx, tvm::te::thread_axis(Range(), "blockIdx.x" )); |
74 | s[out].bind(tx, tvm::te::thread_axis(Range(), "threadIdx.x" )); |
75 | if (detail::contains(s->outputs, pool->op)) { |
76 | s[OL].compute_at(s[out], tx); |
77 | } else { |
78 | s[pool].compute_at(s[out], tx); |
79 | } |
80 | }; |
81 | |
82 | std::function<void(Operation)> traverse; |
83 | traverse = [&](const Operation& op) { |
84 | // Inline all one-to-one-mapping operators except the last stage (output) |
85 | if (is_broadcast(op->tag)) { |
86 | if (!detail::contains(s->outputs, op)) { |
87 | s[op].compute_inline(); |
88 | } |
89 | for (auto tensor : op->InputTensors()) { |
90 | if (tensor->op->InputTensors().size() > 0) { |
91 | traverse(tensor->op); |
92 | } |
93 | } |
94 | } else if (op->tag.rfind("pool" , 0) == 0) { |
95 | // If tag starts with pool |
96 | auto padded_input = op->InputTensors()[0]; |
97 | auto pool = op.output(0); |
98 | _schedule(padded_input, pool); |
99 | } else { |
100 | LOG(ERROR) << "Unsupported operator " << op->tag; |
101 | } |
102 | }; |
103 | |
104 | traverse(outs[0]->op); |
105 | return s; |
106 | } |
107 | |
108 | /*! |
109 | * \brief Create a CUDA schedule for global_pool |
110 | * |
111 | * \param target The target to generate a schedule for. |
112 | * \param outs The output tensors. |
113 | * |
114 | * \return A schedule for the given ops. |
115 | */ |
116 | inline Schedule schedule_global_pool(const Target& target, const Array<Tensor>& outs) { |
117 | Array<Operation> out_ops; |
118 | for (auto t : outs) { |
119 | out_ops.push_back(t->op); |
120 | } |
121 | auto s = create_schedule(out_ops); |
122 | |
123 | auto _schedule = [&](const Tensor& pool) { |
124 | auto num_thread = 8; |
125 | auto block_x = tvm::te::thread_axis(Range(), "blockIdx.x" ); |
126 | auto block_y = tvm::te::thread_axis(Range(), "blockIdx.y" ); |
127 | auto thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x" ); |
128 | auto thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y" ); |
129 | Tensor out; |
130 | Tensor OL; |
131 | if (detail::contains(s->outputs, pool->op)) { |
132 | out = pool; |
133 | OL = s.cache_write(pool, "local" ); |
134 | } else { |
135 | out = outs[0]->op.output(0); |
136 | s[pool].set_scope("local" ); |
137 | } |
138 | |
139 | auto i = s[out]->op.as<ComputeOpNode>()->axis[0]; |
140 | auto c = s[out]->op.as<ComputeOpNode>()->axis[1]; |
141 | |
142 | IterVar by, ty; |
143 | s[out].split(i, num_thread, &by, &ty); |
144 | IterVar bx, tx; |
145 | s[out].split(c, num_thread, &bx, &tx); |
146 | s[out].reorder({by, bx, ty, tx}); |
147 | s[out].bind(ty, thread_y); |
148 | s[out].bind(tx, thread_x); |
149 | s[out].bind(by, block_y); |
150 | s[out].bind(bx, block_x); |
151 | |
152 | if (detail::contains(s->outputs, pool->op)) { |
153 | s[OL].compute_at(s[out], tx); |
154 | } else { |
155 | s[pool].compute_at(s[out], tx); |
156 | } |
157 | }; |
158 | |
159 | std::function<void(Operation)> traverse; |
160 | traverse = [&](const Operation& op) { |
161 | // Inline all one-to-one-mapping operators except the last stage (output) |
162 | if (is_broadcast(op->tag)) { |
163 | if (!detail::contains(s->outputs, op)) { |
164 | s[op].compute_inline(); |
165 | } |
166 | for (auto tensor : op->InputTensors()) { |
167 | if (tensor->op->InputTensors().size() > 0) { |
168 | traverse(tensor->op); |
169 | } |
170 | } |
171 | } else if (op->tag.rfind("global_pool" , 0) == 0) { |
172 | // If tag starts with global_pool |
173 | auto pool = op.output(0); |
174 | _schedule(pool); |
175 | } else { |
176 | LOG(ERROR) << "Unsupported operator " << op->tag; |
177 | } |
178 | }; |
179 | |
180 | traverse(outs[0]->op); |
181 | return s; |
182 | } |
183 | |
184 | } // namespace cuda |
185 | } // namespace topi |
186 | } // namespace tvm |
187 | #endif // TVM_TOPI_CUDA_POOLING_H_ |
188 | |