1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_XENT_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_XENT_OP_H_ |
18 | // Functor definition for XentOp, must be compilable by nvcc. |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | |
22 | #include "tensorflow/core/framework/tensor_types.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace functor { |
26 | |
27 | // Functor used by XentOp to do the computations. |
28 | template <typename Device, typename T> |
29 | struct XentFunctor { |
30 | // Computes Cross Entropy loss and backprop. |
31 | // |
32 | // logits: batch_size, num_classes. |
33 | // labels: batch_size, num_classes. |
34 | // scratch: temporary tensor, dims: batch_size, 1 |
35 | // loss: output tensor for the loss, dims: batch_size. |
36 | // backprop: output tensor for the backprop, dims: batch_size, num_classes. |
37 | void operator()(const Device &d, |
38 | const Eigen::DSizes<Eigen::DenseIndex, 2> &shape, |
39 | const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast, |
40 | const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast, |
41 | typename TTypes<T>::ConstMatrix logits, |
42 | typename TTypes<T>::ConstMatrix labels, |
43 | typename TTypes<T>::Matrix scratch, |
44 | typename TTypes<T>::Vec loss, |
45 | typename TTypes<T>::Matrix backprop); |
46 | }; |
47 | |
48 | // Eigen code implementing XentFunctor::operator(). |
49 | // This code works for both CPU and GPU and is used by the functor |
50 | // specializations for both device types. |
51 | template <typename Device, typename T> |
52 | struct XentEigenImpl { |
53 | static void Compute(const Device &d, |
54 | const Eigen::DSizes<Eigen::DenseIndex, 2> &shape, |
55 | const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast, |
56 | const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast, |
57 | typename TTypes<T>::ConstMatrix logits, |
58 | typename TTypes<T>::ConstMatrix labels, |
59 | typename TTypes<T>::Matrix scratch, |
60 | typename TTypes<T>::Vec loss, |
61 | typename TTypes<T>::Matrix backprop) { |
62 | // NOTE(touts): This duplicates some of the computations in softmax_op |
63 | // because we need the intermediate (logits -max(logits)) values to |
64 | // avoid a log(exp()) in the computation of the loss. |
65 | |
66 | const int kBatchDim = 0; |
67 | const int kClassDim = 1; |
68 | |
69 | const int batch_size = shape[kBatchDim]; |
70 | const int num_classes = shape[kClassDim]; |
71 | |
72 | // These arrays are used to reduce along the class dimension, and broadcast |
73 | // the resulting value to all classes. |
74 | Eigen::IndexList<Eigen::type2index<kClassDim> > along_class; |
75 | Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one; |
76 | batch_by_one.set(0, batch_size); |
77 | Eigen::IndexList<int> batch_only; |
78 | batch_only.set(0, batch_size); |
79 | Eigen::IndexList<Eigen::type2index<1>, int> one_by_class; |
80 | one_by_class.set(1, num_classes); |
81 | |
82 | // max_logits along classes. |
83 | scratch.reshape(batch_only).device(d) = |
84 | logits.broadcast(logits_bcast).maximum(along_class); |
85 | |
86 | // logits - max_logits. |
87 | backprop.device(d) = |
88 | logits.broadcast(logits_bcast) - scratch.broadcast(one_by_class); |
89 | |
90 | // sum(exp(logits - max_logits)) along classes. |
91 | scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class); |
92 | |
93 | // NOTE(keveman): Eigen on GPU dispatches to an optimized implementation |
94 | // for an expression of the form lhs = rhs.sum(). |
95 | // lhs = -rhs.sum() doesn't match the above pattern, so folding in the |
96 | // negation before calling sum(). |
97 | // sum(-labels * |
98 | // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) |
99 | // along classes |
100 | loss.device(d) = (labels.broadcast(labels_bcast) * |
101 | (scratch.log().eval().broadcast(one_by_class) - backprop)) |
102 | .eval() |
103 | .sum(along_class); |
104 | |
105 | // backprop: prob - labels, where |
106 | // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) |
107 | backprop.device(d) = (backprop.exp() / scratch.broadcast(one_by_class)) - |
108 | labels.broadcast(labels_bcast); |
109 | } |
110 | }; |
111 | |
112 | } // namespace functor |
113 | } // namespace tensorflow |
114 | |
115 | #endif // TENSORFLOW_CORE_KERNELS_XENT_OP_H_ |
116 | |