1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | namespace { |
22 | |
23 | // TODO(kttian): Support non-scalar values |
24 | REGISTER_OP("EmptyTensorMap") |
25 | .Output("handle: variant") |
26 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
27 | c->set_output(0, c->Scalar()); |
28 | return OkStatus(); |
29 | }); |
30 | |
31 | REGISTER_OP("TensorMapSize") |
32 | .Input("input_handle: variant") |
33 | .Output("size: int32") |
34 | .SetShapeFn(shape_inference::ScalarShape); |
35 | |
36 | REGISTER_OP("TensorMapLookup") |
37 | .Input("input_handle: variant") |
38 | .Input("key: key_dtype") |
39 | .Output("value: value_dtype") |
40 | .Attr("key_dtype: type") |
41 | .Attr("value_dtype: type") |
42 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
43 | c->set_output(0, c->UnknownShape()); |
44 | return OkStatus(); |
45 | }); |
46 | |
47 | REGISTER_OP("TensorMapInsert") |
48 | .Input("input_handle: variant") |
49 | .Input("key: key_dtype") |
50 | .Input("value: value_dtype") |
51 | .Output("output_handle: variant") |
52 | .Attr("key_dtype: type") |
53 | .Attr("value_dtype: type") |
54 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
55 | c->set_output(0, c->Scalar()); |
56 | return OkStatus(); |
57 | }); |
58 | |
59 | REGISTER_OP("TensorMapErase") |
60 | .Input("input_handle: variant") |
61 | .Input("key: key_dtype") |
62 | .Output("output_handle: variant") |
63 | .Attr("key_dtype: type") |
64 | .Attr("value_dtype: type") |
65 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
66 | c->set_output(0, c->Scalar()); // output map |
67 | return OkStatus(); |
68 | }); |
69 | |
70 | REGISTER_OP("TensorMapHasKey") |
71 | .Input("input_handle: variant") |
72 | .Input("key: key_dtype") |
73 | .Output("has_key: bool") |
74 | .Attr("key_dtype: type") |
75 | .SetShapeFn(shape_inference::ScalarShape); |
76 | |
77 | REGISTER_OP("TensorMapStackKeys") |
78 | .Input("input_handle: variant") |
79 | .Output("keys: key_dtype") |
80 | .Attr("key_dtype: type") |
81 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
82 | c->set_output(0, c->UnknownShape()); // output keys |
83 | return OkStatus(); |
84 | }); |
85 | |
86 | } // namespace |
87 | } // namespace tensorflow |
88 |