1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/usmp/algo/hill_climb.cc |
22 | * \brief Implement greedy by size memory planning algorithm |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/device_api.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/function.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/usmp/algo/greedy.h> |
30 | #include <tvm/tir/usmp/utils.h> |
31 | |
32 | #include <algorithm> |
33 | #include <numeric> |
34 | #include <sstream> |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | namespace usmp { |
39 | namespace algo { |
40 | |
41 | /* |
42 | * Simulated annealing / Hill climb |
43 | * |
44 | * Works by continiously invoking 'greedy-by-size' allocation, |
45 | * assessing the result, and introducing permutations to the allocation |
46 | * order which hopefully will led to more 'compact' memory allocation. |
47 | * Do not forget to use srand for repeatable results |
48 | */ |
49 | class HillClimbAllocator : public GreedyBase { |
50 | private: |
51 | size_t memory_pressure_ = 0; |
52 | |
53 | public: |
54 | explicit HillClimbAllocator(size_t memory_pressure) |
55 | : GreedyBase(), memory_pressure_(memory_pressure) {} |
56 | |
57 | protected: |
58 | using alloc_map_t = std::unordered_map<const BufferInfoNode*, PoolAllocation>; |
59 | |
60 | /* |
61 | * Initial sorting routine |
62 | */ |
63 | template <typename T> |
64 | void sort_vector(std::vector<T>* buffer_info_vec) { |
65 | std::sort(buffer_info_vec->begin(), buffer_info_vec->end(), [](const T& a, const T& b) { |
66 | if (a->size_bytes->value == b->size_bytes->value) { |
67 | if (a->conflicts.size() == b->conflicts.size()) { |
68 | return std::string(a->name_hint->data) > std::string(b->name_hint->data); |
69 | } else { |
70 | return a->conflicts.size() > b->conflicts.size(); |
71 | } |
72 | } |
73 | return a->size_bytes->value > b->size_bytes->value; |
74 | }); |
75 | } |
76 | |
77 | /* |
78 | * HillClimb's version of greedy allocation |
79 | * \param buffer_info_vec - buffers in specific order for allocation |
80 | */ |
81 | alloc_map_t greedy(const std::vector<BufferInfo>& buffer_info_vec, bool* could_not_fit) { |
82 | alloc_map_t pool_allocations(buffer_info_vec.size()); |
83 | for (const auto& buf_info : buffer_info_vec) { |
84 | std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates; |
85 | |
86 | // check whether we can fit the buffer into the empty pool candidate |
87 | for (const auto& pool_info : buf_info->pool_candidates) { |
88 | if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { |
89 | pool_offset_candidates[pool_info] = 0; |
90 | } |
91 | } |
92 | // select conflicting buffers which have already been allocated |
93 | std::vector<const BufferInfoNode*> buf_conf; |
94 | for (const auto& conflict_buf_info_obj : buf_info->conflicts) { |
95 | const BufferInfoNode* conflict_buf_info = conflict_buf_info_obj.as<BufferInfoNode>(); |
96 | if (pool_allocations.end() != pool_allocations.find(conflict_buf_info)) { |
97 | buf_conf.push_back(conflict_buf_info); |
98 | } |
99 | } |
100 | |
101 | // extra sorting for pool offsets |
102 | std::sort(buf_conf.begin(), buf_conf.end(), |
103 | [&pool_allocations](const auto* a, const auto* b) { |
104 | return pool_allocations[a]->byte_offset->value < |
105 | pool_allocations[b]->byte_offset->value; |
106 | }); |
107 | |
108 | for (const auto* conflict_buf_info : buf_conf) { |
109 | size_t next_offset = 0; |
110 | auto pool_allocation = pool_allocations[conflict_buf_info]; |
111 | if (!pool_offset_candidates.count(pool_allocation->pool_info)) { |
112 | continue; |
113 | } |
114 | |
115 | next_offset = |
116 | pool_allocation->byte_offset.IntValue() + conflict_buf_info->size_bytes.IntValue(); |
117 | next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); |
118 | |
119 | if (IsValidPlacement(pool_allocation->pool_info, next_offset, |
120 | buf_info->size_bytes->value)) { |
121 | // extra check whether the previous attempt to fit the buffer is clashing with the current |
122 | // conflict |
123 | if (next_offset > pool_offset_candidates[pool_allocation->pool_info] && |
124 | pool_offset_candidates[pool_allocation->pool_info] + |
125 | static_cast<size_t>(buf_info->size_bytes.IntValue()) > |
126 | static_cast<size_t>(pool_allocation->byte_offset.IntValue())) { |
127 | pool_offset_candidates[pool_allocation->pool_info] = next_offset; |
128 | } |
129 | } else { |
130 | pool_offset_candidates.erase(pool_allocation->pool_info); |
131 | } |
132 | } |
133 | auto selected_pool = NullValue<PoolInfo>(); |
134 | for (const auto& pi : buf_info->pool_candidates) { |
135 | if (pool_offset_candidates.count(pi)) { |
136 | selected_pool = pi; |
137 | break; |
138 | } |
139 | } |
140 | |
141 | if (selected_pool.same_as(NullValue<PoolInfo>())) { |
142 | *could_not_fit = true; |
143 | } |
144 | |
145 | pool_allocations[buf_info.as<BufferInfoNode>()] = |
146 | PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])); |
147 | } |
148 | return pool_allocations; |
149 | } |
150 | |
151 | /* |
152 | * Finds highest allocated memory address for each pool |
153 | */ |
154 | std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> find_highest( |
155 | alloc_map_t* pool_allocations) { |
156 | std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_sizes; |
157 | for (const auto& it : *pool_allocations) { |
158 | const BufferInfoNode* buf = it.first; |
159 | const PoolAllocation& pa = it.second; |
160 | if (pa->pool_info.same_as(NullValue<PoolInfo>())) { |
161 | continue; |
162 | } |
163 | size_t high_sz = pa->byte_offset.IntValue() + buf->size_bytes.IntValue(); |
164 | if (pool_sizes[pa->pool_info] <= high_sz) { |
165 | pool_sizes[pa->pool_info] = high_sz; |
166 | } |
167 | } |
168 | return pool_sizes; |
169 | } |
170 | |
171 | /* |
172 | * Collects lists of first and secind level neigbors for provided buf. |
173 | * First level are the immediate neighbors of the buf and |
174 | * second level are the immediate neighbors of the first level nodes |
175 | */ |
176 | template <typename TPos> |
177 | void collect_neighbor_lists(const BufferInfoNode* buf, |
178 | std::vector<const BufferInfoNode*>* first_level, |
179 | std::vector<const BufferInfoNode*>* second_level, const TPos& _pos) { |
180 | auto buf_pos = _pos(buf); |
181 | for (const auto& c1 : buf->conflicts) { |
182 | const auto* c1_buf = c1.as<BufferInfoNode>(); |
183 | int c1_pos = _pos(c1_buf); |
184 | if (buf_pos > c1_pos) { |
185 | first_level->push_back(c1_buf); |
186 | } |
187 | int c2_pos = -1; |
188 | for (const auto& c2 : c1_buf->conflicts) { |
189 | const auto c2_buf = c2.as<BufferInfoNode>(); |
190 | if (c1_pos > (c2_pos = _pos(c2_buf))) { |
191 | second_level->push_back(c2_buf); |
192 | } |
193 | } |
194 | } |
195 | } |
196 | |
197 | public: |
198 | Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) { |
199 | // rand_r does not exist on Windows platform |
200 | #if defined(__linux__) || defined(__ANDROID__) |
201 | unsigned int _seedp = 0; |
202 | #define rnd_func() rand_r(&_seedp) |
203 | #else |
204 | #define rnd_func() rand() |
205 | #endif |
206 | Map<BufferInfo, PoolAllocation> result; |
207 | if (!buffer_info_arr.size()) { |
208 | return result; |
209 | } |
210 | std::vector<BufferInfo> buffer_info_vec; |
211 | for (const auto& buffer_info : buffer_info_arr) { |
212 | ICHECK(buffer_info->pool_candidates.size()) |
213 | << "Cannot process buffer \"" << buffer_info->name_hint << "\" with no pool candidates" ; |
214 | buffer_info_vec.push_back(std::move(buffer_info)); |
215 | } |
216 | sort_vector<BufferInfo>(&buffer_info_vec); |
217 | |
218 | // populate positional index map |
219 | std::unordered_map<const BufferInfoNode*, int> _pos_map; |
220 | for (size_t index = 0; index < buffer_info_vec.size(); ++index) { |
221 | _pos_map[buffer_info_vec[index].as<BufferInfoNode>()] = index; |
222 | } |
223 | |
224 | size_t total_size = 0; |
225 | int attempts = 0; |
226 | |
227 | int swap_i1 = -1; |
228 | int swap_i2 = -1; |
229 | size_t desired_bytes_ = memory_pressure_; |
230 | constexpr auto _max_attempts = 500; |
231 | alloc_map_t rollback_pool_allocations; |
232 | alloc_map_t result_pool_allocations; |
233 | alloc_map_t pool_allocations; |
234 | |
235 | auto swap_buffers = [&buffer_info_vec, &_pos_map](int i1, int i2) { |
236 | if (i1 == i2) return; |
237 | auto b1 = buffer_info_vec[i1]; |
238 | auto b2 = buffer_info_vec[i2]; |
239 | buffer_info_vec[i1] = b2; |
240 | buffer_info_vec[i2] = b1; |
241 | |
242 | _pos_map[b1.as<BufferInfoNode>()] = i2; |
243 | _pos_map[b2.as<BufferInfoNode>()] = i1; |
244 | }; |
245 | |
246 | auto _pos = [&_pos_map](const auto* e) { |
247 | auto it = _pos_map.find(e); |
248 | if (it != _pos_map.end()) { |
249 | return it->second; |
250 | } |
251 | LOG(FATAL) << "node is not indexed in the _pos_map" ; |
252 | }; |
253 | |
254 | for (; attempts < _max_attempts; ++attempts) { |
255 | rollback_pool_allocations = std::move(pool_allocations); |
256 | bool could_not_fit = false; |
257 | pool_allocations = std::move(greedy(buffer_info_vec, &could_not_fit)); |
258 | |
259 | // estimate result buffers |
260 | std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_sizes = |
261 | find_highest(&pool_allocations); |
262 | if (!pool_sizes.size()) { |
263 | CHECK(false) << "TVM USMP Error: Please increase the size_hints for memory pools." ; |
264 | } |
265 | |
266 | // calculate summary |
267 | size_t total = 0; |
268 | for (const auto& el : pool_sizes) { |
269 | total += el.second; |
270 | } |
271 | // accept/reject result heuristic |
272 | if (!total_size || /* first run */ |
273 | (!could_not_fit && |
274 | (total_size > total || /* always accept if better or with some probability */ |
275 | rnd_func() % 100 < static_cast<int>(50 * (total - total_size) / total / attempts)))) { |
276 | // remember winning combination |
277 | result_pool_allocations = pool_allocations; |
278 | if (!could_not_fit) { |
279 | total_size = total; |
280 | // reached desired size |
281 | if (total_size <= desired_bytes_) { |
282 | break; |
283 | } |
284 | } |
285 | |
286 | } else { |
287 | // rollback |
288 | swap_buffers(swap_i2, swap_i1); |
289 | pool_allocations = std::move(rollback_pool_allocations); |
290 | pool_sizes = find_highest(&pool_allocations); |
291 | } |
292 | |
293 | std::vector<const BufferInfoNode*> max_pool_buf; |
294 | |
295 | for (const auto& it : pool_allocations) { |
296 | const auto* buf = it.first; |
297 | const auto pa = it.second; |
298 | if (pa->pool_info.same_as(NullValue<PoolInfo>())) { |
299 | continue; |
300 | } |
301 | size_t high_sz = pa->byte_offset.IntValue() + buf->size_bytes.IntValue(); |
302 | if (pool_sizes[pa->pool_info] == high_sz) { |
303 | max_pool_buf.push_back(buf); |
304 | } |
305 | } |
306 | if (!max_pool_buf.size()) { |
307 | CHECK(false) << "TVM USMP Error: Please increase the size_hints for memory pools." ; |
308 | } |
309 | sort(max_pool_buf.begin(), max_pool_buf.end(), |
310 | [&_pos](const auto* a, const auto* b) { return _pos(a) < _pos(b); }); |
311 | // pick highest |
312 | const BufferInfoNode* node = max_pool_buf[rnd_func() % max_pool_buf.size()]; |
313 | std::vector<const BufferInfoNode*> first_level; |
314 | std::vector<const BufferInfoNode*> second_level; |
315 | collect_neighbor_lists(node, &first_level, &second_level, _pos); |
316 | sort(first_level.begin(), first_level.end(), |
317 | [&_pos](const auto* a, const auto* b) { return _pos(a) < _pos(b); }); |
318 | sort(second_level.begin(), second_level.end(), |
319 | [&_pos](const auto* a, const auto* b) { return _pos(a) < _pos(b); }); |
320 | |
321 | // retry if no first level neightbors were collected |
322 | if (!first_level.size()) { |
323 | continue; |
324 | } |
325 | |
326 | // pick the buffers |
327 | const BufferInfoNode* swap_buf1 = first_level[rnd_func() % first_level.size()]; |
328 | const BufferInfoNode* swap_buf2 = swap_buf1; |
329 | while (swap_buf2 == swap_buf1) { |
330 | swap_buf2 = second_level.size() && (!first_level.size() || (rnd_func() % 100 > 25)) |
331 | ? second_level[rnd_func() % second_level.size()] |
332 | : first_level[rnd_func() % first_level.size()]; |
333 | |
334 | if (second_level.size() < 2 && first_level.size() < 2) break; |
335 | } |
336 | if (swap_buf1 == swap_buf2) { |
337 | continue; |
338 | } |
339 | |
340 | swap_i1 = _pos(swap_buf1); |
341 | swap_i2 = _pos(swap_buf2); |
342 | // do swap |
343 | swap_buffers(swap_i1, swap_i2); |
344 | } |
345 | |
346 | // return winning combination |
347 | for (auto it : result_pool_allocations) { |
348 | // post-check that everything was fit |
349 | const BufferInfoNode* buf = it.first; |
350 | const PoolAllocation& pa = it.second; |
351 | if (NullValue<PoolInfo>().same_as(pa->pool_info) || |
352 | !IsValidPlacement(pa->pool_info, pa->byte_offset->value, buf->size_bytes->value)) { |
353 | std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> m = {}; |
354 | SelectPlacementPool(GetRef<BufferInfo>(buf), m); |
355 | } |
356 | result.Set(GetRef<BufferInfo>(it.first), it.second); |
357 | } |
358 | return result; |
359 | } |
360 | }; |
361 | |
362 | Map<BufferInfo, PoolAllocation> HillClimb(const Array<BufferInfo>& buffer_info_arr, |
363 | const Integer& memory_pressure) { |
364 | return HillClimbAllocator(memory_pressure.IntValue()).PlanMemory(buffer_info_arr); |
365 | } |
366 | |
367 | TVM_REGISTER_GLOBAL("tir.usmp.algo.hill_climb" ) |
368 | .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) { |
369 | return HillClimb(buffer_info_arr, memory_pressure); |
370 | }); |
371 | |
372 | } // namespace algo |
373 | } // namespace usmp |
374 | } // namespace tir |
375 | } // namespace tvm |
376 | |