1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/op_def_builder.h"
20#include "tensorflow/core/framework/shape_inference.h"
21
22namespace tensorflow {
23
24using shape_inference::DimensionHandle;
25using shape_inference::InferenceContext;
26using shape_inference::ShapeAndType;
27using shape_inference::ShapeHandle;
28
29// --------------------------------------------------------------------------
30
31namespace {
32Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
33 ShapeHandle handle;
34 DimensionHandle unused_handle;
35 for (int i = 0; i < c->num_inputs(); ++i) {
36 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
37 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
38 }
39 for (int i = 0; i < c->num_outputs(); ++i) {
40 c->set_output(i, c->Scalar());
41 }
42 return OkStatus();
43}
44
45Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
46 ShapeHandle handle;
47 DimensionHandle unused_handle;
48 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
49 for (int i = 1; i < c->num_inputs(); ++i) {
50 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
51 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
52 }
53 for (int i = 0; i < c->num_outputs(); ++i) {
54 c->set_output(i, c->Scalar());
55 }
56 return OkStatus();
57}
58
59Status TwoElementOutput(InferenceContext* c) {
60 c->set_output(0, c->Vector(2));
61 return OkStatus();
62}
63
64Status ScalarOutput(InferenceContext* c) {
65 c->set_output(0, c->Scalar());
66 return OkStatus();
67}
68} // namespace
69
70REGISTER_OP("LookupTableFind")
71 .Input("table_handle: Ref(string)")
72 .Input("keys: Tin")
73 .Input("default_value: Tout")
74 .Output("values: Tout")
75 .Attr("Tin: type")
76 .Attr("Tout: type")
77 .SetShapeFn([](InferenceContext* c) {
78 ShapeHandle handle;
79 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
80 DimensionHandle unused_dim;
81 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
82
83 // Default value must be scalar or vector.
84 ShapeHandle unused;
85 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
86 c->set_output(0, c->UnknownShape());
87 return OkStatus();
88 });
89
90Status ValidateTableType(InferenceContext* c,
91 const ShapeAndType& key_shape_and_type,
92 const string& key_dtype_attr,
93 const ShapeAndType& value_shape_and_type,
94 const string& value_dtype_attr) {
95 DataType key_dtype;
96 TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype));
97 if (key_shape_and_type.dtype != key_dtype) {
98 return errors::InvalidArgument(
99 "Trying to read value with wrong dtype. "
100 "Expected ",
101 DataTypeString(key_shape_and_type.dtype), " got ",
102 DataTypeString(key_dtype));
103 }
104 DataType value_dtype;
105 TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype));
106 if (value_shape_and_type.dtype != value_dtype) {
107 return errors::InvalidArgument(
108 "Trying to read value with wrong dtype. "
109 "Expected ",
110 DataTypeString(value_shape_and_type.dtype), " got ",
111 DataTypeString(value_dtype));
112 }
113 return OkStatus();
114}
115
116Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys,
117 const string& key_dtype_attr,
118 const string& value_dtype_attr,
119 ShapeAndType* output_shape_and_type) {
120 auto* handle_data = c->input_handle_shapes_and_types(0);
121 if (handle_data == nullptr || handle_data->size() != 2) {
122 output_shape_and_type->shape = c->UnknownShape();
123 output_shape_and_type->dtype = DT_INVALID;
124 } else {
125 const ShapeAndType& key_shape_and_type = (*handle_data)[0];
126 const ShapeAndType& value_shape_and_type = (*handle_data)[1];
127 TF_RETURN_IF_ERROR(ValidateTableType(c, key_shape_and_type, key_dtype_attr,
128 value_shape_and_type,
129 value_dtype_attr));
130 output_shape_and_type->dtype = value_shape_and_type.dtype;
131 if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) {
132 int keys_rank = c->Rank(keys);
133 int key_suffix_rank = c->Rank(key_shape_and_type.shape);
134 if (keys_rank < key_suffix_rank) {
135 return errors::InvalidArgument(
136 "Expected keys to have suffix ",
137 c->DebugString(key_shape_and_type.shape),
138 " but saw shape: ", c->DebugString(keys));
139 }
140 for (int d = 0; d < key_suffix_rank; d++) {
141 // Ensure the suffix of keys match what's in the Table.
142 DimensionHandle dim = c->Dim(key_shape_and_type.shape, d);
143 TF_RETURN_IF_ERROR(
144 c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys));
145 }
146 std::vector<DimensionHandle> keys_prefix_vec;
147 keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
148 for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
149 keys_prefix_vec.push_back(c->Dim(keys, d));
150 }
151 ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec);
152 TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix, value_shape_and_type.shape,
153 &output_shape_and_type->shape));
154 } else {
155 output_shape_and_type->shape = c->UnknownShape();
156 }
157 }
158 return OkStatus();
159}
160
161REGISTER_OP("LookupTableFindV2")
162 .Input("table_handle: resource")
163 .Input("keys: Tin")
164 .Input("default_value: Tout")
165 .Output("values: Tout")
166 .Attr("Tin: type")
167 .Attr("Tout: type")
168 .SetShapeFn([](InferenceContext* c) {
169 ShapeHandle handle;
170 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
171
172 ShapeAndType value_shape_and_type;
173 TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
174 c,
175 /*keys=*/c->input(1),
176 /*key_dtype_attr=*/"Tin",
177 /*value_dtype_attr=*/"Tout", &value_shape_and_type));
178 c->set_output(0, value_shape_and_type.shape);
179
180 return OkStatus();
181 });
182ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2");
183// TODO(b/72710477): Update this.
184
185REGISTER_OP("LookupTableInsert")
186 .Input("table_handle: Ref(string)")
187 .Input("keys: Tin")
188 .Input("values: Tout")
189 .Attr("Tin: type")
190 .Attr("Tout: type")
191 .SetShapeFn([](InferenceContext* c) {
192 ShapeHandle handle;
193 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
194 DimensionHandle unused_dim;
195 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
196
197 // TODO(ebrevdo): Validate keys and values shape.
198 return OkStatus();
199 });
200
201REGISTER_OP("LookupTableInsertV2")
202 .Input("table_handle: resource")
203 .Input("keys: Tin")
204 .Input("values: Tout")
205 .Attr("Tin: type")
206 .Attr("Tout: type")
207 .SetShapeFn([](InferenceContext* c) {
208 ShapeHandle handle;
209 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
210
211 // TODO: Validate keys and values shape.
212 return OkStatus();
213 });
214
215REGISTER_OP("LookupTableRemoveV2")
216 .Input("table_handle: resource")
217 .Input("keys: Tin")
218 .Attr("Tin: type")
219 .SetShapeFn([](InferenceContext* c) {
220 ShapeHandle handle;
221 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
222 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle));
223
224 // TODO(turboale): Validate keys shape.
225 return OkStatus();
226 });
227
228REGISTER_OP("LookupTableSize")
229 .Input("table_handle: Ref(string)")
230 .Output("size: int64")
231 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
232ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableSize");
233
234REGISTER_OP("LookupTableSizeV2")
235 .Input("table_handle: resource")
236 .Output("size: int64")
237 .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs);
238ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableSizeV2");
239
240REGISTER_OP("LookupTableExport")
241 .Input("table_handle: Ref(string)")
242 .Output("keys: Tkeys")
243 .Output("values: Tvalues")
244 .Attr("Tkeys: type")
245 .Attr("Tvalues: type")
246 .SetShapeFn([](InferenceContext* c) {
247 ShapeHandle handle;
248 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
249 DimensionHandle unused_dim;
250 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
251
252 ShapeHandle values = c->UnknownShape();
253 TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
254 ShapeHandle keys = c->Vector(c->Dim(values, 0));
255 c->set_output(0, keys);
256 c->set_output(1, values);
257 return OkStatus();
258 });
259
260REGISTER_OP("LookupTableExportV2")
261 .Input("table_handle: resource")
262 .Output("keys: Tkeys")
263 .Output("values: Tvalues")
264 .Attr("Tkeys: type")
265 .Attr("Tvalues: type")
266 .SetShapeFn([](InferenceContext* c) {
267 ShapeHandle handle;
268 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
269 auto* handle_data = c->input_handle_shapes_and_types(0);
270 if (handle_data != nullptr && handle_data->size() == 2) {
271 const ShapeAndType& key_shape_and_type = (*handle_data)[0];
272 const ShapeAndType& value_shape_and_type = (*handle_data)[1];
273 TF_RETURN_IF_ERROR(ValidateTableType(c, key_shape_and_type,
274 /*key_dtype_attr*/ "Tkeys",
275 value_shape_and_type,
276 /*value_dtype_attr*/ "Tvalues"));
277 }
278 // Different lookup tables have different output shapes.
279 c->set_output(0, c->UnknownShape());
280 c->set_output(1, c->UnknownShape());
281 return OkStatus();
282 });
283
284REGISTER_OP("LookupTableImport")
285 .Input("table_handle: Ref(string)")
286 .Input("keys: Tin")
287 .Input("values: Tout")
288 .Attr("Tin: type")
289 .Attr("Tout: type")
290 .SetShapeFn([](InferenceContext* c) {
291 ShapeHandle handle;
292 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
293 DimensionHandle unused_dim;
294 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
295
296 // TODO(ebrevdo): Validate keys and values shape.
297 return OkStatus();
298 });
299
300REGISTER_OP("LookupTableImportV2")
301 .Input("table_handle: resource")
302 .Input("keys: Tin")
303 .Input("values: Tout")
304 .Attr("Tin: type")
305 .Attr("Tout: type")
306 .SetShapeFn([](InferenceContext* c) {
307 ShapeHandle handle;
308 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
309
310 ShapeHandle keys;
311 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
312 DimensionHandle unused;
313 TF_RETURN_IF_ERROR(
314 c->Merge(c->Dim(keys, 0), c->Dim(c->input(2), 0), &unused));
315 return OkStatus();
316 });
317
318Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key,
319 const ShapeHandle& value) {
320 c->set_output(0, c->Scalar());
321
322 ShapeHandle key_s;
323 TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s));
324
325 DataType key_t;
326 TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t));
327
328 DataType value_t;
329 TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t));
330
331 // ShapeAndType vector for {key, value}.
332 c->set_output_handle_shapes_and_types(
333 0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}});
334
335 return OkStatus();
336}
337
338Status MutableHashTableShapeFn(InferenceContext* c) {
339 return MutableHashTableShape(c, /*key=*/c->Scalar(),
340 /*value=*/c->Scalar());
341}
342
343Status MutableHashTableOfTensorsShapeFn(InferenceContext* c) {
344 PartialTensorShape value_p;
345 TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
346 ShapeHandle value_s;
347 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
348 return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s);
349}
350
351Status MutableDenseHashTableShapeFn(InferenceContext* c) {
352 PartialTensorShape value_p;
353 TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
354 ShapeHandle value_s;
355 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
356 return MutableHashTableShape(c, /*key=*/c->input(0), /*value=*/value_s);
357}
358
359REGISTER_OP("HashTable")
360 .Output("table_handle: Ref(string)")
361 .Attr("container: string = ''")
362 .Attr("shared_name: string = ''")
363 .Attr("use_node_name_sharing: bool = false")
364 .Attr("key_dtype: type")
365 .Attr("value_dtype: type")
366 .SetIsStateful()
367 .SetShapeFn(TwoElementOutput);
368
369REGISTER_OP("HashTableV2")
370 .Output("table_handle: resource")
371 .Attr("container: string = ''")
372 .Attr("shared_name: string = ''")
373 .Attr("use_node_name_sharing: bool = false")
374 .Attr("key_dtype: type")
375 .Attr("value_dtype: type")
376 .SetIsStateful()
377 .SetShapeFn(ScalarOutput);
378
379REGISTER_OP("AnonymousHashTable")
380 .Output("table_handle: resource")
381 .Attr("key_dtype: type")
382 .Attr("value_dtype: type")
383 .SetIsStateful()
384 .SetShapeFn(ScalarOutput);
385
386REGISTER_OP("MutableHashTable")
387 .Output("table_handle: Ref(string)")
388 .Attr("container: string = ''")
389 .Attr("shared_name: string = ''")
390 .Attr("use_node_name_sharing: bool = false")
391 .Attr("key_dtype: type")
392 .Attr("value_dtype: type")
393 .SetIsStateful()
394 .SetShapeFn(TwoElementOutput);
395
396REGISTER_OP("MutableHashTableV2")
397 .Output("table_handle: resource")
398 .Attr("container: string = ''")
399 .Attr("shared_name: string = ''")
400 .Attr("use_node_name_sharing: bool = false")
401 .Attr("key_dtype: type")
402 .Attr("value_dtype: type")
403 .SetIsStateful()
404 .SetShapeFn(MutableHashTableShapeFn);
405
406REGISTER_OP("AnonymousMutableHashTable")
407 .Output("table_handle: resource")
408 .Attr("key_dtype: type")
409 .Attr("value_dtype: type")
410 .SetIsStateful()
411 .SetShapeFn(MutableHashTableShapeFn);
412
413REGISTER_OP("MutableHashTableOfTensors")
414 .Output("table_handle: Ref(string)")
415 .Attr("container: string = ''")
416 .Attr("shared_name: string = ''")
417 .Attr("use_node_name_sharing: bool = false")
418 .Attr("key_dtype: type")
419 .Attr("value_dtype: type")
420 .Attr("value_shape: shape = {}")
421 .SetIsStateful()
422 .SetShapeFn(TwoElementOutput);
423
424REGISTER_OP("MutableHashTableOfTensorsV2")
425 .Output("table_handle: resource")
426 .Attr("container: string = ''")
427 .Attr("shared_name: string = ''")
428 .Attr("use_node_name_sharing: bool = false")
429 .Attr("key_dtype: type")
430 .Attr("value_dtype: type")
431 .Attr("value_shape: shape = {}")
432 .SetIsStateful()
433 .SetShapeFn(MutableHashTableOfTensorsShapeFn);
434
435REGISTER_OP("AnonymousMutableHashTableOfTensors")
436 .Output("table_handle: resource")
437 .Attr("key_dtype: type")
438 .Attr("value_dtype: type")
439 .Attr("value_shape: shape = {}")
440 .SetIsStateful()
441 .SetShapeFn(MutableHashTableOfTensorsShapeFn);
442
443REGISTER_OP("MutableDenseHashTable")
444 .Input("empty_key: key_dtype")
445 .Output("table_handle: Ref(string)")
446 .Attr("container: string = ''")
447 .Attr("shared_name: string = ''")
448 .Attr("use_node_name_sharing: bool = false")
449 .Attr("key_dtype: type")
450 .Attr("value_dtype: type")
451 .Attr("value_shape: shape = {}")
452 .Attr("initial_num_buckets: int = 131072") // 2^17
453 .Attr("max_load_factor: float = 0.8")
454 .SetIsStateful()
455 .SetShapeFn(TwoElementOutput);
456
457REGISTER_OP("MutableDenseHashTableV2")
458 .Input("empty_key: key_dtype")
459 .Input("deleted_key: key_dtype")
460 .Output("table_handle: resource")
461 .Attr("container: string = ''")
462 .Attr("shared_name: string = ''")
463 .Attr("use_node_name_sharing: bool = false")
464 .Attr("key_dtype: type")
465 .Attr("value_dtype: type")
466 .Attr("value_shape: shape = {}")
467 .Attr("initial_num_buckets: int = 131072") // 2^17
468 .Attr("max_load_factor: float = 0.8")
469 .SetIsStateful()
470 .SetShapeFn(MutableDenseHashTableShapeFn);
471
472REGISTER_OP("AnonymousMutableDenseHashTable")
473 .Input("empty_key: key_dtype")
474 .Input("deleted_key: key_dtype")
475 .Output("table_handle: resource")
476 .Attr("key_dtype: type")
477 .Attr("value_dtype: type")
478 .Attr("value_shape: shape = {}")
479 .Attr("initial_num_buckets: int = 131072") // 2^17
480 .Attr("max_load_factor: float = 0.8")
481 .SetIsStateful()
482 .SetShapeFn(MutableDenseHashTableShapeFn);
483
484REGISTER_OP("InitializeTable")
485 .Input("table_handle: Ref(string)")
486 .Input("keys: Tkey")
487 .Input("values: Tval")
488 .Attr("Tkey: type")
489 .Attr("Tval: type")
490 .SetShapeFn([](InferenceContext* c) {
491 ShapeHandle handle;
492 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
493 DimensionHandle unused_dim;
494 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
495
496 ShapeHandle keys;
497 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
498 TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
499 return OkStatus();
500 });
501
502REGISTER_OP("InitializeTableV2")
503 .Input("table_handle: resource")
504 .Input("keys: Tkey")
505 .Input("values: Tval")
506 .Attr("Tkey: type")
507 .Attr("Tval: type")
508 .SetShapeFn([](InferenceContext* c) {
509 ShapeHandle handle;
510 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
511
512 ShapeHandle keys;
513 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
514 TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
515 return OkStatus();
516 });
517
518REGISTER_OP("InitializeTableFromTextFile")
519 .Input("table_handle: Ref(string)")
520 .Input("filename: string")
521 .Attr("key_index: int >= -2")
522 .Attr("value_index: int >= -2")
523 .Attr("vocab_size: int >= -1 = -1")
524 .Attr("delimiter: string = '\t'")
525 .Attr("offset: int = 0")
526 .SetShapeFn([](InferenceContext* c) {
527 ShapeHandle handle;
528 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
529 DimensionHandle unused_dim;
530 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
531
532 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
533 return OkStatus();
534 });
535
536REGISTER_OP("InitializeTableFromTextFileV2")
537 .Input("table_handle: resource")
538 .Input("filename: string")
539 .Attr("key_index: int >= -2")
540 .Attr("value_index: int >= -2")
541 .Attr("vocab_size: int >= -1 = -1")
542 .Attr("delimiter: string = '\t'")
543 .Attr("offset: int = 0")
544 .SetShapeFn([](InferenceContext* c) {
545 ShapeHandle handle;
546 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
547
548 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
549 return OkStatus();
550 });
551
552} // namespace tensorflow
553