1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#pragma once
8#include <ATen/ATen.h>
9#include "minpybind.h"
10
11#ifdef _WIN32
12#include <intrin.h>
13// https://stackoverflow.com/questions/355967/how-to-use-msvc-intrinsics-to-get-the-equivalent-of-this-gcc-code
14inline unsigned int __builtin_clz(unsigned int x) {
15 unsigned long r = 0;
16 _BitScanReverse(&r, x);
17 return (31 - r);
18}
19#endif
20
21inline int round2min8(int num) {
22 int nzeros = __builtin_clz((num - 1)|4);
23 return 1 << (32 - nzeros);
24}
25
26struct Arena;
27template<typename T>
28struct OwnedSlice;
29
30template<typename T>
31struct Slice {
32 Slice()
33 : begin_(nullptr), size_(0), capacity_(0) {}
34
35 template<typename... Args>
36 Slice(Arena& arena, Args&&... args);
37
38 T* begin() const {
39 return begin_;
40 }
41 T* end() const {
42 return begin_ + size_;
43 }
44 int size() const {
45 return size_;
46 }
47 int capacity() const {
48 return capacity_;
49 }
50
51 T& back(int i=-1) {
52 return begin_[size_ + i];
53 }
54
55 T& operator[](int i) const {
56 return begin_[i];
57 }
58 c10::optional<int> index(const T& value) {
59 for (int i : enumerate()) {
60 if (begin_[i] == value) {
61 return i;
62 }
63 }
64 return c10::nullopt;
65 }
66 bool contains(const T& value) {
67 return index(value).has_value();
68 }
69
70 void insert(Arena& arena, Slice where, Slice to_insert);
71 void insert(Arena& arena, Slice where, T v) {
72 return insert(arena, where, Slice(&v, &v + 1));
73 }
74 void insert(Arena& arena, int where, T v) {
75 return insert(arena, slice(where, where), v);
76 }
77 void append(Arena& arena, T value);
78 void extend(Arena& arena, Slice to_insert);
79 void extend(Arena& arena, const T* begin, const T* end) {
80 return extend(arena, Slice<T>((T*)begin, (T*)end));
81 }
82
83 bool remove(Arena& A, T value) {
84 auto idx = index(value);
85 if (idx) {
86 insert(A, slice(*idx, *idx + 1), Slice());
87 }
88 return idx.has_value();
89 }
90
91 Slice slice(int begin) {
92 return slice(begin, size_);
93 }
94
95 Slice slice(int begin, int end) {
96 if (begin < 0) {
97 begin += size_;
98 }
99 if (end < 0) {
100 end += size_;
101 }
102 Slice result;
103 result.begin_ = begin_ + begin;
104 result.size_ = end - begin;
105 result.capacity_ = result.size_;
106 return result;
107 }
108
109 bool inside(Slice where) {
110 return begin() <= where.begin() && where.end() <= end();
111 }
112
113 irange enumerate() const {
114 return irange(size_);
115 }
116
117 irange reversed_enumerate() const {
118 return irange(size_ - 1, -1, -1);
119 }
120
121 bool operator==(const Slice<T>& rhs) const {
122 if (size() != rhs.size()) {
123 return false;
124 }
125 return std::equal(begin(), end(), rhs.begin());
126 }
127
128 Slice(T* begin, T* end)
129 : begin_(begin), size_(end - begin), capacity_(size_) {}
130
131protected:
132 static int _length(const T& t) {
133 return 1;
134 }
135 static int _length(Slice t) {
136 return t.size_;
137 }
138 static T* _insert(T*& dst, T t) {
139 *dst = std::move(t);
140 return ++dst;
141 }
142 static T* _insert(T*& dst, Slice t) {
143 std::memcpy(dst, t.begin_, sizeof(T)*t.size_);
144 dst += t.size_;
145 return dst;
146 }
147 T* begin_;
148 int size_;
149 int capacity_;
150 friend struct OwnedSlice<T>;
151};
152
153template<typename T>
154struct OwnedSlice {
155 typedef void (*deleter_t)(Slice<T>);
156 static void _no_delete(Slice<T>) {}
157 OwnedSlice()
158 : deleter_(_no_delete) {}
159 OwnedSlice(const OwnedSlice&) = delete;
160 OwnedSlice& operator=(const OwnedSlice&) = delete;
161 ~OwnedSlice() {
162 deleter_(slice_);
163 if (slice_.size_ > 8) {
164 delete [] slice_.begin_;
165 }
166 }
167 void set(Slice<T> to_own, deleter_t deleter = _no_delete) {
168 slice_.size_ = slice_.capacity_ = to_own.size();
169 slice_.begin_ = (slice_.size_ > 8) ? new T[slice_.size_] : &small_buf[0];
170 std::memcpy(slice_.begin_, to_own.begin(), slice_.size_ * sizeof(T));
171 deleter_ = deleter;
172 }
173 Slice<T> slice() const {
174 return slice_;
175 }
176private:
177 Slice<T> slice_;
178 deleter_t deleter_;
179 T small_buf[8];
180};
181
182template<typename T>
183inline std::ostream& operator<<(std::ostream& s, const Slice<T>& v) {
184 s << "[";
185 for (int i : v.enumerate()) {
186 if (i > 0) {
187 s << ", ";
188 }
189 s << v[i];
190 }
191 s << "]";
192 return s;
193}
194
195struct TensorRef {
196 TensorRef()
197 : impl_(nullptr){}
198 TensorRef(const at::Tensor& t)
199 : impl_(t.unsafeGetTensorImpl()) {}
200 const at::Tensor& operator*() const {
201 return *(at::Tensor*)this;
202 }
203 at::Tensor* operator->() const {
204 return (at::Tensor*)this;
205 }
206 operator bool() const {
207 return impl_ != nullptr;
208 }
209private:
210 at::TensorImpl* impl_;
211};
212
213constexpr int ARENA_MAX_SIZE = 4096;
214constexpr int ALIGNMENT = 8;
215struct Arena {
216 Arena()
217 : allocated_(0) {}
218 template<typename T>
219 T* allocate(int n) {
220 if (!n) {
221 return nullptr;
222 }
223 int to_allocate = sizeof(T)*n;
224 int to_allocate_rounded = ALIGNMENT * ((to_allocate - 1) / ALIGNMENT + 1);
225 auto prev_allocated = allocated_;
226 allocated_ += to_allocate_rounded;
227 if (C10_UNLIKELY_OR_CONST(allocated_ > ARENA_MAX_SIZE)) {
228 overflow_.emplace_back(new char[to_allocate]);
229 return (T*) &overflow_.back()[0];
230 }
231 return (T*) (buffer_ + prev_allocated);
232 }
233 TensorRef autorelease(at::Tensor s) {
234 auto ref = TensorRef(s);
235 s.unsafeReleaseTensorImpl();
236 ar_tensors_.append(*this, ref);
237 return ref;
238 }
239 py::handle autorelease(py::object obj) {
240 ar_objects_.append(*this, obj);
241 obj.release();
242 return ar_objects_.back();
243 }
244 ~Arena() {
245 for(TensorRef t: ar_tensors_) {
246 c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(t->unsafeGetTensorImpl());
247 }
248 for(py::handle h: ar_objects_) {
249 py::object::steal(h);
250 }
251 }
252private:
253 int64_t allocated_;
254 char buffer_[ARENA_MAX_SIZE];
255 Slice<TensorRef> ar_tensors_;
256 Slice<py::handle> ar_objects_;
257 std::vector<std::unique_ptr<char[]>> overflow_;
258};
259
260template<typename T>
261inline void Slice<T>::insert(Arena& arena, Slice where, Slice to_insert) {
262 AT_ASSERT(inside(where));
263 Slice result = *this;
264 /// b------sb---se-----e, 0----n
265 T* body_dest = where.begin();
266 if (where.size() != to_insert.size()) {
267 int new_size = size() - where.size() + to_insert.size();
268 T* tail_dest = where.begin() + to_insert.size();
269 if (new_size >= capacity_) {
270 int new_capacity = new_size ? round2min8(new_size) : 0;
271 result.capacity_ = new_capacity;
272 result.begin_ = arena.allocate<T>(new_capacity);
273 body_dest = result.begin_ + (where.begin() - begin());
274 tail_dest = body_dest + to_insert.size();
275 //std::memcpy(result.begin_, begin_, sizeof(T)*(where.begin() - begin()));
276 std::copy(begin_, begin_ + (where.begin() - begin()), result.begin_);
277 }
278 std::memmove(tail_dest, where.end(), sizeof(T)*(end() - where.end()));
279 result.size_ = new_size;
280 }
281
282 //std::memcpy(body_dest, to_insert.begin(), sizeof(T)*to_insert.size());
283 std::copy(to_insert.begin(), to_insert.end(), body_dest);
284 *this = result;
285}
286
287template<typename T>
288inline void Slice<T>::append(Arena& arena, T value) {
289 Slice result = *this;
290 if (size_ == capacity_) {
291 int new_size = size_ ? round2min8(size_)*2 : 8;
292 T* n = arena.allocate<T>(new_size);
293 //memcpy(n, begin_, size_*sizeof(T));
294 std::copy(begin_, begin_ + size_, n);
295 result.begin_ = n;
296 result.capacity_ = new_size;
297 }
298 result[result.size_++] = std::move(value);
299 *this = result;
300}
301
302template<typename T>
303inline void Slice<T>::extend(Arena& arena, Slice<T> rhs) {
304 Slice result = *this;
305 result.size_ = size_ + rhs.size();
306 if (result.size_ > capacity_) {
307 int new_size = round2min8(result.size_);
308 T* n = arena.allocate<T>(new_size);
309 //memcpy(n, begin_, size_*sizeof(T));
310 std::copy(begin_, begin_+size_, n);
311 result.begin_ = n;
312 result.capacity_ = new_size;
313 }
314 //memcpy(result.begin_ + size_, rhs.begin(), sizeof(T)*rhs.size());
315 std::copy(rhs.begin(), rhs.end(), result.begin_ + size_);
316 *this = result;
317}
318
319template<typename T>
320template<typename... Args>
321Slice<T>::Slice(Arena& arena, Args&&... args) {
322 int lens[] = {_length(args)...};
323 size_ = 0;
324 for (auto i : lens) {
325 size_ += i;
326 }
327 capacity_ = size_ ? round2min8(size_) : 0;
328 begin_ = arena.allocate<T>(capacity_);
329 T* dst_ = begin_;
330 T* unused[] = {_insert(dst_, args)...};
331 (void) unused;
332}
333