1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file cuda/reduction.h |
22 | * \brief CUDA schedule for reduction operations |
23 | */ |
24 | #ifndef TVM_TOPI_CUDA_REDUCTION_H_ |
25 | #define TVM_TOPI_CUDA_REDUCTION_H_ |
26 | |
27 | #include <tvm/target/generic_func.h> |
28 | #include <tvm/te/operation.h> |
29 | #include <tvm/te/schedule_pass.h> |
30 | #include <tvm/topi/detail/fuse.h> |
31 | #include <tvm/topi/tags.h> |
32 | |
33 | namespace tvm { |
34 | namespace topi { |
35 | |
36 | using namespace tvm::te; |
37 | |
38 | namespace cuda { |
39 | /*! |
40 | * \brief Schedule a given reduce operation. |
41 | * |
42 | * \param target The target to generate a schedule for. |
43 | * \param op The operation representing the injective operation. |
44 | * \param sch The schedule to apply this scheduling to |
45 | * \param is_idx_reduce Pass true to schedule a reduce op that returns |
46 | * an index, such as argmax or argmin. |
47 | * |
48 | * \return The schedule given by sch |
49 | */ |
50 | Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, |
51 | bool is_idx_reduce = false) { |
52 | Tensor data_out; |
53 | Tensor data_in; |
54 | |
55 | if (!is_idx_reduce) { |
56 | data_in = op->InputTensors()[0]; |
57 | data_out = op.output(0); |
58 | } else { |
59 | data_out = op->InputTensors()[0]; |
60 | } |
61 | |
62 | auto out_stage = sch[data_out]; |
63 | ICHECK_GT(out_stage->op.as<ComputeOpNode>()->reduce_axis.size(), 0) |
64 | << "reduce_axis must be greater than zero" ; |
65 | |
66 | bool all_reduce; |
67 | int num_thread; |
68 | IterVar block_x, thread_x, thread_y; |
69 | |
70 | if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) { |
71 | all_reduce = false; |
72 | num_thread = 32; |
73 | if (target->kind->name == "opencl" || target->kind->name == "metal" ) { |
74 | // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests. |
75 | // Don't know why. |
76 | num_thread = 16; |
77 | } |
78 | block_x = tvm::te::thread_axis(Range(), "blockIdx.x" ); |
79 | thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x" ); |
80 | thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y" ); |
81 | } else { |
82 | all_reduce = true; |
83 | num_thread = target->GetAttr<Integer>("max_num_threads" ).value().IntValue(); |
84 | thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x" ); |
85 | } |
86 | |
87 | auto fused_reduce = detail::Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis); |
88 | |
89 | IterVar ko, ki; |
90 | out_stage.split(fused_reduce, num_thread, &ko, &ki); |
91 | auto data_out_rf = sch.rfactor(data_out, ki)[0]; |
92 | auto tx = out_stage->op.as<ComputeOpNode>()->reduce_axis[0]; |
93 | out_stage.bind(tx, thread_x); |
94 | sch[data_out_rf].compute_at(out_stage, tx); |
95 | |
96 | Tensor real_output; |
97 | Tensor temp_idx_input, temp_val_input; |
98 | if (is_idx_reduce) { |
99 | real_output = op.output(0); |
100 | temp_idx_input = data_out->op.output(0); |
101 | temp_val_input = data_out->op.output(1); |
102 | } else { |
103 | real_output = data_out; |
104 | } |
105 | |
106 | auto stage_real = sch[real_output]; |
107 | if (!all_reduce) { |
108 | // Fuse and split the axis |
109 | auto fused_outer = detail::Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis); |
110 | IterVar bx, outer_in; |
111 | stage_real.split(fused_outer, num_thread, &bx, &outer_in); |
112 | |
113 | // Bind the axes to threads and blocks |
114 | stage_real.bind(outer_in, thread_y); |
115 | stage_real.bind(bx, block_x); |
116 | if (is_idx_reduce) { |
117 | sch[temp_idx_input].compute_at(stage_real, outer_in); |
118 | sch[temp_val_input].compute_at(stage_real, outer_in); |
119 | } |
120 | } else { |
121 | if (is_idx_reduce) { |
122 | sch[temp_idx_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]); |
123 | sch[temp_val_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]); |
124 | } |
125 | } |
126 | |
127 | stage_real.set_store_predicate(static_cast<PrimExpr>(thread_x) == 0); |
128 | return sch; |
129 | } |
130 | |
131 | /*! |
132 | * \brief Recursively traverse operator inputs, setting injective inputs |
133 | * to be computed inline. |
134 | * |
135 | * \param s The schedule we are building |
136 | * \param op The current op in the traversal |
137 | */ |
138 | void TraverseBeforeReduce(Schedule s, Operation op) { |
139 | if (op->IsInstance<PlaceholderOpNode>()) { |
140 | return; |
141 | } else if (is_injective(op->tag)) { |
142 | s[op].compute_inline(); |
143 | for (auto tensor : op->InputTensors()) { |
144 | TraverseBeforeReduce(s, tensor->op); |
145 | } |
146 | } else { |
147 | LOG(ERROR) << "Unsupported operator " << op->tag; |
148 | } |
149 | } |
150 | |
151 | /*! |
152 | * \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each |
153 | * of the op's inputs. |
154 | * |
155 | * \param target The target to generate a schedule for. |
156 | * \param s The schedule we are building |
157 | * \param op The reduce op |
158 | */ |
159 | void TraverseAfterReduce(const Target& target, Schedule s, Operation op) { |
160 | if (is_broadcast(op->tag)) { |
161 | LOG(ERROR) << "Elementwise op after reduce is not yet supported" ; |
162 | } else if (op->tag == kCommReduce) { |
163 | ScheduleReduce(target, op, s, false); |
164 | for (auto tensor : op->InputTensors()) { |
165 | TraverseBeforeReduce(s, tensor->op); |
166 | } |
167 | } else if (op->tag == kCommReduceIdx) { |
168 | ScheduleReduce(target, op, s, true); |
169 | for (auto tensor : op->InputTensors()[0]->op->InputTensors()) { |
170 | TraverseBeforeReduce(s, tensor->op); |
171 | } |
172 | } else { |
173 | LOG(ERROR) << "Unsupported operator " << op->tag; |
174 | } |
175 | } |
176 | |
177 | /*! |
178 | * \brief Create a CUDA schedule for a reduce operation. |
179 | * |
180 | * \param target The target to generate a schedule for. |
181 | * \param outs The output tensors. |
182 | * |
183 | * \return A schedule for the given ops. |
184 | */ |
185 | Schedule schedule_reduce(const Target& target, Array<Tensor> outs) { |
186 | ICHECK_EQ(outs.size(), 1) << "outs must have size 1" ; |
187 | Array<Operation> out_ops; |
188 | for (auto t : outs) { |
189 | out_ops.push_back(t->op); |
190 | } |
191 | auto s = create_schedule(out_ops); |
192 | TraverseAfterReduce(target, s, outs[0]->op); |
193 | return s; |
194 | } |
195 | |
196 | } // namespace cuda |
197 | } // namespace topi |
198 | } // namespace tvm |
199 | #endif // TVM_TOPI_CUDA_REDUCTION_H_ |
200 | |