1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/usmp/algo/greedy.cc
22 * \brief This source contains greedy algorithms for planning
23 * memory for USMP. There are two algorithms present here :
24 * 1) greedy_by_size and 2) greedy_by_conflicts.
25 *
26 * greedy_by_size : this algorithm prioritizes placing the
27 * largest size buffer to the given pools. The BufferInfo objects
28 * are sorted based on the size and placed on each pool adhering
29 * to size_hint constraint.
30 *
31 * greedy_by_conflicts : this algorithm prioritizes placing the
32 * the most liveness conflicted buffer to the given pools. The
33 * BufferInfo objects are sorted based on the number of conflicts
34 * and placed on each pool adhering to size_hint constraint.
35 */
36
37#include <tvm/arith/analyzer.h>
38#include <tvm/runtime/device_api.h>
39#include <tvm/tir/builtin.h>
40#include <tvm/tir/function.h>
41#include <tvm/tir/stmt_functor.h>
42#include <tvm/tir/usmp/algo/greedy.h>
43#include <tvm/tir/usmp/algorithms.h>
44#include <tvm/tir/usmp/utils.h>
45
46namespace tvm {
47namespace tir {
48namespace usmp {
49namespace algo {
50
51/*!
52 * \brief Rounds up the offset to satisfy the alignement requirement
53 */
54size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
55 const int& byte_alignment) {
56 return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
57}
58
59/*!
60 * \brief A helper function check whether a offset is valid given the constraints
61 */
62bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
63 const size_t& size_bytes) {
64 Integer size_hint_bytes = -1;
65 if (const auto* p = candidate_pool.as<WorkspacePoolInfoNode>()) {
66 size_hint_bytes = p->size_hint_bytes;
67 } else if (const auto* p = candidate_pool.as<ConstantPoolInfoNode>()) {
68 size_hint_bytes = p->size_hint_bytes;
69 } else {
70 LOG(FATAL) << "Pool '" << candidate_pool->GetTypeKey() << "' is not supported";
71 }
72
73 if (size_hint_bytes == kUnrestrictedPoolSizeHint) {
74 // this means pool is not bounded
75 return true;
76 }
77 auto pool_size = static_cast<size_t>(size_hint_bytes.IntValue());
78 auto max_address = next_offset + size_bytes;
79 if (max_address <= pool_size) {
80 return true;
81 }
82 return false;
83}
84
85/*!
86 * \brief Selects a pool for placement in the given set of ordered pool candidates
87 */
88PoolInfo GreedyBase::SelectPlacementPool(
89 const BufferInfo& buf_info,
90 const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
91 // Here the pool candidates are ordered when it is consumed by the algorithm.
92 // This could be from order the user has specified. However, schedulers are
93 // welcome to change the order for performance reasons.
94 for (const auto& pool_info : buf_info->pool_candidates) {
95 if (pool_offsets.count(pool_info)) {
96 return pool_info;
97 }
98 }
99 CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when "
100 "trying to allocate the buffer : "
101 << buf_info << "\n. Please increase the size_hints for memory pools.";
102 return PoolInfo();
103}
104
105/*!
106 * \brief This is the base allocation function that works on sorted BufferInfo objects based
107 * on the greedy heuristic. The sorting algorithm has to be called before calling this.
108 */
109Map<BufferInfo, PoolAllocation> GreedyBase::PostSortAllocation(
110 const std::vector<BufferInfo>& buffer_info_vec) {
111 Map<BufferInfo, PoolAllocation> pool_allocations;
112 for (const auto& buf_info : buffer_info_vec) {
113 std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates;
114 for (const auto& pool_info : buf_info->pool_candidates) {
115 // Mark pool candidates that satisfy the size constraints.
116 if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) {
117 pool_offset_candidates[pool_info] = 0;
118 }
119 }
120
121 for (const auto& conflict_buf_info_obj : buf_info->conflicts) {
122 auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj);
123 size_t next_offset = 0;
124 // We only look at already allocated BufferInfo in-terms of conflicts.
125 if (pool_allocations.count(conflict_buf_info)) {
126 auto pool_allocation = pool_allocations[conflict_buf_info];
127 next_offset =
128 pool_allocation->byte_offset.IntValue() + conflict_buf_info->size_bytes.IntValue();
129 next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value);
130 // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid.
131 if (IsValidPlacement(pool_allocation->pool_info, next_offset,
132 buf_info->size_bytes->value)) {
133 // There could be multiple conflicting BufferInfo in the same pool.
134 // Thus, we need to make sure we pick the largest offset of them all.
135 if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) {
136 pool_offset_candidates[pool_allocation->pool_info] = next_offset;
137 }
138 } else {
139 pool_offset_candidates.erase(pool_allocation->pool_info);
140 }
141 }
142 }
143 auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates);
144 pool_allocations.Set(
145 buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
146 }
147 return pool_allocations;
148}
149
150/*!
151 * \brief This class implements Greedy by the size of BufferInfo
152 * greedy algorithm. Please refer to main documentation of the file
153 * for more details.
154 */
155class GreedySize : public GreedyBase {
156 public:
157 GreedySize() {}
158 Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
159 std::vector<BufferInfo> buffer_info_vec;
160 Map<BufferInfo, PoolAllocation> pool_allocations;
161 for (const auto& buffer_info : buffer_info_arr) {
162 buffer_info_vec.push_back(std::move(buffer_info));
163 }
164 std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
165 [](const BufferInfo& a, const BufferInfo& b) {
166 if (a->size_bytes->value == b->size_bytes->value) {
167 if (a->conflicts.size() == b->conflicts.size()) {
168 return std::string(a->name_hint->data) > std::string(b->name_hint->data);
169 } else {
170 return a->conflicts.size() > b->conflicts.size();
171 }
172 }
173 return a->size_bytes.IntValue() > b->size_bytes.IntValue();
174 });
175 return PostSortAllocation(buffer_info_vec);
176 }
177};
178
179/*!
180 * \brief This class implements Greedy by the number of conflicts of
181 * BufferInfo greedy algorithm. Please refer to main documentation
182 * of the file for more details.
183 */
184class GreedyConflicts : public GreedyBase {
185 public:
186 GreedyConflicts() {}
187 Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
188 std::vector<BufferInfo> buffer_info_vec;
189 Map<BufferInfo, PoolAllocation> pool_allocations;
190 for (const auto& buffer_info : buffer_info_arr) {
191 buffer_info_vec.push_back(std::move(buffer_info));
192 }
193 std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
194 [](const BufferInfo& a, const BufferInfo& b) {
195 if (a->conflicts.size() == b->conflicts.size()) {
196 if (a->size_bytes->value == b->size_bytes->value) {
197 return std::string(a->name_hint->data) > std::string(b->name_hint->data);
198 } else {
199 return a->size_bytes->value > b->size_bytes->value;
200 }
201 }
202 return a->conflicts.size() > b->conflicts.size();
203 });
204 return PostSortAllocation(buffer_info_vec);
205 }
206};
207
208Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr,
209 const Integer& memory_pressure) {
210 return GreedySize().PlanMemory(buffer_info_arr);
211}
212
213Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr,
214 const Integer& memory_pressure) {
215 return GreedyConflicts().PlanMemory(buffer_info_arr);
216}
217
218TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
219 .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
220 return GreedyBySize(buffer_info_arr, memory_pressure);
221 });
222
223TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts")
224 .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
225 return GreedyByConflicts(buffer_info_arr, memory_pressure);
226 });
227
228} // namespace algo
229} // namespace usmp
230} // namespace tir
231} // namespace tvm
232