1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_ |
18 | // Functor definition for SoftmaxOp, must be compilable by nvcc. |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/tensor_types.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace functor { |
25 | |
26 | // Functor used by SoftmaxOp to do the computations. |
27 | template <typename Device, typename T> |
28 | struct SoftmaxFunctor { |
29 | // Computes Softmax or LogSoftmax activation. |
30 | // |
31 | // logits: dim: batch_size, num_classes. |
32 | // softmax: dims: batch_size, num_classes. |
33 | // log: boolean |
34 | void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, |
35 | typename TTypes<T>::Matrix softmax, const bool log); |
36 | }; |
37 | |
38 | // Eigen code implementing SoftmaxFunctor::operator() or |
39 | // LogSoftmaxFunctor::operator(). |
40 | // This code works for both CPU and GPU and is used by the functor |
41 | // specializations for both device types. |
42 | template <typename Device, typename T> |
43 | struct SoftmaxEigenImpl { |
44 | static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits, |
45 | typename TTypes<T>::Matrix softmax, const bool log) { |
46 | const int kBatchDim = 0; |
47 | const int kClassDim = 1; |
48 | |
49 | const int batch_size = logits.dimension(kBatchDim); |
50 | const int num_classes = logits.dimension(kClassDim); |
51 | |
52 | // These arrays are used to reduce along the class dimension, and broadcast |
53 | // the resulting value to all classes. |
54 | Eigen::IndexList<Eigen::type2index<kClassDim> > along_class; |
55 | Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one; |
56 | batch_by_one.set(0, batch_size); |
57 | Eigen::IndexList<Eigen::type2index<1>, int> one_by_class; |
58 | one_by_class.set(1, num_classes); |
59 | |
60 | // shifted_logits = logits - max(logits along classes); |
61 | auto shifted_logits = (logits - logits.maximum(along_class) |
62 | .eval() |
63 | .reshape(batch_by_one) |
64 | .broadcast(one_by_class)); |
65 | if (log) { |
66 | // Calculate the log of the softmax |
67 | // softmax = logits - max(logits along classes); |
68 | softmax.device(d) = shifted_logits; |
69 | // softmax = softmax - log(sum(exp(softmax along classes))); |
70 | softmax.device(d) = (softmax - softmax.exp() |
71 | .sum(along_class) |
72 | .log() |
73 | .eval() |
74 | .reshape(batch_by_one) |
75 | .broadcast(one_by_class)); |
76 | } else { |
77 | // NOTE(touts): If you modify this implementation please run |
78 | // the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc. |
79 | // |
80 | // softmax = exp(logits - max(logits along classes)); |
81 | softmax.device(d) = shifted_logits.exp(); |
82 | // softmax = softmax * (1 / sum(softmax along classes)); |
83 | softmax.device(d) = (softmax * softmax.sum(along_class) |
84 | .inverse() |
85 | .eval() |
86 | .reshape(batch_by_one) |
87 | .broadcast(one_by_class)); |
88 | } |
89 | } |
90 | }; |
91 | |
92 | } // namespace functor |
93 | } // namespace tensorflow |
94 | |
95 | #endif // TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_ |
96 | |