1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_ |
18 | // Functor definition for SparseXentOp, must be compilable by nvcc. |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor_types.h" |
24 | #include "tensorflow/core/platform/macros.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | namespace sparse_xent_helpers { |
30 | |
31 | template <typename T> |
32 | typename TTypes<const T, 1>::Tensor32Bit To32BitConst( |
33 | typename TTypes<T>::Vec in) { |
34 | return To32Bit(typename TTypes<T>::ConstVec(in.data(), in.dimensions())); |
35 | } |
36 | |
37 | template <typename T> |
38 | typename TTypes<const T, 2>::Tensor32Bit To32BitConst( |
39 | typename TTypes<T>::Matrix in) { |
40 | return To32Bit(typename TTypes<T>::ConstMatrix(in.data(), in.dimensions())); |
41 | } |
42 | |
43 | } // namespace sparse_xent_helpers |
44 | |
45 | namespace generator { |
46 | |
47 | // Generator for calculation of the sparse Xent loss. |
48 | // This generator takes the logits, the sum of the exponentiated |
49 | // logits, and the label indices. For each minibatch entry, ignoring |
50 | // the batch index b, it calculates: |
51 | // |
52 | // loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label } |
53 | // |
54 | // for j = 0 .. num_classes. This value must be summed over all j for |
55 | // the final loss. |
56 | template <typename T, typename Index> |
57 | class SparseXentLossGenerator { |
58 | public: |
59 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator( |
60 | typename TTypes<const T, 2>::Tensor32Bit logits, |
61 | typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits, |
62 | typename TTypes<const Index, 1>::Tensor32Bit labels, |
63 | const Index max_depth) |
64 | : logits_(logits), |
65 | sum_exp_logits_(sum_exp_logits), |
66 | labels_(labels), |
67 | max_depth_(max_depth) {} |
68 | |
69 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T |
70 | operator()(const Eigen::array<int, 2>& coords) const { |
71 | const int batch = coords[0]; |
72 | const int depth = coords[1]; |
73 | const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch)); |
74 | if (!FastBoundsCheck(label, max_depth_)) { |
75 | return Eigen::NumTraits<T>::quiet_NaN(); |
76 | } |
77 | return TF_PREDICT_FALSE(label == depth) |
78 | ? (Eigen::numext::log(sum_exp_logits_(batch)) - logits_(coords)) |
79 | : T(0.0); |
80 | }; |
81 | |
82 | private: |
83 | typename TTypes<const T, 2>::Tensor32Bit logits_; |
84 | typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_; |
85 | typename TTypes<const Index, 1>::Tensor32Bit labels_; |
86 | const Index max_depth_; |
87 | }; |
88 | |
89 | // Generator for calculation of the sparse Xent gradient. |
90 | // This generator takes the exponentiated logits, their sums, and the label |
91 | // indices. For each minibatch entry, ignoring the batch index b, it calculates: |
92 | // |
93 | // exp_logits[j] / sum_exp_logits - 1{ j == label } |
94 | // |
95 | // for j = 0 .. num_classes. |
96 | template <typename T, typename Index> |
97 | class SparseXentGradGenerator { |
98 | public: |
99 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator( |
100 | typename TTypes<const T, 2>::Tensor32Bit exp_logits, |
101 | typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits, |
102 | typename TTypes<const Index, 1>::Tensor32Bit labels, |
103 | const Index max_depth) |
104 | : exp_logits_(exp_logits), |
105 | sum_exp_logits_(sum_exp_logits), |
106 | labels_(labels), |
107 | max_depth_(max_depth) {} |
108 | |
109 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T |
110 | operator()(const Eigen::array<int, 2>& coords) const { |
111 | const int batch = coords[0]; |
112 | const int depth = coords[1]; |
113 | const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch)); |
114 | if (!FastBoundsCheck(label, max_depth_)) { |
115 | return Eigen::NumTraits<T>::quiet_NaN(); |
116 | } |
117 | T subtract = TF_PREDICT_FALSE(depth == label) ? T(1.0) : T(0.0); |
118 | return exp_logits_(coords) / sum_exp_logits_(batch) - subtract; |
119 | }; |
120 | |
121 | private: |
122 | typename TTypes<const T, 2>::Tensor32Bit exp_logits_; |
123 | typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_; |
124 | typename TTypes<const Index, 1>::Tensor32Bit labels_; |
125 | const Index max_depth_; |
126 | }; |
127 | |
128 | } // namespace generator |
129 | |
130 | namespace functor { |
131 | |
132 | template <typename Device, typename T> |
133 | struct RowMaxReduction { |
134 | // Computes the maximum across the rows of logits |
135 | // |
136 | // logits: batch_size, num_classes. |
137 | // maximum: temporary tensor, dims: batch_size, 1 |
138 | static inline void Compute(OpKernelContext* ctx, |
139 | typename TTypes<T>::ConstMatrix logits, |
140 | typename TTypes<T>::Vec maximum) { |
141 | Eigen::IndexList<Eigen::type2index<1> > along_row; |
142 | Device d = ctx->eigen_device<Device>(); |
143 | To32Bit(maximum).device(d) = To32Bit(logits).maximum(along_row); |
144 | } |
145 | }; |
146 | |
147 | // Functor used by SparseXentOp to do the computations. |
148 | template <typename Device, typename T, typename Index> |
149 | struct SparseXentFunctor { |
150 | // Computes Cross Entropy loss and backprop. |
151 | // |
152 | // logits: batch_size, num_classes. |
153 | // labels: num_classes. |
154 | // scratch: temporary tensor, dims: batch_size, 1 |
155 | // loss: output tensor for the loss, dims: batch_size. |
156 | // backprop: output tensor for the backprop, dims: batch_size, num_classes. |
157 | void operator()(OpKernelContext* ctx, typename TTypes<T>::ConstMatrix logits, |
158 | typename TTypes<Index>::ConstVec labels, |
159 | typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss, |
160 | typename TTypes<T>::Matrix backprop); |
161 | }; |
162 | |
163 | // Eigen code implementing SparseXentFunctor::operator(). |
164 | // This code works for both CPU and GPU and is used by the functor |
165 | // specializations for both device types. |
166 | template <typename Device, typename T, typename Index> |
167 | struct SparseXentEigenImpl { |
168 | static void Compute(OpKernelContext* ctx, |
169 | typename TTypes<T>::ConstMatrix logits, |
170 | typename TTypes<Index>::ConstVec labels, |
171 | typename TTypes<T>::Vec scratch, |
172 | typename TTypes<T>::Vec loss, |
173 | typename TTypes<T>::Matrix backprop) { |
174 | // NOTE(touts): This duplicates some of the computations in softmax_op |
175 | // because we need the intermediate (logits -max(logits)) values to |
176 | // avoid a log(exp()) in the computation of the loss. |
177 | |
178 | const int kBatchDim = 0; |
179 | const int kClassDim = 1; |
180 | |
181 | const int batch_size = logits.dimension(kBatchDim); |
182 | const int num_classes = logits.dimension(kClassDim); |
183 | |
184 | // These arrays are used to reduce along the class dimension, and broadcast |
185 | // the resulting value to all classes. |
186 | Eigen::IndexList<Eigen::type2index<kClassDim> > along_class; |
187 | Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one; |
188 | batch_by_one.set(0, batch_size); |
189 | Eigen::IndexList<int> batch_only; |
190 | batch_only.set(0, batch_size); |
191 | Eigen::IndexList<Eigen::type2index<1>, int> one_by_class; |
192 | one_by_class.set(1, num_classes); |
193 | |
194 | // scratch = max_logits along classes. |
195 | RowMaxReduction<Device, T>::Compute(ctx, logits, scratch); |
196 | |
197 | Device d = ctx->eigen_device<Device>(); |
198 | // backprop = logits - max_logits. |
199 | To32Bit(backprop).device(d) = |
200 | To32Bit(logits) - |
201 | To32Bit(scratch).reshape(batch_by_one).broadcast(one_by_class); |
202 | |
203 | // scratch = sum(exp(logits - max_logits)) along classes. |
204 | To32Bit(scratch).device(d) = To32Bit(backprop).exp().sum(along_class); |
205 | |
206 | // sum(-labels * |
207 | // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) |
208 | // along classes |
209 | generator::SparseXentLossGenerator<T, Index> sparse_xent_loss_gen( |
210 | sparse_xent_helpers::To32BitConst<T>(backprop), |
211 | sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels), |
212 | backprop.dimension(1) /* max_depth */); |
213 | To32Bit(loss).device(d) = |
214 | To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class); |
215 | |
216 | // backprop: prob - labels, where |
217 | // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) |
218 | To32Bit(backprop).device(d) = To32Bit(backprop).exp(); |
219 | generator::SparseXentGradGenerator<T, Index> sparse_xent_grad_gen( |
220 | sparse_xent_helpers::To32BitConst<T>(backprop), |
221 | sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels), |
222 | backprop.dimension(1) /* max_depth */); |
223 | To32Bit(backprop).device(d) = |
224 | To32Bit(backprop).generate(sparse_xent_grad_gen); |
225 | } |
226 | }; |
227 | |
228 | } // namespace functor |
229 | |
230 | } // namespace tensorflow |
231 | |
232 | #endif // TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_ |
233 | |