1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
16#define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/kernels/tensor_map.h"
20#include "tensorflow/core/util/batch_util.h"
21#include "tensorflow/core/util/tensor_ops_util.h"
22
23namespace tensorflow {
24
25Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) {
26 if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) {
27 return errors::InvalidArgument("Input map must be a scalar. Saw: ",
28 ctx->input(index).shape().DebugString());
29 }
30 const TensorMap* map = ctx->input(index).scalar<Variant>()().get<TensorMap>();
31 if (map == nullptr) {
32 return errors::InvalidArgument(
33 "Input handle is not a map. Saw: '",
34 ctx->input(index).scalar<Variant>()().DebugString(), "'");
35 }
36 *ret_map = map;
37 return OkStatus();
38}
39
40// TODO(kattian): change into templated function
41Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32_t input_index,
42 int32_t output_index,
43 const TensorMap& input_map,
44 TensorMap** output_map) {
45 // Attempt to forward the input tensor to the output if possible.
46 std::unique_ptr<Tensor> maybe_output = ctx->forward_input(
47 input_index, output_index, DT_VARIANT, TensorShape{},
48 ctx->input_memory_type(input_index), AllocatorAttributes());
49 Tensor* output_tensor;
50 if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
51 maybe_output->NumElements() == 1) {
52 output_tensor = maybe_output.get();
53 TensorMap* tmp_out = output_tensor->scalar<Variant>()().get<TensorMap>();
54 if (tmp_out == nullptr) {
55 return errors::InvalidArgument(
56 "Expected input ", input_index, " to be a TensorMap but saw ",
57 output_tensor->scalar<Variant>()().TypeName());
58 }
59 if (tmp_out->RefCountIsOne()) {
60 // Woohoo, forwarding succeeded!
61 ctx->set_output(output_index, *output_tensor);
62 *output_map = tmp_out;
63 return OkStatus();
64 }
65 }
66
67 // If forwarding is not possible allocate a new output tensor and copy
68 // the `input_map` to it.
69 AllocatorAttributes attr;
70 attr.set_on_host(true);
71 TF_RETURN_IF_ERROR(
72 ctx->allocate_output(output_index, {}, &output_tensor, attr));
73 output_tensor->scalar<Variant>()() = input_map.Copy();
74
75 *output_map = output_tensor->scalar<Variant>()().get<TensorMap>();
76 return OkStatus();
77}
78
79class EmptyTensorMap : public OpKernel {
80 public:
81 explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {}
82
83 void Compute(OpKernelContext* ctx) override {
84 Tensor* result;
85 AllocatorAttributes attr;
86 attr.set_on_host(true);
87 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
88 TensorMap empty;
89 result->scalar<Variant>()() = std::move(empty);
90 }
91};
92
93class TensorMapSize : public OpKernel {
94 public:
95 explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {}
96 ~TensorMapSize() override {}
97
98 void Compute(OpKernelContext* ctx) override {
99 const TensorMap* map = nullptr;
100 OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
101 Tensor* result;
102 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
103 result->scalar<int32>()() = map->tensors().size();
104 }
105};
106
107class TensorMapLookup : public OpKernel {
108 public:
109 explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {}
110 ~TensorMapLookup() override {}
111
112 void Compute(OpKernelContext* ctx) override {
113 const TensorKey& key = ctx->input(1);
114 const TensorMap* map = nullptr;
115 OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
116
117 OP_REQUIRES(
118 ctx, map->tensors().find(key) != map->tensors().end(),
119 errors::InvalidArgument("Trying to lookup non-existent key. Could not "
120 "find key \"" +
121 key.SummarizeValue(100) + "\"."));
122
123 ctx->set_output(0, map->tensors().find(key)->second);
124 }
125};
126
127class TensorMapInsert : public OpKernel {
128 public:
129 explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {}
130 ~TensorMapInsert() override {}
131
132 void Compute(OpKernelContext* ctx) override {
133 const TensorKey& key = ctx->input(1);
134 const Tensor& value = ctx->input(2);
135 const TensorMap* map = nullptr;
136 OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
137
138 TensorMap* output_map = nullptr;
139 OP_REQUIRES_OK(ctx,
140 ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
141 output_map->replace(key, value);
142 }
143};
144
145class TensorMapErase : public OpKernel {
146 public:
147 explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
148
149 void Compute(OpKernelContext* ctx) override {
150 const TensorKey& key = ctx->input(1);
151 const TensorMap* map = nullptr;
152 OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
153
154 OP_REQUIRES(
155 ctx, map->tensors().find(key) != map->tensors().end(),
156 errors::InvalidArgument("Trying to erase non-existent item. Could not "
157 "find key \"" +
158 key.SummarizeValue(100) + "\"."));
159
160 TensorMap* output_map = nullptr;
161 OP_REQUIRES_OK(ctx,
162 ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
163 output_map->tensors().erase(key);
164 }
165};
166
167class TensorMapHasKey : public OpKernel {
168 public:
169 explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {}
170 ~TensorMapHasKey() override {}
171
172 void Compute(OpKernelContext* ctx) override {
173 const TensorKey& key = ctx->input(1);
174 const TensorMap* map = nullptr;
175 OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
176 Tensor* result;
177 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
178 result->scalar<bool>()() = map->tensors().find(key) != map->tensors().end();
179 }
180};
181
182class TensorMapStackKeys : public OpKernel {
183 public:
184 explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) {
185 OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_));
186 }
187 ~TensorMapStackKeys() override {}
188
189 void Compute(OpKernelContext* ctx) override {
190 const TensorMap* map = nullptr;
191 OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
192
193 OP_REQUIRES(ctx, map->size() != 0,
194 errors::InvalidArgument(
195 "TensorMapStackKeys cannot be called on empty map."));
196
197 auto it = map->tensors().begin();
198 TensorShape output_shape = it->first.shape();
199 output_shape.InsertDim(0, map->tensors().size());
200 Tensor* result;
201 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
202
203 int i = 0;
204 size_t sz = map->tensors().size();
205 TensorShape key_shape = it->first.shape();
206 while (it != map->tensors().end() && i < sz) {
207 OP_REQUIRES(
208 ctx, it->first.dtype() == key_dtype_,
209 errors::InvalidArgument("Key does not match requested dtype."));
210 OP_REQUIRES(
211 ctx, it->first.shape() == key_shape,
212 errors::InvalidArgument("Keys must all have the same shape."));
213 OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i));
214 i++;
215 it++;
216 }
217 }
218
219 private:
220 DataType key_dtype_;
221};
222
223template <typename Device>
224Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a,
225 const TensorMap& b, TensorMap* out) {
226 // Binary add returns a map containing the union of keys.
227 // Values with keys in the intersection are added.
228 out->tensors() = a.tensors();
229 for (const std::pair<TensorKey, Tensor>& p : b.tensors()) {
230 absl::flat_hash_map<TensorKey, Tensor>::iterator it =
231 out->tensors().find(p.first);
232 if (it != out->tensors().end()) {
233 Tensor out_tensor;
234 TF_RETURN_IF_ERROR(
235 BinaryAddTensors<Device>(ctx, p.second, it->second, &out_tensor));
236 it->second = out_tensor;
237 } else {
238 out->tensors().emplace(p.first, p.second);
239 }
240 }
241 return OkStatus();
242}
243
244template <typename Device>
245Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x,
246 TensorMap* y) {
247 // Zeros like returns an empty map.
248 return OkStatus();
249}
250
251} // namespace tensorflow
252
253#endif // TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
254