1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file thread_storage_scope.h |
22 | * \brief Extract launch parameters configuration from TVMArgs. |
23 | */ |
24 | #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ |
25 | #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ |
26 | |
27 | #include <tvm/runtime/metadata.h> |
28 | #include <tvm/runtime/packed_func.h> |
29 | |
30 | #include <string> |
31 | #include <vector> |
32 | |
33 | #include "meta_data.h" |
34 | |
35 | namespace tvm { |
36 | namespace runtime { |
37 | |
38 | /*! |
39 | * \brief Memory hierachy rank in the storage system |
40 | * \note The global rank and shared rank have one to one |
41 | * correspondence to the thread rank. |
42 | */ |
43 | enum class StorageRank { |
44 | /*! \brief global memory */ |
45 | kGlobal = 0, |
46 | /*! \brief shared memory among thread group */ |
47 | kShared = 1, |
48 | /*! |
49 | * \brief reserved for warp memory. |
50 | * This is only used by programming model. |
51 | * There is no such memory usually in GPU. |
52 | * Instead, we can simulate it by registers and shuffle. |
53 | */ |
54 | kWarp = 2, |
55 | /*! \brief thread local memory */ |
56 | kLocal = 3, |
57 | /*! \brief wmma scope memory of matrix_a */ |
58 | kWMMAMatrixA = 4, |
59 | /*! \brief wmma scope memory of matrix_b */ |
60 | kWMMAMatrixB = 5, |
61 | /*! \brief wmma scope memory of accumulator */ |
62 | kWMMAAccumulator = 6, |
63 | /*! \brief global scope texture memory */ |
64 | kTexture = 7, |
65 | /*! \brief global scope amx tmm memory */ |
66 | kAMXTMM = 8, |
67 | }; |
68 | |
69 | /*! |
70 | * \param thread_scope_rank The thread scope rank |
71 | * \return default storage rank given the thread scope |
72 | */ |
73 | inline StorageRank DefaultStorageRank(int thread_scope_rank) { |
74 | switch (thread_scope_rank) { |
75 | case -1: |
76 | return StorageRank::kGlobal; |
77 | case 0: |
78 | return StorageRank::kShared; |
79 | case 1: |
80 | return StorageRank::kLocal; |
81 | default: { |
82 | LOG(FATAL) << "unknown rank" ; |
83 | } |
84 | } |
85 | } |
86 | |
87 | /*! \brief class to represent storage scope */ |
88 | struct StorageScope { |
89 | /*! \brief The rank of the storage */ |
90 | StorageRank rank{StorageRank::kGlobal}; |
91 | /*! \brief tag for special purpose memory. */ |
92 | std::string tag; |
93 | // comparator |
94 | inline bool operator==(const StorageScope& other) const { |
95 | return rank == other.rank && tag == other.tag; |
96 | } |
97 | inline bool operator!=(const StorageScope& other) const { return !(*this == other); } |
98 | inline std::string to_string() const { |
99 | std::string ret; |
100 | switch (rank) { |
101 | case StorageRank::kGlobal: |
102 | return "global" + tag; |
103 | case StorageRank::kShared: |
104 | return "shared" + tag; |
105 | case StorageRank::kWarp: |
106 | return "warp" + tag; |
107 | case StorageRank::kLocal: |
108 | return "local" + tag; |
109 | case StorageRank::kWMMAMatrixA: |
110 | return "wmma.matrix_a" + tag; |
111 | case StorageRank::kWMMAMatrixB: |
112 | return "wmma.matrix_b" + tag; |
113 | case StorageRank::kWMMAAccumulator: |
114 | return "wmma.accumulator" + tag; |
115 | case StorageRank::kTexture: |
116 | return "texture" + tag; |
117 | default: |
118 | LOG(FATAL) << "unknown storage scope" ; |
119 | } |
120 | } |
121 | /*! |
122 | * \brief Create storage scope from string |
123 | * \param s The string to be parsed. |
124 | * \return The storage scope. |
125 | */ |
126 | static StorageScope Create(const std::string& s) { |
127 | StorageScope r; |
128 | if (s.empty()) { |
129 | r.rank = StorageRank::kGlobal; |
130 | } else if (s.compare(0, 6, "global" ) == 0) { |
131 | r.rank = StorageRank::kGlobal; |
132 | r.tag = s.substr(6, std::string::npos); |
133 | } else if (s.compare(0, 6, "shared" ) == 0) { |
134 | r.rank = StorageRank::kShared; |
135 | r.tag = s.substr(6, std::string::npos); |
136 | } else if (s.compare(0, 4, "warp" ) == 0) { |
137 | r.rank = StorageRank::kWarp; |
138 | r.tag = s.substr(4, std::string::npos); |
139 | } else if (s.compare(0, 5, "local" ) == 0) { |
140 | r.rank = StorageRank::kLocal; |
141 | r.tag = s.substr(5, std::string::npos); |
142 | } else if (s.compare(0, 13, "wmma.matrix_a" ) == 0) { |
143 | r.rank = StorageRank::kWMMAMatrixA; |
144 | r.tag = s.substr(13, std::string::npos); |
145 | } else if (s.compare(0, 13, "wmma.matrix_b" ) == 0) { |
146 | r.rank = StorageRank::kWMMAMatrixB; |
147 | r.tag = s.substr(13, std::string::npos); |
148 | } else if (s.compare(0, 16, "wmma.accumulator" ) == 0) { |
149 | r.rank = StorageRank::kWMMAAccumulator; |
150 | r.tag = s.substr(16, std::string::npos); |
151 | } else if (s.compare(0, 7, "texture" ) == 0) { |
152 | r.rank = StorageRank::kTexture; |
153 | r.tag = s.substr(7, std::string::npos); |
154 | } else if (s.compare(0, 7, "amx.tmm" ) == 0) { |
155 | r.rank = StorageRank::kAMXTMM; |
156 | r.tag = s.substr(7, std::string::npos); |
157 | } else { |
158 | LOG(FATAL) << "unknown storage scope " << s; |
159 | } |
160 | return r; |
161 | } |
162 | }; |
163 | |
164 | /*! \brief class to represent thread scope */ |
165 | struct ThreadScope { |
166 | /*! \brief The rank of thread scope */ |
167 | int rank{0}; |
168 | /*! \brief the dimension index under the rank */ |
169 | int dim_index{0}; |
170 | /*! |
171 | * \brief Create storage scope from string |
172 | * \param s The string to be parsed. |
173 | * \return The storage scope. |
174 | */ |
175 | static ThreadScope Create(const std::string& s) { |
176 | ThreadScope r; |
177 | if (s.compare(0, 7, "vthread" ) == 0 || s == "cthread" ) { |
178 | // virtual thread at the same level as local |
179 | r.rank = 1; |
180 | r.dim_index = -1; |
181 | } else if (s.compare(0, 9, "blockIdx." ) == 0) { |
182 | r.rank = 0; |
183 | r.dim_index = static_cast<int>(s[9] - 'x'); |
184 | } else if (s.compare(0, 10, "threadIdx." ) == 0) { |
185 | r.rank = 1; |
186 | r.dim_index = static_cast<int>(s[10] - 'x'); |
187 | } else { |
188 | LOG(FATAL) << "Unknown threadscope " << s; |
189 | } |
190 | return r; |
191 | } |
192 | }; |
193 | |
194 | /*! \brief workload specification */ |
195 | struct ThreadWorkLoad { |
196 | // array, first three are thread configuration. |
197 | size_t work_size[6]; |
198 | // Dynamic shared memory allocation size in bytes. |
199 | size_t dyn_shmem_size{0}; |
200 | /*! |
201 | * \param i The block dimension. |
202 | * \return i-th block dim |
203 | */ |
204 | inline size_t block_dim(size_t i) const { return work_size[i + 3]; } |
205 | /*! |
206 | * \param i The grid dimension. |
207 | * \return i-th grid dim |
208 | */ |
209 | inline size_t grid_dim(size_t i) const { return work_size[i]; } |
210 | }; |
211 | /*! \brief Launch parameters configuration */ |
212 | class LaunchParamConfig { |
213 | public: |
214 | void Init(size_t base, const std::vector<std::string>& launch_param_tags) { |
215 | base_ = base; |
216 | std::vector<bool> filled(6, false); |
217 | for (size_t i = 0; i < launch_param_tags.size(); ++i) { |
218 | const std::string& tag = launch_param_tags[i]; |
219 | if (tag == launch_param::kUseDynamicSharedMemoryTag) { |
220 | ICHECK_EQ(i, launch_param_tags.size() - 1) |
221 | << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags." ; |
222 | use_dyn_shared_memory_ = true; |
223 | } else { |
224 | ThreadScope ts = ThreadScope::Create(tag); |
225 | arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); |
226 | filled[ts.rank * 3 + ts.dim_index] = true; |
227 | } |
228 | } |
229 | work_dim_ = 1; |
230 | for (int i = 0; i < 3; ++i) { |
231 | if (filled[i] || filled[i + 3]) { |
232 | work_dim_ = i + 1; |
233 | } |
234 | } |
235 | } |
236 | // extract workload from arguments. |
237 | ThreadWorkLoad (TVMArgs x) const { |
238 | ThreadWorkLoad w; |
239 | std::fill(w.work_size, w.work_size + 6, 1); |
240 | for (size_t i = 0; i < arg_index_map_.size(); ++i) { |
241 | // Dynamic shapes can result in 0 dim size. Guard to ensure that the dim size is at least 1. |
242 | size_t size = static_cast<size_t>(x.values[base_ + i].v_int64); |
243 | if (size > 0) { |
244 | w.work_size[arg_index_map_[i]] = size; |
245 | } |
246 | } |
247 | if (use_dyn_shared_memory_) { |
248 | w.dyn_shmem_size = static_cast<size_t>(x.values[base_ + arg_index_map_.size()].v_int64); |
249 | } |
250 | return w; |
251 | } |
252 | // return the work dim |
253 | size_t work_dim() const { return work_dim_; } |
254 | |
255 | private: |
256 | /*! \brief base axis */ |
257 | size_t base_; |
258 | /*! \brief The worker dimension */ |
259 | size_t work_dim_; |
260 | /*! \brief The index mapping. */ |
261 | std::vector<uint32_t> arg_index_map_; |
262 | /*! \brief Whether or not use dynamic shared memory. */ |
263 | bool use_dyn_shared_memory_{false}; |
264 | }; |
265 | |
266 | } // namespace runtime |
267 | } // namespace tvm |
268 | |
269 | namespace std { |
270 | template <> |
271 | struct hash<::tvm::runtime::StorageScope> { |
272 | std::size_t operator()(const ::tvm::runtime::StorageScope& k) const { |
273 | return static_cast<size_t>(k.rank); |
274 | } |
275 | }; |
276 | } // namespace std |
277 | #endif // TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ |
278 | |