1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
17#define TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
18// Functor definition for SparseXentOp, must be compilable by nvcc.
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor_types.h"
24#include "tensorflow/core/platform/macros.h"
25#include "tensorflow/core/platform/types.h"
26
27namespace tensorflow {
28
29namespace sparse_xent_helpers {
30
31template <typename T>
32typename TTypes<const T, 1>::Tensor32Bit To32BitConst(
33 typename TTypes<T>::Vec in) {
34 return To32Bit(typename TTypes<T>::ConstVec(in.data(), in.dimensions()));
35}
36
37template <typename T>
38typename TTypes<const T, 2>::Tensor32Bit To32BitConst(
39 typename TTypes<T>::Matrix in) {
40 return To32Bit(typename TTypes<T>::ConstMatrix(in.data(), in.dimensions()));
41}
42
43} // namespace sparse_xent_helpers
44
45namespace generator {
46
47// Generator for calculation of the sparse Xent loss.
48// This generator takes the logits, the sum of the exponentiated
49// logits, and the label indices. For each minibatch entry, ignoring
50// the batch index b, it calculates:
51//
52// loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label }
53//
54// for j = 0 .. num_classes. This value must be summed over all j for
55// the final loss.
56template <typename T, typename Index>
57class SparseXentLossGenerator {
58 public:
59 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator(
60 typename TTypes<const T, 2>::Tensor32Bit logits,
61 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
62 typename TTypes<const Index, 1>::Tensor32Bit labels,
63 const Index max_depth)
64 : logits_(logits),
65 sum_exp_logits_(sum_exp_logits),
66 labels_(labels),
67 max_depth_(max_depth) {}
68
69 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
70 operator()(const Eigen::array<int, 2>& coords) const {
71 const int batch = coords[0];
72 const int depth = coords[1];
73 const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch));
74 if (!FastBoundsCheck(label, max_depth_)) {
75 return Eigen::NumTraits<T>::quiet_NaN();
76 }
77 return TF_PREDICT_FALSE(label == depth)
78 ? (Eigen::numext::log(sum_exp_logits_(batch)) - logits_(coords))
79 : T(0.0);
80 };
81
82 private:
83 typename TTypes<const T, 2>::Tensor32Bit logits_;
84 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
85 typename TTypes<const Index, 1>::Tensor32Bit labels_;
86 const Index max_depth_;
87};
88
89// Generator for calculation of the sparse Xent gradient.
90// This generator takes the exponentiated logits, their sums, and the label
91// indices. For each minibatch entry, ignoring the batch index b, it calculates:
92//
93// exp_logits[j] / sum_exp_logits - 1{ j == label }
94//
95// for j = 0 .. num_classes.
96template <typename T, typename Index>
97class SparseXentGradGenerator {
98 public:
99 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator(
100 typename TTypes<const T, 2>::Tensor32Bit exp_logits,
101 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
102 typename TTypes<const Index, 1>::Tensor32Bit labels,
103 const Index max_depth)
104 : exp_logits_(exp_logits),
105 sum_exp_logits_(sum_exp_logits),
106 labels_(labels),
107 max_depth_(max_depth) {}
108
109 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
110 operator()(const Eigen::array<int, 2>& coords) const {
111 const int batch = coords[0];
112 const int depth = coords[1];
113 const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch));
114 if (!FastBoundsCheck(label, max_depth_)) {
115 return Eigen::NumTraits<T>::quiet_NaN();
116 }
117 T subtract = TF_PREDICT_FALSE(depth == label) ? T(1.0) : T(0.0);
118 return exp_logits_(coords) / sum_exp_logits_(batch) - subtract;
119 };
120
121 private:
122 typename TTypes<const T, 2>::Tensor32Bit exp_logits_;
123 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
124 typename TTypes<const Index, 1>::Tensor32Bit labels_;
125 const Index max_depth_;
126};
127
128} // namespace generator
129
130namespace functor {
131
132template <typename Device, typename T>
133struct RowMaxReduction {
134 // Computes the maximum across the rows of logits
135 //
136 // logits: batch_size, num_classes.
137 // maximum: temporary tensor, dims: batch_size, 1
138 static inline void Compute(OpKernelContext* ctx,
139 typename TTypes<T>::ConstMatrix logits,
140 typename TTypes<T>::Vec maximum) {
141 Eigen::IndexList<Eigen::type2index<1> > along_row;
142 Device d = ctx->eigen_device<Device>();
143 To32Bit(maximum).device(d) = To32Bit(logits).maximum(along_row);
144 }
145};
146
147// Functor used by SparseXentOp to do the computations.
148template <typename Device, typename T, typename Index>
149struct SparseXentFunctor {
150 // Computes Cross Entropy loss and backprop.
151 //
152 // logits: batch_size, num_classes.
153 // labels: num_classes.
154 // scratch: temporary tensor, dims: batch_size, 1
155 // loss: output tensor for the loss, dims: batch_size.
156 // backprop: output tensor for the backprop, dims: batch_size, num_classes.
157 void operator()(OpKernelContext* ctx, typename TTypes<T>::ConstMatrix logits,
158 typename TTypes<Index>::ConstVec labels,
159 typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss,
160 typename TTypes<T>::Matrix backprop);
161};
162
163// Eigen code implementing SparseXentFunctor::operator().
164// This code works for both CPU and GPU and is used by the functor
165// specializations for both device types.
166template <typename Device, typename T, typename Index>
167struct SparseXentEigenImpl {
168 static void Compute(OpKernelContext* ctx,
169 typename TTypes<T>::ConstMatrix logits,
170 typename TTypes<Index>::ConstVec labels,
171 typename TTypes<T>::Vec scratch,
172 typename TTypes<T>::Vec loss,
173 typename TTypes<T>::Matrix backprop) {
174 // NOTE(touts): This duplicates some of the computations in softmax_op
175 // because we need the intermediate (logits -max(logits)) values to
176 // avoid a log(exp()) in the computation of the loss.
177
178 const int kBatchDim = 0;
179 const int kClassDim = 1;
180
181 const int batch_size = logits.dimension(kBatchDim);
182 const int num_classes = logits.dimension(kClassDim);
183
184// These arrays are used to reduce along the class dimension, and broadcast
185// the resulting value to all classes.
186 Eigen::IndexList<Eigen::type2index<kClassDim> > along_class;
187 Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one;
188 batch_by_one.set(0, batch_size);
189 Eigen::IndexList<int> batch_only;
190 batch_only.set(0, batch_size);
191 Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
192 one_by_class.set(1, num_classes);
193
194 // scratch = max_logits along classes.
195 RowMaxReduction<Device, T>::Compute(ctx, logits, scratch);
196
197 Device d = ctx->eigen_device<Device>();
198 // backprop = logits - max_logits.
199 To32Bit(backprop).device(d) =
200 To32Bit(logits) -
201 To32Bit(scratch).reshape(batch_by_one).broadcast(one_by_class);
202
203 // scratch = sum(exp(logits - max_logits)) along classes.
204 To32Bit(scratch).device(d) = To32Bit(backprop).exp().sum(along_class);
205
206 // sum(-labels *
207 // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
208 // along classes
209 generator::SparseXentLossGenerator<T, Index> sparse_xent_loss_gen(
210 sparse_xent_helpers::To32BitConst<T>(backprop),
211 sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels),
212 backprop.dimension(1) /* max_depth */);
213 To32Bit(loss).device(d) =
214 To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class);
215
216 // backprop: prob - labels, where
217 // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
218 To32Bit(backprop).device(d) = To32Bit(backprop).exp();
219 generator::SparseXentGradGenerator<T, Index> sparse_xent_grad_gen(
220 sparse_xent_helpers::To32BitConst<T>(backprop),
221 sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels),
222 backprop.dimension(1) /* max_depth */);
223 To32Bit(backprop).device(d) =
224 To32Bit(backprop).generate(sparse_xent_grad_gen);
225 }
226};
227
228} // namespace functor
229
230} // namespace tensorflow
231
232#endif // TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
233