1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/roll_op.h"
17
18#include "tensorflow/core/framework/bounds_check.h"
19#include "tensorflow/core/framework/common_shape_fns.h"
20#include "tensorflow/core/framework/op.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/framework/register_types_traits.h"
24#include "tensorflow/core/framework/shape_inference.h"
25#include "tensorflow/core/lib/gtl/array_slice.h"
26#include "tensorflow/core/platform/types.h"
27#include "tensorflow/core/util/work_sharder.h"
28
29namespace tensorflow {
30
31typedef Eigen::ThreadPoolDevice CPUDevice;
32typedef Eigen::GpuDevice GPUDevice;
33
34template <typename Device, typename T, typename Tshift, typename Taxis>
35class RollOp : public OpKernel {
36 public:
37 explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
38
39 void Compute(OpKernelContext* context) override {
40 // Grab the input tensor
41 const Tensor& input = context->input(0);
42 const Tensor& shift = context->input(1);
43 const Tensor& axis = context->input(2);
44
45 auto shift_flat = shift.flat<Tshift>();
46 auto axis_flat = axis.flat<Taxis>();
47
48 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
49 errors::InvalidArgument("input must be 1-D or higher"));
50 OP_REQUIRES(context, shift.shape().dims() <= 1,
51 errors::InvalidArgument(
52 "shift must be a scalar or a 1-D vector. Found: ",
53 shift.shape().DebugString()));
54 OP_REQUIRES(context, axis.shape().dims() <= 1,
55 errors::InvalidArgument(
56 "axis must be a scalar or a 1-D vector. Found: ",
57 axis.shape().DebugString()));
58 OP_REQUIRES(
59 context, shift.shape() == axis.shape(),
60 errors::InvalidArgument("shift and axis must have the same size"));
61 const int64_t num_elements = input.NumElements();
62 const int num_shifts = static_cast<int>(shift_flat.size());
63 const int num_dims = input.dims();
64
65 // if there are any duplicate axes, shift_mod_sum will have the
66 // total modulo sum of shifts for each dimension
67 gtl::InlinedVector<int32, 4> shift_mod_sum(num_dims, 0);
68 for (int i = 0; i < num_shifts; i++) {
69 int axis = axis_flat(i);
70 if (axis < 0) {
71 axis += num_dims;
72 }
73 OP_REQUIRES(context, FastBoundsCheck(axis, num_dims),
74 errors::InvalidArgument("axis ", axis, " is out of range"));
75 const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
76 const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
77 // modulo that works with negatives: ((x % y) + y) % y
78 shift_mod_sum[axis] = (sum % ds + ds) % ds;
79 }
80 // the size of each dimension
81 gtl::InlinedVector<int32, 4> dim_size(num_dims);
82 // threshold[i] is the index that the roll starts to wrap back to the front
83 gtl::InlinedVector<int32, 4> threshold(num_dims);
84 // dim_range is the number of indices over in the flattened tensor
85 // you need to skip in order to make it over from one side of a dimension
86 // to the other. Used to make the shifts wrap around after a threshold.
87 gtl::InlinedVector<int64_t, 4> dim_range(num_dims);
88 int64_t dim_size_prod = 1; // dimension size product
89 // inner shift dimension (inner most shifted dimension)
90 int64_t isd = 0;
91 for (int i = num_dims - 1; i >= 0; i--) {
92 if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
93 const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
94 dim_size[i] = ds;
95 threshold[i] = (ds - shift_mod_sum[i]) % ds;
96 dim_size_prod *= static_cast<int64_t>(input.dim_size(i));
97 dim_range[i] = dim_size_prod;
98 }
99
100 Tensor* output = nullptr;
101 OP_REQUIRES_OK(context,
102 context->allocate_output(0, input.shape(), &output));
103 auto input_flat = input.flat<T>().data();
104 auto output_flat = output->flat<T>().data();
105
106 functor::Roll<Device, T>()(context, num_elements, num_dims, dim_size,
107 input_flat, output_flat, threshold, dim_range,
108 isd);
109 }
110};
111
112namespace functor {
113
114// dim_size - the size of each dimension
115// dim_range - the number of indices over in the flattened tensor
116// you need to skip in order to make it over from one side of a dimension
117// to the other. Used to make the shifts wrap around after a threshold.
118// threshold - the index for each dimension that the roll starts to wrap
119// back to the front
120template <typename T>
121void DoRoll(const OpKernelContext* context, const int64_t num_elements,
122 const int num_dims, const gtl::ArraySlice<int32> dim_size,
123 const T* input, T* output, const gtl::ArraySlice<int32> threshold,
124 const gtl::ArraySlice<int64_t> dim_range) {
125 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
126 int64_t start, int64_t end) {
127 // array of indices for each dimension
128 gtl::InlinedVector<int, 4> indices(num_dims);
129 int offset = 0; // the shift along the flattened tensor for current element
130 // initialize indices and offset
131 for (int i = 0; i < num_dims; i++) {
132 // stride is the number of indices over in the flattened tensor
133 // you need to skip in order to make it over to an adjacent element
134 // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
135 const int64_t stride = dim_range[i] / dim_size[i];
136 const int shift = dim_size[i] - threshold[i];
137 const int indx = (start / stride) % dim_size[i];
138 indices[i] = indx;
139 // calculate dimension index after the shift
140 const int shifted_indx = (indx + shift) % dim_size[i];
141 offset += (shifted_indx - indx) * stride;
142 }
143
144 for (int64_t i = start; i < end; i++) {
145 output[i + offset] = input[i];
146 // create next combination of indices
147 // while at it adjust offset if needed
148 for (int j = num_dims - 1; j >= 0; j--) {
149 const int indx = (indices[j] + 1) % dim_size[j];
150 indices[j] = indx;
151 if (indx != 0) {
152 if (indx == threshold[j]) { // we've reached the threshold
153 // dim_range[j] = threshold[j] + shift[j]
154 // offset = shift[j] + ... other offsets
155 // offset - dim_range[j] = -threshold[j] + ... other offsets
156 // thus we undo our previous offset as well as add a new offset of
157 // -threshold[j] in one operation
158 offset -= dim_range[j]; // now wraps around
159 }
160 break; // indx != 0 don't need to carry
161 } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
162 offset += dim_range[j]; // indx became 0 so reverse wrap around
163 }
164 }
165 }
166 };
167 // Shard
168 auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
169 // 15 - experimentally determined with float and bool types
170 const int cost_per_element = 15 * sizeof(T); // rough estimate
171 Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
172 cost_per_element, std::move(work));
173}
174
175// dim_size - the size of each dimension
176// dim_range - the number of indices over in the flattened tensor
177// you need to skip in order to make it over from one side of a dimension
178// to the other. Used to make the shifts wrap around after a threshold.
179// threshold - the index for each dimension that the roll starts to wrap
180// back to the front
181// isd - inner shift dimension
182template <typename T>
183// Use memcpy to copy memory in groups when the data type supports memcpy
184void DoRollWithMemcpy(const OpKernelContext* context,
185 const int64_t num_elements, const int num_dims,
186 const gtl::ArraySlice<int32> dim_size, const T* input,
187 T* output, const gtl::ArraySlice<int32> threshold,
188 const gtl::ArraySlice<int64_t> dim_range,
189 const int64_t isd) {
190 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
191 int64_t start, int64_t end) {
192 // the number of indices over in the flattened tensor you need to skip in
193 // order to make it over from one side of the isd to the other
194 const int64_t isd_range = std::max<int>(dim_range[isd], 1);
195 // the distance along the flattend tensor to the next element in the isd
196 const int64_t isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
197
198 // start and end represent the i-th group currently so we will convert
199 // them into numbers representing the i-th elements.
200 // there are 2 groups per isd one for all elements before threshold[isd]
201 // and another for all elements after threshold[isd].
202 const int64_t start_remainder = (start % 2) * threshold[isd] * isd_stride;
203 const int64_t end_remainder = (end % 2) * threshold[isd] * isd_stride;
204 start = (start / 2) * isd_range + start_remainder;
205 end = (end / 2) * isd_range + end_remainder;
206
207 const T* in_ptr = &input[0];
208 T* out_ptr = &output[0];
209 in_ptr += start;
210 out_ptr += start;
211
212 // array of indices for each dimension
213 // indices = [i, j, k, l, m, n]
214 gtl::InlinedVector<int, 4> indices(num_dims);
215 // the offset needed to make all inner non-shifting dimensions become 0
216 int64_t remainder_offset = 0;
217 // initialize indices
218 for (int i = 0; i < num_dims; i++) {
219 // stride is the number of indices over in the flattened tensor
220 // you need to skip in order to make it over to an adjacent element
221 // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
222 const int64_t stride = dim_range[i] / dim_size[i];
223 const int shift = dim_size[i] - threshold[i];
224 const int indx = (start / stride) % dim_size[i];
225 indices[i] = indx;
226 // calculate dimension index after the shift
227 int out_indx = (indx + shift) % dim_size[i];
228 if (i > isd) {
229 // trailing zeroes for indices after the inner shifted dimension
230 out_indx = 0;
231 remainder_offset += (out_indx - indx) * stride;
232 }
233 out_ptr += (out_indx - indx) * stride;
234 }
235 // set trailing zeroes for indices after the inner shifted dimension
236 for (int i = num_dims - 1; i > isd; i--) indices[i] = 0;
237
238 // the number of indices in the isd dimension the next group will skip
239 // to make it to the next threshold or end point
240 int isd_indx_skip = 0;
241 // the size of the next group
242 int64_t group_size = 0;
243 // initialize isd_indx_skip and group_size
244 if (indices[isd] < threshold[isd]) {
245 isd_indx_skip = threshold[isd] - indices[isd];
246 group_size = isd_indx_skip * isd_stride + remainder_offset;
247 } else {
248 isd_indx_skip = dim_size[isd] - indices[isd];
249 group_size = isd_indx_skip * isd_stride + remainder_offset;
250 }
251
252 int64_t i = start;
253 while (i < end) {
254 // copy group of elements
255 memcpy(out_ptr, in_ptr, group_size * sizeof(T));
256
257 // shift i and the pointers over to the next group position
258 i += group_size;
259 out_ptr += group_size;
260 in_ptr += group_size;
261
262 // produce next combination of indices and adjust the out_ptr position
263 // to fix the offset if necessary
264 // the isd (inner shift dim) should skip to next threshold or endpoint
265 // all dimensions to the left increment by 1 when a digit is carried
266 // all dimensions to the right remain set to 0
267 // +1 +1 +1 +isd_indx_skip
268 // indices = [i, j, k, l, 0, 0]
269 // ^isd
270 for (int j = isd; j >= 0; j--) {
271 int inc = 1;
272 if (j == isd) inc = isd_indx_skip;
273 const int indx = (indices[j] + inc) % dim_size[j];
274 indices[j] = indx;
275 if (indx != 0) {
276 if (indx == threshold[j]) {
277 out_ptr -= dim_range[j]; // now wraps around
278 }
279 break; // indx != 0 don't need to carry
280 } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
281 out_ptr += dim_range[j]; // indx became 0 so reverse wrap around
282 }
283 }
284
285 // set isd_indx_skip and group_size for next iteration
286 if (indices[isd] < threshold[isd]) {
287 isd_indx_skip = threshold[isd] - indices[isd];
288 group_size = isd_indx_skip * isd_stride;
289 } else {
290 isd_indx_skip = dim_size[isd] - indices[isd];
291 group_size = isd_indx_skip * isd_stride;
292 }
293 }
294 };
295 // Shard
296 auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
297 const int64_t ave_group_size = dim_range[isd] / 2;
298 const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
299 // 25000 - experimentally determined with float and bool types
300 const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
301 Shard(worker_threads->num_threads, worker_threads->workers, total_work,
302 cost_per_group, std::move(work));
303}
304
305template <typename T>
306struct Roll<CPUDevice, T> {
307 void operator()(const OpKernelContext* context, const int64_t num_elements,
308 const int num_dims, const gtl::ArraySlice<int32> dim_size,
309 const T* input, T* output,
310 const gtl::ArraySlice<int32> threshold,
311 const gtl::ArraySlice<int64_t> dim_range, const int64_t isd) {
312 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
313 // V2 copies memory in groups instead of element by element
314 DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size, input,
315 output, threshold, dim_range, isd);
316 } else {
317 // incase memcpy does not work for current data type
318 DoRoll<T>(context, num_elements, num_dims, dim_size, input, output,
319 threshold, dim_range);
320 }
321 };
322};
323} // namespace functor
324
325// Register the CPU kernels.
326#define REGISTER_CPU(type) \
327 REGISTER_KERNEL_BUILDER(Name("Roll") \
328 .Device(DEVICE_CPU) \
329 .TypeConstraint<type>("T") \
330 .TypeConstraint<int32>("Tshift") \
331 .TypeConstraint<int32>("Taxis") \
332 .HostMemory("shift") \
333 .HostMemory("axis"), \
334 RollOp<CPUDevice, type, int32, int32>) \
335 REGISTER_KERNEL_BUILDER(Name("Roll") \
336 .Device(DEVICE_CPU) \
337 .TypeConstraint<type>("T") \
338 .TypeConstraint<int64_t>("Tshift") \
339 .TypeConstraint<int32>("Taxis") \
340 .HostMemory("shift") \
341 .HostMemory("axis"), \
342 RollOp<CPUDevice, type, int64, int32>) \
343 REGISTER_KERNEL_BUILDER(Name("Roll") \
344 .Device(DEVICE_CPU) \
345 .TypeConstraint<type>("T") \
346 .TypeConstraint<int32>("Tshift") \
347 .TypeConstraint<int64_t>("Taxis") \
348 .HostMemory("shift") \
349 .HostMemory("axis"), \
350 RollOp<CPUDevice, type, int32, int64>) \
351 REGISTER_KERNEL_BUILDER(Name("Roll") \
352 .Device(DEVICE_CPU) \
353 .TypeConstraint<type>("T") \
354 .TypeConstraint<int64_t>("Tshift") \
355 .TypeConstraint<int64_t>("Taxis") \
356 .HostMemory("shift") \
357 .HostMemory("axis"), \
358 RollOp<CPUDevice, type, int64, int64>)
359
360TF_CALL_ALL_TYPES(REGISTER_CPU);
361#undef REGISTER_CPU
362
363#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
364#define REGISTER_KERNEL(type) \
365 REGISTER_KERNEL_BUILDER(Name("Roll") \
366 .Device(DEVICE_GPU) \
367 .TypeConstraint<type>("T") \
368 .TypeConstraint<int32>("Tshift") \
369 .TypeConstraint<int32>("Taxis") \
370 .HostMemory("shift") \
371 .HostMemory("axis"), \
372 RollOp<GPUDevice, type, int32, int32>) \
373 REGISTER_KERNEL_BUILDER(Name("Roll") \
374 .Device(DEVICE_GPU) \
375 .TypeConstraint<type>("T") \
376 .TypeConstraint<int64_t>("Tshift") \
377 .TypeConstraint<int32>("Taxis") \
378 .HostMemory("shift") \
379 .HostMemory("axis"), \
380 RollOp<GPUDevice, type, int64, int32>) \
381 REGISTER_KERNEL_BUILDER(Name("Roll") \
382 .Device(DEVICE_GPU) \
383 .TypeConstraint<type>("T") \
384 .TypeConstraint<int32>("Tshift") \
385 .TypeConstraint<int64_t>("Taxis") \
386 .HostMemory("shift") \
387 .HostMemory("axis"), \
388 RollOp<GPUDevice, type, int32, int64>) \
389 REGISTER_KERNEL_BUILDER(Name("Roll") \
390 .Device(DEVICE_GPU) \
391 .TypeConstraint<type>("T") \
392 .TypeConstraint<int64_t>("Tshift") \
393 .TypeConstraint<int64_t>("Taxis") \
394 .HostMemory("shift") \
395 .HostMemory("axis"), \
396 RollOp<GPUDevice, type, int64, int64>)
397
398TF_CALL_int32(REGISTER_KERNEL);
399TF_CALL_int64(REGISTER_KERNEL);
400TF_CALL_uint32(REGISTER_KERNEL);
401TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
402TF_CALL_COMPLEX_TYPES(REGISTER_KERNEL);
403
404#undef REGISTER_KERNEL
405#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
406} // namespace tensorflow
407