1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file thread_storage_scope.h
22 * \brief Extract launch parameters configuration from TVMArgs.
23 */
24#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
25#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
26
27#include <tvm/runtime/metadata.h>
28#include <tvm/runtime/packed_func.h>
29
30#include <string>
31#include <vector>
32
33#include "meta_data.h"
34
35namespace tvm {
36namespace runtime {
37
38/*!
39 * \brief Memory hierachy rank in the storage system
40 * \note The global rank and shared rank have one to one
41 * correspondence to the thread rank.
42 */
43enum class StorageRank {
44 /*! \brief global memory */
45 kGlobal = 0,
46 /*! \brief shared memory among thread group */
47 kShared = 1,
48 /*!
49 * \brief reserved for warp memory.
50 * This is only used by programming model.
51 * There is no such memory usually in GPU.
52 * Instead, we can simulate it by registers and shuffle.
53 */
54 kWarp = 2,
55 /*! \brief thread local memory */
56 kLocal = 3,
57 /*! \brief wmma scope memory of matrix_a */
58 kWMMAMatrixA = 4,
59 /*! \brief wmma scope memory of matrix_b */
60 kWMMAMatrixB = 5,
61 /*! \brief wmma scope memory of accumulator */
62 kWMMAAccumulator = 6,
63 /*! \brief global scope texture memory */
64 kTexture = 7,
65 /*! \brief global scope amx tmm memory */
66 kAMXTMM = 8,
67};
68
69/*!
70 * \param thread_scope_rank The thread scope rank
71 * \return default storage rank given the thread scope
72 */
73inline StorageRank DefaultStorageRank(int thread_scope_rank) {
74 switch (thread_scope_rank) {
75 case -1:
76 return StorageRank::kGlobal;
77 case 0:
78 return StorageRank::kShared;
79 case 1:
80 return StorageRank::kLocal;
81 default: {
82 LOG(FATAL) << "unknown rank";
83 }
84 }
85}
86
87/*! \brief class to represent storage scope */
88struct StorageScope {
89 /*! \brief The rank of the storage */
90 StorageRank rank{StorageRank::kGlobal};
91 /*! \brief tag for special purpose memory. */
92 std::string tag;
93 // comparator
94 inline bool operator==(const StorageScope& other) const {
95 return rank == other.rank && tag == other.tag;
96 }
97 inline bool operator!=(const StorageScope& other) const { return !(*this == other); }
98 inline std::string to_string() const {
99 std::string ret;
100 switch (rank) {
101 case StorageRank::kGlobal:
102 return "global" + tag;
103 case StorageRank::kShared:
104 return "shared" + tag;
105 case StorageRank::kWarp:
106 return "warp" + tag;
107 case StorageRank::kLocal:
108 return "local" + tag;
109 case StorageRank::kWMMAMatrixA:
110 return "wmma.matrix_a" + tag;
111 case StorageRank::kWMMAMatrixB:
112 return "wmma.matrix_b" + tag;
113 case StorageRank::kWMMAAccumulator:
114 return "wmma.accumulator" + tag;
115 case StorageRank::kTexture:
116 return "texture" + tag;
117 default:
118 LOG(FATAL) << "unknown storage scope";
119 }
120 }
121 /*!
122 * \brief Create storage scope from string
123 * \param s The string to be parsed.
124 * \return The storage scope.
125 */
126 static StorageScope Create(const std::string& s) {
127 StorageScope r;
128 if (s.empty()) {
129 r.rank = StorageRank::kGlobal;
130 } else if (s.compare(0, 6, "global") == 0) {
131 r.rank = StorageRank::kGlobal;
132 r.tag = s.substr(6, std::string::npos);
133 } else if (s.compare(0, 6, "shared") == 0) {
134 r.rank = StorageRank::kShared;
135 r.tag = s.substr(6, std::string::npos);
136 } else if (s.compare(0, 4, "warp") == 0) {
137 r.rank = StorageRank::kWarp;
138 r.tag = s.substr(4, std::string::npos);
139 } else if (s.compare(0, 5, "local") == 0) {
140 r.rank = StorageRank::kLocal;
141 r.tag = s.substr(5, std::string::npos);
142 } else if (s.compare(0, 13, "wmma.matrix_a") == 0) {
143 r.rank = StorageRank::kWMMAMatrixA;
144 r.tag = s.substr(13, std::string::npos);
145 } else if (s.compare(0, 13, "wmma.matrix_b") == 0) {
146 r.rank = StorageRank::kWMMAMatrixB;
147 r.tag = s.substr(13, std::string::npos);
148 } else if (s.compare(0, 16, "wmma.accumulator") == 0) {
149 r.rank = StorageRank::kWMMAAccumulator;
150 r.tag = s.substr(16, std::string::npos);
151 } else if (s.compare(0, 7, "texture") == 0) {
152 r.rank = StorageRank::kTexture;
153 r.tag = s.substr(7, std::string::npos);
154 } else if (s.compare(0, 7, "amx.tmm") == 0) {
155 r.rank = StorageRank::kAMXTMM;
156 r.tag = s.substr(7, std::string::npos);
157 } else {
158 LOG(FATAL) << "unknown storage scope " << s;
159 }
160 return r;
161 }
162};
163
164/*! \brief class to represent thread scope */
165struct ThreadScope {
166 /*! \brief The rank of thread scope */
167 int rank{0};
168 /*! \brief the dimension index under the rank */
169 int dim_index{0};
170 /*!
171 * \brief Create storage scope from string
172 * \param s The string to be parsed.
173 * \return The storage scope.
174 */
175 static ThreadScope Create(const std::string& s) {
176 ThreadScope r;
177 if (s.compare(0, 7, "vthread") == 0 || s == "cthread") {
178 // virtual thread at the same level as local
179 r.rank = 1;
180 r.dim_index = -1;
181 } else if (s.compare(0, 9, "blockIdx.") == 0) {
182 r.rank = 0;
183 r.dim_index = static_cast<int>(s[9] - 'x');
184 } else if (s.compare(0, 10, "threadIdx.") == 0) {
185 r.rank = 1;
186 r.dim_index = static_cast<int>(s[10] - 'x');
187 } else {
188 LOG(FATAL) << "Unknown threadscope " << s;
189 }
190 return r;
191 }
192};
193
194/*! \brief workload specification */
195struct ThreadWorkLoad {
196 // array, first three are thread configuration.
197 size_t work_size[6];
198 // Dynamic shared memory allocation size in bytes.
199 size_t dyn_shmem_size{0};
200 /*!
201 * \param i The block dimension.
202 * \return i-th block dim
203 */
204 inline size_t block_dim(size_t i) const { return work_size[i + 3]; }
205 /*!
206 * \param i The grid dimension.
207 * \return i-th grid dim
208 */
209 inline size_t grid_dim(size_t i) const { return work_size[i]; }
210};
211/*! \brief Launch parameters configuration */
212class LaunchParamConfig {
213 public:
214 void Init(size_t base, const std::vector<std::string>& launch_param_tags) {
215 base_ = base;
216 std::vector<bool> filled(6, false);
217 for (size_t i = 0; i < launch_param_tags.size(); ++i) {
218 const std::string& tag = launch_param_tags[i];
219 if (tag == launch_param::kUseDynamicSharedMemoryTag) {
220 ICHECK_EQ(i, launch_param_tags.size() - 1)
221 << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags.";
222 use_dyn_shared_memory_ = true;
223 } else {
224 ThreadScope ts = ThreadScope::Create(tag);
225 arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
226 filled[ts.rank * 3 + ts.dim_index] = true;
227 }
228 }
229 work_dim_ = 1;
230 for (int i = 0; i < 3; ++i) {
231 if (filled[i] || filled[i + 3]) {
232 work_dim_ = i + 1;
233 }
234 }
235 }
236 // extract workload from arguments.
237 ThreadWorkLoad Extract(TVMArgs x) const {
238 ThreadWorkLoad w;
239 std::fill(w.work_size, w.work_size + 6, 1);
240 for (size_t i = 0; i < arg_index_map_.size(); ++i) {
241 // Dynamic shapes can result in 0 dim size. Guard to ensure that the dim size is at least 1.
242 size_t size = static_cast<size_t>(x.values[base_ + i].v_int64);
243 if (size > 0) {
244 w.work_size[arg_index_map_[i]] = size;
245 }
246 }
247 if (use_dyn_shared_memory_) {
248 w.dyn_shmem_size = static_cast<size_t>(x.values[base_ + arg_index_map_.size()].v_int64);
249 }
250 return w;
251 }
252 // return the work dim
253 size_t work_dim() const { return work_dim_; }
254
255 private:
256 /*! \brief base axis */
257 size_t base_;
258 /*! \brief The worker dimension */
259 size_t work_dim_;
260 /*! \brief The index mapping. */
261 std::vector<uint32_t> arg_index_map_;
262 /*! \brief Whether or not use dynamic shared memory. */
263 bool use_dyn_shared_memory_{false};
264};
265
266} // namespace runtime
267} // namespace tvm
268
269namespace std {
270template <>
271struct hash<::tvm::runtime::StorageScope> {
272 std::size_t operator()(const ::tvm::runtime::StorageScope& k) const {
273 return static_cast<size_t>(k.rank);
274 }
275};
276} // namespace std
277#endif // TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
278