1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h" |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/op_def_builder.h" |
20 | #include "tensorflow/core/framework/shape_inference.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | using shape_inference::DimensionHandle; |
25 | using shape_inference::InferenceContext; |
26 | using shape_inference::ShapeAndType; |
27 | using shape_inference::ShapeHandle; |
28 | |
29 | // -------------------------------------------------------------------------- |
30 | |
31 | namespace { |
32 | Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { |
33 | ShapeHandle handle; |
34 | DimensionHandle unused_handle; |
35 | for (int i = 0; i < c->num_inputs(); ++i) { |
36 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); |
37 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); |
38 | } |
39 | for (int i = 0; i < c->num_outputs(); ++i) { |
40 | c->set_output(i, c->Scalar()); |
41 | } |
42 | return OkStatus(); |
43 | } |
44 | |
45 | Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { |
46 | ShapeHandle handle; |
47 | DimensionHandle unused_handle; |
48 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
49 | for (int i = 1; i < c->num_inputs(); ++i) { |
50 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); |
51 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); |
52 | } |
53 | for (int i = 0; i < c->num_outputs(); ++i) { |
54 | c->set_output(i, c->Scalar()); |
55 | } |
56 | return OkStatus(); |
57 | } |
58 | |
59 | Status TwoElementOutput(InferenceContext* c) { |
60 | c->set_output(0, c->Vector(2)); |
61 | return OkStatus(); |
62 | } |
63 | |
64 | Status ScalarOutput(InferenceContext* c) { |
65 | c->set_output(0, c->Scalar()); |
66 | return OkStatus(); |
67 | } |
68 | } // namespace |
69 | |
70 | REGISTER_OP("LookupTableFind" ) |
71 | .Input("table_handle: Ref(string)" ) |
72 | .Input("keys: Tin" ) |
73 | .Input("default_value: Tout" ) |
74 | .Output("values: Tout" ) |
75 | .Attr("Tin: type" ) |
76 | .Attr("Tout: type" ) |
77 | .SetShapeFn([](InferenceContext* c) { |
78 | ShapeHandle handle; |
79 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
80 | DimensionHandle unused_dim; |
81 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
82 | |
83 | // Default value must be scalar or vector. |
84 | ShapeHandle unused; |
85 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); |
86 | c->set_output(0, c->UnknownShape()); |
87 | return OkStatus(); |
88 | }); |
89 | |
90 | Status ValidateTableType(InferenceContext* c, |
91 | const ShapeAndType& key_shape_and_type, |
92 | const string& key_dtype_attr, |
93 | const ShapeAndType& value_shape_and_type, |
94 | const string& value_dtype_attr) { |
95 | DataType key_dtype; |
96 | TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); |
97 | if (key_shape_and_type.dtype != key_dtype) { |
98 | return errors::InvalidArgument( |
99 | "Trying to read value with wrong dtype. " |
100 | "Expected " , |
101 | DataTypeString(key_shape_and_type.dtype), " got " , |
102 | DataTypeString(key_dtype)); |
103 | } |
104 | DataType value_dtype; |
105 | TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype)); |
106 | if (value_shape_and_type.dtype != value_dtype) { |
107 | return errors::InvalidArgument( |
108 | "Trying to read value with wrong dtype. " |
109 | "Expected " , |
110 | DataTypeString(value_shape_and_type.dtype), " got " , |
111 | DataTypeString(value_dtype)); |
112 | } |
113 | return OkStatus(); |
114 | } |
115 | |
116 | Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys, |
117 | const string& key_dtype_attr, |
118 | const string& value_dtype_attr, |
119 | ShapeAndType* output_shape_and_type) { |
120 | auto* handle_data = c->input_handle_shapes_and_types(0); |
121 | if (handle_data == nullptr || handle_data->size() != 2) { |
122 | output_shape_and_type->shape = c->UnknownShape(); |
123 | output_shape_and_type->dtype = DT_INVALID; |
124 | } else { |
125 | const ShapeAndType& key_shape_and_type = (*handle_data)[0]; |
126 | const ShapeAndType& value_shape_and_type = (*handle_data)[1]; |
127 | TF_RETURN_IF_ERROR(ValidateTableType(c, key_shape_and_type, key_dtype_attr, |
128 | value_shape_and_type, |
129 | value_dtype_attr)); |
130 | output_shape_and_type->dtype = value_shape_and_type.dtype; |
131 | if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) { |
132 | int keys_rank = c->Rank(keys); |
133 | int key_suffix_rank = c->Rank(key_shape_and_type.shape); |
134 | if (keys_rank < key_suffix_rank) { |
135 | return errors::InvalidArgument( |
136 | "Expected keys to have suffix " , |
137 | c->DebugString(key_shape_and_type.shape), |
138 | " but saw shape: " , c->DebugString(keys)); |
139 | } |
140 | for (int d = 0; d < key_suffix_rank; d++) { |
141 | // Ensure the suffix of keys match what's in the Table. |
142 | DimensionHandle dim = c->Dim(key_shape_and_type.shape, d); |
143 | TF_RETURN_IF_ERROR( |
144 | c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys)); |
145 | } |
146 | std::vector<DimensionHandle> keys_prefix_vec; |
147 | keys_prefix_vec.reserve(keys_rank - key_suffix_rank); |
148 | for (int d = 0; d < keys_rank - key_suffix_rank; ++d) { |
149 | keys_prefix_vec.push_back(c->Dim(keys, d)); |
150 | } |
151 | ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec); |
152 | TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix, value_shape_and_type.shape, |
153 | &output_shape_and_type->shape)); |
154 | } else { |
155 | output_shape_and_type->shape = c->UnknownShape(); |
156 | } |
157 | } |
158 | return OkStatus(); |
159 | } |
160 | |
161 | REGISTER_OP("LookupTableFindV2" ) |
162 | .Input("table_handle: resource" ) |
163 | .Input("keys: Tin" ) |
164 | .Input("default_value: Tout" ) |
165 | .Output("values: Tout" ) |
166 | .Attr("Tin: type" ) |
167 | .Attr("Tout: type" ) |
168 | .SetShapeFn([](InferenceContext* c) { |
169 | ShapeHandle handle; |
170 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
171 | |
172 | ShapeAndType value_shape_and_type; |
173 | TF_RETURN_IF_ERROR(ValidateTableResourceHandle( |
174 | c, |
175 | /*keys=*/c->input(1), |
176 | /*key_dtype_attr=*/"Tin" , |
177 | /*value_dtype_attr=*/"Tout" , &value_shape_and_type)); |
178 | c->set_output(0, value_shape_and_type.shape); |
179 | |
180 | return OkStatus(); |
181 | }); |
182 | ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2" ); |
183 | // TODO(b/72710477): Update this. |
184 | |
185 | REGISTER_OP("LookupTableInsert" ) |
186 | .Input("table_handle: Ref(string)" ) |
187 | .Input("keys: Tin" ) |
188 | .Input("values: Tout" ) |
189 | .Attr("Tin: type" ) |
190 | .Attr("Tout: type" ) |
191 | .SetShapeFn([](InferenceContext* c) { |
192 | ShapeHandle handle; |
193 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
194 | DimensionHandle unused_dim; |
195 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
196 | |
197 | // TODO(ebrevdo): Validate keys and values shape. |
198 | return OkStatus(); |
199 | }); |
200 | |
201 | REGISTER_OP("LookupTableInsertV2" ) |
202 | .Input("table_handle: resource" ) |
203 | .Input("keys: Tin" ) |
204 | .Input("values: Tout" ) |
205 | .Attr("Tin: type" ) |
206 | .Attr("Tout: type" ) |
207 | .SetShapeFn([](InferenceContext* c) { |
208 | ShapeHandle handle; |
209 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
210 | |
211 | // TODO: Validate keys and values shape. |
212 | return OkStatus(); |
213 | }); |
214 | |
215 | REGISTER_OP("LookupTableRemoveV2" ) |
216 | .Input("table_handle: resource" ) |
217 | .Input("keys: Tin" ) |
218 | .Attr("Tin: type" ) |
219 | .SetShapeFn([](InferenceContext* c) { |
220 | ShapeHandle handle; |
221 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
222 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle)); |
223 | |
224 | // TODO(turboale): Validate keys shape. |
225 | return OkStatus(); |
226 | }); |
227 | |
228 | REGISTER_OP("LookupTableSize" ) |
229 | .Input("table_handle: Ref(string)" ) |
230 | .Output("size: int64" ) |
231 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); |
232 | ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableSize" ); |
233 | |
234 | REGISTER_OP("LookupTableSizeV2" ) |
235 | .Input("table_handle: resource" ) |
236 | .Output("size: int64" ) |
237 | .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs); |
238 | ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableSizeV2" ); |
239 | |
240 | REGISTER_OP("LookupTableExport" ) |
241 | .Input("table_handle: Ref(string)" ) |
242 | .Output("keys: Tkeys" ) |
243 | .Output("values: Tvalues" ) |
244 | .Attr("Tkeys: type" ) |
245 | .Attr("Tvalues: type" ) |
246 | .SetShapeFn([](InferenceContext* c) { |
247 | ShapeHandle handle; |
248 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
249 | DimensionHandle unused_dim; |
250 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
251 | |
252 | ShapeHandle values = c->UnknownShape(); |
253 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); |
254 | ShapeHandle keys = c->Vector(c->Dim(values, 0)); |
255 | c->set_output(0, keys); |
256 | c->set_output(1, values); |
257 | return OkStatus(); |
258 | }); |
259 | |
260 | REGISTER_OP("LookupTableExportV2" ) |
261 | .Input("table_handle: resource" ) |
262 | .Output("keys: Tkeys" ) |
263 | .Output("values: Tvalues" ) |
264 | .Attr("Tkeys: type" ) |
265 | .Attr("Tvalues: type" ) |
266 | .SetShapeFn([](InferenceContext* c) { |
267 | ShapeHandle handle; |
268 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
269 | auto* handle_data = c->input_handle_shapes_and_types(0); |
270 | if (handle_data != nullptr && handle_data->size() == 2) { |
271 | const ShapeAndType& key_shape_and_type = (*handle_data)[0]; |
272 | const ShapeAndType& value_shape_and_type = (*handle_data)[1]; |
273 | TF_RETURN_IF_ERROR(ValidateTableType(c, key_shape_and_type, |
274 | /*key_dtype_attr*/ "Tkeys" , |
275 | value_shape_and_type, |
276 | /*value_dtype_attr*/ "Tvalues" )); |
277 | } |
278 | // Different lookup tables have different output shapes. |
279 | c->set_output(0, c->UnknownShape()); |
280 | c->set_output(1, c->UnknownShape()); |
281 | return OkStatus(); |
282 | }); |
283 | |
284 | REGISTER_OP("LookupTableImport" ) |
285 | .Input("table_handle: Ref(string)" ) |
286 | .Input("keys: Tin" ) |
287 | .Input("values: Tout" ) |
288 | .Attr("Tin: type" ) |
289 | .Attr("Tout: type" ) |
290 | .SetShapeFn([](InferenceContext* c) { |
291 | ShapeHandle handle; |
292 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
293 | DimensionHandle unused_dim; |
294 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
295 | |
296 | // TODO(ebrevdo): Validate keys and values shape. |
297 | return OkStatus(); |
298 | }); |
299 | |
300 | REGISTER_OP("LookupTableImportV2" ) |
301 | .Input("table_handle: resource" ) |
302 | .Input("keys: Tin" ) |
303 | .Input("values: Tout" ) |
304 | .Attr("Tin: type" ) |
305 | .Attr("Tout: type" ) |
306 | .SetShapeFn([](InferenceContext* c) { |
307 | ShapeHandle handle; |
308 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
309 | |
310 | ShapeHandle keys; |
311 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); |
312 | DimensionHandle unused; |
313 | TF_RETURN_IF_ERROR( |
314 | c->Merge(c->Dim(keys, 0), c->Dim(c->input(2), 0), &unused)); |
315 | return OkStatus(); |
316 | }); |
317 | |
318 | Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key, |
319 | const ShapeHandle& value) { |
320 | c->set_output(0, c->Scalar()); |
321 | |
322 | ShapeHandle key_s; |
323 | TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s)); |
324 | |
325 | DataType key_t; |
326 | TF_RETURN_IF_ERROR(c->GetAttr("key_dtype" , &key_t)); |
327 | |
328 | DataType value_t; |
329 | TF_RETURN_IF_ERROR(c->GetAttr("value_dtype" , &value_t)); |
330 | |
331 | // ShapeAndType vector for {key, value}. |
332 | c->set_output_handle_shapes_and_types( |
333 | 0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}}); |
334 | |
335 | return OkStatus(); |
336 | } |
337 | |
338 | Status MutableHashTableShapeFn(InferenceContext* c) { |
339 | return MutableHashTableShape(c, /*key=*/c->Scalar(), |
340 | /*value=*/c->Scalar()); |
341 | } |
342 | |
343 | Status (InferenceContext* c) { |
344 | PartialTensorShape value_p; |
345 | TF_RETURN_IF_ERROR(c->GetAttr("value_shape" , &value_p)); |
346 | ShapeHandle value_s; |
347 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s)); |
348 | return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s); |
349 | } |
350 | |
351 | Status MutableDenseHashTableShapeFn(InferenceContext* c) { |
352 | PartialTensorShape value_p; |
353 | TF_RETURN_IF_ERROR(c->GetAttr("value_shape" , &value_p)); |
354 | ShapeHandle value_s; |
355 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s)); |
356 | return MutableHashTableShape(c, /*key=*/c->input(0), /*value=*/value_s); |
357 | } |
358 | |
359 | REGISTER_OP("HashTable" ) |
360 | .Output("table_handle: Ref(string)" ) |
361 | .Attr("container: string = ''" ) |
362 | .Attr("shared_name: string = ''" ) |
363 | .Attr("use_node_name_sharing: bool = false" ) |
364 | .Attr("key_dtype: type" ) |
365 | .Attr("value_dtype: type" ) |
366 | .SetIsStateful() |
367 | .SetShapeFn(TwoElementOutput); |
368 | |
369 | REGISTER_OP("HashTableV2" ) |
370 | .Output("table_handle: resource" ) |
371 | .Attr("container: string = ''" ) |
372 | .Attr("shared_name: string = ''" ) |
373 | .Attr("use_node_name_sharing: bool = false" ) |
374 | .Attr("key_dtype: type" ) |
375 | .Attr("value_dtype: type" ) |
376 | .SetIsStateful() |
377 | .SetShapeFn(ScalarOutput); |
378 | |
379 | REGISTER_OP("AnonymousHashTable" ) |
380 | .Output("table_handle: resource" ) |
381 | .Attr("key_dtype: type" ) |
382 | .Attr("value_dtype: type" ) |
383 | .SetIsStateful() |
384 | .SetShapeFn(ScalarOutput); |
385 | |
386 | REGISTER_OP("MutableHashTable" ) |
387 | .Output("table_handle: Ref(string)" ) |
388 | .Attr("container: string = ''" ) |
389 | .Attr("shared_name: string = ''" ) |
390 | .Attr("use_node_name_sharing: bool = false" ) |
391 | .Attr("key_dtype: type" ) |
392 | .Attr("value_dtype: type" ) |
393 | .SetIsStateful() |
394 | .SetShapeFn(TwoElementOutput); |
395 | |
396 | REGISTER_OP("MutableHashTableV2" ) |
397 | .Output("table_handle: resource" ) |
398 | .Attr("container: string = ''" ) |
399 | .Attr("shared_name: string = ''" ) |
400 | .Attr("use_node_name_sharing: bool = false" ) |
401 | .Attr("key_dtype: type" ) |
402 | .Attr("value_dtype: type" ) |
403 | .SetIsStateful() |
404 | .SetShapeFn(MutableHashTableShapeFn); |
405 | |
406 | REGISTER_OP("AnonymousMutableHashTable" ) |
407 | .Output("table_handle: resource" ) |
408 | .Attr("key_dtype: type" ) |
409 | .Attr("value_dtype: type" ) |
410 | .SetIsStateful() |
411 | .SetShapeFn(MutableHashTableShapeFn); |
412 | |
413 | REGISTER_OP("MutableHashTableOfTensors" ) |
414 | .Output("table_handle: Ref(string)" ) |
415 | .Attr("container: string = ''" ) |
416 | .Attr("shared_name: string = ''" ) |
417 | .Attr("use_node_name_sharing: bool = false" ) |
418 | .Attr("key_dtype: type" ) |
419 | .Attr("value_dtype: type" ) |
420 | .Attr("value_shape: shape = {}" ) |
421 | .SetIsStateful() |
422 | .SetShapeFn(TwoElementOutput); |
423 | |
424 | REGISTER_OP("MutableHashTableOfTensorsV2" ) |
425 | .Output("table_handle: resource" ) |
426 | .Attr("container: string = ''" ) |
427 | .Attr("shared_name: string = ''" ) |
428 | .Attr("use_node_name_sharing: bool = false" ) |
429 | .Attr("key_dtype: type" ) |
430 | .Attr("value_dtype: type" ) |
431 | .Attr("value_shape: shape = {}" ) |
432 | .SetIsStateful() |
433 | .SetShapeFn(MutableHashTableOfTensorsShapeFn); |
434 | |
435 | REGISTER_OP("AnonymousMutableHashTableOfTensors" ) |
436 | .Output("table_handle: resource" ) |
437 | .Attr("key_dtype: type" ) |
438 | .Attr("value_dtype: type" ) |
439 | .Attr("value_shape: shape = {}" ) |
440 | .SetIsStateful() |
441 | .SetShapeFn(MutableHashTableOfTensorsShapeFn); |
442 | |
443 | REGISTER_OP("MutableDenseHashTable" ) |
444 | .Input("empty_key: key_dtype" ) |
445 | .Output("table_handle: Ref(string)" ) |
446 | .Attr("container: string = ''" ) |
447 | .Attr("shared_name: string = ''" ) |
448 | .Attr("use_node_name_sharing: bool = false" ) |
449 | .Attr("key_dtype: type" ) |
450 | .Attr("value_dtype: type" ) |
451 | .Attr("value_shape: shape = {}" ) |
452 | .Attr("initial_num_buckets: int = 131072" ) // 2^17 |
453 | .Attr("max_load_factor: float = 0.8" ) |
454 | .SetIsStateful() |
455 | .SetShapeFn(TwoElementOutput); |
456 | |
457 | REGISTER_OP("MutableDenseHashTableV2" ) |
458 | .Input("empty_key: key_dtype" ) |
459 | .Input("deleted_key: key_dtype" ) |
460 | .Output("table_handle: resource" ) |
461 | .Attr("container: string = ''" ) |
462 | .Attr("shared_name: string = ''" ) |
463 | .Attr("use_node_name_sharing: bool = false" ) |
464 | .Attr("key_dtype: type" ) |
465 | .Attr("value_dtype: type" ) |
466 | .Attr("value_shape: shape = {}" ) |
467 | .Attr("initial_num_buckets: int = 131072" ) // 2^17 |
468 | .Attr("max_load_factor: float = 0.8" ) |
469 | .SetIsStateful() |
470 | .SetShapeFn(MutableDenseHashTableShapeFn); |
471 | |
472 | REGISTER_OP("AnonymousMutableDenseHashTable" ) |
473 | .Input("empty_key: key_dtype" ) |
474 | .Input("deleted_key: key_dtype" ) |
475 | .Output("table_handle: resource" ) |
476 | .Attr("key_dtype: type" ) |
477 | .Attr("value_dtype: type" ) |
478 | .Attr("value_shape: shape = {}" ) |
479 | .Attr("initial_num_buckets: int = 131072" ) // 2^17 |
480 | .Attr("max_load_factor: float = 0.8" ) |
481 | .SetIsStateful() |
482 | .SetShapeFn(MutableDenseHashTableShapeFn); |
483 | |
484 | REGISTER_OP("InitializeTable" ) |
485 | .Input("table_handle: Ref(string)" ) |
486 | .Input("keys: Tkey" ) |
487 | .Input("values: Tval" ) |
488 | .Attr("Tkey: type" ) |
489 | .Attr("Tval: type" ) |
490 | .SetShapeFn([](InferenceContext* c) { |
491 | ShapeHandle handle; |
492 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
493 | DimensionHandle unused_dim; |
494 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
495 | |
496 | ShapeHandle keys; |
497 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); |
498 | TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); |
499 | return OkStatus(); |
500 | }); |
501 | |
502 | REGISTER_OP("InitializeTableV2" ) |
503 | .Input("table_handle: resource" ) |
504 | .Input("keys: Tkey" ) |
505 | .Input("values: Tval" ) |
506 | .Attr("Tkey: type" ) |
507 | .Attr("Tval: type" ) |
508 | .SetShapeFn([](InferenceContext* c) { |
509 | ShapeHandle handle; |
510 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
511 | |
512 | ShapeHandle keys; |
513 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); |
514 | TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); |
515 | return OkStatus(); |
516 | }); |
517 | |
518 | REGISTER_OP("InitializeTableFromTextFile" ) |
519 | .Input("table_handle: Ref(string)" ) |
520 | .Input("filename: string" ) |
521 | .Attr("key_index: int >= -2" ) |
522 | .Attr("value_index: int >= -2" ) |
523 | .Attr("vocab_size: int >= -1 = -1" ) |
524 | .Attr("delimiter: string = '\t'" ) |
525 | .Attr("offset: int = 0" ) |
526 | .SetShapeFn([](InferenceContext* c) { |
527 | ShapeHandle handle; |
528 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
529 | DimensionHandle unused_dim; |
530 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
531 | |
532 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); |
533 | return OkStatus(); |
534 | }); |
535 | |
536 | REGISTER_OP("InitializeTableFromTextFileV2" ) |
537 | .Input("table_handle: resource" ) |
538 | .Input("filename: string" ) |
539 | .Attr("key_index: int >= -2" ) |
540 | .Attr("value_index: int >= -2" ) |
541 | .Attr("vocab_size: int >= -1 = -1" ) |
542 | .Attr("delimiter: string = '\t'" ) |
543 | .Attr("offset: int = 0" ) |
544 | .SetShapeFn([](InferenceContext* c) { |
545 | ShapeHandle handle; |
546 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
547 | |
548 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); |
549 | return OkStatus(); |
550 | }); |
551 | |
552 | } // namespace tensorflow |
553 | |