1 | #pragma once |
2 | #include <ATen/native/TensorIterator.h> |
3 | #include <c10/util/SmallBuffer.h> |
4 | #include <c10/util/irange.h> |
5 | |
6 | namespace at { |
7 | |
8 | struct DimCounter { |
9 | DimCounter(IntArrayRef shape, Range range); |
10 | |
11 | void increment(const std::array<int64_t, 2>& step); |
12 | bool is_done() const; |
13 | std::array<int64_t, 2> max_2d_step() const; |
14 | |
15 | IntArrayRef shape; |
16 | Range range; |
17 | c10::SmallBuffer<int64_t, 4> values; |
18 | int64_t offset; |
19 | }; |
20 | |
21 | namespace internal { |
22 | |
23 | inline void get_data_ptrs( |
24 | char** ptrs, |
25 | ArrayRef<char*> base, |
26 | IntArrayRef strides, |
27 | IntArrayRef counter) { |
28 | const int64_t ntensors = base.size(); |
29 | const int64_t ndim = counter.size(); |
30 | std::copy(base.begin(), base.end(), ptrs); |
31 | for (const auto dim : c10::irange(ndim)) { |
32 | int64_t value = counter[dim]; |
33 | for (const auto arg : c10::irange(ntensors)) { |
34 | ptrs[arg] += value * strides[dim * ntensors + arg]; |
35 | } |
36 | } |
37 | } |
38 | |
39 | inline void serial_for_each( |
40 | IntArrayRef shape, |
41 | IntArrayRef strides, |
42 | char** base_ptrs, |
43 | size_t ntensors, |
44 | typename TensorIteratorBase::loop2d_t loop, |
45 | Range range) { |
46 | const auto ndim = shape.size(); |
47 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
48 | strides.size() == ntensors * std::max(size_t{2}, ndim)); |
49 | |
50 | if (ndim <= 1) { |
51 | if (range.begin == 0) { |
52 | loop(base_ptrs, strides.data(), range.size(), 1); |
53 | } else { |
54 | c10::SmallBuffer<char*, 4> ptrs(ntensors); |
55 | get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin}); |
56 | loop(ptrs.data(), strides.data(), range.size(), 1); |
57 | } |
58 | } else { |
59 | c10::SmallBuffer<char*, 4> ptrs(ntensors); |
60 | auto counter = DimCounter(shape, range); |
61 | while (!counter.is_done()) { |
62 | get_data_ptrs( |
63 | ptrs.data(), {base_ptrs, ntensors}, strides, counter.values); |
64 | auto step = counter.max_2d_step(); |
65 | loop(ptrs.data(), strides.data(), step[0], step[1]); |
66 | counter.increment(step); |
67 | } |
68 | } |
69 | } |
70 | |
71 | } // namespace internal |
72 | } // namespace at |
73 | |