1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/ctc_ops.cc. |
17 | |
18 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
19 | #define EIGEN_USE_GPU |
20 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | |
22 | #include <utility> |
23 | |
24 | #include "tensorflow/core/framework/bounds_check.h" |
25 | #include "tensorflow/core/framework/op.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/core/platform/macros.h" |
31 | #include "tensorflow/core/util/ctc/ctc_loss_calculator.h" |
32 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
33 | |
34 | #if GOOGLE_CUDA |
35 | #include "third_party/gpus/cudnn/cudnn.h" |
36 | #endif // GOOGLE_CUDA |
37 | |
38 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
39 | #include "tensorflow/core/kernels/conv_ops_gpu.h" |
40 | #include "tensorflow/core/util/stream_executor_util.h" |
41 | #include "tensorflow/core/util/tensor_format.h" |
42 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
43 | |
44 | namespace tensorflow { |
45 | |
46 | typedef Eigen::ThreadPoolDevice CPUDevice; |
47 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
48 | using GPUDevice = Eigen::GpuDevice; |
49 | |
50 | namespace { |
51 | using se::Stream; |
52 | using se::StreamExecutor; |
53 | using se::dnn::RnnStateTensorDescriptor; |
54 | using se::dnn::ToDataType; |
55 | |
56 | template <typename T> |
57 | void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices, |
58 | int num_indices, int batch_size, |
59 | std::vector<int>* labels_lengths) { |
60 | const T* h_in = labels_indices->flat<T>().data(); |
61 | for (int i = 0; i < num_indices; i++) { |
62 | const T& key = h_in[i * 2]; |
63 | (*labels_lengths)[key]++; |
64 | } |
65 | } |
66 | |
67 | } // end namespace |
68 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
69 | |
70 | template <typename T> |
71 | class CTCLossOp : public OpKernel { |
72 | typedef Eigen::Map< |
73 | const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> > |
74 | InputMap; |
75 | typedef Eigen::Map< |
76 | Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> > |
77 | OutputMap; |
78 | |
79 | public: |
80 | explicit CTCLossOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
81 | OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated" , |
82 | &preprocess_collapse_repeated_)); |
83 | OP_REQUIRES_OK(ctx, |
84 | ctx->GetAttr("ctc_merge_repeated" , &ctc_merge_repeated_)); |
85 | OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs" , |
86 | &ignore_longer_outputs_than_inputs_)); |
87 | } |
88 | |
89 | void Compute(OpKernelContext* ctx) override { |
90 | const Tensor* inputs; |
91 | const Tensor* labels_indices; |
92 | const Tensor* labels_values; |
93 | const Tensor* seq_len; |
94 | OP_REQUIRES_OK(ctx, ctx->input("inputs" , &inputs)); |
95 | OP_REQUIRES_OK(ctx, ctx->input("labels_indices" , &labels_indices)); |
96 | OP_REQUIRES_OK(ctx, ctx->input("labels_values" , &labels_values)); |
97 | OP_REQUIRES_OK(ctx, ctx->input("sequence_length" , &seq_len)); |
98 | |
99 | OP_REQUIRES(ctx, inputs->shape().dims() == 3, |
100 | errors::InvalidArgument("inputs is not a 3-Tensor" )); |
101 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()), |
102 | errors::InvalidArgument("sequence_length is not a vector" )); |
103 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()), |
104 | errors::InvalidArgument("labels_indices is not a matrix" )); |
105 | OP_REQUIRES(ctx, labels_indices->dim_size(1) > 1, |
106 | errors::InvalidArgument( |
107 | "labels_indices second dimension must be >= 1. Received " , |
108 | labels_indices->dim_size(1))); |
109 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()), |
110 | errors::InvalidArgument("labels_values is not a vector" )); |
111 | |
112 | const TensorShape& inputs_shape = inputs->shape(); |
113 | const int64_t max_time = inputs_shape.dim_size(0); |
114 | OP_REQUIRES(ctx, max_time != 0, |
115 | errors::InvalidArgument( |
116 | "Max time or first dimension of input cannot be 0." )); |
117 | const int64_t batch_size = inputs_shape.dim_size(1); |
118 | const int64_t num_classes_raw = inputs_shape.dim_size(2); |
119 | OP_REQUIRES( |
120 | ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), |
121 | errors::InvalidArgument("num_classes cannot exceed max int" )); |
122 | const int num_classes = static_cast<const int>(num_classes_raw); |
123 | |
124 | OP_REQUIRES( |
125 | ctx, batch_size == seq_len->dim_size(0), |
126 | errors::InvalidArgument("len(sequence_length) != batch_size. " , |
127 | "len(sequence_length): " , seq_len->dim_size(0), |
128 | " batch_size: " , batch_size)); |
129 | auto seq_len_t = seq_len->vec<int32>(); |
130 | |
131 | OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0), |
132 | errors::InvalidArgument( |
133 | "labels_indices and labels_values must contain the " |
134 | "same number of rows, but saw shapes: " , |
135 | labels_indices->shape().DebugString(), " vs. " , |
136 | labels_values->shape().DebugString())); |
137 | |
138 | OP_REQUIRES(ctx, batch_size != 0, |
139 | errors::InvalidArgument("batch_size must not be 0" )); |
140 | |
141 | // Figure out the maximum label length to use as sparse tensor dimension. |
142 | auto labels_indices_t = labels_indices->matrix<int64_t>(); |
143 | int64_t max_label_len = 0; |
144 | for (int i = 0; i < labels_indices->dim_size(0); i++) { |
145 | max_label_len = std::max(max_label_len, labels_indices_t(i, 1) + 1); |
146 | } |
147 | |
148 | TensorShape labels_shape({batch_size, max_label_len}); |
149 | std::vector<int64_t> order{0, 1}; |
150 | sparse::SparseTensor labels_sp; |
151 | OP_REQUIRES_OK( |
152 | ctx, sparse::SparseTensor::Create(*labels_indices, *labels_values, |
153 | labels_shape, order, &labels_sp)); |
154 | |
155 | Status labels_sp_valid = labels_sp.IndicesValid(); |
156 | OP_REQUIRES(ctx, labels_sp_valid.ok(), |
157 | errors::InvalidArgument("label SparseTensor is not valid: " , |
158 | labels_sp_valid.error_message())); |
159 | |
160 | typename ctc::CTCLossCalculator<T>::LabelSequences labels_t(batch_size); |
161 | for (const auto& g : labels_sp.group({0})) { // iterate by batch |
162 | const int64_t batch_indices = g.group()[0]; |
163 | OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size), |
164 | errors::InvalidArgument("labels batch index must be between " , |
165 | 0, " and " , batch_size, |
166 | " but saw: " , batch_indices)); |
167 | |
168 | auto values = g.values<int32>(); |
169 | std::vector<int>* b_values = &labels_t[batch_indices]; |
170 | b_values->resize(values.size()); |
171 | for (int i = 0; i < values.size(); ++i) (*b_values)[i] = values(i); |
172 | } |
173 | |
174 | OP_REQUIRES(ctx, static_cast<size_t>(batch_size) == labels_t.size(), |
175 | errors::InvalidArgument("len(labels) != batch_size. " , |
176 | "len(labels): " , labels_t.size(), |
177 | " batch_size: " , batch_size)); |
178 | |
179 | for (int64_t b = 0; b < batch_size; ++b) { |
180 | OP_REQUIRES( |
181 | ctx, seq_len_t(b) <= max_time, |
182 | errors::InvalidArgument("sequence_length(" , b, ") <= " , max_time)); |
183 | } |
184 | |
185 | Tensor* loss = nullptr; |
186 | OP_REQUIRES_OK(ctx, ctx->allocate_output("loss" , seq_len->shape(), &loss)); |
187 | auto loss_t = loss->vec<T>(); |
188 | |
189 | Tensor* gradient; |
190 | OP_REQUIRES_OK(ctx, |
191 | ctx->allocate_output("gradient" , inputs_shape, &gradient)); |
192 | auto gradient_t = gradient->tensor<T, 3>(); |
193 | auto inputs_t = inputs->tensor<T, 3>(); |
194 | std::vector<OutputMap> gradient_list_t; |
195 | std::vector<InputMap> input_list_t; |
196 | |
197 | for (std::size_t t = 0; t < max_time; ++t) { |
198 | input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, |
199 | batch_size, num_classes); |
200 | gradient_list_t.emplace_back( |
201 | gradient_t.data() + t * batch_size * num_classes, batch_size, |
202 | num_classes); |
203 | } |
204 | |
205 | gradient_t.setZero(); |
206 | |
207 | // Assumption: the blank index is num_classes - 1 |
208 | ctc::CTCLossCalculator<T> ctc_loss_calculator(num_classes - 1, 0); |
209 | DeviceBase::CpuWorkerThreads workers = |
210 | *ctx->device()->tensorflow_cpu_worker_threads(); |
211 | OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss( |
212 | seq_len_t, labels_t, input_list_t, |
213 | preprocess_collapse_repeated_, ctc_merge_repeated_, |
214 | ignore_longer_outputs_than_inputs_, &loss_t, |
215 | &gradient_list_t, &workers)); |
216 | } |
217 | |
218 | private: |
219 | bool preprocess_collapse_repeated_; |
220 | bool ctc_merge_repeated_; |
221 | bool ignore_longer_outputs_than_inputs_; |
222 | |
223 | TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp<T>); |
224 | }; |
225 | |
226 | #define REGISTER_CPU(T) \ |
227 | REGISTER_KERNEL_BUILDER( \ |
228 | Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
229 | CTCLossOp<T>); |
230 | |
231 | REGISTER_CPU(float); |
232 | REGISTER_CPU(double); |
233 | |
234 | #undef REGISTER_CPU |
235 | |
236 | #if ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM) |
237 | class CTCLossOpGPU : public OpKernel { |
238 | public: |
239 | explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { |
240 | bool preprocess_collapse_repeated; |
241 | bool ctc_merge_repeated; |
242 | bool ignore_longer_outputs_than_inputs; |
243 | OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated" , |
244 | &preprocess_collapse_repeated)); |
245 | OP_REQUIRES_OK(ctx, |
246 | ctx->GetAttr("ctc_merge_repeated" , &ctc_merge_repeated)); |
247 | OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs" , |
248 | &ignore_longer_outputs_than_inputs)); |
249 | |
250 | OP_REQUIRES(ctx, !preprocess_collapse_repeated, |
251 | errors::InvalidArgument("GPU CTCLossOp requires " |
252 | "preprocess_collapse_repeated to be " |
253 | "false" )); |
254 | OP_REQUIRES(ctx, ctc_merge_repeated, |
255 | errors::InvalidArgument("GPU CTCLossOp requires " |
256 | "ctc_merge_repeated to be " |
257 | "true" )); |
258 | OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs, |
259 | errors::InvalidArgument("GPU CTCLossOp requires " |
260 | "ignore_longer_outputs_than_inputs to" |
261 | "be false" )); |
262 | } |
263 | |
264 | void Compute(OpKernelContext* ctx) override { |
265 | const Tensor* inputs; |
266 | const Tensor* labels_indices; |
267 | const Tensor* labels_values; |
268 | const Tensor* seq_len; |
269 | OP_REQUIRES_OK(ctx, ctx->input("inputs" , &inputs)); |
270 | OP_REQUIRES_OK(ctx, ctx->input("labels_indices" , &labels_indices)); |
271 | OP_REQUIRES_OK(ctx, ctx->input("labels_values" , &labels_values)); |
272 | OP_REQUIRES_OK(ctx, ctx->input("sequence_length" , &seq_len)); |
273 | |
274 | OP_REQUIRES(ctx, inputs->shape().dims() == 3, |
275 | errors::InvalidArgument("inputs is not a 3-Tensor" )); |
276 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()), |
277 | errors::InvalidArgument("sequence_length is not a vector" )); |
278 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()), |
279 | errors::InvalidArgument("labels_indices is not a matrix" )); |
280 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()), |
281 | errors::InvalidArgument("labels_values is not a vector" )); |
282 | |
283 | const TensorShape& inputs_shape = inputs->shape(); |
284 | const int64_t max_time_raw = inputs_shape.dim_size(0); |
285 | const int64_t batch_size_raw = inputs_shape.dim_size(1); |
286 | const int64_t num_classes_raw = inputs_shape.dim_size(2); |
287 | OP_REQUIRES(ctx, |
288 | FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()), |
289 | errors::InvalidArgument("max_time_ cannot exceed max int" )); |
290 | OP_REQUIRES( |
291 | ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()), |
292 | errors::InvalidArgument("batch_size cannot exceed max int" )); |
293 | OP_REQUIRES( |
294 | ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), |
295 | errors::InvalidArgument("num_classes cannot exceed max int" )); |
296 | const int max_time = static_cast<const int>(max_time_raw); |
297 | const int batch_size = static_cast<const int>(batch_size_raw); |
298 | const int num_classes = static_cast<const int>(num_classes_raw); |
299 | |
300 | OP_REQUIRES( |
301 | ctx, batch_size == seq_len->dim_size(0), |
302 | errors::InvalidArgument("len(sequence_length) != batch_size. " , |
303 | "len(sequence_length): " , seq_len->dim_size(0), |
304 | " batch_size: " , batch_size)); |
305 | |
306 | OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0), |
307 | errors::InvalidArgument( |
308 | "labels_indices and labels_values must contain the " |
309 | "same number of rows, but saw shapes: " , |
310 | labels_indices->shape().DebugString(), " vs. " , |
311 | labels_values->shape().DebugString())); |
312 | auto num_indices = labels_indices->dim_size(0); |
313 | |
314 | OP_REQUIRES(ctx, batch_size != 0, |
315 | errors::InvalidArgument("batch_size must not be 0" )); |
316 | |
317 | Tensor* loss = nullptr; |
318 | OP_REQUIRES_OK(ctx, ctx->allocate_output("loss" , seq_len->shape(), &loss)); |
319 | |
320 | Tensor* gradient = nullptr; |
321 | OP_REQUIRES_OK(ctx, |
322 | ctx->allocate_output("gradient" , inputs_shape, &gradient)); |
323 | |
324 | // Convert the labels_indices to labels_lengths. |
325 | std::vector<int> labels_lengths(batch_size, 0); |
326 | DoHistogram<int64_t>(ctx, labels_indices, num_indices, batch_size, |
327 | &labels_lengths); |
328 | |
329 | StreamExecutor* executor = ctx->op_device_context()->stream()->parent(); |
330 | se::dnn::DataType data_type = ToDataType<float>::value; |
331 | |
332 | auto probs_desc_s = executor->createRnnStateTensorDescriptor( |
333 | max_time, batch_size, num_classes, data_type); |
334 | OP_REQUIRES_OK(ctx, probs_desc_s.status()); |
335 | std::unique_ptr<RnnStateTensorDescriptor> probs_desc = |
336 | std::move(probs_desc_s).value(); |
337 | |
338 | auto grads_desc_s = executor->createRnnStateTensorDescriptor( |
339 | max_time, batch_size, num_classes, data_type); |
340 | OP_REQUIRES_OK(ctx, grads_desc_s.status()); |
341 | std::unique_ptr<RnnStateTensorDescriptor> grads_desc = |
342 | std::move(grads_desc_s).value(); |
343 | |
344 | absl::Span<const int32> labels_data(labels_values->flat<int32>().data(), |
345 | num_indices); |
346 | absl::Span<const int32> labels_lengths_data(labels_lengths.data(), |
347 | batch_size); |
348 | absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(), |
349 | batch_size); |
350 | |
351 | auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs); |
352 | auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss); |
353 | auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient); |
354 | |
355 | // Set the memory limitation to 4GB for workspace memory. |
356 | DnnScratchAllocator workspace_allocator(1LL << 32, ctx); |
357 | |
358 | Stream* stream = ctx->op_device_context()->stream(); |
359 | bool cudnn_launch_status = |
360 | stream |
361 | ->ThenCtcLoss(*probs_desc, probs_data, labels_data, |
362 | labels_lengths_data, input_lengths_data, &costs_data, |
363 | *grads_desc, &grads_data, &workspace_allocator) |
364 | .ok(); |
365 | |
366 | if (!cudnn_launch_status) { |
367 | ctx->SetStatus(errors::Internal("cuDNN CTCLoss launch failure" )); |
368 | } |
369 | } |
370 | |
371 | private: |
372 | TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU); |
373 | }; |
374 | |
375 | REGISTER_KERNEL_BUILDER(Name("CTCLossV2" ) |
376 | .Device(DEVICE_GPU) |
377 | .HostMemory("labels_indices" ) |
378 | .HostMemory("labels_values" ) |
379 | .HostMemory("sequence_length" ), |
380 | CTCLossOpGPU); |
381 | #endif // ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM) |
382 | } // end namespace tensorflow |
383 | |