1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ |
16 | #define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/kernels/tensor_map.h" |
20 | #include "tensorflow/core/util/batch_util.h" |
21 | #include "tensorflow/core/util/tensor_ops_util.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) { |
26 | if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) { |
27 | return errors::InvalidArgument("Input map must be a scalar. Saw: " , |
28 | ctx->input(index).shape().DebugString()); |
29 | } |
30 | const TensorMap* map = ctx->input(index).scalar<Variant>()().get<TensorMap>(); |
31 | if (map == nullptr) { |
32 | return errors::InvalidArgument( |
33 | "Input handle is not a map. Saw: '" , |
34 | ctx->input(index).scalar<Variant>()().DebugString(), "'" ); |
35 | } |
36 | *ret_map = map; |
37 | return OkStatus(); |
38 | } |
39 | |
40 | // TODO(kattian): change into templated function |
41 | Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32_t input_index, |
42 | int32_t output_index, |
43 | const TensorMap& input_map, |
44 | TensorMap** output_map) { |
45 | // Attempt to forward the input tensor to the output if possible. |
46 | std::unique_ptr<Tensor> maybe_output = ctx->forward_input( |
47 | input_index, output_index, DT_VARIANT, TensorShape{}, |
48 | ctx->input_memory_type(input_index), AllocatorAttributes()); |
49 | Tensor* output_tensor; |
50 | if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT && |
51 | maybe_output->NumElements() == 1) { |
52 | output_tensor = maybe_output.get(); |
53 | TensorMap* tmp_out = output_tensor->scalar<Variant>()().get<TensorMap>(); |
54 | if (tmp_out == nullptr) { |
55 | return errors::InvalidArgument( |
56 | "Expected input " , input_index, " to be a TensorMap but saw " , |
57 | output_tensor->scalar<Variant>()().TypeName()); |
58 | } |
59 | if (tmp_out->RefCountIsOne()) { |
60 | // Woohoo, forwarding succeeded! |
61 | ctx->set_output(output_index, *output_tensor); |
62 | *output_map = tmp_out; |
63 | return OkStatus(); |
64 | } |
65 | } |
66 | |
67 | // If forwarding is not possible allocate a new output tensor and copy |
68 | // the `input_map` to it. |
69 | AllocatorAttributes attr; |
70 | attr.set_on_host(true); |
71 | TF_RETURN_IF_ERROR( |
72 | ctx->allocate_output(output_index, {}, &output_tensor, attr)); |
73 | output_tensor->scalar<Variant>()() = input_map.Copy(); |
74 | |
75 | *output_map = output_tensor->scalar<Variant>()().get<TensorMap>(); |
76 | return OkStatus(); |
77 | } |
78 | |
79 | class EmptyTensorMap : public OpKernel { |
80 | public: |
81 | explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
82 | |
83 | void Compute(OpKernelContext* ctx) override { |
84 | Tensor* result; |
85 | AllocatorAttributes attr; |
86 | attr.set_on_host(true); |
87 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr)); |
88 | TensorMap empty; |
89 | result->scalar<Variant>()() = std::move(empty); |
90 | } |
91 | }; |
92 | |
93 | class TensorMapSize : public OpKernel { |
94 | public: |
95 | explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
96 | ~TensorMapSize() override {} |
97 | |
98 | void Compute(OpKernelContext* ctx) override { |
99 | const TensorMap* map = nullptr; |
100 | OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); |
101 | Tensor* result; |
102 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result)); |
103 | result->scalar<int32>()() = map->tensors().size(); |
104 | } |
105 | }; |
106 | |
107 | class TensorMapLookup : public OpKernel { |
108 | public: |
109 | explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
110 | ~TensorMapLookup() override {} |
111 | |
112 | void Compute(OpKernelContext* ctx) override { |
113 | const TensorKey& key = ctx->input(1); |
114 | const TensorMap* map = nullptr; |
115 | OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); |
116 | |
117 | OP_REQUIRES( |
118 | ctx, map->tensors().find(key) != map->tensors().end(), |
119 | errors::InvalidArgument("Trying to lookup non-existent key. Could not " |
120 | "find key \"" + |
121 | key.SummarizeValue(100) + "\"." )); |
122 | |
123 | ctx->set_output(0, map->tensors().find(key)->second); |
124 | } |
125 | }; |
126 | |
127 | class TensorMapInsert : public OpKernel { |
128 | public: |
129 | explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
130 | ~TensorMapInsert() override {} |
131 | |
132 | void Compute(OpKernelContext* ctx) override { |
133 | const TensorKey& key = ctx->input(1); |
134 | const Tensor& value = ctx->input(2); |
135 | const TensorMap* map = nullptr; |
136 | OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); |
137 | |
138 | TensorMap* output_map = nullptr; |
139 | OP_REQUIRES_OK(ctx, |
140 | ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map)); |
141 | output_map->replace(key, value); |
142 | } |
143 | }; |
144 | |
145 | class TensorMapErase : public OpKernel { |
146 | public: |
147 | explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
148 | |
149 | void Compute(OpKernelContext* ctx) override { |
150 | const TensorKey& key = ctx->input(1); |
151 | const TensorMap* map = nullptr; |
152 | OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); |
153 | |
154 | OP_REQUIRES( |
155 | ctx, map->tensors().find(key) != map->tensors().end(), |
156 | errors::InvalidArgument("Trying to erase non-existent item. Could not " |
157 | "find key \"" + |
158 | key.SummarizeValue(100) + "\"." )); |
159 | |
160 | TensorMap* output_map = nullptr; |
161 | OP_REQUIRES_OK(ctx, |
162 | ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map)); |
163 | output_map->tensors().erase(key); |
164 | } |
165 | }; |
166 | |
167 | class TensorMapHasKey : public OpKernel { |
168 | public: |
169 | explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
170 | ~TensorMapHasKey() override {} |
171 | |
172 | void Compute(OpKernelContext* ctx) override { |
173 | const TensorKey& key = ctx->input(1); |
174 | const TensorMap* map = nullptr; |
175 | OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); |
176 | Tensor* result; |
177 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result)); |
178 | result->scalar<bool>()() = map->tensors().find(key) != map->tensors().end(); |
179 | } |
180 | }; |
181 | |
182 | class TensorMapStackKeys : public OpKernel { |
183 | public: |
184 | explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) { |
185 | OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype" , &key_dtype_)); |
186 | } |
187 | ~TensorMapStackKeys() override {} |
188 | |
189 | void Compute(OpKernelContext* ctx) override { |
190 | const TensorMap* map = nullptr; |
191 | OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); |
192 | |
193 | OP_REQUIRES(ctx, map->size() != 0, |
194 | errors::InvalidArgument( |
195 | "TensorMapStackKeys cannot be called on empty map." )); |
196 | |
197 | auto it = map->tensors().begin(); |
198 | TensorShape output_shape = it->first.shape(); |
199 | output_shape.InsertDim(0, map->tensors().size()); |
200 | Tensor* result; |
201 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result)); |
202 | |
203 | int i = 0; |
204 | size_t sz = map->tensors().size(); |
205 | TensorShape key_shape = it->first.shape(); |
206 | while (it != map->tensors().end() && i < sz) { |
207 | OP_REQUIRES( |
208 | ctx, it->first.dtype() == key_dtype_, |
209 | errors::InvalidArgument("Key does not match requested dtype." )); |
210 | OP_REQUIRES( |
211 | ctx, it->first.shape() == key_shape, |
212 | errors::InvalidArgument("Keys must all have the same shape." )); |
213 | OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i)); |
214 | i++; |
215 | it++; |
216 | } |
217 | } |
218 | |
219 | private: |
220 | DataType key_dtype_; |
221 | }; |
222 | |
223 | template <typename Device> |
224 | Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a, |
225 | const TensorMap& b, TensorMap* out) { |
226 | // Binary add returns a map containing the union of keys. |
227 | // Values with keys in the intersection are added. |
228 | out->tensors() = a.tensors(); |
229 | for (const std::pair<TensorKey, Tensor>& p : b.tensors()) { |
230 | absl::flat_hash_map<TensorKey, Tensor>::iterator it = |
231 | out->tensors().find(p.first); |
232 | if (it != out->tensors().end()) { |
233 | Tensor out_tensor; |
234 | TF_RETURN_IF_ERROR( |
235 | BinaryAddTensors<Device>(ctx, p.second, it->second, &out_tensor)); |
236 | it->second = out_tensor; |
237 | } else { |
238 | out->tensors().emplace(p.first, p.second); |
239 | } |
240 | } |
241 | return OkStatus(); |
242 | } |
243 | |
244 | template <typename Device> |
245 | Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x, |
246 | TensorMap* y) { |
247 | // Zeros like returns an empty map. |
248 | return OkStatus(); |
249 | } |
250 | |
251 | } // namespace tensorflow |
252 | |
253 | #endif // TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ |
254 | |