1#pragma once
2#include <ATen/native/TensorIterator.h>
3#include <c10/util/SmallBuffer.h>
4#include <c10/util/irange.h>
5
6namespace at {
7
8struct DimCounter {
9 DimCounter(IntArrayRef shape, Range range);
10
11 void increment(const std::array<int64_t, 2>& step);
12 bool is_done() const;
13 std::array<int64_t, 2> max_2d_step() const;
14
15 IntArrayRef shape;
16 Range range;
17 c10::SmallBuffer<int64_t, 4> values;
18 int64_t offset;
19};
20
21namespace internal {
22
23inline void get_data_ptrs(
24 char** ptrs,
25 ArrayRef<char*> base,
26 IntArrayRef strides,
27 IntArrayRef counter) {
28 const int64_t ntensors = base.size();
29 const int64_t ndim = counter.size();
30 std::copy(base.begin(), base.end(), ptrs);
31 for (const auto dim : c10::irange(ndim)) {
32 int64_t value = counter[dim];
33 for (const auto arg : c10::irange(ntensors)) {
34 ptrs[arg] += value * strides[dim * ntensors + arg];
35 }
36 }
37}
38
39inline void serial_for_each(
40 IntArrayRef shape,
41 IntArrayRef strides,
42 char** base_ptrs,
43 size_t ntensors,
44 typename TensorIteratorBase::loop2d_t loop,
45 Range range) {
46 const auto ndim = shape.size();
47 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
48 strides.size() == ntensors * std::max(size_t{2}, ndim));
49
50 if (ndim <= 1) {
51 if (range.begin == 0) {
52 loop(base_ptrs, strides.data(), range.size(), 1);
53 } else {
54 c10::SmallBuffer<char*, 4> ptrs(ntensors);
55 get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
56 loop(ptrs.data(), strides.data(), range.size(), 1);
57 }
58 } else {
59 c10::SmallBuffer<char*, 4> ptrs(ntensors);
60 auto counter = DimCounter(shape, range);
61 while (!counter.is_done()) {
62 get_data_ptrs(
63 ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
64 auto step = counter.max_2d_step();
65 loop(ptrs.data(), strides.data(), step[0], step[1]);
66 counter.increment(step);
67 }
68 }
69}
70
71} // namespace internal
72} // namespace at
73