1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/ctc_ops.cc.
17
18#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19#define EIGEN_USE_GPU
20#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21
22#include <utility>
23
24#include "tensorflow/core/framework/bounds_check.h"
25#include "tensorflow/core/framework/op.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/platform/macros.h"
31#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
32#include "tensorflow/core/util/sparse/sparse_tensor.h"
33
34#if GOOGLE_CUDA
35#include "third_party/gpus/cudnn/cudnn.h"
36#endif // GOOGLE_CUDA
37
38#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39#include "tensorflow/core/kernels/conv_ops_gpu.h"
40#include "tensorflow/core/util/stream_executor_util.h"
41#include "tensorflow/core/util/tensor_format.h"
42#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
43
44namespace tensorflow {
45
46typedef Eigen::ThreadPoolDevice CPUDevice;
47#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48using GPUDevice = Eigen::GpuDevice;
49
50namespace {
51using se::Stream;
52using se::StreamExecutor;
53using se::dnn::RnnStateTensorDescriptor;
54using se::dnn::ToDataType;
55
56template <typename T>
57void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
58 int num_indices, int batch_size,
59 std::vector<int>* labels_lengths) {
60 const T* h_in = labels_indices->flat<T>().data();
61 for (int i = 0; i < num_indices; i++) {
62 const T& key = h_in[i * 2];
63 (*labels_lengths)[key]++;
64 }
65}
66
67} // end namespace
68#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
69
70template <typename T>
71class CTCLossOp : public OpKernel {
72 typedef Eigen::Map<
73 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
74 InputMap;
75 typedef Eigen::Map<
76 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
77 OutputMap;
78
79 public:
80 explicit CTCLossOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
81 OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
82 &preprocess_collapse_repeated_));
83 OP_REQUIRES_OK(ctx,
84 ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
85 OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
86 &ignore_longer_outputs_than_inputs_));
87 }
88
89 void Compute(OpKernelContext* ctx) override {
90 const Tensor* inputs;
91 const Tensor* labels_indices;
92 const Tensor* labels_values;
93 const Tensor* seq_len;
94 OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
95 OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
96 OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
97 OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
98
99 OP_REQUIRES(ctx, inputs->shape().dims() == 3,
100 errors::InvalidArgument("inputs is not a 3-Tensor"));
101 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
102 errors::InvalidArgument("sequence_length is not a vector"));
103 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
104 errors::InvalidArgument("labels_indices is not a matrix"));
105 OP_REQUIRES(ctx, labels_indices->dim_size(1) > 1,
106 errors::InvalidArgument(
107 "labels_indices second dimension must be >= 1. Received ",
108 labels_indices->dim_size(1)));
109 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
110 errors::InvalidArgument("labels_values is not a vector"));
111
112 const TensorShape& inputs_shape = inputs->shape();
113 const int64_t max_time = inputs_shape.dim_size(0);
114 OP_REQUIRES(ctx, max_time != 0,
115 errors::InvalidArgument(
116 "Max time or first dimension of input cannot be 0."));
117 const int64_t batch_size = inputs_shape.dim_size(1);
118 const int64_t num_classes_raw = inputs_shape.dim_size(2);
119 OP_REQUIRES(
120 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
121 errors::InvalidArgument("num_classes cannot exceed max int"));
122 const int num_classes = static_cast<const int>(num_classes_raw);
123
124 OP_REQUIRES(
125 ctx, batch_size == seq_len->dim_size(0),
126 errors::InvalidArgument("len(sequence_length) != batch_size. ",
127 "len(sequence_length): ", seq_len->dim_size(0),
128 " batch_size: ", batch_size));
129 auto seq_len_t = seq_len->vec<int32>();
130
131 OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
132 errors::InvalidArgument(
133 "labels_indices and labels_values must contain the "
134 "same number of rows, but saw shapes: ",
135 labels_indices->shape().DebugString(), " vs. ",
136 labels_values->shape().DebugString()));
137
138 OP_REQUIRES(ctx, batch_size != 0,
139 errors::InvalidArgument("batch_size must not be 0"));
140
141 // Figure out the maximum label length to use as sparse tensor dimension.
142 auto labels_indices_t = labels_indices->matrix<int64_t>();
143 int64_t max_label_len = 0;
144 for (int i = 0; i < labels_indices->dim_size(0); i++) {
145 max_label_len = std::max(max_label_len, labels_indices_t(i, 1) + 1);
146 }
147
148 TensorShape labels_shape({batch_size, max_label_len});
149 std::vector<int64_t> order{0, 1};
150 sparse::SparseTensor labels_sp;
151 OP_REQUIRES_OK(
152 ctx, sparse::SparseTensor::Create(*labels_indices, *labels_values,
153 labels_shape, order, &labels_sp));
154
155 Status labels_sp_valid = labels_sp.IndicesValid();
156 OP_REQUIRES(ctx, labels_sp_valid.ok(),
157 errors::InvalidArgument("label SparseTensor is not valid: ",
158 labels_sp_valid.error_message()));
159
160 typename ctc::CTCLossCalculator<T>::LabelSequences labels_t(batch_size);
161 for (const auto& g : labels_sp.group({0})) { // iterate by batch
162 const int64_t batch_indices = g.group()[0];
163 OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size),
164 errors::InvalidArgument("labels batch index must be between ",
165 0, " and ", batch_size,
166 " but saw: ", batch_indices));
167
168 auto values = g.values<int32>();
169 std::vector<int>* b_values = &labels_t[batch_indices];
170 b_values->resize(values.size());
171 for (int i = 0; i < values.size(); ++i) (*b_values)[i] = values(i);
172 }
173
174 OP_REQUIRES(ctx, static_cast<size_t>(batch_size) == labels_t.size(),
175 errors::InvalidArgument("len(labels) != batch_size. ",
176 "len(labels): ", labels_t.size(),
177 " batch_size: ", batch_size));
178
179 for (int64_t b = 0; b < batch_size; ++b) {
180 OP_REQUIRES(
181 ctx, seq_len_t(b) <= max_time,
182 errors::InvalidArgument("sequence_length(", b, ") <= ", max_time));
183 }
184
185 Tensor* loss = nullptr;
186 OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
187 auto loss_t = loss->vec<T>();
188
189 Tensor* gradient;
190 OP_REQUIRES_OK(ctx,
191 ctx->allocate_output("gradient", inputs_shape, &gradient));
192 auto gradient_t = gradient->tensor<T, 3>();
193 auto inputs_t = inputs->tensor<T, 3>();
194 std::vector<OutputMap> gradient_list_t;
195 std::vector<InputMap> input_list_t;
196
197 for (std::size_t t = 0; t < max_time; ++t) {
198 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
199 batch_size, num_classes);
200 gradient_list_t.emplace_back(
201 gradient_t.data() + t * batch_size * num_classes, batch_size,
202 num_classes);
203 }
204
205 gradient_t.setZero();
206
207 // Assumption: the blank index is num_classes - 1
208 ctc::CTCLossCalculator<T> ctc_loss_calculator(num_classes - 1, 0);
209 DeviceBase::CpuWorkerThreads workers =
210 *ctx->device()->tensorflow_cpu_worker_threads();
211 OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
212 seq_len_t, labels_t, input_list_t,
213 preprocess_collapse_repeated_, ctc_merge_repeated_,
214 ignore_longer_outputs_than_inputs_, &loss_t,
215 &gradient_list_t, &workers));
216 }
217
218 private:
219 bool preprocess_collapse_repeated_;
220 bool ctc_merge_repeated_;
221 bool ignore_longer_outputs_than_inputs_;
222
223 TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp<T>);
224};
225
226#define REGISTER_CPU(T) \
227 REGISTER_KERNEL_BUILDER( \
228 Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
229 CTCLossOp<T>);
230
231REGISTER_CPU(float);
232REGISTER_CPU(double);
233
234#undef REGISTER_CPU
235
236#if ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
237class CTCLossOpGPU : public OpKernel {
238 public:
239 explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
240 bool preprocess_collapse_repeated;
241 bool ctc_merge_repeated;
242 bool ignore_longer_outputs_than_inputs;
243 OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
244 &preprocess_collapse_repeated));
245 OP_REQUIRES_OK(ctx,
246 ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated));
247 OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
248 &ignore_longer_outputs_than_inputs));
249
250 OP_REQUIRES(ctx, !preprocess_collapse_repeated,
251 errors::InvalidArgument("GPU CTCLossOp requires "
252 "preprocess_collapse_repeated to be "
253 "false"));
254 OP_REQUIRES(ctx, ctc_merge_repeated,
255 errors::InvalidArgument("GPU CTCLossOp requires "
256 "ctc_merge_repeated to be "
257 "true"));
258 OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs,
259 errors::InvalidArgument("GPU CTCLossOp requires "
260 "ignore_longer_outputs_than_inputs to"
261 "be false"));
262 }
263
264 void Compute(OpKernelContext* ctx) override {
265 const Tensor* inputs;
266 const Tensor* labels_indices;
267 const Tensor* labels_values;
268 const Tensor* seq_len;
269 OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
270 OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
271 OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
272 OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
273
274 OP_REQUIRES(ctx, inputs->shape().dims() == 3,
275 errors::InvalidArgument("inputs is not a 3-Tensor"));
276 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
277 errors::InvalidArgument("sequence_length is not a vector"));
278 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
279 errors::InvalidArgument("labels_indices is not a matrix"));
280 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
281 errors::InvalidArgument("labels_values is not a vector"));
282
283 const TensorShape& inputs_shape = inputs->shape();
284 const int64_t max_time_raw = inputs_shape.dim_size(0);
285 const int64_t batch_size_raw = inputs_shape.dim_size(1);
286 const int64_t num_classes_raw = inputs_shape.dim_size(2);
287 OP_REQUIRES(ctx,
288 FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()),
289 errors::InvalidArgument("max_time_ cannot exceed max int"));
290 OP_REQUIRES(
291 ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()),
292 errors::InvalidArgument("batch_size cannot exceed max int"));
293 OP_REQUIRES(
294 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
295 errors::InvalidArgument("num_classes cannot exceed max int"));
296 const int max_time = static_cast<const int>(max_time_raw);
297 const int batch_size = static_cast<const int>(batch_size_raw);
298 const int num_classes = static_cast<const int>(num_classes_raw);
299
300 OP_REQUIRES(
301 ctx, batch_size == seq_len->dim_size(0),
302 errors::InvalidArgument("len(sequence_length) != batch_size. ",
303 "len(sequence_length): ", seq_len->dim_size(0),
304 " batch_size: ", batch_size));
305
306 OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
307 errors::InvalidArgument(
308 "labels_indices and labels_values must contain the "
309 "same number of rows, but saw shapes: ",
310 labels_indices->shape().DebugString(), " vs. ",
311 labels_values->shape().DebugString()));
312 auto num_indices = labels_indices->dim_size(0);
313
314 OP_REQUIRES(ctx, batch_size != 0,
315 errors::InvalidArgument("batch_size must not be 0"));
316
317 Tensor* loss = nullptr;
318 OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
319
320 Tensor* gradient = nullptr;
321 OP_REQUIRES_OK(ctx,
322 ctx->allocate_output("gradient", inputs_shape, &gradient));
323
324 // Convert the labels_indices to labels_lengths.
325 std::vector<int> labels_lengths(batch_size, 0);
326 DoHistogram<int64_t>(ctx, labels_indices, num_indices, batch_size,
327 &labels_lengths);
328
329 StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
330 se::dnn::DataType data_type = ToDataType<float>::value;
331
332 auto probs_desc_s = executor->createRnnStateTensorDescriptor(
333 max_time, batch_size, num_classes, data_type);
334 OP_REQUIRES_OK(ctx, probs_desc_s.status());
335 std::unique_ptr<RnnStateTensorDescriptor> probs_desc =
336 std::move(probs_desc_s).value();
337
338 auto grads_desc_s = executor->createRnnStateTensorDescriptor(
339 max_time, batch_size, num_classes, data_type);
340 OP_REQUIRES_OK(ctx, grads_desc_s.status());
341 std::unique_ptr<RnnStateTensorDescriptor> grads_desc =
342 std::move(grads_desc_s).value();
343
344 absl::Span<const int32> labels_data(labels_values->flat<int32>().data(),
345 num_indices);
346 absl::Span<const int32> labels_lengths_data(labels_lengths.data(),
347 batch_size);
348 absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(),
349 batch_size);
350
351 auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
352 auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
353 auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);
354
355 // Set the memory limitation to 4GB for workspace memory.
356 DnnScratchAllocator workspace_allocator(1LL << 32, ctx);
357
358 Stream* stream = ctx->op_device_context()->stream();
359 bool cudnn_launch_status =
360 stream
361 ->ThenCtcLoss(*probs_desc, probs_data, labels_data,
362 labels_lengths_data, input_lengths_data, &costs_data,
363 *grads_desc, &grads_data, &workspace_allocator)
364 .ok();
365
366 if (!cudnn_launch_status) {
367 ctx->SetStatus(errors::Internal("cuDNN CTCLoss launch failure"));
368 }
369 }
370
371 private:
372 TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
373};
374
375REGISTER_KERNEL_BUILDER(Name("CTCLossV2")
376 .Device(DEVICE_GPU)
377 .HostMemory("labels_indices")
378 .HostMemory("labels_values")
379 .HostMemory("sequence_length"),
380 CTCLossOpGPU);
381#endif // ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
382} // end namespace tensorflow
383