1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <limits> |
21 | |
22 | #include <vector> |
23 | #include "tensorflow/core/common_runtime/device.h" |
24 | #include "tensorflow/core/framework/op.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/lib/gtl/edit_distance.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | #include "tensorflow/core/platform/macros.h" |
32 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | namespace { |
37 | |
38 | Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices, |
39 | const Tensor& hypothesis_values, |
40 | const Tensor& hypothesis_shape, |
41 | const Tensor& truth_indices, const Tensor& truth_values, |
42 | const Tensor& truth_shape) { |
43 | if (!TensorShapeUtils::IsMatrix(hypothesis_indices.shape())) |
44 | return errors::InvalidArgument( |
45 | "hypothesis_indices should be a matrix, but got shape: " , |
46 | hypothesis_indices.shape().DebugString()); |
47 | if (!TensorShapeUtils::IsMatrix(truth_indices.shape())) |
48 | return errors::InvalidArgument( |
49 | "truth_indices should be a matrix, but got shape: " , |
50 | truth_indices.shape().DebugString()); |
51 | if (!TensorShapeUtils::IsVector(hypothesis_values.shape())) |
52 | return errors::InvalidArgument( |
53 | "hypothesis_values should be a vector, but got shape: " , |
54 | hypothesis_values.shape().DebugString()); |
55 | if (!TensorShapeUtils::IsVector(truth_values.shape())) |
56 | return errors::InvalidArgument( |
57 | "truth_values should be a vector, but got shape: " , |
58 | truth_values.shape().DebugString()); |
59 | if (!TensorShapeUtils::IsVector(hypothesis_shape.shape())) |
60 | return errors::InvalidArgument( |
61 | "hypothesis_shape should be a vector, but got shape: " , |
62 | hypothesis_shape.shape().DebugString()); |
63 | if (!TensorShapeUtils::IsVector(truth_shape.shape())) |
64 | return errors::InvalidArgument( |
65 | "truth_shape should be a vector, but got shape: " , |
66 | truth_shape.shape().DebugString()); |
67 | if (hypothesis_values.NumElements() != hypothesis_indices.dim_size(0)) |
68 | return errors::InvalidArgument( |
69 | "Expected hypothesis_values.NumElements == " |
70 | "#rows(hypothesis_indices), their shapes are: " , |
71 | hypothesis_values.shape().DebugString(), " and " , |
72 | hypothesis_indices.shape().DebugString()); |
73 | if (hypothesis_shape.NumElements() != hypothesis_indices.dim_size(1)) |
74 | return errors::InvalidArgument( |
75 | "Expected hypothesis_shape.NumElements == " |
76 | "#cols(hypothesis_indices), their shapes are: " , |
77 | hypothesis_shape.shape().DebugString(), " and " , |
78 | hypothesis_indices.shape().DebugString()); |
79 | if (truth_shape.NumElements() < 2) |
80 | return errors::InvalidArgument( |
81 | "Input SparseTensors must have rank at least 2, but truth_shape " |
82 | "rank is: " , |
83 | truth_shape.NumElements()); |
84 | if (truth_values.NumElements() != truth_indices.dim_size(0)) |
85 | return errors::InvalidArgument( |
86 | "Expected truth_values.NumElements == " |
87 | "#rows(truth_indices), their shapes are: " , |
88 | truth_values.shape().DebugString(), " and " , |
89 | truth_indices.shape().DebugString()); |
90 | if (truth_shape.NumElements() != truth_indices.dim_size(1)) |
91 | return errors::InvalidArgument( |
92 | "Expected truth_shape.NumElements == " |
93 | "#cols(truth_indices), their shapes are: " , |
94 | truth_shape.shape().DebugString(), " and " , |
95 | truth_indices.shape().DebugString()); |
96 | if (truth_shape.NumElements() != hypothesis_shape.NumElements()) |
97 | return errors::InvalidArgument( |
98 | "Expected truth and hypothesis to have matching ranks, but " |
99 | "their shapes are: " , |
100 | truth_shape.shape().DebugString(), " and " , |
101 | hypothesis_shape.shape().DebugString()); |
102 | |
103 | return OkStatus(); |
104 | } |
105 | |
106 | } // namespace |
107 | |
108 | template <typename T> |
109 | class EditDistanceOp : public OpKernel { |
110 | public: |
111 | explicit EditDistanceOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
112 | OP_REQUIRES_OK(ctx, ctx->GetAttr("normalize" , &normalize_)); |
113 | } |
114 | |
115 | void Compute(OpKernelContext* ctx) override { |
116 | const Tensor* hypothesis_indices; |
117 | const Tensor* hypothesis_values; |
118 | const Tensor* hypothesis_shape; |
119 | const Tensor* truth_indices; |
120 | const Tensor* truth_values; |
121 | const Tensor* truth_shape; |
122 | OP_REQUIRES_OK(ctx, ctx->input("hypothesis_indices" , &hypothesis_indices)); |
123 | OP_REQUIRES_OK(ctx, ctx->input("hypothesis_values" , &hypothesis_values)); |
124 | OP_REQUIRES_OK(ctx, ctx->input("hypothesis_shape" , &hypothesis_shape)); |
125 | OP_REQUIRES_OK(ctx, ctx->input("truth_indices" , &truth_indices)); |
126 | OP_REQUIRES_OK(ctx, ctx->input("truth_values" , &truth_values)); |
127 | OP_REQUIRES_OK(ctx, ctx->input("truth_shape" , &truth_shape)); |
128 | |
129 | OP_REQUIRES_OK( |
130 | ctx, ValidateShapes(ctx, *hypothesis_indices, *hypothesis_values, |
131 | *hypothesis_shape, *truth_indices, *truth_values, |
132 | *truth_shape)); |
133 | |
134 | TensorShape hypothesis_st_shape; |
135 | OP_REQUIRES_OK(ctx, |
136 | TensorShapeUtils::MakeShape( |
137 | hypothesis_shape->vec<int64_t>().data(), |
138 | hypothesis_shape->NumElements(), &hypothesis_st_shape)); |
139 | TensorShape truth_st_shape; |
140 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( |
141 | truth_shape->vec<int64_t>().data(), |
142 | truth_shape->NumElements(), &truth_st_shape)); |
143 | |
144 | // Assume indices are sorted in row-major order. |
145 | std::vector<int64_t> sorted_order(truth_st_shape.dims()); |
146 | std::iota(sorted_order.begin(), sorted_order.end(), 0); |
147 | |
148 | sparse::SparseTensor hypothesis; |
149 | OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create( |
150 | *hypothesis_indices, *hypothesis_values, |
151 | hypothesis_st_shape, sorted_order, &hypothesis)); |
152 | |
153 | sparse::SparseTensor truth; |
154 | OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create( |
155 | *truth_indices, *truth_values, truth_st_shape, |
156 | sorted_order, &truth)); |
157 | |
158 | // Group dims 0, 1, ..., RANK - 1. The very last dim is assumed |
159 | // to store the variable length sequences. |
160 | std::vector<int64_t> group_dims(truth_st_shape.dims() - 1); |
161 | std::iota(group_dims.begin(), group_dims.end(), 0); |
162 | |
163 | TensorShape output_shape; |
164 | for (int d = 0; d < static_cast<int>(group_dims.size()); ++d) { |
165 | output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d), |
166 | truth_st_shape.dim_size(d))); |
167 | } |
168 | const auto output_elements = output_shape.num_elements(); |
169 | OP_REQUIRES( |
170 | ctx, output_elements > 0, |
171 | errors::InvalidArgument("Got output shape " , output_shape.DebugString(), |
172 | " which has 0 elements" )); |
173 | |
174 | Tensor* output = nullptr; |
175 | OP_REQUIRES_OK(ctx, ctx->allocate_output("output" , output_shape, &output)); |
176 | auto output_t = output->flat<float>(); |
177 | output_t.setZero(); |
178 | |
179 | std::vector<int64_t> output_strides(output_shape.dims()); |
180 | output_strides[output_shape.dims() - 1] = 1; |
181 | for (int d = output_shape.dims() - 2; d >= 0; --d) { |
182 | output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); |
183 | } |
184 | |
185 | auto hypothesis_grouper = hypothesis.group(group_dims); |
186 | auto truth_grouper = truth.group(group_dims); |
187 | |
188 | auto hypothesis_iter = hypothesis_grouper.begin(); |
189 | auto truth_iter = truth_grouper.begin(); |
190 | |
191 | auto cmp = std::equal_to<T>(); |
192 | |
193 | while (hypothesis_iter != hypothesis_grouper.end() && |
194 | truth_iter != truth_grouper.end()) { |
195 | sparse::Group truth_i = *truth_iter; |
196 | sparse::Group hypothesis_j = *hypothesis_iter; |
197 | std::vector<int64_t> g_truth = truth_i.group(); |
198 | std::vector<int64_t> g_hypothesis = hypothesis_j.group(); |
199 | auto truth_seq = truth_i.values<T>(); |
200 | auto hypothesis_seq = hypothesis_j.values<T>(); |
201 | |
202 | if (g_truth == g_hypothesis) { |
203 | auto loc = std::inner_product(g_truth.begin(), g_truth.end(), |
204 | output_strides.begin(), int64_t{0}); |
205 | OP_REQUIRES( |
206 | ctx, 0 <= loc && loc < output_elements, |
207 | errors::Internal("Got an inner product " , loc, |
208 | " which would require writing to outside of " |
209 | "the buffer for the output tensor (max elements " , |
210 | output_elements, ")" )); |
211 | output_t(loc) = |
212 | gtl::LevenshteinDistance<T>(truth_seq, hypothesis_seq, cmp); |
213 | if (normalize_) output_t(loc) /= truth_seq.size(); |
214 | |
215 | ++hypothesis_iter; |
216 | ++truth_iter; |
217 | } else if (g_truth > g_hypothesis) { // zero-length truth |
218 | auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), |
219 | output_strides.begin(), int64_t{0}); |
220 | OP_REQUIRES( |
221 | ctx, 0 <= loc && loc < output_elements, |
222 | errors::Internal("Got an inner product " , loc, |
223 | " which would require writing to outside of " |
224 | "the buffer for the output tensor (max elements " , |
225 | output_elements, ")" )); |
226 | output_t(loc) = hypothesis_seq.size(); |
227 | if (normalize_ && output_t(loc) != 0.0f) { |
228 | output_t(loc) = std::numeric_limits<float>::infinity(); |
229 | } |
230 | ++hypothesis_iter; |
231 | } else { // zero-length hypothesis |
232 | auto loc = std::inner_product(g_truth.begin(), g_truth.end(), |
233 | output_strides.begin(), int64_t{0}); |
234 | OP_REQUIRES( |
235 | ctx, 0 <= loc && loc < output_elements, |
236 | errors::Internal("Got an inner product " , loc, |
237 | " which would require writing to outside of " |
238 | "the buffer for the output tensor (max elements " , |
239 | output_elements, ")" )); |
240 | output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); |
241 | ++truth_iter; |
242 | } |
243 | } |
244 | while (hypothesis_iter != hypothesis_grouper.end()) { // zero-length truths |
245 | sparse::Group hypothesis_j = *hypothesis_iter; |
246 | std::vector<int64_t> g_hypothesis = hypothesis_j.group(); |
247 | auto hypothesis_seq = hypothesis_j.values<T>(); |
248 | auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), |
249 | output_strides.begin(), int64_t{0}); |
250 | OP_REQUIRES( |
251 | ctx, 0 <= loc && loc < output_elements, |
252 | errors::Internal("Got an inner product " , loc, |
253 | " which would require writing to outside of the " |
254 | "buffer for the output tensor (max elements " , |
255 | output_elements, ")" )); |
256 | output_t(loc) = hypothesis_seq.size(); |
257 | if (normalize_ && output_t(loc) != 0.0f) { |
258 | output_t(loc) = std::numeric_limits<float>::infinity(); |
259 | } |
260 | ++hypothesis_iter; |
261 | } |
262 | while (truth_iter != truth_grouper.end()) { // missing hypotheses |
263 | sparse::Group truth_i = *truth_iter; |
264 | std::vector<int64_t> g_truth = truth_i.group(); |
265 | auto truth_seq = truth_i.values<T>(); |
266 | auto loc = std::inner_product(g_truth.begin(), g_truth.end(), |
267 | output_strides.begin(), int64_t{0}); |
268 | OP_REQUIRES( |
269 | ctx, 0 <= loc && loc < output_elements, |
270 | errors::Internal("Got an inner product " , loc, |
271 | " which would require writing to outside of the " |
272 | "buffer for the output tensor (max elements " , |
273 | output_elements, ")" )); |
274 | output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); |
275 | ++truth_iter; |
276 | } |
277 | } |
278 | |
279 | private: |
280 | bool normalize_; |
281 | |
282 | TF_DISALLOW_COPY_AND_ASSIGN(EditDistanceOp); |
283 | }; |
284 | |
285 | #define REGISTER_CPU_KERNEL(T) \ |
286 | REGISTER_KERNEL_BUILDER( \ |
287 | Name("EditDistance").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
288 | EditDistanceOp<T>); |
289 | |
290 | TF_CALL_POD_STRING_TYPES(REGISTER_CPU_KERNEL); |
291 | |
292 | #undef REGISTER_CPU_KERNEL |
293 | |
294 | } // end namespace tensorflow |
295 | |