1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/roll_op.h" |
17 | |
18 | #include "tensorflow/core/framework/bounds_check.h" |
19 | #include "tensorflow/core/framework/common_shape_fns.h" |
20 | #include "tensorflow/core/framework/op.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/register_types_traits.h" |
24 | #include "tensorflow/core/framework/shape_inference.h" |
25 | #include "tensorflow/core/lib/gtl/array_slice.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | #include "tensorflow/core/util/work_sharder.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | typedef Eigen::ThreadPoolDevice CPUDevice; |
32 | typedef Eigen::GpuDevice GPUDevice; |
33 | |
34 | template <typename Device, typename T, typename Tshift, typename Taxis> |
35 | class RollOp : public OpKernel { |
36 | public: |
37 | explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {} |
38 | |
39 | void Compute(OpKernelContext* context) override { |
40 | // Grab the input tensor |
41 | const Tensor& input = context->input(0); |
42 | const Tensor& shift = context->input(1); |
43 | const Tensor& axis = context->input(2); |
44 | |
45 | auto shift_flat = shift.flat<Tshift>(); |
46 | auto axis_flat = axis.flat<Taxis>(); |
47 | |
48 | OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()), |
49 | errors::InvalidArgument("input must be 1-D or higher" )); |
50 | OP_REQUIRES(context, shift.shape().dims() <= 1, |
51 | errors::InvalidArgument( |
52 | "shift must be a scalar or a 1-D vector. Found: " , |
53 | shift.shape().DebugString())); |
54 | OP_REQUIRES(context, axis.shape().dims() <= 1, |
55 | errors::InvalidArgument( |
56 | "axis must be a scalar or a 1-D vector. Found: " , |
57 | axis.shape().DebugString())); |
58 | OP_REQUIRES( |
59 | context, shift.shape() == axis.shape(), |
60 | errors::InvalidArgument("shift and axis must have the same size" )); |
61 | const int64_t num_elements = input.NumElements(); |
62 | const int num_shifts = static_cast<int>(shift_flat.size()); |
63 | const int num_dims = input.dims(); |
64 | |
65 | // if there are any duplicate axes, shift_mod_sum will have the |
66 | // total modulo sum of shifts for each dimension |
67 | gtl::InlinedVector<int32, 4> shift_mod_sum(num_dims, 0); |
68 | for (int i = 0; i < num_shifts; i++) { |
69 | int axis = axis_flat(i); |
70 | if (axis < 0) { |
71 | axis += num_dims; |
72 | } |
73 | OP_REQUIRES(context, FastBoundsCheck(axis, num_dims), |
74 | errors::InvalidArgument("axis " , axis, " is out of range" )); |
75 | const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1); |
76 | const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i)); |
77 | // modulo that works with negatives: ((x % y) + y) % y |
78 | shift_mod_sum[axis] = (sum % ds + ds) % ds; |
79 | } |
80 | // the size of each dimension |
81 | gtl::InlinedVector<int32, 4> dim_size(num_dims); |
82 | // threshold[i] is the index that the roll starts to wrap back to the front |
83 | gtl::InlinedVector<int32, 4> threshold(num_dims); |
84 | // dim_range is the number of indices over in the flattened tensor |
85 | // you need to skip in order to make it over from one side of a dimension |
86 | // to the other. Used to make the shifts wrap around after a threshold. |
87 | gtl::InlinedVector<int64_t, 4> dim_range(num_dims); |
88 | int64_t dim_size_prod = 1; // dimension size product |
89 | // inner shift dimension (inner most shifted dimension) |
90 | int64_t isd = 0; |
91 | for (int i = num_dims - 1; i >= 0; i--) { |
92 | if (isd == 0 && shift_mod_sum[i] != 0) isd = i; |
93 | const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1); |
94 | dim_size[i] = ds; |
95 | threshold[i] = (ds - shift_mod_sum[i]) % ds; |
96 | dim_size_prod *= static_cast<int64_t>(input.dim_size(i)); |
97 | dim_range[i] = dim_size_prod; |
98 | } |
99 | |
100 | Tensor* output = nullptr; |
101 | OP_REQUIRES_OK(context, |
102 | context->allocate_output(0, input.shape(), &output)); |
103 | auto input_flat = input.flat<T>().data(); |
104 | auto output_flat = output->flat<T>().data(); |
105 | |
106 | functor::Roll<Device, T>()(context, num_elements, num_dims, dim_size, |
107 | input_flat, output_flat, threshold, dim_range, |
108 | isd); |
109 | } |
110 | }; |
111 | |
112 | namespace functor { |
113 | |
114 | // dim_size - the size of each dimension |
115 | // dim_range - the number of indices over in the flattened tensor |
116 | // you need to skip in order to make it over from one side of a dimension |
117 | // to the other. Used to make the shifts wrap around after a threshold. |
118 | // threshold - the index for each dimension that the roll starts to wrap |
119 | // back to the front |
120 | template <typename T> |
121 | void DoRoll(const OpKernelContext* context, const int64_t num_elements, |
122 | const int num_dims, const gtl::ArraySlice<int32> dim_size, |
123 | const T* input, T* output, const gtl::ArraySlice<int32> threshold, |
124 | const gtl::ArraySlice<int64_t> dim_range) { |
125 | auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range]( |
126 | int64_t start, int64_t end) { |
127 | // array of indices for each dimension |
128 | gtl::InlinedVector<int, 4> indices(num_dims); |
129 | int offset = 0; // the shift along the flattened tensor for current element |
130 | // initialize indices and offset |
131 | for (int i = 0; i < num_dims; i++) { |
132 | // stride is the number of indices over in the flattened tensor |
133 | // you need to skip in order to make it over to an adjacent element |
134 | // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1) |
135 | const int64_t stride = dim_range[i] / dim_size[i]; |
136 | const int shift = dim_size[i] - threshold[i]; |
137 | const int indx = (start / stride) % dim_size[i]; |
138 | indices[i] = indx; |
139 | // calculate dimension index after the shift |
140 | const int shifted_indx = (indx + shift) % dim_size[i]; |
141 | offset += (shifted_indx - indx) * stride; |
142 | } |
143 | |
144 | for (int64_t i = start; i < end; i++) { |
145 | output[i + offset] = input[i]; |
146 | // create next combination of indices |
147 | // while at it adjust offset if needed |
148 | for (int j = num_dims - 1; j >= 0; j--) { |
149 | const int indx = (indices[j] + 1) % dim_size[j]; |
150 | indices[j] = indx; |
151 | if (indx != 0) { |
152 | if (indx == threshold[j]) { // we've reached the threshold |
153 | // dim_range[j] = threshold[j] + shift[j] |
154 | // offset = shift[j] + ... other offsets |
155 | // offset - dim_range[j] = -threshold[j] + ... other offsets |
156 | // thus we undo our previous offset as well as add a new offset of |
157 | // -threshold[j] in one operation |
158 | offset -= dim_range[j]; // now wraps around |
159 | } |
160 | break; // indx != 0 don't need to carry |
161 | } else if (threshold[j] != 0) { // if threshold is 0 shift is 0 |
162 | offset += dim_range[j]; // indx became 0 so reverse wrap around |
163 | } |
164 | } |
165 | } |
166 | }; |
167 | // Shard |
168 | auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); |
169 | // 15 - experimentally determined with float and bool types |
170 | const int cost_per_element = 15 * sizeof(T); // rough estimate |
171 | Shard(worker_threads->num_threads, worker_threads->workers, num_elements, |
172 | cost_per_element, std::move(work)); |
173 | } |
174 | |
175 | // dim_size - the size of each dimension |
176 | // dim_range - the number of indices over in the flattened tensor |
177 | // you need to skip in order to make it over from one side of a dimension |
178 | // to the other. Used to make the shifts wrap around after a threshold. |
179 | // threshold - the index for each dimension that the roll starts to wrap |
180 | // back to the front |
181 | // isd - inner shift dimension |
182 | template <typename T> |
183 | // Use memcpy to copy memory in groups when the data type supports memcpy |
184 | void DoRollWithMemcpy(const OpKernelContext* context, |
185 | const int64_t num_elements, const int num_dims, |
186 | const gtl::ArraySlice<int32> dim_size, const T* input, |
187 | T* output, const gtl::ArraySlice<int32> threshold, |
188 | const gtl::ArraySlice<int64_t> dim_range, |
189 | const int64_t isd) { |
190 | auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd]( |
191 | int64_t start, int64_t end) { |
192 | // the number of indices over in the flattened tensor you need to skip in |
193 | // order to make it over from one side of the isd to the other |
194 | const int64_t isd_range = std::max<int>(dim_range[isd], 1); |
195 | // the distance along the flattend tensor to the next element in the isd |
196 | const int64_t isd_stride = isd_range / std::max<int>(dim_size[isd], 1); |
197 | |
198 | // start and end represent the i-th group currently so we will convert |
199 | // them into numbers representing the i-th elements. |
200 | // there are 2 groups per isd one for all elements before threshold[isd] |
201 | // and another for all elements after threshold[isd]. |
202 | const int64_t start_remainder = (start % 2) * threshold[isd] * isd_stride; |
203 | const int64_t end_remainder = (end % 2) * threshold[isd] * isd_stride; |
204 | start = (start / 2) * isd_range + start_remainder; |
205 | end = (end / 2) * isd_range + end_remainder; |
206 | |
207 | const T* in_ptr = &input[0]; |
208 | T* out_ptr = &output[0]; |
209 | in_ptr += start; |
210 | out_ptr += start; |
211 | |
212 | // array of indices for each dimension |
213 | // indices = [i, j, k, l, m, n] |
214 | gtl::InlinedVector<int, 4> indices(num_dims); |
215 | // the offset needed to make all inner non-shifting dimensions become 0 |
216 | int64_t remainder_offset = 0; |
217 | // initialize indices |
218 | for (int i = 0; i < num_dims; i++) { |
219 | // stride is the number of indices over in the flattened tensor |
220 | // you need to skip in order to make it over to an adjacent element |
221 | // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1) |
222 | const int64_t stride = dim_range[i] / dim_size[i]; |
223 | const int shift = dim_size[i] - threshold[i]; |
224 | const int indx = (start / stride) % dim_size[i]; |
225 | indices[i] = indx; |
226 | // calculate dimension index after the shift |
227 | int out_indx = (indx + shift) % dim_size[i]; |
228 | if (i > isd) { |
229 | // trailing zeroes for indices after the inner shifted dimension |
230 | out_indx = 0; |
231 | remainder_offset += (out_indx - indx) * stride; |
232 | } |
233 | out_ptr += (out_indx - indx) * stride; |
234 | } |
235 | // set trailing zeroes for indices after the inner shifted dimension |
236 | for (int i = num_dims - 1; i > isd; i--) indices[i] = 0; |
237 | |
238 | // the number of indices in the isd dimension the next group will skip |
239 | // to make it to the next threshold or end point |
240 | int isd_indx_skip = 0; |
241 | // the size of the next group |
242 | int64_t group_size = 0; |
243 | // initialize isd_indx_skip and group_size |
244 | if (indices[isd] < threshold[isd]) { |
245 | isd_indx_skip = threshold[isd] - indices[isd]; |
246 | group_size = isd_indx_skip * isd_stride + remainder_offset; |
247 | } else { |
248 | isd_indx_skip = dim_size[isd] - indices[isd]; |
249 | group_size = isd_indx_skip * isd_stride + remainder_offset; |
250 | } |
251 | |
252 | int64_t i = start; |
253 | while (i < end) { |
254 | // copy group of elements |
255 | memcpy(out_ptr, in_ptr, group_size * sizeof(T)); |
256 | |
257 | // shift i and the pointers over to the next group position |
258 | i += group_size; |
259 | out_ptr += group_size; |
260 | in_ptr += group_size; |
261 | |
262 | // produce next combination of indices and adjust the out_ptr position |
263 | // to fix the offset if necessary |
264 | // the isd (inner shift dim) should skip to next threshold or endpoint |
265 | // all dimensions to the left increment by 1 when a digit is carried |
266 | // all dimensions to the right remain set to 0 |
267 | // +1 +1 +1 +isd_indx_skip |
268 | // indices = [i, j, k, l, 0, 0] |
269 | // ^isd |
270 | for (int j = isd; j >= 0; j--) { |
271 | int inc = 1; |
272 | if (j == isd) inc = isd_indx_skip; |
273 | const int indx = (indices[j] + inc) % dim_size[j]; |
274 | indices[j] = indx; |
275 | if (indx != 0) { |
276 | if (indx == threshold[j]) { |
277 | out_ptr -= dim_range[j]; // now wraps around |
278 | } |
279 | break; // indx != 0 don't need to carry |
280 | } else if (threshold[j] != 0) { // if threshold is 0 shift is 0 |
281 | out_ptr += dim_range[j]; // indx became 0 so reverse wrap around |
282 | } |
283 | } |
284 | |
285 | // set isd_indx_skip and group_size for next iteration |
286 | if (indices[isd] < threshold[isd]) { |
287 | isd_indx_skip = threshold[isd] - indices[isd]; |
288 | group_size = isd_indx_skip * isd_stride; |
289 | } else { |
290 | isd_indx_skip = dim_size[isd] - indices[isd]; |
291 | group_size = isd_indx_skip * isd_stride; |
292 | } |
293 | } |
294 | }; |
295 | // Shard |
296 | auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); |
297 | const int64_t ave_group_size = dim_range[isd] / 2; |
298 | const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1); |
299 | // 25000 - experimentally determined with float and bool types |
300 | const int cost_per_group = 25000 * sizeof(T) * ave_group_size; |
301 | Shard(worker_threads->num_threads, worker_threads->workers, total_work, |
302 | cost_per_group, std::move(work)); |
303 | } |
304 | |
305 | template <typename T> |
306 | struct Roll<CPUDevice, T> { |
307 | void operator()(const OpKernelContext* context, const int64_t num_elements, |
308 | const int num_dims, const gtl::ArraySlice<int32> dim_size, |
309 | const T* input, T* output, |
310 | const gtl::ArraySlice<int32> threshold, |
311 | const gtl::ArraySlice<int64_t> dim_range, const int64_t isd) { |
312 | if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { |
313 | // V2 copies memory in groups instead of element by element |
314 | DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size, input, |
315 | output, threshold, dim_range, isd); |
316 | } else { |
317 | // incase memcpy does not work for current data type |
318 | DoRoll<T>(context, num_elements, num_dims, dim_size, input, output, |
319 | threshold, dim_range); |
320 | } |
321 | }; |
322 | }; |
323 | } // namespace functor |
324 | |
325 | // Register the CPU kernels. |
326 | #define REGISTER_CPU(type) \ |
327 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
328 | .Device(DEVICE_CPU) \ |
329 | .TypeConstraint<type>("T") \ |
330 | .TypeConstraint<int32>("Tshift") \ |
331 | .TypeConstraint<int32>("Taxis") \ |
332 | .HostMemory("shift") \ |
333 | .HostMemory("axis"), \ |
334 | RollOp<CPUDevice, type, int32, int32>) \ |
335 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
336 | .Device(DEVICE_CPU) \ |
337 | .TypeConstraint<type>("T") \ |
338 | .TypeConstraint<int64_t>("Tshift") \ |
339 | .TypeConstraint<int32>("Taxis") \ |
340 | .HostMemory("shift") \ |
341 | .HostMemory("axis"), \ |
342 | RollOp<CPUDevice, type, int64, int32>) \ |
343 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
344 | .Device(DEVICE_CPU) \ |
345 | .TypeConstraint<type>("T") \ |
346 | .TypeConstraint<int32>("Tshift") \ |
347 | .TypeConstraint<int64_t>("Taxis") \ |
348 | .HostMemory("shift") \ |
349 | .HostMemory("axis"), \ |
350 | RollOp<CPUDevice, type, int32, int64>) \ |
351 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
352 | .Device(DEVICE_CPU) \ |
353 | .TypeConstraint<type>("T") \ |
354 | .TypeConstraint<int64_t>("Tshift") \ |
355 | .TypeConstraint<int64_t>("Taxis") \ |
356 | .HostMemory("shift") \ |
357 | .HostMemory("axis"), \ |
358 | RollOp<CPUDevice, type, int64, int64>) |
359 | |
360 | TF_CALL_ALL_TYPES(REGISTER_CPU); |
361 | #undef REGISTER_CPU |
362 | |
363 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
364 | #define REGISTER_KERNEL(type) \ |
365 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
366 | .Device(DEVICE_GPU) \ |
367 | .TypeConstraint<type>("T") \ |
368 | .TypeConstraint<int32>("Tshift") \ |
369 | .TypeConstraint<int32>("Taxis") \ |
370 | .HostMemory("shift") \ |
371 | .HostMemory("axis"), \ |
372 | RollOp<GPUDevice, type, int32, int32>) \ |
373 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
374 | .Device(DEVICE_GPU) \ |
375 | .TypeConstraint<type>("T") \ |
376 | .TypeConstraint<int64_t>("Tshift") \ |
377 | .TypeConstraint<int32>("Taxis") \ |
378 | .HostMemory("shift") \ |
379 | .HostMemory("axis"), \ |
380 | RollOp<GPUDevice, type, int64, int32>) \ |
381 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
382 | .Device(DEVICE_GPU) \ |
383 | .TypeConstraint<type>("T") \ |
384 | .TypeConstraint<int32>("Tshift") \ |
385 | .TypeConstraint<int64_t>("Taxis") \ |
386 | .HostMemory("shift") \ |
387 | .HostMemory("axis"), \ |
388 | RollOp<GPUDevice, type, int32, int64>) \ |
389 | REGISTER_KERNEL_BUILDER(Name("Roll") \ |
390 | .Device(DEVICE_GPU) \ |
391 | .TypeConstraint<type>("T") \ |
392 | .TypeConstraint<int64_t>("Tshift") \ |
393 | .TypeConstraint<int64_t>("Taxis") \ |
394 | .HostMemory("shift") \ |
395 | .HostMemory("axis"), \ |
396 | RollOp<GPUDevice, type, int64, int64>) |
397 | |
398 | TF_CALL_int32(REGISTER_KERNEL); |
399 | TF_CALL_int64(REGISTER_KERNEL); |
400 | TF_CALL_uint32(REGISTER_KERNEL); |
401 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL); |
402 | TF_CALL_COMPLEX_TYPES(REGISTER_KERNEL); |
403 | |
404 | #undef REGISTER_KERNEL |
405 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
406 | } // namespace tensorflow |
407 | |