1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/ctc_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include <limits>
21
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/lib/core/status.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/platform/macros.h"
29#include "tensorflow/core/util/ctc/ctc_beam_search.h"
30#include "tensorflow/core/util/sparse/sparse_tensor.h"
31#include "tensorflow/core/util/work_sharder.h"
32
33namespace tensorflow {
34
35typedef Eigen::ThreadPoolDevice CPUDevice;
36
37template <typename T>
38inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,
39 int* c) {
40 *c = 0;
41 CHECK_LT(0, m.dimension(1));
42 auto p = m(r, 0);
43 for (int i = 1; i < m.dimension(1); ++i) {
44 if (m(r, i) > p) {
45 p = m(r, i);
46 *c = i;
47 }
48 }
49 return p;
50}
51
52class CTCDecodeHelper {
53 public:
54 CTCDecodeHelper() : top_paths_(1) {}
55
56 inline int GetTopPaths() const { return top_paths_; }
57 void SetTopPaths(int tp) { top_paths_ = tp; }
58
59 Status ValidateInputsGenerateOutputs(
60 OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
61 Tensor** log_prob, OpOutputList* decoded_indices,
62 OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
63 Status status = ctx->input("inputs", inputs);
64 if (!status.ok()) return status;
65 status = ctx->input("sequence_length", seq_len);
66 if (!status.ok()) return status;
67
68 const TensorShape& inputs_shape = (*inputs)->shape();
69
70 if (inputs_shape.dims() != 3) {
71 return errors::InvalidArgument("inputs is not a 3-Tensor");
72 }
73 if (inputs_shape.num_elements() == 0) {
74 return errors::InvalidArgument("inputs must not be empty");
75 }
76
77 const int64_t max_time = inputs_shape.dim_size(0);
78 const int64_t batch_size = inputs_shape.dim_size(1);
79
80 if (max_time == 0) {
81 return errors::InvalidArgument("max_time is 0");
82 }
83 if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
84 return errors::InvalidArgument("sequence_length is not a vector");
85 }
86
87 if (!(batch_size == (*seq_len)->dim_size(0))) {
88 return errors::FailedPrecondition(
89 "len(sequence_length) != batch_size. ",
90 "len(sequence_length): ", (*seq_len)->dim_size(0),
91 " batch_size: ", batch_size);
92 }
93
94 auto seq_len_t = (*seq_len)->vec<int32>();
95
96 for (int b = 0; b < batch_size; ++b) {
97 if (!(seq_len_t(b) <= max_time)) {
98 return errors::FailedPrecondition("sequence_length(", b,
99 ") <= ", max_time);
100 }
101 }
102
103 Status s = ctx->allocate_output(
104 "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
105 if (!s.ok()) return s;
106
107 s = ctx->output_list("decoded_indices", decoded_indices);
108 if (!s.ok()) return s;
109 s = ctx->output_list("decoded_values", decoded_values);
110 if (!s.ok()) return s;
111 s = ctx->output_list("decoded_shape", decoded_shape);
112 if (!s.ok()) return s;
113
114 return OkStatus();
115 }
116
117 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
118 Status StoreAllDecodedSequences(
119 const std::vector<std::vector<std::vector<int> > >& sequences,
120 OpOutputList* decoded_indices, OpOutputList* decoded_values,
121 OpOutputList* decoded_shape) const {
122 // Calculate the total number of entries for each path
123 const int64_t batch_size = sequences.size();
124 std::vector<int64_t> num_entries(top_paths_, 0);
125
126 // Calculate num_entries per path
127 for (const auto& batch_s : sequences) {
128 CHECK_EQ(batch_s.size(), top_paths_);
129 for (int p = 0; p < top_paths_; ++p) {
130 num_entries[p] += batch_s[p].size();
131 }
132 }
133
134 for (int p = 0; p < top_paths_; ++p) {
135 Tensor* p_indices = nullptr;
136 Tensor* p_values = nullptr;
137 Tensor* p_shape = nullptr;
138
139 const int64_t p_num = num_entries[p];
140
141 Status s =
142 decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
143 if (!s.ok()) return s;
144 s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
145 if (!s.ok()) return s;
146 s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
147 if (!s.ok()) return s;
148
149 auto indices_t = p_indices->matrix<int64_t>();
150 auto values_t = p_values->vec<int64_t>();
151 auto shape_t = p_shape->vec<int64_t>();
152
153 int64_t max_decoded = 0;
154 int64_t offset = 0;
155
156 for (int64_t b = 0; b < batch_size; ++b) {
157 auto& p_batch = sequences[b][p];
158 int64_t num_decoded = p_batch.size();
159 max_decoded = std::max(max_decoded, num_decoded);
160 if (num_decoded > 0) {
161 DCHECK_NE(values_t.data(), nullptr)
162 << "values_t should not be nullptr: p_num=" << p_num
163 << " num_decoded=" << num_decoded;
164 DCHECK_LT(offset, values_t.size())
165 << "offset should be smaller than values_t.size()";
166 std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
167 }
168 for (int64_t t = 0; t < num_decoded; ++t, ++offset) {
169 indices_t(offset, 0) = b;
170 indices_t(offset, 1) = t;
171 }
172 }
173
174 shape_t(0) = batch_size;
175 shape_t(1) = max_decoded;
176 }
177 return OkStatus();
178 }
179
180 private:
181 int top_paths_;
182 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
183};
184
185template <typename T>
186class CTCGreedyDecoderOp : public OpKernel {
187 public:
188 explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
189 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
190 OP_REQUIRES_OK(ctx, ctx->GetAttr("blank_index", &blank_index_));
191 }
192
193 void Compute(OpKernelContext* ctx) override {
194 const Tensor* inputs;
195 const Tensor* seq_len;
196 Tensor* log_prob = nullptr;
197 OpOutputList decoded_indices;
198 OpOutputList decoded_values;
199 OpOutputList decoded_shape;
200 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
201 ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
202 &decoded_values, &decoded_shape));
203
204 const TensorShape& inputs_shape = inputs->shape();
205
206 std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
207 const int64_t max_time = inputs_shape.dim_size(0);
208 const int64_t batch_size = inputs_shape.dim_size(1);
209 const int64_t num_classes_raw = inputs_shape.dim_size(2);
210 OP_REQUIRES(
211 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
212 errors::InvalidArgument("num_classes cannot exceed max int"));
213 const int num_classes = static_cast<const int>(num_classes_raw);
214
215 auto inputs_t = inputs->tensor<T, 3>();
216
217 input_list_t.reserve(max_time);
218 for (std::size_t t = 0; t < max_time; ++t) {
219 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
220 batch_size, num_classes);
221 }
222 auto seq_len_t = seq_len->vec<int32>();
223 auto log_prob_t = log_prob->matrix<T>();
224
225 log_prob_t.setZero();
226
227 int blank_index =
228 (blank_index_ < 0) ? num_classes + blank_index_ : blank_index_;
229 OP_REQUIRES(ctx, FastBoundsCheck(blank_index, num_classes),
230 errors::InvalidArgument("blank_index expected to be between ",
231 -num_classes, " and ", num_classes - 1,
232 " but was ", blank_index_));
233
234 // Perform best path decoding
235 std::vector<std::vector<std::vector<int> > > sequences(batch_size);
236 auto decode = [&](const int64_t begin, const int64_t end) {
237 for (int b = begin; b < end; ++b) {
238 sequences[b].resize(1);
239 auto &sequence = sequences[b][0];
240 int prev_indices = -1;
241 for (int t = 0; t < seq_len_t(b); ++t) {
242 int max_class_indices;
243 OP_REQUIRES(ctx, input_list_t[t].dimension(1) > 0,
244 errors::InvalidArgument("Invalid input dimensions."));
245 log_prob_t(b, 0) +=
246 -RowMax<T>(input_list_t[t], b, &max_class_indices);
247 if (max_class_indices != blank_index &&
248 !(merge_repeated_ && max_class_indices == prev_indices)) {
249 sequence.push_back(max_class_indices);
250 }
251 prev_indices = max_class_indices;
252 }
253 }
254 };
255
256 const int64_t kCostPerUnit = 50 * max_time * num_classes;
257 const int64_t total = batch_size;
258 const DeviceBase::CpuWorkerThreads& worker_threads =
259 *ctx->device()->tensorflow_cpu_worker_threads();
260 Shard(worker_threads.num_threads, worker_threads.workers, total,
261 kCostPerUnit, decode);
262
263 OP_REQUIRES_OK(
264 ctx, decode_helper_.StoreAllDecodedSequences(
265 sequences, &decoded_indices, &decoded_values, &decoded_shape));
266 }
267
268 private:
269 CTCDecodeHelper decode_helper_;
270 bool merge_repeated_;
271 int blank_index_;
272
273 TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
274};
275
276#define REGISTER_CPU(T) \
277 REGISTER_KERNEL_BUILDER( \
278 Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
279 CTCGreedyDecoderOp<T>);
280
281REGISTER_CPU(float);
282REGISTER_CPU(double);
283
284#undef REGISTER_CPU
285
286// CTC beam search
287template <typename T>
288class CTCBeamSearchDecoderOp : public OpKernel {
289 public:
290 explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
291 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
292 OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
293 int top_paths;
294 OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
295 decode_helper_.SetTopPaths(top_paths);
296 }
297
298 void Compute(OpKernelContext* ctx) override {
299 const Tensor* inputs;
300 const Tensor* seq_len;
301 Tensor* log_prob = nullptr;
302 OpOutputList decoded_indices;
303 OpOutputList decoded_values;
304 OpOutputList decoded_shape;
305 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
306 ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
307 &decoded_values, &decoded_shape));
308
309 auto inputs_t = inputs->tensor<T, 3>();
310 auto seq_len_t = seq_len->vec<int32>();
311 auto log_prob_t = log_prob->matrix<T>();
312
313 const TensorShape& inputs_shape = inputs->shape();
314
315 const int64_t max_time = inputs_shape.dim_size(0);
316 const int64_t batch_size = inputs_shape.dim_size(1);
317 const int64_t num_classes_raw = inputs_shape.dim_size(2);
318 OP_REQUIRES(
319 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
320 errors::InvalidArgument("num_classes cannot exceed max int"));
321 const int num_classes = static_cast<const int>(num_classes_raw);
322
323 log_prob_t.setZero();
324
325 std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
326
327 input_list_t.reserve(max_time);
328 for (std::size_t t = 0; t < max_time; ++t) {
329 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
330 batch_size, num_classes);
331 }
332
333 ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
334 &beam_scorer_, 1 /* batch_size */,
335 merge_repeated_);
336 Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
337 auto input_chip_t = input_chip.flat<T>();
338
339 std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
340 std::vector<T> log_probs;
341
342 // Assumption: the blank index is num_classes - 1
343 for (int b = 0; b < batch_size; ++b) {
344 auto& best_paths_b = best_paths[b];
345 best_paths_b.resize(decode_helper_.GetTopPaths());
346 for (int t = 0; t < seq_len_t(b); ++t) {
347 input_chip_t = input_list_t[t].chip(b, 0);
348 auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>(
349 input_chip_t.data(), num_classes);
350 beam_search.Step(input_bi);
351 }
352 OP_REQUIRES_OK(
353 ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
354 &log_probs, merge_repeated_));
355
356 beam_search.Reset();
357
358 for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
359 log_prob_t(b, bp) = log_probs[bp];
360 }
361 }
362
363 OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
364 best_paths, &decoded_indices, &decoded_values,
365 &decoded_shape));
366 }
367
368 private:
369 CTCDecodeHelper decode_helper_;
370 typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
371 bool merge_repeated_;
372 int beam_width_;
373 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
374};
375
376#define REGISTER_CPU(T) \
377 REGISTER_KERNEL_BUILDER( \
378 Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
379 CTCBeamSearchDecoderOp<T>);
380
381REGISTER_CPU(float);
382REGISTER_CPU(double);
383
384#undef REGISTER_CPU
385
386} // end namespace tensorflow
387