1 | #include <fusion.h> |
2 | #include <ir_builder.h> |
3 | #include <ir_cloner.h> |
4 | #include <kernel.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | namespace fuser { |
9 | namespace cuda { |
10 | |
11 | //! Clone an IR node, forwarding the arguments to the IrCloner constructor. |
12 | template <class T> |
13 | T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { |
14 | TORCH_INTERNAL_ASSERT( |
15 | ir_cloner != nullptr, |
16 | "Cannot use create when a cloner object is set. Use clone." ); |
17 | |
18 | TORCH_INTERNAL_ASSERT( |
19 | ir_cloner->container() != nullptr, |
20 | "Cloner doesn't have a valid container to store cloned object." ); |
21 | |
22 | T* dest = new T(src, ir_cloner); |
23 | const Statement* src_stmt = dynamic_cast<const Statement*>(src); |
24 | Statement* dest_stmt = dynamic_cast<Statement*>(dest); |
25 | |
26 | auto dest_container = ir_cloner->container(); |
27 | auto src_container = src_stmt->container(); |
28 | |
29 | dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); |
30 | |
31 | if (src_container != dest_container) { |
32 | dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); |
33 | } |
34 | |
35 | ir_cloner->registerClone(src_stmt, dest_stmt); |
36 | |
37 | return dest; |
38 | } |
39 | |
40 | #define IR_BUILDER_INSTANTIATE(T) \ |
41 | template T* IrBuilder::clone(const T* src, IrCloner* ir_cloner); |
42 | |
43 | // Vals |
44 | IR_BUILDER_INSTANTIATE(IterDomain) |
45 | IR_BUILDER_INSTANTIATE(TensorDomain) |
46 | IR_BUILDER_INSTANTIATE(TensorView) |
47 | IR_BUILDER_INSTANTIATE(Bool) |
48 | IR_BUILDER_INSTANTIATE(Double) |
49 | IR_BUILDER_INSTANTIATE(Int) |
50 | IR_BUILDER_INSTANTIATE(ComplexDouble) |
51 | IR_BUILDER_INSTANTIATE(NamedScalar) |
52 | |
53 | // Exprs |
54 | IR_BUILDER_INSTANTIATE(Split) |
55 | IR_BUILDER_INSTANTIATE(Merge) |
56 | IR_BUILDER_INSTANTIATE(Swizzle2D) |
57 | IR_BUILDER_INSTANTIATE(TransposeOp) |
58 | IR_BUILDER_INSTANTIATE(ExpandOp) |
59 | IR_BUILDER_INSTANTIATE(ShiftOp) |
60 | IR_BUILDER_INSTANTIATE(GatherOp) |
61 | IR_BUILDER_INSTANTIATE(ViewAsScalar) |
62 | IR_BUILDER_INSTANTIATE(ViewOp) |
63 | IR_BUILDER_INSTANTIATE(FullOp) |
64 | IR_BUILDER_INSTANTIATE(ARangeOp) |
65 | IR_BUILDER_INSTANTIATE(EyeOp) |
66 | IR_BUILDER_INSTANTIATE(UnaryOp) |
67 | IR_BUILDER_INSTANTIATE(BinaryOp) |
68 | IR_BUILDER_INSTANTIATE(TernaryOp) |
69 | IR_BUILDER_INSTANTIATE(RNGOp) |
70 | IR_BUILDER_INSTANTIATE(ReductionOp) |
71 | IR_BUILDER_INSTANTIATE(GroupedReductionOp) |
72 | IR_BUILDER_INSTANTIATE(WelfordOp) |
73 | IR_BUILDER_INSTANTIATE(LoadStoreOp) |
74 | IR_BUILDER_INSTANTIATE(MmaOp) |
75 | IR_BUILDER_INSTANTIATE(BroadcastOp) |
76 | |
77 | Val* IrBuilder::newResult(DataType dtype) { |
78 | switch (dtype) { |
79 | case DataType::Bool: |
80 | return IrBuilder::create<Bool>(c10::nullopt); |
81 | case DataType::Double: |
82 | return IrBuilder::create<Double>(c10::nullopt); |
83 | case DataType::Int: |
84 | return IrBuilder::create<Int>(c10::nullopt); |
85 | default: |
86 | TORCH_CHECK(false, "Unexpected data type" ); |
87 | } |
88 | } |
89 | |
90 | Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { |
91 | TORCH_CHECK( |
92 | lhs != nullptr && rhs != nullptr, |
93 | "Either lhs or rhs is a nullptr in newArithmeticExpr." ); |
94 | TORCH_CHECK( |
95 | lhs->dtype() == rhs->dtype(), |
96 | "Incompatible operand types: " , |
97 | lhs->dtype(), |
98 | " and " , |
99 | rhs->dtype()); |
100 | auto result = newResult(lhs->dtype()); |
101 | IrBuilder::create<BinaryOp>(op_type, result, lhs, rhs); |
102 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
103 | return result; |
104 | } |
105 | |
106 | Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { |
107 | TORCH_CHECK( |
108 | lhs != nullptr && rhs != nullptr, |
109 | "Either lhs or rhs is a nullptr in newLogicExpr." ); |
110 | auto result = IrBuilder::create<Bool>(c10::nullopt); |
111 | IrBuilder::create<BinaryOp>(op_type, result, lhs, rhs); |
112 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
113 | return result; |
114 | } |
115 | |
116 | Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) { |
117 | TORCH_CHECK( |
118 | pred != nullptr && lhs != nullptr && rhs != nullptr, |
119 | "Either pred, lhs, or rhs is a nullptr in whereExpr." ); |
120 | TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types" ); |
121 | auto result = newResult(lhs->dtype()); |
122 | IrBuilder::create<TernaryOp>(TernaryOpType::Where, result, pred, lhs, rhs); |
123 | return result; |
124 | } |
125 | |
126 | Val* IrBuilder::negExpr(Val* val) { |
127 | TORCH_CHECK(val != nullptr, "val is a nullptr in negExpr." ); |
128 | auto result = newResult(val->dtype()); |
129 | IrBuilder::create<UnaryOp>(UnaryOpType::Neg, result, val); |
130 | return result; |
131 | } |
132 | |
133 | Val* IrBuilder::notExpr(Val* val) { |
134 | TORCH_CHECK(val != nullptr, "val is a nullptr in notExpr." ); |
135 | auto result = newResult(val->dtype()); |
136 | IrBuilder::create<UnaryOp>(UnaryOpType::Not, result, val); |
137 | return result; |
138 | } |
139 | |
140 | Val* IrBuilder::setExpr(Val* val) { |
141 | TORCH_CHECK(val != nullptr, "val is a nullptr in setExpr." ); |
142 | auto result = newResult(val->dtype()); |
143 | IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val); |
144 | return result; |
145 | } |
146 | |
147 | Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) { |
148 | TORCH_CHECK(val != nullptr, "val is a nullptr in setExprNamedScalar." ); |
149 | auto result = IrBuilder::create<NamedScalar>(name, val->dtype()); |
150 | IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val); |
151 | return result; |
152 | } |
153 | |
154 | Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) { |
155 | TORCH_CHECK(val != nullptr, "val is a nullptr in addressExprNamedScalar." ); |
156 | auto result = IrBuilder::create<NamedScalar>(name, DataType::Int); |
157 | IrBuilder::create<UnaryOp>(UnaryOpType::Address, result, val); |
158 | return result; |
159 | } |
160 | |
161 | Val* IrBuilder::andExpr(Val* lhs, Val* rhs) { |
162 | return newLogicExpr(BinaryOpType::And, lhs, rhs); |
163 | } |
164 | |
165 | Val* IrBuilder::eqExpr(Val* lhs, Val* rhs) { |
166 | return newLogicExpr(BinaryOpType::Eq, lhs, rhs); |
167 | } |
168 | |
169 | Val* IrBuilder::gtExpr(Val* lhs, Val* rhs) { |
170 | return newLogicExpr(BinaryOpType::GT, lhs, rhs); |
171 | } |
172 | |
173 | Val* IrBuilder::ltExpr(Val* lhs, Val* rhs) { |
174 | return newLogicExpr(BinaryOpType::LT, lhs, rhs); |
175 | } |
176 | |
177 | Val* IrBuilder::leExpr(Val* lhs, Val* rhs) { |
178 | return newLogicExpr(BinaryOpType::LE, lhs, rhs); |
179 | } |
180 | |
181 | Val* IrBuilder::geExpr(Val* lhs, Val* rhs) { |
182 | return newLogicExpr(BinaryOpType::GE, lhs, rhs); |
183 | } |
184 | |
185 | Val* IrBuilder::addExpr(Val* lhs, Val* rhs) { |
186 | return newArithmeticExpr(BinaryOpType::Add, lhs, rhs); |
187 | } |
188 | |
189 | Val* IrBuilder::subExpr(Val* lhs, Val* rhs) { |
190 | return newArithmeticExpr(BinaryOpType::Sub, lhs, rhs); |
191 | } |
192 | |
193 | Val* IrBuilder::mulExpr(Val* lhs, Val* rhs) { |
194 | return newArithmeticExpr(BinaryOpType::Mul, lhs, rhs); |
195 | } |
196 | |
197 | Val* IrBuilder::divExpr(Val* lhs, Val* rhs) { |
198 | return newArithmeticExpr(BinaryOpType::Div, lhs, rhs); |
199 | } |
200 | |
201 | Val* IrBuilder::ceilDivExpr(Val* lhs, Val* rhs) { |
202 | return newArithmeticExpr(BinaryOpType::CeilDiv, lhs, rhs); |
203 | } |
204 | |
205 | Val* IrBuilder::modExpr(Val* lhs, Val* rhs) { |
206 | return newArithmeticExpr(BinaryOpType::Mod, lhs, rhs); |
207 | } |
208 | |
209 | Val* IrBuilder::maxExpr(Val* lhs, Val* rhs) { |
210 | return newArithmeticExpr(BinaryOpType::Max, lhs, rhs); |
211 | } |
212 | |
213 | Val* IrBuilder::minExpr(Val* lhs, Val* rhs) { |
214 | return newArithmeticExpr(BinaryOpType::Min, lhs, rhs); |
215 | } |
216 | |
217 | Val* IrBuilder::swizzle2DIntExpr( |
218 | Val* in_x, |
219 | Val* in_y, |
220 | Val* extent_x, |
221 | Val* extent_y, |
222 | Swizzle2DType swizzle_type) { |
223 | auto result = create<kir::IntPair>(); |
224 | |
225 | create<kir::Swizzle2DInt>( |
226 | result, in_x, in_y, extent_x, extent_y, swizzle_type); |
227 | return result; |
228 | } |
229 | |
230 | Val* IrBuilder::pairSelectExpr(Val* in, kir::PairSelect::Selection sel) { |
231 | auto int_pair = dynamic_cast<kir::IntPair*>(in); |
232 | TORCH_INTERNAL_ASSERT(int_pair != nullptr); |
233 | auto result = create<Int>(); |
234 | create<kir::PairSelect>(result, int_pair, sel); |
235 | return result; |
236 | } |
237 | |
238 | Val* SimplifyingIrBuilder::negExpr(Val* val) { |
239 | if (auto int_val = dynamic_cast<Int*>(val)) { |
240 | if (int_val->isConst()) { |
241 | return IrBuilder::create<Int>(-int_val->value().value()); |
242 | } |
243 | } |
244 | return IrBuilder::negExpr(val); |
245 | } |
246 | |
247 | Val* SimplifyingIrBuilder::notExpr(Val* val) { |
248 | if (auto bool_val = dynamic_cast<Bool*>(val)) { |
249 | if (bool_val->isConst()) { |
250 | if (bool_val->value().value()) { |
251 | return FusionGuard::getCurFusion()->falseVal(); |
252 | } else { |
253 | return FusionGuard::getCurFusion()->trueVal(); |
254 | } |
255 | } |
256 | } |
257 | return IrBuilder::notExpr(val); |
258 | } |
259 | |
260 | Val* SimplifyingIrBuilder::addExpr(Int* lhs, Int::ScalarType rhs) { |
261 | if (rhs == 0) { |
262 | return lhs; |
263 | } else if (lhs == nullptr) { |
264 | return IrBuilder::IrBuilder::create<Int>(rhs); |
265 | } else if (lhs->isConst()) { |
266 | return IrBuilder::IrBuilder::create<Int>(lhs->value().value() + rhs); |
267 | } else if (rhs > 0) { |
268 | return IrBuilder::addExpr(lhs, IrBuilder::IrBuilder::create<Int>(rhs)); |
269 | } else { |
270 | return IrBuilder::subExpr(lhs, IrBuilder::IrBuilder::create<Int>(-rhs)); |
271 | } |
272 | } |
273 | |
274 | Val* SimplifyingIrBuilder::addExpr(Int* lhs, Int* rhs) { |
275 | if (rhs == nullptr) { |
276 | return lhs; |
277 | } else if (lhs == nullptr) { |
278 | return rhs; |
279 | } else if (lhs->isConst()) { |
280 | return addExpr(rhs, lhs->value().value()); |
281 | } else if (rhs->isConst()) { |
282 | return addExpr(lhs, rhs->value().value()); |
283 | } else { |
284 | return IrBuilder::addExpr(lhs, rhs); |
285 | } |
286 | } |
287 | |
288 | Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { |
289 | TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr); |
290 | if (lhs == nullptr || lhs->isZeroInt()) { |
291 | return rhs; |
292 | } else if (rhs == nullptr || rhs->isZeroInt()) { |
293 | return lhs; |
294 | } |
295 | auto lhs_int = dynamic_cast<Int*>(lhs); |
296 | auto rhs_int = dynamic_cast<Int*>(rhs); |
297 | if (lhs_int != nullptr && rhs_int != nullptr) { |
298 | return addExpr(lhs_int, rhs_int); |
299 | } else { |
300 | return IrBuilder::addExpr(lhs, rhs); |
301 | } |
302 | } |
303 | |
304 | Val* SimplifyingIrBuilder::addExpr(Val* lhs, Int::ScalarType rhs) { |
305 | auto lhs_int = dynamic_cast<Int*>(lhs); |
306 | if (lhs_int != nullptr) { |
307 | return addExpr(lhs_int, rhs); |
308 | } else { |
309 | return addExpr(lhs, IrBuilder::create<Int>(rhs)); |
310 | } |
311 | } |
312 | |
313 | Val* SimplifyingIrBuilder::subExpr(Val* lhs, Val* rhs) { |
314 | return addExpr(lhs, negExpr(rhs)); |
315 | } |
316 | |
317 | Val* SimplifyingIrBuilder::mulExpr(Int* lhs, Int::ScalarType rhs) { |
318 | if (rhs == 0) { |
319 | return lhs->container()->zeroVal(); |
320 | } else if (rhs == 1) { |
321 | return lhs; |
322 | } else if (lhs == nullptr) { |
323 | return IrBuilder::create<Int>(rhs); |
324 | } else if (lhs->isConst()) { |
325 | return IrBuilder::create<Int>(lhs->value().value() * rhs); |
326 | } else { |
327 | return IrBuilder::mulExpr(lhs, IrBuilder::create<Int>(rhs)); |
328 | } |
329 | } |
330 | |
331 | Val* SimplifyingIrBuilder::mulExpr(Val* lhs, Int::ScalarType rhs) { |
332 | auto lhs_int = dynamic_cast<Int*>(lhs); |
333 | if (lhs_int != nullptr) { |
334 | return mulExpr(lhs_int, rhs); |
335 | } else { |
336 | return IrBuilder::mulExpr(lhs, IrBuilder::create<Int>(rhs)); |
337 | } |
338 | } |
339 | |
340 | Val* SimplifyingIrBuilder::mulExpr(Int* lhs, Int* rhs) { |
341 | if (rhs == nullptr) { |
342 | return lhs; |
343 | } else if (lhs == nullptr) { |
344 | return rhs; |
345 | } else if (lhs->isConst()) { |
346 | return mulExpr(rhs, lhs->value().value()); |
347 | } else if (rhs->isConst()) { |
348 | return mulExpr(lhs, rhs->value().value()); |
349 | } else { |
350 | return IrBuilder::mulExpr(lhs, rhs); |
351 | } |
352 | } |
353 | |
354 | Val* SimplifyingIrBuilder::mulExpr(Val* lhs, Val* rhs) { |
355 | TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr); |
356 | if (lhs == nullptr || lhs->isOneInt()) { |
357 | return rhs; |
358 | } else if (rhs == nullptr || rhs->isOneInt()) { |
359 | return lhs; |
360 | } else if (lhs->isZeroInt() || rhs->isZeroInt()) { |
361 | return lhs->container()->zeroVal(); |
362 | } |
363 | auto lhs_int = dynamic_cast<Int*>(lhs); |
364 | auto rhs_int = dynamic_cast<Int*>(rhs); |
365 | if (lhs_int != nullptr && rhs_int != nullptr) { |
366 | return mulExpr(lhs_int, rhs_int); |
367 | } else { |
368 | return IrBuilder::mulExpr(lhs, rhs); |
369 | } |
370 | } |
371 | |
372 | Val* SimplifyingIrBuilder::andExpr(Val* lhs, Val* rhs) { |
373 | TORCH_INTERNAL_ASSERT(!(lhs == nullptr && rhs == nullptr)); |
374 | |
375 | if (lhs == nullptr) { |
376 | return rhs; |
377 | } else if (rhs == nullptr) { |
378 | return lhs; |
379 | } |
380 | |
381 | bool lhs_definitely_true = false; |
382 | bool lhs_definitely_false = false; |
383 | auto lhs_bool = dynamic_cast<Bool*>(lhs); |
384 | if (lhs_bool && lhs_bool->isConst()) { |
385 | lhs_definitely_true = lhs_bool->value().value(); |
386 | lhs_definitely_false = !lhs_bool->value().value(); |
387 | } |
388 | auto rhs_bool = dynamic_cast<Bool*>(rhs); |
389 | bool rhs_definitely_true = false; |
390 | bool rhs_definitely_false = false; |
391 | if (rhs_bool && rhs_bool->isConst()) { |
392 | rhs_definitely_true = rhs_bool->value().value(); |
393 | rhs_definitely_false = !rhs_bool->value().value(); |
394 | } |
395 | |
396 | if (lhs_definitely_true && rhs_definitely_true) { |
397 | return FusionGuard::getCurFusion()->trueVal(); |
398 | } else if (lhs_definitely_false || rhs_definitely_false) { |
399 | return FusionGuard::getCurFusion()->falseVal(); |
400 | } else if (lhs_definitely_true) { |
401 | return rhs; |
402 | } else if (rhs_definitely_true) { |
403 | return lhs; |
404 | } |
405 | |
406 | return IrBuilder::andExpr(lhs, rhs); |
407 | } |
408 | |
409 | namespace { |
410 | |
411 | template <typename IrBuilderFunc, typename IntFunc> |
412 | Val* minOrMaxExpr( |
413 | Int* lhs, |
414 | Int* rhs, |
415 | IrBuilderFunc ir_builder_func, |
416 | IntFunc int_func) { |
417 | if (rhs == nullptr) { |
418 | return lhs; |
419 | } else if (lhs == nullptr) { |
420 | return rhs; |
421 | } else if (lhs->isConst() && rhs->isConst()) { |
422 | return IrBuilder::create<Int>( |
423 | int_func(lhs->value().value(), rhs->value().value())); |
424 | } else { |
425 | return ir_builder_func(lhs, rhs); |
426 | } |
427 | } |
428 | |
429 | template <typename IrBuilderFunc, typename IntFunc> |
430 | Val* minOrMaxExpr( |
431 | Val* lhs, |
432 | Val* rhs, |
433 | IrBuilderFunc ir_builder_func, |
434 | IntFunc int_func) { |
435 | TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr); |
436 | if (lhs == nullptr) { |
437 | return rhs; |
438 | } else if (rhs == nullptr || lhs == rhs) { |
439 | return lhs; |
440 | } |
441 | auto lhs_int = dynamic_cast<Int*>(lhs); |
442 | auto rhs_int = dynamic_cast<Int*>(rhs); |
443 | if (lhs_int != nullptr && rhs_int != nullptr) { |
444 | return minOrMaxExpr(lhs_int, rhs_int, ir_builder_func, int_func); |
445 | } else { |
446 | return ir_builder_func(lhs, rhs); |
447 | } |
448 | } |
449 | |
450 | } // namespace |
451 | |
452 | Val* SimplifyingIrBuilder::maxExpr(Val* lhs, Val* rhs) { |
453 | return minOrMaxExpr( |
454 | lhs, |
455 | rhs, |
456 | [](Val* lhs, Val* rhs) { return IrBuilder::maxExpr(lhs, rhs); }, |
457 | [](int64_t lhs, int64_t rhs) { return std::max(lhs, rhs); }); |
458 | } |
459 | |
460 | Val* SimplifyingIrBuilder::minExpr(Val* lhs, Val* rhs) { |
461 | return minOrMaxExpr( |
462 | lhs, |
463 | rhs, |
464 | [](Val* lhs, Val* rhs) { return IrBuilder::minExpr(lhs, rhs); }, |
465 | [](int64_t lhs, int64_t rhs) { return std::min(lhs, rhs); }); |
466 | } |
467 | |
468 | } // namespace cuda |
469 | } // namespace fuser |
470 | } // namespace jit |
471 | } // namespace torch |
472 | |