1#pragma once
2
3#include <string>
4#include <vector>
5
6#include <c10/util/irange.h>
7#include "torch/csrc/jit/tensorexpr/eval.h"
8
9namespace torch {
10namespace jit {
11namespace tensorexpr {
12
13template <typename T>
14struct DefaultPaddedValue;
15
16template <>
17struct DefaultPaddedValue<int> {
18 static const int kValue = static_cast<int>(0xDEADBEEF);
19};
20
21template <>
22struct DefaultPaddedValue<int8_t> {
23 static const int8_t kValue = static_cast<int8_t>(0xBE);
24};
25
26template <>
27struct DefaultPaddedValue<uint8_t> {
28 static const uint8_t kValue = static_cast<uint8_t>(0xBE);
29};
30
31template <>
32struct DefaultPaddedValue<int16_t> {
33 static const int16_t kValue = static_cast<int16_t>(0xBEEF);
34};
35
36template <>
37struct DefaultPaddedValue<int64_t> {
38 static const int64_t kValue = static_cast<int64_t>(0xDEADBEEF);
39};
40
41template <>
42struct DefaultPaddedValue<float> {
43 static constexpr float kValue = 0.1357;
44};
45
46template <>
47struct DefaultPaddedValue<at::Half> {
48 // at::Half ctor isn't constexpr, so just fill it with bits.
49 static constexpr uint16_t kValue = 1357;
50};
51
52template <>
53struct DefaultPaddedValue<double> {
54 static constexpr double kValue = 0.1357;
55};
56
57// A concrete base to be used in PaddedBase.
58class PaddedBufferBase {
59 public:
60 const std::string& name() const {
61 return name_;
62 }
63
64 int size() const {
65 return total_size_;
66 }
67
68 int raw_size() const {
69 return total_size_ + 2 * kPaddingSize;
70 }
71
72 virtual ~PaddedBufferBase() {}
73
74 protected:
75 explicit PaddedBufferBase(
76 const std::vector<int>& dims,
77 const std::string& name);
78 int Index(const std::vector<int>& indices) const;
79
80 std::vector<int> dims_;
81 std::string name_;
82 std::vector<int> strides_;
83 int total_size_; // total number of useful element, does not include the
84 // paddings
85 static constexpr int kPaddingSize = 64;
86};
87
88// A padded buffer with wartermarks for testing.
89// The buffer carries padded watermarks on both sides to catch potential
90// out-of-bounds writes. For read-only data that are not supposed to change, it
91// can also make a backup and be compared later.
92template <typename T>
93class PaddedBuffer : public PaddedBufferBase {
94 public:
95 PaddedBuffer(int d0, const std::string& name = "")
96 : PaddedBuffer(std::vector<int>({d0}), name) {}
97 PaddedBuffer(int d0, int d1, const std::string& name = "")
98 : PaddedBuffer(std::vector<int>({d0, d1}), name) {}
99 PaddedBuffer(int d0, int d1, int d2, const std::string& name = "")
100 : PaddedBuffer(std::vector<int>({d0, d1, d2}), name) {}
101 PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "")
102 : PaddedBuffer(std::vector<int>({d0, d1, d2, d3}), name) {}
103 PaddedBuffer(const std::vector<int>& dims, const std::string& name = "")
104 : PaddedBufferBase(dims, name) {
105 data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue);
106 }
107 PaddedBuffer(const PaddedBuffer& other, const std::string& name)
108 : PaddedBuffer(other) {
109 this->name_ = name;
110 }
111
112 T* data() {
113 return data_.data() + kPaddingSize;
114 }
115 const T* data() const {
116 return const_cast<PaddedBuffer*>(this)->data();
117 }
118 T* raw_data() {
119 return data_.data();
120 }
121 const T* raw_data() const {
122 return const_cast<PaddedBuffer*>(this)->raw_data();
123 }
124 T& operator()(int i0) {
125 // There is a bit performance impact with forming a vector here. But this
126 // data structure is for testing only, and not performance critical.
127 return this->operator()(std::vector<int>({i0}));
128 }
129 const T& operator()(int i0) const {
130 return const_cast<PaddedBuffer*>(this)->operator()(i0);
131 }
132 T& operator()(int i0, int i1) {
133 return this->operator()(std::vector<int>({i0, i1}));
134 }
135 const T& operator()(int i0, int i1) const {
136 return const_cast<PaddedBuffer*>(this)->operator()(i0, i1);
137 }
138 T& operator()(int i0, int i1, int i2) {
139 return this->operator()(std::vector<int>({i0, i1, i2}));
140 }
141 const T& operator()(int i0, int i1, int i2) const {
142 return const_cast<PaddedBuffer*>(this)->operator()(i0, i1, i2);
143 }
144 T& operator()(int i0, int i1, int i2, int i3) {
145 return this->operator()(std::vector<int>({i0, i1, i2, i3}));
146 }
147 const T& operator()(int i0, int i1, int i2, int i3) const {
148 return const_cast<PaddedBuffer*>(this)->operator()(i0, i1, i2, i3);
149 }
150 T& operator()(const std::vector<int>& indices) {
151 return data_[kPaddingSize + Index(indices)];
152 }
153 const T& operator()(const std::vector<int>& indices) const {
154 return const_cast<PaddedBuffer*>(this)->operator()(indices);
155 }
156
157 template <typename U>
158 friend void ExpectAllNear(
159 const PaddedBuffer<U>& v1,
160 const PaddedBuffer<U>& v2,
161 float abs_error);
162 template <typename U>
163 friend void ExpectAllEqual(
164 const PaddedBuffer<U>& v1,
165 const PaddedBuffer<U>& v2);
166 void Backup() {
167 backup_data_ = data_;
168 }
169
170 // Verify the watermarks in the paddings are intact.
171 void ValidateWatermark() const {
172 for (const auto i : c10::irange(kPaddingSize)) {
173 ASSERT_EQ(data_[i], kPaddingValue);
174 ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue);
175 }
176 }
177
178 void CheckBackup() const {
179 ValidateWatermark();
180 DCHECK(backup_data_.size() == data_.size())
181 << "Please make sure you have call Backup() before calling CheckBackup()";
182 for (const auto i : c10::irange(total_size_)) {
183 ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]);
184 }
185 }
186
187 private:
188 std::vector<T> data_;
189 std::vector<T> backup_data_;
190 T kPaddingValue = DefaultPaddedValue<T>::kValue;
191};
192
193template <typename T>
194inline CodeGen::CallArg::CallArg(const PaddedBuffer<T>& buffer)
195 : data_(const_cast<T*>(buffer.data())) {}
196
197template <typename T>
198std::string CompareErrorMsg(
199 const PaddedBuffer<T>& v1,
200 const PaddedBuffer<T>& v2,
201 int index) {
202 std::ostringstream oss;
203 oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index)
204 << ")"
205 << ", v2: (" << v2.name() << ", " << v2(index) << ")";
206 return oss.str();
207}
208
209template <typename T>
210void ExpectAllEqual(const PaddedBuffer<T>& f1, const PaddedBuffer<T>& f2) {
211 const std::vector<T>& v1 = f1.data_;
212 const std::vector<T>& v2 = f2.data_;
213 const int kPaddingSize = f1.kPaddingSize;
214 const int total_size = f1.total_size_;
215 ASSERT_EQ(v1.size(), v2.size());
216 f1.ValidateWatermark();
217 f2.ValidateWatermark();
218 for (const auto i : c10::irange(total_size)) {
219 ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]);
220 }
221}
222
223template <typename T>
224void ExpectAllNear(
225 const PaddedBuffer<T>& f1,
226 const PaddedBuffer<T>& f2,
227 float abs_error) {
228 const std::vector<T>& v1 = f1.data_;
229 const std::vector<T>& v2 = f2.data_;
230 const int kPaddingSize = f1.kPaddingSize;
231 const int total_size = f1.total_size_;
232 ASSERT_EQ(v1.size(), v2.size());
233 f1.ValidateWatermark();
234 f2.ValidateWatermark();
235 for (const auto i : c10::irange(total_size)) {
236 ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error);
237 }
238}
239
240} // namespace tensorexpr
241} // namespace jit
242} // namespace torch
243