1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/state_ops.cc.
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/register_types.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/kernels/scatter_functor.h"
22#include "tensorflow/core/platform/mutex.h"
23#include "tensorflow/core/platform/types.h"
24#include "tensorflow/core/util/determinism.h"
25#include "tensorflow/core/util/util.h"
26
27
28namespace tensorflow {
29
30typedef Eigen::ThreadPoolDevice CPUDevice;
31typedef Eigen::GpuDevice GPUDevice;
32
33// Check whether updates.shape = indices.shape + params.shape[1:]
34static bool ValidShapes(const Tensor& params, const Tensor& updates,
35 const Tensor& indices) {
36 if (updates.dims() == 0) return true;
37 if (updates.dims() != indices.dims() + params.dims() - 1) return false;
38 for (int d = 0; d < indices.dims(); d++) {
39 if (updates.dim_size(d) != indices.dim_size(d)) {
40 return false;
41 }
42 }
43 for (int d = 1; d < params.dims(); d++) {
44 if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
45 return false;
46 }
47 }
48 return true;
49}
50
51static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
52 const Tensor& indices, const Tensor& updates) {
53 OP_REQUIRES(c, params.IsInitialized(),
54 errors::FailedPrecondition("Null ref for params"));
55 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
56 errors::InvalidArgument("params must be at least 1-D, got shape ",
57 params.shape().DebugString()));
58 OP_REQUIRES(
59 c, ValidShapes(params, updates, indices),
60 errors::InvalidArgument("Must have updates.shape = indices.shape + "
61 "params.shape[1:] or updates.shape = [], got ",
62 "updates.shape ", updates.shape().DebugString(),
63 ", indices.shape ", indices.shape().DebugString(),
64 ", params.shape ", params.shape().DebugString()));
65}
66
67template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
68class ScatterUpdateOp : public OpKernel {
69 public:
70 // QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
71 // etc. here. Should we have the framework do some sort of
72 // integer promotion automatically, or should that be something
73 // that users have to do explicitly with a conversion operator
74 // in the graph?
75 explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
76 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
77 if (std::is_same<Device, GPUDevice>::value) {
78 OP_REQUIRES(
79 c, !OpDeterminismRequired(),
80 errors::Unimplemented(
81 "Determinism is not yet supported in GPU implementation of "
82 "Scatter ops with ref inputs. Consider using resource variables "
83 "instead if you want to run Scatter when op determinism is "
84 "enabled."));
85 }
86 }
87
88 void Compute(OpKernelContext* c) override {
89 if (use_exclusive_lock_) {
90 // Hold mutex while we apply updates
91 mutex_lock l(*c->input_ref_mutex(0));
92 DoCompute(c);
93 } else {
94 DoCompute(c);
95 }
96 }
97
98 private:
99 bool use_exclusive_lock_;
100
101 void DoCompute(OpKernelContext* c) {
102 Tensor params = c->mutable_input(0, use_exclusive_lock_);
103 const Tensor& indices = c->input(1);
104 const Tensor& updates = c->input(2);
105 DoValidationChecking(c, params, indices, updates);
106 if (!c->status().ok()) return;
107
108 // Check that we have enough index space
109 const int64_t N_big = indices.NumElements();
110 OP_REQUIRES(
111 c, N_big <= std::numeric_limits<Index>::max(),
112 errors::InvalidArgument("indices has too many elements for ",
113 DataTypeString(DataTypeToEnum<Index>::v()),
114 " indexing: ", N_big, " > ",
115 std::numeric_limits<Index>::max()));
116 const Index N = static_cast<Index>(indices.NumElements());
117 OP_REQUIRES(
118 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
119 errors::InvalidArgument("params.shape[0] too large for ",
120 DataTypeString(DataTypeToEnum<Index>::v()),
121 " indexing: ", params.dim_size(0), " > ",
122 std::numeric_limits<Index>::max()));
123
124 // We always return the input ref.
125 c->forward_ref_input_to_ref_output(0, 0);
126
127 if (N > 0) {
128 auto indices_flat = indices.flat<Index>();
129 auto params_flat = params.flat_outer_dims<T>();
130
131 if (TensorShapeUtils::IsScalar(updates.shape())) {
132 const auto update = updates.scalar<T>();
133 functor::ScatterScalarFunctor<Device, T, Index, op> functor;
134 const Index bad_i = functor(c, c->template eigen_device<Device>(),
135 params_flat, update, indices_flat);
136 OP_REQUIRES(c, bad_i < 0,
137 errors::InvalidArgument(
138 "indices", SliceDebugString(indices.shape(), bad_i),
139 " = ", indices_flat(bad_i), " is not in [0, ",
140 params.dim_size(0), ")"));
141 } else {
142 auto updates_flat =
143 updates.shaped<T, 2>({N, updates.NumElements() / N});
144
145 functor::ScatterFunctor<Device, T, Index, op> functor;
146 const Index bad_i = functor(c, c->template eigen_device<Device>(),
147 params_flat, updates_flat, indices_flat);
148 OP_REQUIRES(c, bad_i < 0,
149 errors::InvalidArgument(
150 "indices", SliceDebugString(indices.shape(), bad_i),
151 " = ", indices_flat(bad_i), " is not in [0, ",
152 params.dim_size(0), ")"));
153 }
154 }
155 }
156};
157
158
159#define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
160 REGISTER_KERNEL_BUILDER(Name(name) \
161 .Device(DEVICE_##dev) \
162 .TypeConstraint<type>("T") \
163 .TypeConstraint<index_type>("Tindices"), \
164 ScatterUpdateOp<dev##Device, type, index_type, op>)
165
166#define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
167 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
168 REGISTER_SCATTER_KERNEL_INDEX(type, int64_t, dev, name, op);
169
170#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
171 REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
172 REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
173 REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
174 REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
175
176#define REGISTER_SCATTER_MINMAX(type, dev) \
177 REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \
178 REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX);
179
180#define REGISTER_SCATTER_UPDATE(type, dev) \
181 REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
182 scatter_op::UpdateOp::ASSIGN);
183
184// Registers CPU kernels.
185#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
186 REGISTER_SCATTER_ARITHMETIC(type, CPU);
187
188#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
189
190#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
191
192TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
193TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
194TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
195
196// Registers GPU kernels.
197#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
198#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
199 REGISTER_SCATTER_ARITHMETIC(type, GPU);
200
201#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
202
203#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
204
205TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
206TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
207TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
208
209#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
210
211// Registers GPU kernels.
212
213#undef REGISTER_SCATTER_ARITHMETIC
214#undef REGISTER_SCATTER_ARITHMETIC_CPU
215#undef REGISTER_SCATTER_ARITHMETIC_GPU
216#undef REGISTER_SCATTER_MINMAX
217#undef REGISTER_SCATTER_MINMAX_CPU
218#undef REGISTER_SCATTER_MINMAX_GPU
219#undef REGISTER_SCATTER_UPDATE
220#undef REGISTER_SCATTER_UPDATE_CPU
221#undef REGISTER_SCATTER_UPDATE_GPU
222#undef REGISTER_SCATTER_KERNEL
223#undef REGISTER_SCATTER_KERNEL_INDEX
224
225} // namespace tensorflow
226