1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include <limits>
21
22#include <vector>
23#include "tensorflow/core/common_runtime/device.h"
24#include "tensorflow/core/framework/op.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/lib/gtl/edit_distance.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/macros.h"
32#include "tensorflow/core/util/sparse/sparse_tensor.h"
33
34namespace tensorflow {
35
36namespace {
37
38Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices,
39 const Tensor& hypothesis_values,
40 const Tensor& hypothesis_shape,
41 const Tensor& truth_indices, const Tensor& truth_values,
42 const Tensor& truth_shape) {
43 if (!TensorShapeUtils::IsMatrix(hypothesis_indices.shape()))
44 return errors::InvalidArgument(
45 "hypothesis_indices should be a matrix, but got shape: ",
46 hypothesis_indices.shape().DebugString());
47 if (!TensorShapeUtils::IsMatrix(truth_indices.shape()))
48 return errors::InvalidArgument(
49 "truth_indices should be a matrix, but got shape: ",
50 truth_indices.shape().DebugString());
51 if (!TensorShapeUtils::IsVector(hypothesis_values.shape()))
52 return errors::InvalidArgument(
53 "hypothesis_values should be a vector, but got shape: ",
54 hypothesis_values.shape().DebugString());
55 if (!TensorShapeUtils::IsVector(truth_values.shape()))
56 return errors::InvalidArgument(
57 "truth_values should be a vector, but got shape: ",
58 truth_values.shape().DebugString());
59 if (!TensorShapeUtils::IsVector(hypothesis_shape.shape()))
60 return errors::InvalidArgument(
61 "hypothesis_shape should be a vector, but got shape: ",
62 hypothesis_shape.shape().DebugString());
63 if (!TensorShapeUtils::IsVector(truth_shape.shape()))
64 return errors::InvalidArgument(
65 "truth_shape should be a vector, but got shape: ",
66 truth_shape.shape().DebugString());
67 if (hypothesis_values.NumElements() != hypothesis_indices.dim_size(0))
68 return errors::InvalidArgument(
69 "Expected hypothesis_values.NumElements == "
70 "#rows(hypothesis_indices), their shapes are: ",
71 hypothesis_values.shape().DebugString(), " and ",
72 hypothesis_indices.shape().DebugString());
73 if (hypothesis_shape.NumElements() != hypothesis_indices.dim_size(1))
74 return errors::InvalidArgument(
75 "Expected hypothesis_shape.NumElements == "
76 "#cols(hypothesis_indices), their shapes are: ",
77 hypothesis_shape.shape().DebugString(), " and ",
78 hypothesis_indices.shape().DebugString());
79 if (truth_shape.NumElements() < 2)
80 return errors::InvalidArgument(
81 "Input SparseTensors must have rank at least 2, but truth_shape "
82 "rank is: ",
83 truth_shape.NumElements());
84 if (truth_values.NumElements() != truth_indices.dim_size(0))
85 return errors::InvalidArgument(
86 "Expected truth_values.NumElements == "
87 "#rows(truth_indices), their shapes are: ",
88 truth_values.shape().DebugString(), " and ",
89 truth_indices.shape().DebugString());
90 if (truth_shape.NumElements() != truth_indices.dim_size(1))
91 return errors::InvalidArgument(
92 "Expected truth_shape.NumElements == "
93 "#cols(truth_indices), their shapes are: ",
94 truth_shape.shape().DebugString(), " and ",
95 truth_indices.shape().DebugString());
96 if (truth_shape.NumElements() != hypothesis_shape.NumElements())
97 return errors::InvalidArgument(
98 "Expected truth and hypothesis to have matching ranks, but "
99 "their shapes are: ",
100 truth_shape.shape().DebugString(), " and ",
101 hypothesis_shape.shape().DebugString());
102
103 return OkStatus();
104}
105
106} // namespace
107
108template <typename T>
109class EditDistanceOp : public OpKernel {
110 public:
111 explicit EditDistanceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
112 OP_REQUIRES_OK(ctx, ctx->GetAttr("normalize", &normalize_));
113 }
114
115 void Compute(OpKernelContext* ctx) override {
116 const Tensor* hypothesis_indices;
117 const Tensor* hypothesis_values;
118 const Tensor* hypothesis_shape;
119 const Tensor* truth_indices;
120 const Tensor* truth_values;
121 const Tensor* truth_shape;
122 OP_REQUIRES_OK(ctx, ctx->input("hypothesis_indices", &hypothesis_indices));
123 OP_REQUIRES_OK(ctx, ctx->input("hypothesis_values", &hypothesis_values));
124 OP_REQUIRES_OK(ctx, ctx->input("hypothesis_shape", &hypothesis_shape));
125 OP_REQUIRES_OK(ctx, ctx->input("truth_indices", &truth_indices));
126 OP_REQUIRES_OK(ctx, ctx->input("truth_values", &truth_values));
127 OP_REQUIRES_OK(ctx, ctx->input("truth_shape", &truth_shape));
128
129 OP_REQUIRES_OK(
130 ctx, ValidateShapes(ctx, *hypothesis_indices, *hypothesis_values,
131 *hypothesis_shape, *truth_indices, *truth_values,
132 *truth_shape));
133
134 TensorShape hypothesis_st_shape;
135 OP_REQUIRES_OK(ctx,
136 TensorShapeUtils::MakeShape(
137 hypothesis_shape->vec<int64_t>().data(),
138 hypothesis_shape->NumElements(), &hypothesis_st_shape));
139 TensorShape truth_st_shape;
140 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
141 truth_shape->vec<int64_t>().data(),
142 truth_shape->NumElements(), &truth_st_shape));
143
144 // Assume indices are sorted in row-major order.
145 std::vector<int64_t> sorted_order(truth_st_shape.dims());
146 std::iota(sorted_order.begin(), sorted_order.end(), 0);
147
148 sparse::SparseTensor hypothesis;
149 OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create(
150 *hypothesis_indices, *hypothesis_values,
151 hypothesis_st_shape, sorted_order, &hypothesis));
152
153 sparse::SparseTensor truth;
154 OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create(
155 *truth_indices, *truth_values, truth_st_shape,
156 sorted_order, &truth));
157
158 // Group dims 0, 1, ..., RANK - 1. The very last dim is assumed
159 // to store the variable length sequences.
160 std::vector<int64_t> group_dims(truth_st_shape.dims() - 1);
161 std::iota(group_dims.begin(), group_dims.end(), 0);
162
163 TensorShape output_shape;
164 for (int d = 0; d < static_cast<int>(group_dims.size()); ++d) {
165 output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d),
166 truth_st_shape.dim_size(d)));
167 }
168 const auto output_elements = output_shape.num_elements();
169 OP_REQUIRES(
170 ctx, output_elements > 0,
171 errors::InvalidArgument("Got output shape ", output_shape.DebugString(),
172 " which has 0 elements"));
173
174 Tensor* output = nullptr;
175 OP_REQUIRES_OK(ctx, ctx->allocate_output("output", output_shape, &output));
176 auto output_t = output->flat<float>();
177 output_t.setZero();
178
179 std::vector<int64_t> output_strides(output_shape.dims());
180 output_strides[output_shape.dims() - 1] = 1;
181 for (int d = output_shape.dims() - 2; d >= 0; --d) {
182 output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
183 }
184
185 auto hypothesis_grouper = hypothesis.group(group_dims);
186 auto truth_grouper = truth.group(group_dims);
187
188 auto hypothesis_iter = hypothesis_grouper.begin();
189 auto truth_iter = truth_grouper.begin();
190
191 auto cmp = std::equal_to<T>();
192
193 while (hypothesis_iter != hypothesis_grouper.end() &&
194 truth_iter != truth_grouper.end()) {
195 sparse::Group truth_i = *truth_iter;
196 sparse::Group hypothesis_j = *hypothesis_iter;
197 std::vector<int64_t> g_truth = truth_i.group();
198 std::vector<int64_t> g_hypothesis = hypothesis_j.group();
199 auto truth_seq = truth_i.values<T>();
200 auto hypothesis_seq = hypothesis_j.values<T>();
201
202 if (g_truth == g_hypothesis) {
203 auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
204 output_strides.begin(), int64_t{0});
205 OP_REQUIRES(
206 ctx, 0 <= loc && loc < output_elements,
207 errors::Internal("Got an inner product ", loc,
208 " which would require writing to outside of "
209 "the buffer for the output tensor (max elements ",
210 output_elements, ")"));
211 output_t(loc) =
212 gtl::LevenshteinDistance<T>(truth_seq, hypothesis_seq, cmp);
213 if (normalize_) output_t(loc) /= truth_seq.size();
214
215 ++hypothesis_iter;
216 ++truth_iter;
217 } else if (g_truth > g_hypothesis) { // zero-length truth
218 auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
219 output_strides.begin(), int64_t{0});
220 OP_REQUIRES(
221 ctx, 0 <= loc && loc < output_elements,
222 errors::Internal("Got an inner product ", loc,
223 " which would require writing to outside of "
224 "the buffer for the output tensor (max elements ",
225 output_elements, ")"));
226 output_t(loc) = hypothesis_seq.size();
227 if (normalize_ && output_t(loc) != 0.0f) {
228 output_t(loc) = std::numeric_limits<float>::infinity();
229 }
230 ++hypothesis_iter;
231 } else { // zero-length hypothesis
232 auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
233 output_strides.begin(), int64_t{0});
234 OP_REQUIRES(
235 ctx, 0 <= loc && loc < output_elements,
236 errors::Internal("Got an inner product ", loc,
237 " which would require writing to outside of "
238 "the buffer for the output tensor (max elements ",
239 output_elements, ")"));
240 output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
241 ++truth_iter;
242 }
243 }
244 while (hypothesis_iter != hypothesis_grouper.end()) { // zero-length truths
245 sparse::Group hypothesis_j = *hypothesis_iter;
246 std::vector<int64_t> g_hypothesis = hypothesis_j.group();
247 auto hypothesis_seq = hypothesis_j.values<T>();
248 auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
249 output_strides.begin(), int64_t{0});
250 OP_REQUIRES(
251 ctx, 0 <= loc && loc < output_elements,
252 errors::Internal("Got an inner product ", loc,
253 " which would require writing to outside of the "
254 "buffer for the output tensor (max elements ",
255 output_elements, ")"));
256 output_t(loc) = hypothesis_seq.size();
257 if (normalize_ && output_t(loc) != 0.0f) {
258 output_t(loc) = std::numeric_limits<float>::infinity();
259 }
260 ++hypothesis_iter;
261 }
262 while (truth_iter != truth_grouper.end()) { // missing hypotheses
263 sparse::Group truth_i = *truth_iter;
264 std::vector<int64_t> g_truth = truth_i.group();
265 auto truth_seq = truth_i.values<T>();
266 auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
267 output_strides.begin(), int64_t{0});
268 OP_REQUIRES(
269 ctx, 0 <= loc && loc < output_elements,
270 errors::Internal("Got an inner product ", loc,
271 " which would require writing to outside of the "
272 "buffer for the output tensor (max elements ",
273 output_elements, ")"));
274 output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
275 ++truth_iter;
276 }
277 }
278
279 private:
280 bool normalize_;
281
282 TF_DISALLOW_COPY_AND_ASSIGN(EditDistanceOp);
283};
284
285#define REGISTER_CPU_KERNEL(T) \
286 REGISTER_KERNEL_BUILDER( \
287 Name("EditDistance").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
288 EditDistanceOp<T>);
289
290TF_CALL_POD_STRING_TYPES(REGISTER_CPU_KERNEL);
291
292#undef REGISTER_CPU_KERNEL
293
294} // end namespace tensorflow
295