1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuda/reduction.h
22 * \brief CUDA schedule for reduction operations
23 */
24#ifndef TVM_TOPI_CUDA_REDUCTION_H_
25#define TVM_TOPI_CUDA_REDUCTION_H_
26
27#include <tvm/target/generic_func.h>
28#include <tvm/te/operation.h>
29#include <tvm/te/schedule_pass.h>
30#include <tvm/topi/detail/fuse.h>
31#include <tvm/topi/tags.h>
32
33namespace tvm {
34namespace topi {
35
36using namespace tvm::te;
37
38namespace cuda {
39/*!
40 * \brief Schedule a given reduce operation.
41 *
42 * \param target The target to generate a schedule for.
43 * \param op The operation representing the injective operation.
44 * \param sch The schedule to apply this scheduling to
45 * \param is_idx_reduce Pass true to schedule a reduce op that returns
46 * an index, such as argmax or argmin.
47 *
48 * \return The schedule given by sch
49 */
50Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch,
51 bool is_idx_reduce = false) {
52 Tensor data_out;
53 Tensor data_in;
54
55 if (!is_idx_reduce) {
56 data_in = op->InputTensors()[0];
57 data_out = op.output(0);
58 } else {
59 data_out = op->InputTensors()[0];
60 }
61
62 auto out_stage = sch[data_out];
63 ICHECK_GT(out_stage->op.as<ComputeOpNode>()->reduce_axis.size(), 0)
64 << "reduce_axis must be greater than zero";
65
66 bool all_reduce;
67 int num_thread;
68 IterVar block_x, thread_x, thread_y;
69
70 if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
71 all_reduce = false;
72 num_thread = 32;
73 if (target->kind->name == "opencl" || target->kind->name == "metal") {
74 // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
75 // Don't know why.
76 num_thread = 16;
77 }
78 block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
79 thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
80 thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y");
81 } else {
82 all_reduce = true;
83 num_thread = target->GetAttr<Integer>("max_num_threads").value().IntValue();
84 thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
85 }
86
87 auto fused_reduce = detail::Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis);
88
89 IterVar ko, ki;
90 out_stage.split(fused_reduce, num_thread, &ko, &ki);
91 auto data_out_rf = sch.rfactor(data_out, ki)[0];
92 auto tx = out_stage->op.as<ComputeOpNode>()->reduce_axis[0];
93 out_stage.bind(tx, thread_x);
94 sch[data_out_rf].compute_at(out_stage, tx);
95
96 Tensor real_output;
97 Tensor temp_idx_input, temp_val_input;
98 if (is_idx_reduce) {
99 real_output = op.output(0);
100 temp_idx_input = data_out->op.output(0);
101 temp_val_input = data_out->op.output(1);
102 } else {
103 real_output = data_out;
104 }
105
106 auto stage_real = sch[real_output];
107 if (!all_reduce) {
108 // Fuse and split the axis
109 auto fused_outer = detail::Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis);
110 IterVar bx, outer_in;
111 stage_real.split(fused_outer, num_thread, &bx, &outer_in);
112
113 // Bind the axes to threads and blocks
114 stage_real.bind(outer_in, thread_y);
115 stage_real.bind(bx, block_x);
116 if (is_idx_reduce) {
117 sch[temp_idx_input].compute_at(stage_real, outer_in);
118 sch[temp_val_input].compute_at(stage_real, outer_in);
119 }
120 } else {
121 if (is_idx_reduce) {
122 sch[temp_idx_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]);
123 sch[temp_val_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]);
124 }
125 }
126
127 stage_real.set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
128 return sch;
129}
130
131/*!
132 * \brief Recursively traverse operator inputs, setting injective inputs
133 * to be computed inline.
134 *
135 * \param s The schedule we are building
136 * \param op The current op in the traversal
137 */
138void TraverseBeforeReduce(Schedule s, Operation op) {
139 if (op->IsInstance<PlaceholderOpNode>()) {
140 return;
141 } else if (is_injective(op->tag)) {
142 s[op].compute_inline();
143 for (auto tensor : op->InputTensors()) {
144 TraverseBeforeReduce(s, tensor->op);
145 }
146 } else {
147 LOG(ERROR) << "Unsupported operator " << op->tag;
148 }
149}
150
151/*!
152 * \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each
153 * of the op's inputs.
154 *
155 * \param target The target to generate a schedule for.
156 * \param s The schedule we are building
157 * \param op The reduce op
158 */
159void TraverseAfterReduce(const Target& target, Schedule s, Operation op) {
160 if (is_broadcast(op->tag)) {
161 LOG(ERROR) << "Elementwise op after reduce is not yet supported";
162 } else if (op->tag == kCommReduce) {
163 ScheduleReduce(target, op, s, false);
164 for (auto tensor : op->InputTensors()) {
165 TraverseBeforeReduce(s, tensor->op);
166 }
167 } else if (op->tag == kCommReduceIdx) {
168 ScheduleReduce(target, op, s, true);
169 for (auto tensor : op->InputTensors()[0]->op->InputTensors()) {
170 TraverseBeforeReduce(s, tensor->op);
171 }
172 } else {
173 LOG(ERROR) << "Unsupported operator " << op->tag;
174 }
175}
176
177/*!
178 * \brief Create a CUDA schedule for a reduce operation.
179 *
180 * \param target The target to generate a schedule for.
181 * \param outs The output tensors.
182 *
183 * \return A schedule for the given ops.
184 */
185Schedule schedule_reduce(const Target& target, Array<Tensor> outs) {
186 ICHECK_EQ(outs.size(), 1) << "outs must have size 1";
187 Array<Operation> out_ops;
188 for (auto t : outs) {
189 out_ops.push_back(t->op);
190 }
191 auto s = create_schedule(out_ops);
192 TraverseAfterReduce(target, s, outs[0]->op);
193 return s;
194}
195
196} // namespace cuda
197} // namespace topi
198} // namespace tvm
199#endif // TVM_TOPI_CUDA_REDUCTION_H_
200