1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::DimensionHandle; |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | // CTC is Connectionist Temporal Classification. See util/ctc/ for details. |
26 | |
27 | REGISTER_OP("CTCLoss" ) |
28 | .Input("inputs: T" ) |
29 | .Input("labels_indices: int64" ) |
30 | .Input("labels_values: int32" ) |
31 | .Input("sequence_length: int32" ) |
32 | .Attr("preprocess_collapse_repeated: bool = false" ) |
33 | .Attr("ctc_merge_repeated: bool = true" ) |
34 | .Attr("ignore_longer_outputs_than_inputs: bool = false" ) |
35 | .Output("loss: T" ) |
36 | .Output("gradient: T" ) |
37 | .Attr("T: {float, double} = DT_FLOAT" ) |
38 | .SetShapeFn([](InferenceContext* c) { |
39 | ShapeHandle inputs; |
40 | ShapeHandle labels_indices; |
41 | ShapeHandle labels_values; |
42 | ShapeHandle sequence_length; |
43 | |
44 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); |
45 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices)); |
46 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values)); |
47 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length)); |
48 | |
49 | DimensionHandle unused; |
50 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0), |
51 | c->Dim(labels_values, 0), &unused)); |
52 | |
53 | // Get batch size from inputs and sequence_length, and update inputs |
54 | // with the merged batch_size since it is returned. |
55 | DimensionHandle batch_size; |
56 | TF_RETURN_IF_ERROR( |
57 | c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); |
58 | TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs)); |
59 | |
60 | c->set_output(0, c->Vector(batch_size)); |
61 | c->set_output(1, inputs); |
62 | return OkStatus(); |
63 | }); |
64 | |
65 | REGISTER_OP("CTCLossV2" ) |
66 | .Input("inputs: float" ) |
67 | .Input("labels_indices: int64" ) |
68 | .Input("labels_values: int32" ) |
69 | .Input("sequence_length: int32" ) |
70 | .Attr("preprocess_collapse_repeated: bool = false" ) |
71 | .Attr("ctc_merge_repeated: bool = true" ) |
72 | .Attr("ignore_longer_outputs_than_inputs: bool = false" ) |
73 | .Output("loss: float" ) |
74 | .Output("gradient: float" ) |
75 | .SetShapeFn([](InferenceContext* c) { |
76 | ShapeHandle inputs; |
77 | ShapeHandle labels_indices; |
78 | ShapeHandle labels_values; |
79 | ShapeHandle sequence_length; |
80 | |
81 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); |
82 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices)); |
83 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values)); |
84 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length)); |
85 | |
86 | DimensionHandle unused; |
87 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0), |
88 | c->Dim(labels_values, 0), &unused)); |
89 | |
90 | // Get batch size from inputs and sequence_length, and update inputs |
91 | // with the merged batch_size since it is returned. |
92 | DimensionHandle batch_size; |
93 | TF_RETURN_IF_ERROR( |
94 | c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); |
95 | TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs)); |
96 | |
97 | c->set_output(0, c->Vector(batch_size)); |
98 | c->set_output(1, inputs); |
99 | return OkStatus(); |
100 | }); |
101 | |
102 | REGISTER_OP("CTCGreedyDecoder" ) |
103 | .Input("inputs: T" ) |
104 | .Input("sequence_length: int32" ) |
105 | .Attr("merge_repeated: bool = false" ) |
106 | .Attr("blank_index: int = -1" ) |
107 | .Output("decoded_indices: int64" ) |
108 | .Output("decoded_values: int64" ) |
109 | .Output("decoded_shape: int64" ) |
110 | .Output("log_probability: T" ) |
111 | .Attr("T: {float, double} = DT_FLOAT" ) |
112 | .SetShapeFn([](InferenceContext* c) { |
113 | ShapeHandle inputs; |
114 | ShapeHandle sequence_length; |
115 | |
116 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); |
117 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); |
118 | |
119 | // Get batch size from inputs and sequence_length. |
120 | DimensionHandle batch_size; |
121 | TF_RETURN_IF_ERROR( |
122 | c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); |
123 | |
124 | DimensionHandle total_decoded_outputs = c->UnknownDim(); |
125 | c->set_output(0, c->Matrix(total_decoded_outputs, 2)); |
126 | c->set_output(1, c->Vector(total_decoded_outputs)); |
127 | c->set_output(2, c->Vector(2)); |
128 | c->set_output(3, c->Matrix(batch_size, 1)); |
129 | return OkStatus(); |
130 | }); |
131 | |
132 | REGISTER_OP("CTCBeamSearchDecoder" ) |
133 | .Input("inputs: T" ) |
134 | .Input("sequence_length: int32" ) |
135 | .Attr("beam_width: int >= 1" ) |
136 | .Attr("top_paths: int >= 1" ) |
137 | .Attr("merge_repeated: bool = true" ) |
138 | .Output("decoded_indices: top_paths * int64" ) |
139 | .Output("decoded_values: top_paths * int64" ) |
140 | .Output("decoded_shape: top_paths * int64" ) |
141 | .Output("log_probability: T" ) |
142 | .Attr("T: {float, double} = DT_FLOAT" ) |
143 | .SetShapeFn([](InferenceContext* c) { |
144 | ShapeHandle inputs; |
145 | ShapeHandle sequence_length; |
146 | |
147 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); |
148 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); |
149 | |
150 | // Get batch size from inputs and sequence_length. |
151 | DimensionHandle batch_size; |
152 | TF_RETURN_IF_ERROR( |
153 | c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); |
154 | |
155 | int32_t top_paths; |
156 | TF_RETURN_IF_ERROR(c->GetAttr("top_paths" , &top_paths)); |
157 | |
158 | // Outputs. |
159 | int out_idx = 0; |
160 | for (int i = 0; i < top_paths; ++i) { // decoded_indices |
161 | c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2)); |
162 | } |
163 | for (int i = 0; i < top_paths; ++i) { // decoded_values |
164 | c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim)); |
165 | } |
166 | ShapeHandle shape_v = c->Vector(2); |
167 | for (int i = 0; i < top_paths; ++i) { // decoded_shape |
168 | c->set_output(out_idx++, shape_v); |
169 | } |
170 | c->set_output(out_idx++, c->Matrix(batch_size, top_paths)); |
171 | return OkStatus(); |
172 | }); |
173 | |
174 | } // namespace tensorflow |
175 | |