1 | #pragma once |
2 | |
3 | #include <string> |
4 | #include <vector> |
5 | |
6 | #include <c10/util/irange.h> |
7 | #include "torch/csrc/jit/tensorexpr/eval.h" |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace tensorexpr { |
12 | |
13 | template <typename T> |
14 | struct DefaultPaddedValue; |
15 | |
16 | template <> |
17 | struct DefaultPaddedValue<int> { |
18 | static const int kValue = static_cast<int>(0xDEADBEEF); |
19 | }; |
20 | |
21 | template <> |
22 | struct DefaultPaddedValue<int8_t> { |
23 | static const int8_t kValue = static_cast<int8_t>(0xBE); |
24 | }; |
25 | |
26 | template <> |
27 | struct DefaultPaddedValue<uint8_t> { |
28 | static const uint8_t kValue = static_cast<uint8_t>(0xBE); |
29 | }; |
30 | |
31 | template <> |
32 | struct DefaultPaddedValue<int16_t> { |
33 | static const int16_t kValue = static_cast<int16_t>(0xBEEF); |
34 | }; |
35 | |
36 | template <> |
37 | struct DefaultPaddedValue<int64_t> { |
38 | static const int64_t kValue = static_cast<int64_t>(0xDEADBEEF); |
39 | }; |
40 | |
41 | template <> |
42 | struct DefaultPaddedValue<float> { |
43 | static constexpr float kValue = 0.1357; |
44 | }; |
45 | |
46 | template <> |
47 | struct DefaultPaddedValue<at::Half> { |
48 | // at::Half ctor isn't constexpr, so just fill it with bits. |
49 | static constexpr uint16_t kValue = 1357; |
50 | }; |
51 | |
52 | template <> |
53 | struct DefaultPaddedValue<double> { |
54 | static constexpr double kValue = 0.1357; |
55 | }; |
56 | |
57 | // A concrete base to be used in PaddedBase. |
58 | class PaddedBufferBase { |
59 | public: |
60 | const std::string& name() const { |
61 | return name_; |
62 | } |
63 | |
64 | int size() const { |
65 | return total_size_; |
66 | } |
67 | |
68 | int raw_size() const { |
69 | return total_size_ + 2 * kPaddingSize; |
70 | } |
71 | |
72 | virtual ~PaddedBufferBase() {} |
73 | |
74 | protected: |
75 | explicit PaddedBufferBase( |
76 | const std::vector<int>& dims, |
77 | const std::string& name); |
78 | int Index(const std::vector<int>& indices) const; |
79 | |
80 | std::vector<int> dims_; |
81 | std::string name_; |
82 | std::vector<int> strides_; |
83 | int total_size_; // total number of useful element, does not include the |
84 | // paddings |
85 | static constexpr int kPaddingSize = 64; |
86 | }; |
87 | |
88 | // A padded buffer with wartermarks for testing. |
89 | // The buffer carries padded watermarks on both sides to catch potential |
90 | // out-of-bounds writes. For read-only data that are not supposed to change, it |
91 | // can also make a backup and be compared later. |
92 | template <typename T> |
93 | class PaddedBuffer : public PaddedBufferBase { |
94 | public: |
95 | PaddedBuffer(int d0, const std::string& name = "" ) |
96 | : PaddedBuffer(std::vector<int>({d0}), name) {} |
97 | PaddedBuffer(int d0, int d1, const std::string& name = "" ) |
98 | : PaddedBuffer(std::vector<int>({d0, d1}), name) {} |
99 | PaddedBuffer(int d0, int d1, int d2, const std::string& name = "" ) |
100 | : PaddedBuffer(std::vector<int>({d0, d1, d2}), name) {} |
101 | PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "" ) |
102 | : PaddedBuffer(std::vector<int>({d0, d1, d2, d3}), name) {} |
103 | PaddedBuffer(const std::vector<int>& dims, const std::string& name = "" ) |
104 | : PaddedBufferBase(dims, name) { |
105 | data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue); |
106 | } |
107 | PaddedBuffer(const PaddedBuffer& other, const std::string& name) |
108 | : PaddedBuffer(other) { |
109 | this->name_ = name; |
110 | } |
111 | |
112 | T* data() { |
113 | return data_.data() + kPaddingSize; |
114 | } |
115 | const T* data() const { |
116 | return const_cast<PaddedBuffer*>(this)->data(); |
117 | } |
118 | T* raw_data() { |
119 | return data_.data(); |
120 | } |
121 | const T* raw_data() const { |
122 | return const_cast<PaddedBuffer*>(this)->raw_data(); |
123 | } |
124 | T& operator()(int i0) { |
125 | // There is a bit performance impact with forming a vector here. But this |
126 | // data structure is for testing only, and not performance critical. |
127 | return this->operator()(std::vector<int>({i0})); |
128 | } |
129 | const T& operator()(int i0) const { |
130 | return const_cast<PaddedBuffer*>(this)->operator()(i0); |
131 | } |
132 | T& operator()(int i0, int i1) { |
133 | return this->operator()(std::vector<int>({i0, i1})); |
134 | } |
135 | const T& operator()(int i0, int i1) const { |
136 | return const_cast<PaddedBuffer*>(this)->operator()(i0, i1); |
137 | } |
138 | T& operator()(int i0, int i1, int i2) { |
139 | return this->operator()(std::vector<int>({i0, i1, i2})); |
140 | } |
141 | const T& operator()(int i0, int i1, int i2) const { |
142 | return const_cast<PaddedBuffer*>(this)->operator()(i0, i1, i2); |
143 | } |
144 | T& operator()(int i0, int i1, int i2, int i3) { |
145 | return this->operator()(std::vector<int>({i0, i1, i2, i3})); |
146 | } |
147 | const T& operator()(int i0, int i1, int i2, int i3) const { |
148 | return const_cast<PaddedBuffer*>(this)->operator()(i0, i1, i2, i3); |
149 | } |
150 | T& operator()(const std::vector<int>& indices) { |
151 | return data_[kPaddingSize + Index(indices)]; |
152 | } |
153 | const T& operator()(const std::vector<int>& indices) const { |
154 | return const_cast<PaddedBuffer*>(this)->operator()(indices); |
155 | } |
156 | |
157 | template <typename U> |
158 | friend void ExpectAllNear( |
159 | const PaddedBuffer<U>& v1, |
160 | const PaddedBuffer<U>& v2, |
161 | float abs_error); |
162 | template <typename U> |
163 | friend void ExpectAllEqual( |
164 | const PaddedBuffer<U>& v1, |
165 | const PaddedBuffer<U>& v2); |
166 | void Backup() { |
167 | backup_data_ = data_; |
168 | } |
169 | |
170 | // Verify the watermarks in the paddings are intact. |
171 | void ValidateWatermark() const { |
172 | for (const auto i : c10::irange(kPaddingSize)) { |
173 | ASSERT_EQ(data_[i], kPaddingValue); |
174 | ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue); |
175 | } |
176 | } |
177 | |
178 | void CheckBackup() const { |
179 | ValidateWatermark(); |
180 | DCHECK(backup_data_.size() == data_.size()) |
181 | << "Please make sure you have call Backup() before calling CheckBackup()" ; |
182 | for (const auto i : c10::irange(total_size_)) { |
183 | ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]); |
184 | } |
185 | } |
186 | |
187 | private: |
188 | std::vector<T> data_; |
189 | std::vector<T> backup_data_; |
190 | T kPaddingValue = DefaultPaddedValue<T>::kValue; |
191 | }; |
192 | |
193 | template <typename T> |
194 | inline CodeGen::CallArg::CallArg(const PaddedBuffer<T>& buffer) |
195 | : data_(const_cast<T*>(buffer.data())) {} |
196 | |
197 | template <typename T> |
198 | std::string CompareErrorMsg( |
199 | const PaddedBuffer<T>& v1, |
200 | const PaddedBuffer<T>& v2, |
201 | int index) { |
202 | std::ostringstream oss; |
203 | oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index) |
204 | << ")" |
205 | << ", v2: (" << v2.name() << ", " << v2(index) << ")" ; |
206 | return oss.str(); |
207 | } |
208 | |
209 | template <typename T> |
210 | void ExpectAllEqual(const PaddedBuffer<T>& f1, const PaddedBuffer<T>& f2) { |
211 | const std::vector<T>& v1 = f1.data_; |
212 | const std::vector<T>& v2 = f2.data_; |
213 | const int kPaddingSize = f1.kPaddingSize; |
214 | const int total_size = f1.total_size_; |
215 | ASSERT_EQ(v1.size(), v2.size()); |
216 | f1.ValidateWatermark(); |
217 | f2.ValidateWatermark(); |
218 | for (const auto i : c10::irange(total_size)) { |
219 | ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]); |
220 | } |
221 | } |
222 | |
223 | template <typename T> |
224 | void ExpectAllNear( |
225 | const PaddedBuffer<T>& f1, |
226 | const PaddedBuffer<T>& f2, |
227 | float abs_error) { |
228 | const std::vector<T>& v1 = f1.data_; |
229 | const std::vector<T>& v2 = f2.data_; |
230 | const int kPaddingSize = f1.kPaddingSize; |
231 | const int total_size = f1.total_size_; |
232 | ASSERT_EQ(v1.size(), v2.size()); |
233 | f1.ValidateWatermark(); |
234 | f2.ValidateWatermark(); |
235 | for (const auto i : c10::irange(total_size)) { |
236 | ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error); |
237 | } |
238 | } |
239 | |
240 | } // namespace tensorexpr |
241 | } // namespace jit |
242 | } // namespace torch |
243 | |