1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/usmp/algo/greedy.cc |
22 | * \brief This source contains greedy algorithms for planning |
23 | * memory for USMP. There are two algorithms present here : |
24 | * 1) greedy_by_size and 2) greedy_by_conflicts. |
25 | * |
26 | * greedy_by_size : this algorithm prioritizes placing the |
27 | * largest size buffer to the given pools. The BufferInfo objects |
28 | * are sorted based on the size and placed on each pool adhering |
29 | * to size_hint constraint. |
30 | * |
31 | * greedy_by_conflicts : this algorithm prioritizes placing the |
32 | * the most liveness conflicted buffer to the given pools. The |
33 | * BufferInfo objects are sorted based on the number of conflicts |
34 | * and placed on each pool adhering to size_hint constraint. |
35 | */ |
36 | |
37 | #include <tvm/arith/analyzer.h> |
38 | #include <tvm/runtime/device_api.h> |
39 | #include <tvm/tir/builtin.h> |
40 | #include <tvm/tir/function.h> |
41 | #include <tvm/tir/stmt_functor.h> |
42 | #include <tvm/tir/usmp/algo/greedy.h> |
43 | #include <tvm/tir/usmp/algorithms.h> |
44 | #include <tvm/tir/usmp/utils.h> |
45 | |
46 | namespace tvm { |
47 | namespace tir { |
48 | namespace usmp { |
49 | namespace algo { |
50 | |
51 | /*! |
52 | * \brief Rounds up the offset to satisfy the alignement requirement |
53 | */ |
54 | size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, |
55 | const int& byte_alignment) { |
56 | return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; |
57 | } |
58 | |
59 | /*! |
60 | * \brief A helper function check whether a offset is valid given the constraints |
61 | */ |
62 | bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, |
63 | const size_t& size_bytes) { |
64 | Integer size_hint_bytes = -1; |
65 | if (const auto* p = candidate_pool.as<WorkspacePoolInfoNode>()) { |
66 | size_hint_bytes = p->size_hint_bytes; |
67 | } else if (const auto* p = candidate_pool.as<ConstantPoolInfoNode>()) { |
68 | size_hint_bytes = p->size_hint_bytes; |
69 | } else { |
70 | LOG(FATAL) << "Pool '" << candidate_pool->GetTypeKey() << "' is not supported" ; |
71 | } |
72 | |
73 | if (size_hint_bytes == kUnrestrictedPoolSizeHint) { |
74 | // this means pool is not bounded |
75 | return true; |
76 | } |
77 | auto pool_size = static_cast<size_t>(size_hint_bytes.IntValue()); |
78 | auto max_address = next_offset + size_bytes; |
79 | if (max_address <= pool_size) { |
80 | return true; |
81 | } |
82 | return false; |
83 | } |
84 | |
85 | /*! |
86 | * \brief Selects a pool for placement in the given set of ordered pool candidates |
87 | */ |
88 | PoolInfo GreedyBase::SelectPlacementPool( |
89 | const BufferInfo& buf_info, |
90 | const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) { |
91 | // Here the pool candidates are ordered when it is consumed by the algorithm. |
92 | // This could be from order the user has specified. However, schedulers are |
93 | // welcome to change the order for performance reasons. |
94 | for (const auto& pool_info : buf_info->pool_candidates) { |
95 | if (pool_offsets.count(pool_info)) { |
96 | return pool_info; |
97 | } |
98 | } |
99 | CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when " |
100 | "trying to allocate the buffer : " |
101 | << buf_info << "\n. Please increase the size_hints for memory pools." ; |
102 | return PoolInfo(); |
103 | } |
104 | |
105 | /*! |
106 | * \brief This is the base allocation function that works on sorted BufferInfo objects based |
107 | * on the greedy heuristic. The sorting algorithm has to be called before calling this. |
108 | */ |
109 | Map<BufferInfo, PoolAllocation> GreedyBase::PostSortAllocation( |
110 | const std::vector<BufferInfo>& buffer_info_vec) { |
111 | Map<BufferInfo, PoolAllocation> pool_allocations; |
112 | for (const auto& buf_info : buffer_info_vec) { |
113 | std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates; |
114 | for (const auto& pool_info : buf_info->pool_candidates) { |
115 | // Mark pool candidates that satisfy the size constraints. |
116 | if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { |
117 | pool_offset_candidates[pool_info] = 0; |
118 | } |
119 | } |
120 | |
121 | for (const auto& conflict_buf_info_obj : buf_info->conflicts) { |
122 | auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj); |
123 | size_t next_offset = 0; |
124 | // We only look at already allocated BufferInfo in-terms of conflicts. |
125 | if (pool_allocations.count(conflict_buf_info)) { |
126 | auto pool_allocation = pool_allocations[conflict_buf_info]; |
127 | next_offset = |
128 | pool_allocation->byte_offset.IntValue() + conflict_buf_info->size_bytes.IntValue(); |
129 | next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); |
130 | // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid. |
131 | if (IsValidPlacement(pool_allocation->pool_info, next_offset, |
132 | buf_info->size_bytes->value)) { |
133 | // There could be multiple conflicting BufferInfo in the same pool. |
134 | // Thus, we need to make sure we pick the largest offset of them all. |
135 | if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { |
136 | pool_offset_candidates[pool_allocation->pool_info] = next_offset; |
137 | } |
138 | } else { |
139 | pool_offset_candidates.erase(pool_allocation->pool_info); |
140 | } |
141 | } |
142 | } |
143 | auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); |
144 | pool_allocations.Set( |
145 | buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); |
146 | } |
147 | return pool_allocations; |
148 | } |
149 | |
150 | /*! |
151 | * \brief This class implements Greedy by the size of BufferInfo |
152 | * greedy algorithm. Please refer to main documentation of the file |
153 | * for more details. |
154 | */ |
155 | class GreedySize : public GreedyBase { |
156 | public: |
157 | GreedySize() {} |
158 | Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) { |
159 | std::vector<BufferInfo> buffer_info_vec; |
160 | Map<BufferInfo, PoolAllocation> pool_allocations; |
161 | for (const auto& buffer_info : buffer_info_arr) { |
162 | buffer_info_vec.push_back(std::move(buffer_info)); |
163 | } |
164 | std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), |
165 | [](const BufferInfo& a, const BufferInfo& b) { |
166 | if (a->size_bytes->value == b->size_bytes->value) { |
167 | if (a->conflicts.size() == b->conflicts.size()) { |
168 | return std::string(a->name_hint->data) > std::string(b->name_hint->data); |
169 | } else { |
170 | return a->conflicts.size() > b->conflicts.size(); |
171 | } |
172 | } |
173 | return a->size_bytes.IntValue() > b->size_bytes.IntValue(); |
174 | }); |
175 | return PostSortAllocation(buffer_info_vec); |
176 | } |
177 | }; |
178 | |
179 | /*! |
180 | * \brief This class implements Greedy by the number of conflicts of |
181 | * BufferInfo greedy algorithm. Please refer to main documentation |
182 | * of the file for more details. |
183 | */ |
184 | class GreedyConflicts : public GreedyBase { |
185 | public: |
186 | GreedyConflicts() {} |
187 | Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) { |
188 | std::vector<BufferInfo> buffer_info_vec; |
189 | Map<BufferInfo, PoolAllocation> pool_allocations; |
190 | for (const auto& buffer_info : buffer_info_arr) { |
191 | buffer_info_vec.push_back(std::move(buffer_info)); |
192 | } |
193 | std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), |
194 | [](const BufferInfo& a, const BufferInfo& b) { |
195 | if (a->conflicts.size() == b->conflicts.size()) { |
196 | if (a->size_bytes->value == b->size_bytes->value) { |
197 | return std::string(a->name_hint->data) > std::string(b->name_hint->data); |
198 | } else { |
199 | return a->size_bytes->value > b->size_bytes->value; |
200 | } |
201 | } |
202 | return a->conflicts.size() > b->conflicts.size(); |
203 | }); |
204 | return PostSortAllocation(buffer_info_vec); |
205 | } |
206 | }; |
207 | |
208 | Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr, |
209 | const Integer& memory_pressure) { |
210 | return GreedySize().PlanMemory(buffer_info_arr); |
211 | } |
212 | |
213 | Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr, |
214 | const Integer& memory_pressure) { |
215 | return GreedyConflicts().PlanMemory(buffer_info_arr); |
216 | } |
217 | |
218 | TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size" ) |
219 | .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) { |
220 | return GreedyBySize(buffer_info_arr, memory_pressure); |
221 | }); |
222 | |
223 | TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts" ) |
224 | .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) { |
225 | return GreedyByConflicts(buffer_info_arr, memory_pressure); |
226 | }); |
227 | |
228 | } // namespace algo |
229 | } // namespace usmp |
230 | } // namespace tir |
231 | } // namespace tvm |
232 | |