1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuda/injective.h
22 * \brief CUDA schedule for injective operations
23 */
24#ifndef TVM_TOPI_CUDA_SOFTMAX_H_
25#define TVM_TOPI_CUDA_SOFTMAX_H_
26
27#include <tvm/target/generic_func.h>
28#include <tvm/te/operation.h>
29#include <tvm/te/schedule_pass.h>
30#include <tvm/topi/detail/fuse.h>
31#include <tvm/topi/tags.h>
32
33namespace tvm {
34namespace topi {
35
36using namespace tvm::te;
37
38namespace cuda {
39
40/*!
41 * \brief Create a CUDA schedule for the given softmax output tensors.
42 *
43 * \param target The target to generate a schedule for.
44 * \param outs The output tensors.
45 *
46 * \return A schedule for the given ops.
47 */
48inline Schedule schedule_softmax(const Target& target, const Array<Tensor>& outs) {
49 Array<Operation> out_ops;
50 for (auto t : outs) {
51 out_ops.push_back(t->op);
52 }
53 auto s = create_schedule(out_ops);
54
55 auto softmax = outs[0];
56 tvm::te::Tensor max_elem;
57 tvm::te::Tensor expsum;
58 tvm::te::Tensor exp;
59 bool has_exp = false;
60
61 auto tag = softmax->op.as<ComputeOpNode>()->tag;
62 if (tag == "softmax_output") {
63 expsum = softmax->op->InputTensors()[1];
64 exp = softmax->op->InputTensors()[0];
65 max_elem = s[exp]->op->InputTensors()[1];
66 has_exp = true;
67 } else if (tag == "log_softmax_output") {
68 max_elem = softmax->op->InputTensors()[1];
69 expsum = softmax->op->InputTensors()[2];
70 } else {
71 LOG(ERROR) << "Tag is expected to be softmax_output or log_softmax_output. Got " << tag;
72 }
73
74 int num_thread = 64;
75 auto block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
76 auto thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
77
78 if (has_exp) {
79 s[exp].bind(exp->op.as<ComputeOpNode>()->axis[0], block_x);
80 }
81
82 s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
83
84 auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
85 IterVar ko, ki;
86 s[expsum].split(k, num_thread, &ko, &ki);
87 auto EF = s.rfactor(expsum, ki)[0];
88 s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->axis[0], block_x);
89 s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
90 s[EF].compute_at(s[expsum], s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0]);
91 s[expsum].set_store_predicate(thread_x->var == 0);
92
93 IterVar tx, xi;
94 s[softmax].split_by_nparts(softmax->op.as<ComputeOpNode>()->axis[1], num_thread, &tx, &xi);
95 s[softmax].bind(tx, thread_x);
96
97 return s;
98}
99
100} // namespace cuda
101} // namespace topi
102} // namespace tvm
103#endif // TVM_TOPI_CUDA_SOFTMAX_H_
104