1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/ctc_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <limits> |
21 | |
22 | #include "tensorflow/core/framework/op.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/platform/macros.h" |
29 | #include "tensorflow/core/util/ctc/ctc_beam_search.h" |
30 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
31 | #include "tensorflow/core/util/work_sharder.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | typedef Eigen::ThreadPoolDevice CPUDevice; |
36 | |
37 | template <typename T> |
38 | inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r, |
39 | int* c) { |
40 | *c = 0; |
41 | CHECK_LT(0, m.dimension(1)); |
42 | auto p = m(r, 0); |
43 | for (int i = 1; i < m.dimension(1); ++i) { |
44 | if (m(r, i) > p) { |
45 | p = m(r, i); |
46 | *c = i; |
47 | } |
48 | } |
49 | return p; |
50 | } |
51 | |
52 | class CTCDecodeHelper { |
53 | public: |
54 | CTCDecodeHelper() : top_paths_(1) {} |
55 | |
56 | inline int GetTopPaths() const { return top_paths_; } |
57 | void SetTopPaths(int tp) { top_paths_ = tp; } |
58 | |
59 | Status ValidateInputsGenerateOutputs( |
60 | OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len, |
61 | Tensor** log_prob, OpOutputList* decoded_indices, |
62 | OpOutputList* decoded_values, OpOutputList* decoded_shape) const { |
63 | Status status = ctx->input("inputs" , inputs); |
64 | if (!status.ok()) return status; |
65 | status = ctx->input("sequence_length" , seq_len); |
66 | if (!status.ok()) return status; |
67 | |
68 | const TensorShape& inputs_shape = (*inputs)->shape(); |
69 | |
70 | if (inputs_shape.dims() != 3) { |
71 | return errors::InvalidArgument("inputs is not a 3-Tensor" ); |
72 | } |
73 | if (inputs_shape.num_elements() == 0) { |
74 | return errors::InvalidArgument("inputs must not be empty" ); |
75 | } |
76 | |
77 | const int64_t max_time = inputs_shape.dim_size(0); |
78 | const int64_t batch_size = inputs_shape.dim_size(1); |
79 | |
80 | if (max_time == 0) { |
81 | return errors::InvalidArgument("max_time is 0" ); |
82 | } |
83 | if (!TensorShapeUtils::IsVector((*seq_len)->shape())) { |
84 | return errors::InvalidArgument("sequence_length is not a vector" ); |
85 | } |
86 | |
87 | if (!(batch_size == (*seq_len)->dim_size(0))) { |
88 | return errors::FailedPrecondition( |
89 | "len(sequence_length) != batch_size. " , |
90 | "len(sequence_length): " , (*seq_len)->dim_size(0), |
91 | " batch_size: " , batch_size); |
92 | } |
93 | |
94 | auto seq_len_t = (*seq_len)->vec<int32>(); |
95 | |
96 | for (int b = 0; b < batch_size; ++b) { |
97 | if (!(seq_len_t(b) <= max_time)) { |
98 | return errors::FailedPrecondition("sequence_length(" , b, |
99 | ") <= " , max_time); |
100 | } |
101 | } |
102 | |
103 | Status s = ctx->allocate_output( |
104 | "log_probability" , TensorShape({batch_size, top_paths_}), log_prob); |
105 | if (!s.ok()) return s; |
106 | |
107 | s = ctx->output_list("decoded_indices" , decoded_indices); |
108 | if (!s.ok()) return s; |
109 | s = ctx->output_list("decoded_values" , decoded_values); |
110 | if (!s.ok()) return s; |
111 | s = ctx->output_list("decoded_shape" , decoded_shape); |
112 | if (!s.ok()) return s; |
113 | |
114 | return OkStatus(); |
115 | } |
116 | |
117 | // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b". |
118 | Status StoreAllDecodedSequences( |
119 | const std::vector<std::vector<std::vector<int> > >& sequences, |
120 | OpOutputList* decoded_indices, OpOutputList* decoded_values, |
121 | OpOutputList* decoded_shape) const { |
122 | // Calculate the total number of entries for each path |
123 | const int64_t batch_size = sequences.size(); |
124 | std::vector<int64_t> num_entries(top_paths_, 0); |
125 | |
126 | // Calculate num_entries per path |
127 | for (const auto& batch_s : sequences) { |
128 | CHECK_EQ(batch_s.size(), top_paths_); |
129 | for (int p = 0; p < top_paths_; ++p) { |
130 | num_entries[p] += batch_s[p].size(); |
131 | } |
132 | } |
133 | |
134 | for (int p = 0; p < top_paths_; ++p) { |
135 | Tensor* p_indices = nullptr; |
136 | Tensor* p_values = nullptr; |
137 | Tensor* p_shape = nullptr; |
138 | |
139 | const int64_t p_num = num_entries[p]; |
140 | |
141 | Status s = |
142 | decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices); |
143 | if (!s.ok()) return s; |
144 | s = decoded_values->allocate(p, TensorShape({p_num}), &p_values); |
145 | if (!s.ok()) return s; |
146 | s = decoded_shape->allocate(p, TensorShape({2}), &p_shape); |
147 | if (!s.ok()) return s; |
148 | |
149 | auto indices_t = p_indices->matrix<int64_t>(); |
150 | auto values_t = p_values->vec<int64_t>(); |
151 | auto shape_t = p_shape->vec<int64_t>(); |
152 | |
153 | int64_t max_decoded = 0; |
154 | int64_t offset = 0; |
155 | |
156 | for (int64_t b = 0; b < batch_size; ++b) { |
157 | auto& p_batch = sequences[b][p]; |
158 | int64_t num_decoded = p_batch.size(); |
159 | max_decoded = std::max(max_decoded, num_decoded); |
160 | if (num_decoded > 0) { |
161 | DCHECK_NE(values_t.data(), nullptr) |
162 | << "values_t should not be nullptr: p_num=" << p_num |
163 | << " num_decoded=" << num_decoded; |
164 | DCHECK_LT(offset, values_t.size()) |
165 | << "offset should be smaller than values_t.size()" ; |
166 | std::copy_n(p_batch.begin(), num_decoded, &values_t(offset)); |
167 | } |
168 | for (int64_t t = 0; t < num_decoded; ++t, ++offset) { |
169 | indices_t(offset, 0) = b; |
170 | indices_t(offset, 1) = t; |
171 | } |
172 | } |
173 | |
174 | shape_t(0) = batch_size; |
175 | shape_t(1) = max_decoded; |
176 | } |
177 | return OkStatus(); |
178 | } |
179 | |
180 | private: |
181 | int top_paths_; |
182 | TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper); |
183 | }; |
184 | |
185 | template <typename T> |
186 | class CTCGreedyDecoderOp : public OpKernel { |
187 | public: |
188 | explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
189 | OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated" , &merge_repeated_)); |
190 | OP_REQUIRES_OK(ctx, ctx->GetAttr("blank_index" , &blank_index_)); |
191 | } |
192 | |
193 | void Compute(OpKernelContext* ctx) override { |
194 | const Tensor* inputs; |
195 | const Tensor* seq_len; |
196 | Tensor* log_prob = nullptr; |
197 | OpOutputList decoded_indices; |
198 | OpOutputList decoded_values; |
199 | OpOutputList decoded_shape; |
200 | OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs( |
201 | ctx, &inputs, &seq_len, &log_prob, &decoded_indices, |
202 | &decoded_values, &decoded_shape)); |
203 | |
204 | const TensorShape& inputs_shape = inputs->shape(); |
205 | |
206 | std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t; |
207 | const int64_t max_time = inputs_shape.dim_size(0); |
208 | const int64_t batch_size = inputs_shape.dim_size(1); |
209 | const int64_t num_classes_raw = inputs_shape.dim_size(2); |
210 | OP_REQUIRES( |
211 | ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), |
212 | errors::InvalidArgument("num_classes cannot exceed max int" )); |
213 | const int num_classes = static_cast<const int>(num_classes_raw); |
214 | |
215 | auto inputs_t = inputs->tensor<T, 3>(); |
216 | |
217 | input_list_t.reserve(max_time); |
218 | for (std::size_t t = 0; t < max_time; ++t) { |
219 | input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, |
220 | batch_size, num_classes); |
221 | } |
222 | auto seq_len_t = seq_len->vec<int32>(); |
223 | auto log_prob_t = log_prob->matrix<T>(); |
224 | |
225 | log_prob_t.setZero(); |
226 | |
227 | int blank_index = |
228 | (blank_index_ < 0) ? num_classes + blank_index_ : blank_index_; |
229 | OP_REQUIRES(ctx, FastBoundsCheck(blank_index, num_classes), |
230 | errors::InvalidArgument("blank_index expected to be between " , |
231 | -num_classes, " and " , num_classes - 1, |
232 | " but was " , blank_index_)); |
233 | |
234 | // Perform best path decoding |
235 | std::vector<std::vector<std::vector<int> > > sequences(batch_size); |
236 | auto decode = [&](const int64_t begin, const int64_t end) { |
237 | for (int b = begin; b < end; ++b) { |
238 | sequences[b].resize(1); |
239 | auto &sequence = sequences[b][0]; |
240 | int prev_indices = -1; |
241 | for (int t = 0; t < seq_len_t(b); ++t) { |
242 | int max_class_indices; |
243 | OP_REQUIRES(ctx, input_list_t[t].dimension(1) > 0, |
244 | errors::InvalidArgument("Invalid input dimensions." )); |
245 | log_prob_t(b, 0) += |
246 | -RowMax<T>(input_list_t[t], b, &max_class_indices); |
247 | if (max_class_indices != blank_index && |
248 | !(merge_repeated_ && max_class_indices == prev_indices)) { |
249 | sequence.push_back(max_class_indices); |
250 | } |
251 | prev_indices = max_class_indices; |
252 | } |
253 | } |
254 | }; |
255 | |
256 | const int64_t kCostPerUnit = 50 * max_time * num_classes; |
257 | const int64_t total = batch_size; |
258 | const DeviceBase::CpuWorkerThreads& worker_threads = |
259 | *ctx->device()->tensorflow_cpu_worker_threads(); |
260 | Shard(worker_threads.num_threads, worker_threads.workers, total, |
261 | kCostPerUnit, decode); |
262 | |
263 | OP_REQUIRES_OK( |
264 | ctx, decode_helper_.StoreAllDecodedSequences( |
265 | sequences, &decoded_indices, &decoded_values, &decoded_shape)); |
266 | } |
267 | |
268 | private: |
269 | CTCDecodeHelper decode_helper_; |
270 | bool merge_repeated_; |
271 | int blank_index_; |
272 | |
273 | TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp); |
274 | }; |
275 | |
276 | #define REGISTER_CPU(T) \ |
277 | REGISTER_KERNEL_BUILDER( \ |
278 | Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
279 | CTCGreedyDecoderOp<T>); |
280 | |
281 | REGISTER_CPU(float); |
282 | REGISTER_CPU(double); |
283 | |
284 | #undef REGISTER_CPU |
285 | |
286 | // CTC beam search |
287 | template <typename T> |
288 | class CTCBeamSearchDecoderOp : public OpKernel { |
289 | public: |
290 | explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
291 | OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated" , &merge_repeated_)); |
292 | OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width" , &beam_width_)); |
293 | int top_paths; |
294 | OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths" , &top_paths)); |
295 | decode_helper_.SetTopPaths(top_paths); |
296 | } |
297 | |
298 | void Compute(OpKernelContext* ctx) override { |
299 | const Tensor* inputs; |
300 | const Tensor* seq_len; |
301 | Tensor* log_prob = nullptr; |
302 | OpOutputList decoded_indices; |
303 | OpOutputList decoded_values; |
304 | OpOutputList decoded_shape; |
305 | OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs( |
306 | ctx, &inputs, &seq_len, &log_prob, &decoded_indices, |
307 | &decoded_values, &decoded_shape)); |
308 | |
309 | auto inputs_t = inputs->tensor<T, 3>(); |
310 | auto seq_len_t = seq_len->vec<int32>(); |
311 | auto log_prob_t = log_prob->matrix<T>(); |
312 | |
313 | const TensorShape& inputs_shape = inputs->shape(); |
314 | |
315 | const int64_t max_time = inputs_shape.dim_size(0); |
316 | const int64_t batch_size = inputs_shape.dim_size(1); |
317 | const int64_t num_classes_raw = inputs_shape.dim_size(2); |
318 | OP_REQUIRES( |
319 | ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), |
320 | errors::InvalidArgument("num_classes cannot exceed max int" )); |
321 | const int num_classes = static_cast<const int>(num_classes_raw); |
322 | |
323 | log_prob_t.setZero(); |
324 | |
325 | std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t; |
326 | |
327 | input_list_t.reserve(max_time); |
328 | for (std::size_t t = 0; t < max_time; ++t) { |
329 | input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, |
330 | batch_size, num_classes); |
331 | } |
332 | |
333 | ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_, |
334 | &beam_scorer_, 1 /* batch_size */, |
335 | merge_repeated_); |
336 | Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes})); |
337 | auto input_chip_t = input_chip.flat<T>(); |
338 | |
339 | std::vector<std::vector<std::vector<int> > > best_paths(batch_size); |
340 | std::vector<T> log_probs; |
341 | |
342 | // Assumption: the blank index is num_classes - 1 |
343 | for (int b = 0; b < batch_size; ++b) { |
344 | auto& best_paths_b = best_paths[b]; |
345 | best_paths_b.resize(decode_helper_.GetTopPaths()); |
346 | for (int t = 0; t < seq_len_t(b); ++t) { |
347 | input_chip_t = input_list_t[t].chip(b, 0); |
348 | auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>( |
349 | input_chip_t.data(), num_classes); |
350 | beam_search.Step(input_bi); |
351 | } |
352 | OP_REQUIRES_OK( |
353 | ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b, |
354 | &log_probs, merge_repeated_)); |
355 | |
356 | beam_search.Reset(); |
357 | |
358 | for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) { |
359 | log_prob_t(b, bp) = log_probs[bp]; |
360 | } |
361 | } |
362 | |
363 | OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences( |
364 | best_paths, &decoded_indices, &decoded_values, |
365 | &decoded_shape)); |
366 | } |
367 | |
368 | private: |
369 | CTCDecodeHelper decode_helper_; |
370 | typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_; |
371 | bool merge_repeated_; |
372 | int beam_width_; |
373 | TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>); |
374 | }; |
375 | |
376 | #define REGISTER_CPU(T) \ |
377 | REGISTER_KERNEL_BUILDER( \ |
378 | Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
379 | CTCBeamSearchDecoderOp<T>); |
380 | |
381 | REGISTER_CPU(float); |
382 | REGISTER_CPU(double); |
383 | |
384 | #undef REGISTER_CPU |
385 | |
386 | } // end namespace tensorflow |
387 | |