1#include <parallel_dimension_map.h>
2
3#include <ATen/cuda/CUDAContext.h>
4#include <expr_evaluator.h>
5#include <ir_utils.h>
6#include <iter_visitor.h>
7#include <kernel_expr_evaluator.h>
8#include <lower2device.h>
9
10#include <sstream>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17void ParallelDimensionMap::build(Fusion* fusion) {
18 // Scan all TVs to build ParallelType maps
19 auto all_vals = fusion->usedMathVals();
20 for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
21 for (auto id : tv->domain()->domain()) {
22 registerConstantExtent(id);
23 if (!isParallelTypeThread(id->getParallelType())) {
24 continue;
25 }
26 handleParallelDomain(id);
27 }
28 }
29
30 // Populate the dimension map for each parallel type
31 for (const auto& kv : concrete_dom_map_) {
32 auto pt = kv.first;
33 const auto& concrete_dom_set = kv.second;
34 TORCH_INTERNAL_ASSERT(!concrete_dom_set.empty());
35 if (concrete_dom_set.size() == 1) {
36 populateDimensionMapWithSingleCASet(pt, concrete_dom_set);
37 } else {
38 populateDimensionMapWithMultipleCASet(pt, concrete_dom_set);
39 }
40 }
41
42 adjustMappingsForWarpPadding();
43}
44
45void ParallelDimensionMap::registerConstantExtent(IterDomain* id) {
46 if (!id->extent()->isConstScalar()) {
47 // Nothing to do if not constant
48 return;
49 }
50
51 ExpressionEvaluator ee(id->fusion());
52 auto extent_int = ee.evaluate(id->extent());
53 TORCH_INTERNAL_ASSERT(
54 extent_int.has_value(),
55 "Extent of ",
56 id->toString(),
57 " should have been constant, but could not be evaluated at compile time.");
58
59 auto const_extent = extent_int->as<int64_t>();
60
61 // Uses index map
62 auto concrete_id = getCAMappedConcreteDomain(id);
63
64 auto existing_it = constant_extent_map_.find(id);
65
66 // Adds the constant extent to the set for the concrete domain. If
67 // multiple constants are found, this concrete domain has multiple
68 // distinctive extents, which can happen with broadcast.
69 if (existing_it == constant_extent_map_.end()) {
70 constant_extent_map_.insert({concrete_id, {const_extent}});
71 } else {
72 existing_it->second.insert(const_extent);
73 }
74}
75
76// Adds the conrecte domain of id to the mappsed set for its
77// parallel type
78void ParallelDimensionMap::handleParallelDomain(IterDomain* id) {
79 auto pt = id->getParallelType();
80 TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt));
81 auto concrete_id = getCAMappedConcreteDomain(id);
82
83 auto it = concrete_dom_map_.find(pt);
84 if (it == concrete_dom_map_.end()) {
85 concrete_dom_map_.insert({pt, {concrete_id}});
86 } else {
87 it->second.insert(concrete_id);
88 }
89}
90
91void ParallelDimensionMap::populateDimensionMapWithSingleCASet(
92 ParallelType pt,
93 const std::unordered_set<IterDomain*>& dom_set) {
94 TORCH_INTERNAL_ASSERT(dom_set.size() == 1);
95
96 // pt is used by only one concrete domain
97 auto id = *dom_set.begin();
98 auto it = constant_extent_map_.find(id);
99
100 if (it != constant_extent_map_.end()) {
101 TORCH_INTERNAL_ASSERT(
102 it->second.size() == 1,
103 "Only one value found mapped to parallel type ",
104 stringifyThread(pt),
105 " yet its bound to multiple extents.");
106 dim_map_.insert({pt, IrBuilder::create<Int>(*(it->second.begin()))});
107 exact_types_.insert(pt);
108 } else {
109 // Prefer to use blockDim/gridDim if not constant
110 dim_map_.insert({pt, NamedScalar::getParallelDim(pt)});
111 exact_types_.insert(pt);
112 }
113}
114
115void ParallelDimensionMap::populateDimensionMapWithMultipleCASet(
116 ParallelType pt,
117 const std::unordered_set<IterDomain*>& dom_set) {
118 TORCH_INTERNAL_ASSERT(dom_set.size() > 1);
119
120 bool all_equal = true;
121 // Use nullptr to signal it's not initialied yet
122 Val* known_dimension = nullptr;
123 // Use -1 to signal it's not initialied yet
124 int64_t known_const = -1;
125
126 // Check all of concrete domains to see if they match all together.
127 for (auto concrete_id : dom_set) {
128 if (concrete_id->isBroadcast()) {
129 // Broadcasted concrete id's don't specify anything about shape
130 continue;
131 }
132 // If this concrete domain has a constant extent, check if it
133 // matches with the known constant extent.
134 auto it = constant_extent_map_.find(concrete_id);
135 if (it != constant_extent_map_.end()) {
136 const auto& const_extent_set = it->second;
137 // If multiple constants are detected, it's not exact.
138 if (const_extent_set.size() > 1) {
139 all_equal = false;
140 break;
141 }
142 auto this_const = *(const_extent_set.begin());
143 // known_const is initialized to -1
144 if (known_const == -1) {
145 known_const = this_const;
146 } else if (known_const == this_const) {
147 // Matched with previously known const. The extent of this
148 // domain must be equal to that's previously known.
149 continue;
150 } else {
151 // Unmatched. This dom_set extents may not be unique.
152 all_equal = false;
153 break;
154 }
155 }
156
157 // At this point, it still remains undetermined whether this id
158 // matches with those previously looked at. Constant check failed,
159 // but symbolic matching may succeed.
160 auto this_dimension = concrete_id->extent();
161 if (known_dimension == nullptr) {
162 // No previous dimension found yet
163 known_dimension = this_dimension;
164 } else {
165 if (!equalDim(known_dimension, this_dimension)) {
166 all_equal = false;
167 break;
168 }
169 }
170 }
171
172 // If all_equal is still true, the dimension of this paralel type
173 // must be exact.
174 if (all_equal) {
175 exact_types_.insert(pt);
176 }
177 // Use the const value, if found, as its dimension
178 if (all_equal && known_const != -1) {
179 dim_map_.insert({pt, IrBuilder::create<Int>(known_const)});
180 } else {
181 dim_map_.insert({pt, NamedScalar::getParallelDim(pt)});
182 }
183}
184
185void ParallelDimensionMap::adjustMappingsForWarpPadding() {
186 const auto gpu_lower = GpuLower::current();
187
188 // If TIDx is padded to a multiple of the warp size, mark it as
189 // non-exact.
190
191 auto& warp_info = gpu_lower->getWarpPaddedParallelInfo();
192 // TIDx isn't really padded if there isn't a warp reduction (this could
193 // change)
194 if (!(warp_info.is_tidx_padded && warp_info.has_warp_reduction)) {
195 return;
196 }
197
198 const auto tidx_pt = ParallelType::TIDx;
199 auto warp_size = at::cuda::warp_size();
200
201 // If the dimension of TIDx is actually a multple of the warp size
202 // before padding, it can be left as exact
203 if (isExact(tidx_pt)) {
204 auto tidx_dim = dynamic_cast<Int*>(get(tidx_pt));
205 if (tidx_dim && tidx_dim->isConst()) {
206 auto tidx_dim_val = tidx_dim->value().value();
207 if (tidx_dim_val % warp_size == 0) {
208 // Dimension of TIDx is a multiple of the warp size
209 return;
210 }
211 }
212 // If tidx is strictly defined as blockDim.x then it must be set to a
213 // multiple of the warp and can be considered exact
214 bool tidx_def_trivial = true;
215 for (auto entry : concrete_dom_map_.at(tidx_pt)) {
216 if (!entry->isA<NamedScalar>() ||
217 !entry->as<NamedScalar>()->sameAs(
218 NamedScalar::getParallelDim(tidx_pt))) {
219 tidx_def_trivial = false;
220 }
221 }
222 if (tidx_def_trivial) {
223 return;
224 }
225 }
226
227 // TIDx is padded to a multiple of warp. If it's known to be a
228 // single warp, use the constant warp size as the dimension of
229 // TIDx. Otherwise, just use blockDim.x.
230 if (warp_info.is_tidx_single_warp) {
231 dim_map_.at(ParallelType::TIDx) = IrBuilder::create<Int>(warp_size);
232 } else {
233 dim_map_.at(ParallelType::TIDx) =
234 NamedScalar::getParallelDim(ParallelType::TIDx);
235 }
236
237 // TIDx is no longer exact
238 exact_types_.erase(ParallelType::TIDx);
239}
240
241Val* ParallelDimensionMap::get(ParallelType pt) const {
242 TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
243 auto it = dim_map_.find(pt);
244 if (it == dim_map_.end()) {
245 return nullptr;
246 } else {
247 return it->second;
248 }
249}
250
251bool ParallelDimensionMap::isExact(ParallelType pt) const {
252 return exact_types_.find(pt) != exact_types_.end();
253}
254
255IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) {
256 return GpuLower::current()->caMap()->getConcreteMappedID(
257 id, IdMappingMode::EXACT);
258}
259
260// Symbolically compares equality of two KIR vals. Comparison is done
261// conservatively, so returning false does not guarantee non-equality.
262bool ParallelDimensionMap::equalDim(Val* dim1, Val* dim2) {
263 TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr);
264
265 if (dim1 == dim2) {
266 return true;
267 }
268
269 // When Both are Int, they are same if both have the same constant
270 auto dim1_int = dynamic_cast<Int*>(dim1);
271 auto dim2_int = dynamic_cast<Int*>(dim2);
272 if (dim1_int && dim2_int) {
273 if (dim1_int->isConst() && dim2_int->isConst()) {
274 return dim1_int->value() == dim2_int->value();
275 }
276 }
277
278 // When both are NamedScalar, they are same if Both have the same
279 // name
280 auto dim1_ns = dynamic_cast<NamedScalar*>(dim1);
281 auto dim2_ns = dynamic_cast<NamedScalar*>(dim2);
282 if (dim1_ns && dim2_ns) {
283 return dim1_ns->name() == dim2_ns->name();
284 }
285
286 // Check recursively their definitions
287
288 auto dim1_def = dim1->definition();
289 auto dim2_def = dim2->definition();
290
291 if (dim1_def == nullptr || dim2_def == nullptr) {
292 return false;
293 }
294
295 // If both are BinaryOp or UnaryOp, check their inputs. Since these
296 // Vals are IterDomain extents, UnaryOp should not occur, but
297 // checking shouldn't be harmful.
298 // TODO:
299 // We might be able to replace this with dim1->toInlineString() ==
300 // dim2->toInlineString()
301 // If we want this less conservative we could make an "exact map" which
302 // could be another mode in compute at that maps all iter domains, but not
303 // concretized broadcast axes and only forwards through non-concretized
304 // broadcast axes.
305 if ((dim1_def->isA<BinaryOp>() && dim2_def->isA<BinaryOp>() &&
306 (dim1_def->as<BinaryOp>()->getBinaryOpType() ==
307 dim2_def->as<BinaryOp>()->getBinaryOpType())) ||
308 (dim1_def->isA<UnaryOp>() && dim2_def->isA<UnaryOp>() &&
309 (dim1_def->as<UnaryOp>()->getUnaryOpType() ==
310 dim2_def->as<UnaryOp>()->getUnaryOpType()))) {
311 for (const auto i : c10::irange(dim1_def->inputs().size())) {
312 (void)i; // Suppress unused variable warning
313 if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) {
314 return false;
315 }
316 }
317 return true;
318 }
319
320 return false;
321}
322
323std::string ParallelDimensionMap::toString() const {
324 std::stringstream ss;
325 for (auto pt : kParallelTypeThreads) {
326 ss << pt << ": ";
327 auto dim = get(pt);
328 if (dim != nullptr) {
329 ss << dim->toString();
330 if (isExact(pt)) {
331 ss << ", exact";
332 } else {
333 ss << ", non-exact";
334 }
335 } else {
336 ss << "unused";
337 }
338 ss << "\n";
339 }
340
341 return ss.str();
342}
343
344} // namespace cuda
345} // namespace fuser
346} // namespace jit
347} // namespace torch
348