1 | #include <parallel_dimension_map.h> |
2 | |
3 | #include <ATen/cuda/CUDAContext.h> |
4 | #include <expr_evaluator.h> |
5 | #include <ir_utils.h> |
6 | #include <iter_visitor.h> |
7 | #include <kernel_expr_evaluator.h> |
8 | #include <lower2device.h> |
9 | |
10 | #include <sstream> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | void ParallelDimensionMap::build(Fusion* fusion) { |
18 | // Scan all TVs to build ParallelType maps |
19 | auto all_vals = fusion->usedMathVals(); |
20 | for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) { |
21 | for (auto id : tv->domain()->domain()) { |
22 | registerConstantExtent(id); |
23 | if (!isParallelTypeThread(id->getParallelType())) { |
24 | continue; |
25 | } |
26 | handleParallelDomain(id); |
27 | } |
28 | } |
29 | |
30 | // Populate the dimension map for each parallel type |
31 | for (const auto& kv : concrete_dom_map_) { |
32 | auto pt = kv.first; |
33 | const auto& concrete_dom_set = kv.second; |
34 | TORCH_INTERNAL_ASSERT(!concrete_dom_set.empty()); |
35 | if (concrete_dom_set.size() == 1) { |
36 | populateDimensionMapWithSingleCASet(pt, concrete_dom_set); |
37 | } else { |
38 | populateDimensionMapWithMultipleCASet(pt, concrete_dom_set); |
39 | } |
40 | } |
41 | |
42 | adjustMappingsForWarpPadding(); |
43 | } |
44 | |
45 | void ParallelDimensionMap::registerConstantExtent(IterDomain* id) { |
46 | if (!id->extent()->isConstScalar()) { |
47 | // Nothing to do if not constant |
48 | return; |
49 | } |
50 | |
51 | ExpressionEvaluator ee(id->fusion()); |
52 | auto extent_int = ee.evaluate(id->extent()); |
53 | TORCH_INTERNAL_ASSERT( |
54 | extent_int.has_value(), |
55 | "Extent of " , |
56 | id->toString(), |
57 | " should have been constant, but could not be evaluated at compile time." ); |
58 | |
59 | auto const_extent = extent_int->as<int64_t>(); |
60 | |
61 | // Uses index map |
62 | auto concrete_id = getCAMappedConcreteDomain(id); |
63 | |
64 | auto existing_it = constant_extent_map_.find(id); |
65 | |
66 | // Adds the constant extent to the set for the concrete domain. If |
67 | // multiple constants are found, this concrete domain has multiple |
68 | // distinctive extents, which can happen with broadcast. |
69 | if (existing_it == constant_extent_map_.end()) { |
70 | constant_extent_map_.insert({concrete_id, {const_extent}}); |
71 | } else { |
72 | existing_it->second.insert(const_extent); |
73 | } |
74 | } |
75 | |
76 | // Adds the conrecte domain of id to the mappsed set for its |
77 | // parallel type |
78 | void ParallelDimensionMap::handleParallelDomain(IterDomain* id) { |
79 | auto pt = id->getParallelType(); |
80 | TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt)); |
81 | auto concrete_id = getCAMappedConcreteDomain(id); |
82 | |
83 | auto it = concrete_dom_map_.find(pt); |
84 | if (it == concrete_dom_map_.end()) { |
85 | concrete_dom_map_.insert({pt, {concrete_id}}); |
86 | } else { |
87 | it->second.insert(concrete_id); |
88 | } |
89 | } |
90 | |
91 | void ParallelDimensionMap::populateDimensionMapWithSingleCASet( |
92 | ParallelType pt, |
93 | const std::unordered_set<IterDomain*>& dom_set) { |
94 | TORCH_INTERNAL_ASSERT(dom_set.size() == 1); |
95 | |
96 | // pt is used by only one concrete domain |
97 | auto id = *dom_set.begin(); |
98 | auto it = constant_extent_map_.find(id); |
99 | |
100 | if (it != constant_extent_map_.end()) { |
101 | TORCH_INTERNAL_ASSERT( |
102 | it->second.size() == 1, |
103 | "Only one value found mapped to parallel type " , |
104 | stringifyThread(pt), |
105 | " yet its bound to multiple extents." ); |
106 | dim_map_.insert({pt, IrBuilder::create<Int>(*(it->second.begin()))}); |
107 | exact_types_.insert(pt); |
108 | } else { |
109 | // Prefer to use blockDim/gridDim if not constant |
110 | dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); |
111 | exact_types_.insert(pt); |
112 | } |
113 | } |
114 | |
115 | void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( |
116 | ParallelType pt, |
117 | const std::unordered_set<IterDomain*>& dom_set) { |
118 | TORCH_INTERNAL_ASSERT(dom_set.size() > 1); |
119 | |
120 | bool all_equal = true; |
121 | // Use nullptr to signal it's not initialied yet |
122 | Val* known_dimension = nullptr; |
123 | // Use -1 to signal it's not initialied yet |
124 | int64_t known_const = -1; |
125 | |
126 | // Check all of concrete domains to see if they match all together. |
127 | for (auto concrete_id : dom_set) { |
128 | if (concrete_id->isBroadcast()) { |
129 | // Broadcasted concrete id's don't specify anything about shape |
130 | continue; |
131 | } |
132 | // If this concrete domain has a constant extent, check if it |
133 | // matches with the known constant extent. |
134 | auto it = constant_extent_map_.find(concrete_id); |
135 | if (it != constant_extent_map_.end()) { |
136 | const auto& const_extent_set = it->second; |
137 | // If multiple constants are detected, it's not exact. |
138 | if (const_extent_set.size() > 1) { |
139 | all_equal = false; |
140 | break; |
141 | } |
142 | auto this_const = *(const_extent_set.begin()); |
143 | // known_const is initialized to -1 |
144 | if (known_const == -1) { |
145 | known_const = this_const; |
146 | } else if (known_const == this_const) { |
147 | // Matched with previously known const. The extent of this |
148 | // domain must be equal to that's previously known. |
149 | continue; |
150 | } else { |
151 | // Unmatched. This dom_set extents may not be unique. |
152 | all_equal = false; |
153 | break; |
154 | } |
155 | } |
156 | |
157 | // At this point, it still remains undetermined whether this id |
158 | // matches with those previously looked at. Constant check failed, |
159 | // but symbolic matching may succeed. |
160 | auto this_dimension = concrete_id->extent(); |
161 | if (known_dimension == nullptr) { |
162 | // No previous dimension found yet |
163 | known_dimension = this_dimension; |
164 | } else { |
165 | if (!equalDim(known_dimension, this_dimension)) { |
166 | all_equal = false; |
167 | break; |
168 | } |
169 | } |
170 | } |
171 | |
172 | // If all_equal is still true, the dimension of this paralel type |
173 | // must be exact. |
174 | if (all_equal) { |
175 | exact_types_.insert(pt); |
176 | } |
177 | // Use the const value, if found, as its dimension |
178 | if (all_equal && known_const != -1) { |
179 | dim_map_.insert({pt, IrBuilder::create<Int>(known_const)}); |
180 | } else { |
181 | dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); |
182 | } |
183 | } |
184 | |
185 | void ParallelDimensionMap::adjustMappingsForWarpPadding() { |
186 | const auto gpu_lower = GpuLower::current(); |
187 | |
188 | // If TIDx is padded to a multiple of the warp size, mark it as |
189 | // non-exact. |
190 | |
191 | auto& warp_info = gpu_lower->getWarpPaddedParallelInfo(); |
192 | // TIDx isn't really padded if there isn't a warp reduction (this could |
193 | // change) |
194 | if (!(warp_info.is_tidx_padded && warp_info.has_warp_reduction)) { |
195 | return; |
196 | } |
197 | |
198 | const auto tidx_pt = ParallelType::TIDx; |
199 | auto warp_size = at::cuda::warp_size(); |
200 | |
201 | // If the dimension of TIDx is actually a multple of the warp size |
202 | // before padding, it can be left as exact |
203 | if (isExact(tidx_pt)) { |
204 | auto tidx_dim = dynamic_cast<Int*>(get(tidx_pt)); |
205 | if (tidx_dim && tidx_dim->isConst()) { |
206 | auto tidx_dim_val = tidx_dim->value().value(); |
207 | if (tidx_dim_val % warp_size == 0) { |
208 | // Dimension of TIDx is a multiple of the warp size |
209 | return; |
210 | } |
211 | } |
212 | // If tidx is strictly defined as blockDim.x then it must be set to a |
213 | // multiple of the warp and can be considered exact |
214 | bool tidx_def_trivial = true; |
215 | for (auto entry : concrete_dom_map_.at(tidx_pt)) { |
216 | if (!entry->isA<NamedScalar>() || |
217 | !entry->as<NamedScalar>()->sameAs( |
218 | NamedScalar::getParallelDim(tidx_pt))) { |
219 | tidx_def_trivial = false; |
220 | } |
221 | } |
222 | if (tidx_def_trivial) { |
223 | return; |
224 | } |
225 | } |
226 | |
227 | // TIDx is padded to a multiple of warp. If it's known to be a |
228 | // single warp, use the constant warp size as the dimension of |
229 | // TIDx. Otherwise, just use blockDim.x. |
230 | if (warp_info.is_tidx_single_warp) { |
231 | dim_map_.at(ParallelType::TIDx) = IrBuilder::create<Int>(warp_size); |
232 | } else { |
233 | dim_map_.at(ParallelType::TIDx) = |
234 | NamedScalar::getParallelDim(ParallelType::TIDx); |
235 | } |
236 | |
237 | // TIDx is no longer exact |
238 | exact_types_.erase(ParallelType::TIDx); |
239 | } |
240 | |
241 | Val* ParallelDimensionMap::get(ParallelType pt) const { |
242 | TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: " , pt); |
243 | auto it = dim_map_.find(pt); |
244 | if (it == dim_map_.end()) { |
245 | return nullptr; |
246 | } else { |
247 | return it->second; |
248 | } |
249 | } |
250 | |
251 | bool ParallelDimensionMap::isExact(ParallelType pt) const { |
252 | return exact_types_.find(pt) != exact_types_.end(); |
253 | } |
254 | |
255 | IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) { |
256 | return GpuLower::current()->caMap()->getConcreteMappedID( |
257 | id, IdMappingMode::EXACT); |
258 | } |
259 | |
260 | // Symbolically compares equality of two KIR vals. Comparison is done |
261 | // conservatively, so returning false does not guarantee non-equality. |
262 | bool ParallelDimensionMap::equalDim(Val* dim1, Val* dim2) { |
263 | TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr); |
264 | |
265 | if (dim1 == dim2) { |
266 | return true; |
267 | } |
268 | |
269 | // When Both are Int, they are same if both have the same constant |
270 | auto dim1_int = dynamic_cast<Int*>(dim1); |
271 | auto dim2_int = dynamic_cast<Int*>(dim2); |
272 | if (dim1_int && dim2_int) { |
273 | if (dim1_int->isConst() && dim2_int->isConst()) { |
274 | return dim1_int->value() == dim2_int->value(); |
275 | } |
276 | } |
277 | |
278 | // When both are NamedScalar, they are same if Both have the same |
279 | // name |
280 | auto dim1_ns = dynamic_cast<NamedScalar*>(dim1); |
281 | auto dim2_ns = dynamic_cast<NamedScalar*>(dim2); |
282 | if (dim1_ns && dim2_ns) { |
283 | return dim1_ns->name() == dim2_ns->name(); |
284 | } |
285 | |
286 | // Check recursively their definitions |
287 | |
288 | auto dim1_def = dim1->definition(); |
289 | auto dim2_def = dim2->definition(); |
290 | |
291 | if (dim1_def == nullptr || dim2_def == nullptr) { |
292 | return false; |
293 | } |
294 | |
295 | // If both are BinaryOp or UnaryOp, check their inputs. Since these |
296 | // Vals are IterDomain extents, UnaryOp should not occur, but |
297 | // checking shouldn't be harmful. |
298 | // TODO: |
299 | // We might be able to replace this with dim1->toInlineString() == |
300 | // dim2->toInlineString() |
301 | // If we want this less conservative we could make an "exact map" which |
302 | // could be another mode in compute at that maps all iter domains, but not |
303 | // concretized broadcast axes and only forwards through non-concretized |
304 | // broadcast axes. |
305 | if ((dim1_def->isA<BinaryOp>() && dim2_def->isA<BinaryOp>() && |
306 | (dim1_def->as<BinaryOp>()->getBinaryOpType() == |
307 | dim2_def->as<BinaryOp>()->getBinaryOpType())) || |
308 | (dim1_def->isA<UnaryOp>() && dim2_def->isA<UnaryOp>() && |
309 | (dim1_def->as<UnaryOp>()->getUnaryOpType() == |
310 | dim2_def->as<UnaryOp>()->getUnaryOpType()))) { |
311 | for (const auto i : c10::irange(dim1_def->inputs().size())) { |
312 | (void)i; // Suppress unused variable warning |
313 | if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) { |
314 | return false; |
315 | } |
316 | } |
317 | return true; |
318 | } |
319 | |
320 | return false; |
321 | } |
322 | |
323 | std::string ParallelDimensionMap::toString() const { |
324 | std::stringstream ss; |
325 | for (auto pt : kParallelTypeThreads) { |
326 | ss << pt << ": " ; |
327 | auto dim = get(pt); |
328 | if (dim != nullptr) { |
329 | ss << dim->toString(); |
330 | if (isExact(pt)) { |
331 | ss << ", exact" ; |
332 | } else { |
333 | ss << ", non-exact" ; |
334 | } |
335 | } else { |
336 | ss << "unused" ; |
337 | } |
338 | ss << "\n" ; |
339 | } |
340 | |
341 | return ss.str(); |
342 | } |
343 | |
344 | } // namespace cuda |
345 | } // namespace fuser |
346 | } // namespace jit |
347 | } // namespace torch |
348 | |