1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Softmax op constructions |
22 | * \file nn/softmax.h |
23 | */ |
24 | #ifndef TVM_TOPI_NN_SOFTMAX_H_ |
25 | #define TVM_TOPI_NN_SOFTMAX_H_ |
26 | |
27 | #include <tvm/te/operation.h> |
28 | #include <tvm/topi/reduction.h> |
29 | #include <tvm/topi/tags.h> |
30 | |
31 | #include <algorithm> |
32 | #include <string> |
33 | |
34 | namespace tvm { |
35 | namespace topi { |
36 | namespace nn { |
37 | |
38 | using namespace tvm::te; |
39 | |
40 | /*! |
41 | * \brief Softmax activation |
42 | * |
43 | * \param x The input tensor. Can be any dimension |
44 | * \param axis The channel axis along which softmax is performed |
45 | * \param name The name of the operation |
46 | * \param tag The tag to mark the operation |
47 | * |
48 | * \return A Tensor whose op member is the softmax operation |
49 | */ |
50 | inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor" , |
51 | std::string tag = "softmax_output" ) { |
52 | auto input_shape = x->shape; |
53 | auto ndim = input_shape.size(); |
54 | if (axis < 0) { |
55 | axis = ndim + axis; |
56 | } |
57 | ICHECK_LT(axis, ndim) << "axis parameter should be less than input dim" ; |
58 | |
59 | auto k1 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k1" ); |
60 | auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2" ); |
61 | auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); |
62 | |
63 | tvm::Map<String, ObjectRef> attrs; |
64 | attrs.Set("axis" , Integer(axis)); |
65 | |
66 | auto insert_reduce_index = [axis, ndim](const Array<Var>& indices, const IterVar& reduce_index) { |
67 | Array<PrimExpr> eval_range; |
68 | int arg_counter = 0; |
69 | for (size_t i = 0; i < ndim; ++i) { |
70 | if (static_cast<int>(i) == axis) { |
71 | eval_range.push_back(reduce_index); |
72 | } else { |
73 | eval_range.push_back(indices[arg_counter++]); |
74 | } |
75 | } |
76 | return eval_range; |
77 | }; |
78 | |
79 | auto get_non_reduce_indices = [axis, ndim](const Array<Var>& indices) { |
80 | Array<PrimExpr> non_reduce_indices; |
81 | for (size_t i = 0; i < ndim; ++i) { |
82 | if (static_cast<int>(i) != axis) non_reduce_indices.push_back(indices[i]); |
83 | } |
84 | return non_reduce_indices; |
85 | }; |
86 | |
87 | auto _compute_max = [&](const Array<Var>& indices) { |
88 | auto eval_range = insert_reduce_index(indices, k1); |
89 | return topi::MaxOp(x(eval_range), {k1}); |
90 | }; |
91 | |
92 | auto _compute_exp = [&](const Tensor& max_elem, const Array<Var>& indices) { |
93 | auto non_reduce_indices = get_non_reduce_indices(indices); |
94 | return tvm::exp(x(indices) - max_elem(non_reduce_indices)); |
95 | }; |
96 | |
97 | auto _compute_expsum = [&](const Tensor& exp, const Array<Var>& indices) { |
98 | auto eval_range = insert_reduce_index(indices, k2); |
99 | return tvm::sum(exp(eval_range), {k2}); |
100 | }; |
101 | |
102 | auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array<Var>& indices) { |
103 | auto non_reduce_indices = get_non_reduce_indices(indices); |
104 | return exp(indices) / expsum(non_reduce_indices); |
105 | }; |
106 | |
107 | auto max_elem = tvm::te::compute(reduced_shape, _compute_max); |
108 | auto exp = tvm::te::compute( |
109 | input_shape, [&](const Array<Var>& indices) { return _compute_exp(max_elem, indices); }); |
110 | auto expsum = tvm::te::compute( |
111 | reduced_shape, [&](const Array<Var>& indices) { return _compute_expsum(exp, indices); }); |
112 | return tvm::te::compute( |
113 | input_shape, [&](const Array<Var>& indices) { return _normalize(exp, expsum, indices); }, |
114 | name, tag, attrs); |
115 | } |
116 | |
117 | /*! |
118 | * \brief Log softmax activation |
119 | * |
120 | * \param x The input tensor. 2-D where log softmax is performed along the second dimension |
121 | * \param name The name of the operation |
122 | * \param tag The tag to mark the operation |
123 | * |
124 | * \return A Tensor whose op member is the log softmax operation |
125 | */ |
126 | inline Tensor log_softmax(const Tensor& x, std::string name = "tensor" , |
127 | std::string tag = "log_softmax_output" ) { |
128 | ICHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input" ; |
129 | |
130 | PrimExpr m = x->shape[0]; |
131 | PrimExpr n = x->shape[1]; |
132 | |
133 | auto k = tvm::te::reduce_axis(Range(0, n), "k" ); |
134 | auto max_elem = |
135 | tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array<IterVar>{k}); }); |
136 | k = tvm::te::reduce_axis(Range(0, n), "k" ); |
137 | |
138 | auto expsum = |
139 | tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); }); |
140 | |
141 | return tvm::te::compute( |
142 | x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name, |
143 | tag); |
144 | } |
145 | |
146 | } // namespace nn |
147 | } // namespace topi |
148 | } // namespace tvm |
149 | #endif // TVM_TOPI_NN_SOFTMAX_H_ |
150 | |