1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #pragma once |
8 | #include <ATen/ATen.h> |
9 | #include "minpybind.h" |
10 | |
11 | #ifdef _WIN32 |
12 | #include <intrin.h> |
13 | // https://stackoverflow.com/questions/355967/how-to-use-msvc-intrinsics-to-get-the-equivalent-of-this-gcc-code |
14 | inline unsigned int __builtin_clz(unsigned int x) { |
15 | unsigned long r = 0; |
16 | _BitScanReverse(&r, x); |
17 | return (31 - r); |
18 | } |
19 | #endif |
20 | |
21 | inline int round2min8(int num) { |
22 | int nzeros = __builtin_clz((num - 1)|4); |
23 | return 1 << (32 - nzeros); |
24 | } |
25 | |
26 | struct Arena; |
27 | template<typename T> |
28 | struct OwnedSlice; |
29 | |
30 | template<typename T> |
31 | struct Slice { |
32 | Slice() |
33 | : begin_(nullptr), size_(0), capacity_(0) {} |
34 | |
35 | template<typename... Args> |
36 | Slice(Arena& arena, Args&&... args); |
37 | |
38 | T* begin() const { |
39 | return begin_; |
40 | } |
41 | T* end() const { |
42 | return begin_ + size_; |
43 | } |
44 | int size() const { |
45 | return size_; |
46 | } |
47 | int capacity() const { |
48 | return capacity_; |
49 | } |
50 | |
51 | T& back(int i=-1) { |
52 | return begin_[size_ + i]; |
53 | } |
54 | |
55 | T& operator[](int i) const { |
56 | return begin_[i]; |
57 | } |
58 | c10::optional<int> index(const T& value) { |
59 | for (int i : enumerate()) { |
60 | if (begin_[i] == value) { |
61 | return i; |
62 | } |
63 | } |
64 | return c10::nullopt; |
65 | } |
66 | bool contains(const T& value) { |
67 | return index(value).has_value(); |
68 | } |
69 | |
70 | void insert(Arena& arena, Slice where, Slice to_insert); |
71 | void insert(Arena& arena, Slice where, T v) { |
72 | return insert(arena, where, Slice(&v, &v + 1)); |
73 | } |
74 | void insert(Arena& arena, int where, T v) { |
75 | return insert(arena, slice(where, where), v); |
76 | } |
77 | void append(Arena& arena, T value); |
78 | void extend(Arena& arena, Slice to_insert); |
79 | void extend(Arena& arena, const T* begin, const T* end) { |
80 | return extend(arena, Slice<T>((T*)begin, (T*)end)); |
81 | } |
82 | |
83 | bool remove(Arena& A, T value) { |
84 | auto idx = index(value); |
85 | if (idx) { |
86 | insert(A, slice(*idx, *idx + 1), Slice()); |
87 | } |
88 | return idx.has_value(); |
89 | } |
90 | |
91 | Slice slice(int begin) { |
92 | return slice(begin, size_); |
93 | } |
94 | |
95 | Slice slice(int begin, int end) { |
96 | if (begin < 0) { |
97 | begin += size_; |
98 | } |
99 | if (end < 0) { |
100 | end += size_; |
101 | } |
102 | Slice result; |
103 | result.begin_ = begin_ + begin; |
104 | result.size_ = end - begin; |
105 | result.capacity_ = result.size_; |
106 | return result; |
107 | } |
108 | |
109 | bool inside(Slice where) { |
110 | return begin() <= where.begin() && where.end() <= end(); |
111 | } |
112 | |
113 | irange enumerate() const { |
114 | return irange(size_); |
115 | } |
116 | |
117 | irange reversed_enumerate() const { |
118 | return irange(size_ - 1, -1, -1); |
119 | } |
120 | |
121 | bool operator==(const Slice<T>& rhs) const { |
122 | if (size() != rhs.size()) { |
123 | return false; |
124 | } |
125 | return std::equal(begin(), end(), rhs.begin()); |
126 | } |
127 | |
128 | Slice(T* begin, T* end) |
129 | : begin_(begin), size_(end - begin), capacity_(size_) {} |
130 | |
131 | protected: |
132 | static int _length(const T& t) { |
133 | return 1; |
134 | } |
135 | static int _length(Slice t) { |
136 | return t.size_; |
137 | } |
138 | static T* _insert(T*& dst, T t) { |
139 | *dst = std::move(t); |
140 | return ++dst; |
141 | } |
142 | static T* _insert(T*& dst, Slice t) { |
143 | std::memcpy(dst, t.begin_, sizeof(T)*t.size_); |
144 | dst += t.size_; |
145 | return dst; |
146 | } |
147 | T* begin_; |
148 | int size_; |
149 | int capacity_; |
150 | friend struct OwnedSlice<T>; |
151 | }; |
152 | |
153 | template<typename T> |
154 | struct OwnedSlice { |
155 | typedef void (*deleter_t)(Slice<T>); |
156 | static void _no_delete(Slice<T>) {} |
157 | OwnedSlice() |
158 | : deleter_(_no_delete) {} |
159 | OwnedSlice(const OwnedSlice&) = delete; |
160 | OwnedSlice& operator=(const OwnedSlice&) = delete; |
161 | ~OwnedSlice() { |
162 | deleter_(slice_); |
163 | if (slice_.size_ > 8) { |
164 | delete [] slice_.begin_; |
165 | } |
166 | } |
167 | void set(Slice<T> to_own, deleter_t deleter = _no_delete) { |
168 | slice_.size_ = slice_.capacity_ = to_own.size(); |
169 | slice_.begin_ = (slice_.size_ > 8) ? new T[slice_.size_] : &small_buf[0]; |
170 | std::memcpy(slice_.begin_, to_own.begin(), slice_.size_ * sizeof(T)); |
171 | deleter_ = deleter; |
172 | } |
173 | Slice<T> slice() const { |
174 | return slice_; |
175 | } |
176 | private: |
177 | Slice<T> slice_; |
178 | deleter_t deleter_; |
179 | T small_buf[8]; |
180 | }; |
181 | |
182 | template<typename T> |
183 | inline std::ostream& operator<<(std::ostream& s, const Slice<T>& v) { |
184 | s << "[" ; |
185 | for (int i : v.enumerate()) { |
186 | if (i > 0) { |
187 | s << ", " ; |
188 | } |
189 | s << v[i]; |
190 | } |
191 | s << "]" ; |
192 | return s; |
193 | } |
194 | |
195 | struct TensorRef { |
196 | TensorRef() |
197 | : impl_(nullptr){} |
198 | TensorRef(const at::Tensor& t) |
199 | : impl_(t.unsafeGetTensorImpl()) {} |
200 | const at::Tensor& operator*() const { |
201 | return *(at::Tensor*)this; |
202 | } |
203 | at::Tensor* operator->() const { |
204 | return (at::Tensor*)this; |
205 | } |
206 | operator bool() const { |
207 | return impl_ != nullptr; |
208 | } |
209 | private: |
210 | at::TensorImpl* impl_; |
211 | }; |
212 | |
213 | constexpr int ARENA_MAX_SIZE = 4096; |
214 | constexpr int ALIGNMENT = 8; |
215 | struct Arena { |
216 | Arena() |
217 | : allocated_(0) {} |
218 | template<typename T> |
219 | T* allocate(int n) { |
220 | if (!n) { |
221 | return nullptr; |
222 | } |
223 | int to_allocate = sizeof(T)*n; |
224 | int to_allocate_rounded = ALIGNMENT * ((to_allocate - 1) / ALIGNMENT + 1); |
225 | auto prev_allocated = allocated_; |
226 | allocated_ += to_allocate_rounded; |
227 | if (C10_UNLIKELY_OR_CONST(allocated_ > ARENA_MAX_SIZE)) { |
228 | overflow_.emplace_back(new char[to_allocate]); |
229 | return (T*) &overflow_.back()[0]; |
230 | } |
231 | return (T*) (buffer_ + prev_allocated); |
232 | } |
233 | TensorRef autorelease(at::Tensor s) { |
234 | auto ref = TensorRef(s); |
235 | s.unsafeReleaseTensorImpl(); |
236 | ar_tensors_.append(*this, ref); |
237 | return ref; |
238 | } |
239 | py::handle autorelease(py::object obj) { |
240 | ar_objects_.append(*this, obj); |
241 | obj.release(); |
242 | return ar_objects_.back(); |
243 | } |
244 | ~Arena() { |
245 | for(TensorRef t: ar_tensors_) { |
246 | c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(t->unsafeGetTensorImpl()); |
247 | } |
248 | for(py::handle h: ar_objects_) { |
249 | py::object::steal(h); |
250 | } |
251 | } |
252 | private: |
253 | int64_t allocated_; |
254 | char buffer_[ARENA_MAX_SIZE]; |
255 | Slice<TensorRef> ar_tensors_; |
256 | Slice<py::handle> ar_objects_; |
257 | std::vector<std::unique_ptr<char[]>> overflow_; |
258 | }; |
259 | |
260 | template<typename T> |
261 | inline void Slice<T>::insert(Arena& arena, Slice where, Slice to_insert) { |
262 | AT_ASSERT(inside(where)); |
263 | Slice result = *this; |
264 | /// b------sb---se-----e, 0----n |
265 | T* body_dest = where.begin(); |
266 | if (where.size() != to_insert.size()) { |
267 | int new_size = size() - where.size() + to_insert.size(); |
268 | T* tail_dest = where.begin() + to_insert.size(); |
269 | if (new_size >= capacity_) { |
270 | int new_capacity = new_size ? round2min8(new_size) : 0; |
271 | result.capacity_ = new_capacity; |
272 | result.begin_ = arena.allocate<T>(new_capacity); |
273 | body_dest = result.begin_ + (where.begin() - begin()); |
274 | tail_dest = body_dest + to_insert.size(); |
275 | //std::memcpy(result.begin_, begin_, sizeof(T)*(where.begin() - begin())); |
276 | std::copy(begin_, begin_ + (where.begin() - begin()), result.begin_); |
277 | } |
278 | std::memmove(tail_dest, where.end(), sizeof(T)*(end() - where.end())); |
279 | result.size_ = new_size; |
280 | } |
281 | |
282 | //std::memcpy(body_dest, to_insert.begin(), sizeof(T)*to_insert.size()); |
283 | std::copy(to_insert.begin(), to_insert.end(), body_dest); |
284 | *this = result; |
285 | } |
286 | |
287 | template<typename T> |
288 | inline void Slice<T>::append(Arena& arena, T value) { |
289 | Slice result = *this; |
290 | if (size_ == capacity_) { |
291 | int new_size = size_ ? round2min8(size_)*2 : 8; |
292 | T* n = arena.allocate<T>(new_size); |
293 | //memcpy(n, begin_, size_*sizeof(T)); |
294 | std::copy(begin_, begin_ + size_, n); |
295 | result.begin_ = n; |
296 | result.capacity_ = new_size; |
297 | } |
298 | result[result.size_++] = std::move(value); |
299 | *this = result; |
300 | } |
301 | |
302 | template<typename T> |
303 | inline void Slice<T>::extend(Arena& arena, Slice<T> rhs) { |
304 | Slice result = *this; |
305 | result.size_ = size_ + rhs.size(); |
306 | if (result.size_ > capacity_) { |
307 | int new_size = round2min8(result.size_); |
308 | T* n = arena.allocate<T>(new_size); |
309 | //memcpy(n, begin_, size_*sizeof(T)); |
310 | std::copy(begin_, begin_+size_, n); |
311 | result.begin_ = n; |
312 | result.capacity_ = new_size; |
313 | } |
314 | //memcpy(result.begin_ + size_, rhs.begin(), sizeof(T)*rhs.size()); |
315 | std::copy(rhs.begin(), rhs.end(), result.begin_ + size_); |
316 | *this = result; |
317 | } |
318 | |
319 | template<typename T> |
320 | template<typename... Args> |
321 | Slice<T>::Slice(Arena& arena, Args&&... args) { |
322 | int lens[] = {_length(args)...}; |
323 | size_ = 0; |
324 | for (auto i : lens) { |
325 | size_ += i; |
326 | } |
327 | capacity_ = size_ ? round2min8(size_) : 0; |
328 | begin_ = arena.allocate<T>(capacity_); |
329 | T* dst_ = begin_; |
330 | T* unused[] = {_insert(dst_, args)...}; |
331 | (void) unused; |
332 | } |
333 | |