1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/state_ops.cc. |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/register_types.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/kernels/scatter_functor.h" |
22 | #include "tensorflow/core/platform/mutex.h" |
23 | #include "tensorflow/core/platform/types.h" |
24 | #include "tensorflow/core/util/determinism.h" |
25 | #include "tensorflow/core/util/util.h" |
26 | |
27 | |
28 | namespace tensorflow { |
29 | |
30 | typedef Eigen::ThreadPoolDevice CPUDevice; |
31 | typedef Eigen::GpuDevice GPUDevice; |
32 | |
33 | // Check whether updates.shape = indices.shape + params.shape[1:] |
34 | static bool ValidShapes(const Tensor& params, const Tensor& updates, |
35 | const Tensor& indices) { |
36 | if (updates.dims() == 0) return true; |
37 | if (updates.dims() != indices.dims() + params.dims() - 1) return false; |
38 | for (int d = 0; d < indices.dims(); d++) { |
39 | if (updates.dim_size(d) != indices.dim_size(d)) { |
40 | return false; |
41 | } |
42 | } |
43 | for (int d = 1; d < params.dims(); d++) { |
44 | if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) { |
45 | return false; |
46 | } |
47 | } |
48 | return true; |
49 | } |
50 | |
51 | static void DoValidationChecking(OpKernelContext* c, const Tensor& params, |
52 | const Tensor& indices, const Tensor& updates) { |
53 | OP_REQUIRES(c, params.IsInitialized(), |
54 | errors::FailedPrecondition("Null ref for params" )); |
55 | OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()), |
56 | errors::InvalidArgument("params must be at least 1-D, got shape " , |
57 | params.shape().DebugString())); |
58 | OP_REQUIRES( |
59 | c, ValidShapes(params, updates, indices), |
60 | errors::InvalidArgument("Must have updates.shape = indices.shape + " |
61 | "params.shape[1:] or updates.shape = [], got " , |
62 | "updates.shape " , updates.shape().DebugString(), |
63 | ", indices.shape " , indices.shape().DebugString(), |
64 | ", params.shape " , params.shape().DebugString())); |
65 | } |
66 | |
67 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
68 | class ScatterUpdateOp : public OpKernel { |
69 | public: |
70 | // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, |
71 | // etc. here. Should we have the framework do some sort of |
72 | // integer promotion automatically, or should that be something |
73 | // that users have to do explicitly with a conversion operator |
74 | // in the graph? |
75 | explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { |
76 | OP_REQUIRES_OK(c, c->GetAttr("use_locking" , &use_exclusive_lock_)); |
77 | if (std::is_same<Device, GPUDevice>::value) { |
78 | OP_REQUIRES( |
79 | c, !OpDeterminismRequired(), |
80 | errors::Unimplemented( |
81 | "Determinism is not yet supported in GPU implementation of " |
82 | "Scatter ops with ref inputs. Consider using resource variables " |
83 | "instead if you want to run Scatter when op determinism is " |
84 | "enabled." )); |
85 | } |
86 | } |
87 | |
88 | void Compute(OpKernelContext* c) override { |
89 | if (use_exclusive_lock_) { |
90 | // Hold mutex while we apply updates |
91 | mutex_lock l(*c->input_ref_mutex(0)); |
92 | DoCompute(c); |
93 | } else { |
94 | DoCompute(c); |
95 | } |
96 | } |
97 | |
98 | private: |
99 | bool use_exclusive_lock_; |
100 | |
101 | void DoCompute(OpKernelContext* c) { |
102 | Tensor params = c->mutable_input(0, use_exclusive_lock_); |
103 | const Tensor& indices = c->input(1); |
104 | const Tensor& updates = c->input(2); |
105 | DoValidationChecking(c, params, indices, updates); |
106 | if (!c->status().ok()) return; |
107 | |
108 | // Check that we have enough index space |
109 | const int64_t N_big = indices.NumElements(); |
110 | OP_REQUIRES( |
111 | c, N_big <= std::numeric_limits<Index>::max(), |
112 | errors::InvalidArgument("indices has too many elements for " , |
113 | DataTypeString(DataTypeToEnum<Index>::v()), |
114 | " indexing: " , N_big, " > " , |
115 | std::numeric_limits<Index>::max())); |
116 | const Index N = static_cast<Index>(indices.NumElements()); |
117 | OP_REQUIRES( |
118 | c, params.dim_size(0) <= std::numeric_limits<Index>::max(), |
119 | errors::InvalidArgument("params.shape[0] too large for " , |
120 | DataTypeString(DataTypeToEnum<Index>::v()), |
121 | " indexing: " , params.dim_size(0), " > " , |
122 | std::numeric_limits<Index>::max())); |
123 | |
124 | // We always return the input ref. |
125 | c->forward_ref_input_to_ref_output(0, 0); |
126 | |
127 | if (N > 0) { |
128 | auto indices_flat = indices.flat<Index>(); |
129 | auto params_flat = params.flat_outer_dims<T>(); |
130 | |
131 | if (TensorShapeUtils::IsScalar(updates.shape())) { |
132 | const auto update = updates.scalar<T>(); |
133 | functor::ScatterScalarFunctor<Device, T, Index, op> functor; |
134 | const Index bad_i = functor(c, c->template eigen_device<Device>(), |
135 | params_flat, update, indices_flat); |
136 | OP_REQUIRES(c, bad_i < 0, |
137 | errors::InvalidArgument( |
138 | "indices" , SliceDebugString(indices.shape(), bad_i), |
139 | " = " , indices_flat(bad_i), " is not in [0, " , |
140 | params.dim_size(0), ")" )); |
141 | } else { |
142 | auto updates_flat = |
143 | updates.shaped<T, 2>({N, updates.NumElements() / N}); |
144 | |
145 | functor::ScatterFunctor<Device, T, Index, op> functor; |
146 | const Index bad_i = functor(c, c->template eigen_device<Device>(), |
147 | params_flat, updates_flat, indices_flat); |
148 | OP_REQUIRES(c, bad_i < 0, |
149 | errors::InvalidArgument( |
150 | "indices" , SliceDebugString(indices.shape(), bad_i), |
151 | " = " , indices_flat(bad_i), " is not in [0, " , |
152 | params.dim_size(0), ")" )); |
153 | } |
154 | } |
155 | } |
156 | }; |
157 | |
158 | |
159 | #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \ |
160 | REGISTER_KERNEL_BUILDER(Name(name) \ |
161 | .Device(DEVICE_##dev) \ |
162 | .TypeConstraint<type>("T") \ |
163 | .TypeConstraint<index_type>("Tindices"), \ |
164 | ScatterUpdateOp<dev##Device, type, index_type, op>) |
165 | |
166 | #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \ |
167 | REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ |
168 | REGISTER_SCATTER_KERNEL_INDEX(type, int64_t, dev, name, op); |
169 | |
170 | #define REGISTER_SCATTER_ARITHMETIC(type, dev) \ |
171 | REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ |
172 | REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \ |
173 | REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \ |
174 | REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); |
175 | |
176 | #define REGISTER_SCATTER_MINMAX(type, dev) \ |
177 | REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \ |
178 | REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX); |
179 | |
180 | #define REGISTER_SCATTER_UPDATE(type, dev) \ |
181 | REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \ |
182 | scatter_op::UpdateOp::ASSIGN); |
183 | |
184 | // Registers CPU kernels. |
185 | #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ |
186 | REGISTER_SCATTER_ARITHMETIC(type, CPU); |
187 | |
188 | #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); |
189 | |
190 | #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); |
191 | |
192 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); |
193 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); |
194 | TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); |
195 | |
196 | // Registers GPU kernels. |
197 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
198 | #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ |
199 | REGISTER_SCATTER_ARITHMETIC(type, GPU); |
200 | |
201 | #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); |
202 | |
203 | #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); |
204 | |
205 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU); |
206 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU); |
207 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU); |
208 | |
209 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
210 | |
211 | // Registers GPU kernels. |
212 | |
213 | #undef REGISTER_SCATTER_ARITHMETIC |
214 | #undef REGISTER_SCATTER_ARITHMETIC_CPU |
215 | #undef REGISTER_SCATTER_ARITHMETIC_GPU |
216 | #undef REGISTER_SCATTER_MINMAX |
217 | #undef REGISTER_SCATTER_MINMAX_CPU |
218 | #undef REGISTER_SCATTER_MINMAX_GPU |
219 | #undef REGISTER_SCATTER_UPDATE |
220 | #undef REGISTER_SCATTER_UPDATE_CPU |
221 | #undef REGISTER_SCATTER_UPDATE_GPU |
222 | #undef REGISTER_SCATTER_KERNEL |
223 | #undef REGISTER_SCATTER_KERNEL_INDEX |
224 | |
225 | } // namespace tensorflow |
226 | |