1 | #include "test/cpp/tensorexpr/padded_buffer.h" |
---|---|
2 | |
3 | #include <c10/util/Logging.h> |
4 | #include <c10/util/irange.h> |
5 | #include <sstream> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace tensorexpr { |
10 | |
11 | int PaddedBufferBase::Index(const std::vector<int>& indices) const { |
12 | TORCH_DCHECK_EQ(dims_.size(), indices.size()); |
13 | int total_index = 0; |
14 | for (const auto i : c10::irange(dims_.size())) { |
15 | total_index += indices[i] * strides_[i]; |
16 | } |
17 | return total_index; |
18 | } |
19 | |
20 | PaddedBufferBase::PaddedBufferBase( |
21 | const std::vector<int>& dims, |
22 | // NOLINTNEXTLINE(modernize-pass-by-value) |
23 | const std::string& name) |
24 | : dims_(dims), name_(name), strides_(dims.size()) { |
25 | for (int i = (int)dims.size() - 1; i >= 0; --i) { |
26 | if (i == (int)dims.size() - 1) { |
27 | strides_[i] = 1; |
28 | } else { |
29 | strides_[i] = strides_[i + 1] * dims[i + 1]; |
30 | } |
31 | } |
32 | total_size_ = strides_[0] * dims[0]; |
33 | } |
34 | |
35 | } // namespace tensorexpr |
36 | } // namespace jit |
37 | } // namespace torch |
38 |