1#include <fusion.h>
2#include <ir_builder.h>
3#include <ir_cloner.h>
4#include <kernel.h>
5
6namespace torch {
7namespace jit {
8namespace fuser {
9namespace cuda {
10
11//! Clone an IR node, forwarding the arguments to the IrCloner constructor.
12template <class T>
13T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) {
14 TORCH_INTERNAL_ASSERT(
15 ir_cloner != nullptr,
16 "Cannot use create when a cloner object is set. Use clone.");
17
18 TORCH_INTERNAL_ASSERT(
19 ir_cloner->container() != nullptr,
20 "Cloner doesn't have a valid container to store cloned object.");
21
22 T* dest = new T(src, ir_cloner);
23 const Statement* src_stmt = dynamic_cast<const Statement*>(src);
24 Statement* dest_stmt = dynamic_cast<Statement*>(dest);
25
26 auto dest_container = ir_cloner->container();
27 auto src_container = src_stmt->container();
28
29 dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt);
30
31 if (src_container != dest_container) {
32 dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name());
33 }
34
35 ir_cloner->registerClone(src_stmt, dest_stmt);
36
37 return dest;
38}
39
40#define IR_BUILDER_INSTANTIATE(T) \
41 template T* IrBuilder::clone(const T* src, IrCloner* ir_cloner);
42
43// Vals
44IR_BUILDER_INSTANTIATE(IterDomain)
45IR_BUILDER_INSTANTIATE(TensorDomain)
46IR_BUILDER_INSTANTIATE(TensorView)
47IR_BUILDER_INSTANTIATE(Bool)
48IR_BUILDER_INSTANTIATE(Double)
49IR_BUILDER_INSTANTIATE(Int)
50IR_BUILDER_INSTANTIATE(ComplexDouble)
51IR_BUILDER_INSTANTIATE(NamedScalar)
52
53// Exprs
54IR_BUILDER_INSTANTIATE(Split)
55IR_BUILDER_INSTANTIATE(Merge)
56IR_BUILDER_INSTANTIATE(Swizzle2D)
57IR_BUILDER_INSTANTIATE(TransposeOp)
58IR_BUILDER_INSTANTIATE(ExpandOp)
59IR_BUILDER_INSTANTIATE(ShiftOp)
60IR_BUILDER_INSTANTIATE(GatherOp)
61IR_BUILDER_INSTANTIATE(ViewAsScalar)
62IR_BUILDER_INSTANTIATE(ViewOp)
63IR_BUILDER_INSTANTIATE(FullOp)
64IR_BUILDER_INSTANTIATE(ARangeOp)
65IR_BUILDER_INSTANTIATE(EyeOp)
66IR_BUILDER_INSTANTIATE(UnaryOp)
67IR_BUILDER_INSTANTIATE(BinaryOp)
68IR_BUILDER_INSTANTIATE(TernaryOp)
69IR_BUILDER_INSTANTIATE(RNGOp)
70IR_BUILDER_INSTANTIATE(ReductionOp)
71IR_BUILDER_INSTANTIATE(GroupedReductionOp)
72IR_BUILDER_INSTANTIATE(WelfordOp)
73IR_BUILDER_INSTANTIATE(LoadStoreOp)
74IR_BUILDER_INSTANTIATE(MmaOp)
75IR_BUILDER_INSTANTIATE(BroadcastOp)
76
77Val* IrBuilder::newResult(DataType dtype) {
78 switch (dtype) {
79 case DataType::Bool:
80 return IrBuilder::create<Bool>(c10::nullopt);
81 case DataType::Double:
82 return IrBuilder::create<Double>(c10::nullopt);
83 case DataType::Int:
84 return IrBuilder::create<Int>(c10::nullopt);
85 default:
86 TORCH_CHECK(false, "Unexpected data type");
87 }
88}
89
90Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
91 TORCH_CHECK(
92 lhs != nullptr && rhs != nullptr,
93 "Either lhs or rhs is a nullptr in newArithmeticExpr.");
94 TORCH_CHECK(
95 lhs->dtype() == rhs->dtype(),
96 "Incompatible operand types: ",
97 lhs->dtype(),
98 " and ",
99 rhs->dtype());
100 auto result = newResult(lhs->dtype());
101 IrBuilder::create<BinaryOp>(op_type, result, lhs, rhs);
102 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
103 return result;
104}
105
106Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
107 TORCH_CHECK(
108 lhs != nullptr && rhs != nullptr,
109 "Either lhs or rhs is a nullptr in newLogicExpr.");
110 auto result = IrBuilder::create<Bool>(c10::nullopt);
111 IrBuilder::create<BinaryOp>(op_type, result, lhs, rhs);
112 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
113 return result;
114}
115
116Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) {
117 TORCH_CHECK(
118 pred != nullptr && lhs != nullptr && rhs != nullptr,
119 "Either pred, lhs, or rhs is a nullptr in whereExpr.");
120 TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types");
121 auto result = newResult(lhs->dtype());
122 IrBuilder::create<TernaryOp>(TernaryOpType::Where, result, pred, lhs, rhs);
123 return result;
124}
125
126Val* IrBuilder::negExpr(Val* val) {
127 TORCH_CHECK(val != nullptr, "val is a nullptr in negExpr.");
128 auto result = newResult(val->dtype());
129 IrBuilder::create<UnaryOp>(UnaryOpType::Neg, result, val);
130 return result;
131}
132
133Val* IrBuilder::notExpr(Val* val) {
134 TORCH_CHECK(val != nullptr, "val is a nullptr in notExpr.");
135 auto result = newResult(val->dtype());
136 IrBuilder::create<UnaryOp>(UnaryOpType::Not, result, val);
137 return result;
138}
139
140Val* IrBuilder::setExpr(Val* val) {
141 TORCH_CHECK(val != nullptr, "val is a nullptr in setExpr.");
142 auto result = newResult(val->dtype());
143 IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val);
144 return result;
145}
146
147Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) {
148 TORCH_CHECK(val != nullptr, "val is a nullptr in setExprNamedScalar.");
149 auto result = IrBuilder::create<NamedScalar>(name, val->dtype());
150 IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val);
151 return result;
152}
153
154Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) {
155 TORCH_CHECK(val != nullptr, "val is a nullptr in addressExprNamedScalar.");
156 auto result = IrBuilder::create<NamedScalar>(name, DataType::Int);
157 IrBuilder::create<UnaryOp>(UnaryOpType::Address, result, val);
158 return result;
159}
160
161Val* IrBuilder::andExpr(Val* lhs, Val* rhs) {
162 return newLogicExpr(BinaryOpType::And, lhs, rhs);
163}
164
165Val* IrBuilder::eqExpr(Val* lhs, Val* rhs) {
166 return newLogicExpr(BinaryOpType::Eq, lhs, rhs);
167}
168
169Val* IrBuilder::gtExpr(Val* lhs, Val* rhs) {
170 return newLogicExpr(BinaryOpType::GT, lhs, rhs);
171}
172
173Val* IrBuilder::ltExpr(Val* lhs, Val* rhs) {
174 return newLogicExpr(BinaryOpType::LT, lhs, rhs);
175}
176
177Val* IrBuilder::leExpr(Val* lhs, Val* rhs) {
178 return newLogicExpr(BinaryOpType::LE, lhs, rhs);
179}
180
181Val* IrBuilder::geExpr(Val* lhs, Val* rhs) {
182 return newLogicExpr(BinaryOpType::GE, lhs, rhs);
183}
184
185Val* IrBuilder::addExpr(Val* lhs, Val* rhs) {
186 return newArithmeticExpr(BinaryOpType::Add, lhs, rhs);
187}
188
189Val* IrBuilder::subExpr(Val* lhs, Val* rhs) {
190 return newArithmeticExpr(BinaryOpType::Sub, lhs, rhs);
191}
192
193Val* IrBuilder::mulExpr(Val* lhs, Val* rhs) {
194 return newArithmeticExpr(BinaryOpType::Mul, lhs, rhs);
195}
196
197Val* IrBuilder::divExpr(Val* lhs, Val* rhs) {
198 return newArithmeticExpr(BinaryOpType::Div, lhs, rhs);
199}
200
201Val* IrBuilder::ceilDivExpr(Val* lhs, Val* rhs) {
202 return newArithmeticExpr(BinaryOpType::CeilDiv, lhs, rhs);
203}
204
205Val* IrBuilder::modExpr(Val* lhs, Val* rhs) {
206 return newArithmeticExpr(BinaryOpType::Mod, lhs, rhs);
207}
208
209Val* IrBuilder::maxExpr(Val* lhs, Val* rhs) {
210 return newArithmeticExpr(BinaryOpType::Max, lhs, rhs);
211}
212
213Val* IrBuilder::minExpr(Val* lhs, Val* rhs) {
214 return newArithmeticExpr(BinaryOpType::Min, lhs, rhs);
215}
216
217Val* IrBuilder::swizzle2DIntExpr(
218 Val* in_x,
219 Val* in_y,
220 Val* extent_x,
221 Val* extent_y,
222 Swizzle2DType swizzle_type) {
223 auto result = create<kir::IntPair>();
224
225 create<kir::Swizzle2DInt>(
226 result, in_x, in_y, extent_x, extent_y, swizzle_type);
227 return result;
228}
229
230Val* IrBuilder::pairSelectExpr(Val* in, kir::PairSelect::Selection sel) {
231 auto int_pair = dynamic_cast<kir::IntPair*>(in);
232 TORCH_INTERNAL_ASSERT(int_pair != nullptr);
233 auto result = create<Int>();
234 create<kir::PairSelect>(result, int_pair, sel);
235 return result;
236}
237
238Val* SimplifyingIrBuilder::negExpr(Val* val) {
239 if (auto int_val = dynamic_cast<Int*>(val)) {
240 if (int_val->isConst()) {
241 return IrBuilder::create<Int>(-int_val->value().value());
242 }
243 }
244 return IrBuilder::negExpr(val);
245}
246
247Val* SimplifyingIrBuilder::notExpr(Val* val) {
248 if (auto bool_val = dynamic_cast<Bool*>(val)) {
249 if (bool_val->isConst()) {
250 if (bool_val->value().value()) {
251 return FusionGuard::getCurFusion()->falseVal();
252 } else {
253 return FusionGuard::getCurFusion()->trueVal();
254 }
255 }
256 }
257 return IrBuilder::notExpr(val);
258}
259
260Val* SimplifyingIrBuilder::addExpr(Int* lhs, Int::ScalarType rhs) {
261 if (rhs == 0) {
262 return lhs;
263 } else if (lhs == nullptr) {
264 return IrBuilder::IrBuilder::create<Int>(rhs);
265 } else if (lhs->isConst()) {
266 return IrBuilder::IrBuilder::create<Int>(lhs->value().value() + rhs);
267 } else if (rhs > 0) {
268 return IrBuilder::addExpr(lhs, IrBuilder::IrBuilder::create<Int>(rhs));
269 } else {
270 return IrBuilder::subExpr(lhs, IrBuilder::IrBuilder::create<Int>(-rhs));
271 }
272}
273
274Val* SimplifyingIrBuilder::addExpr(Int* lhs, Int* rhs) {
275 if (rhs == nullptr) {
276 return lhs;
277 } else if (lhs == nullptr) {
278 return rhs;
279 } else if (lhs->isConst()) {
280 return addExpr(rhs, lhs->value().value());
281 } else if (rhs->isConst()) {
282 return addExpr(lhs, rhs->value().value());
283 } else {
284 return IrBuilder::addExpr(lhs, rhs);
285 }
286}
287
288Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) {
289 TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr);
290 if (lhs == nullptr || lhs->isZeroInt()) {
291 return rhs;
292 } else if (rhs == nullptr || rhs->isZeroInt()) {
293 return lhs;
294 }
295 auto lhs_int = dynamic_cast<Int*>(lhs);
296 auto rhs_int = dynamic_cast<Int*>(rhs);
297 if (lhs_int != nullptr && rhs_int != nullptr) {
298 return addExpr(lhs_int, rhs_int);
299 } else {
300 return IrBuilder::addExpr(lhs, rhs);
301 }
302}
303
304Val* SimplifyingIrBuilder::addExpr(Val* lhs, Int::ScalarType rhs) {
305 auto lhs_int = dynamic_cast<Int*>(lhs);
306 if (lhs_int != nullptr) {
307 return addExpr(lhs_int, rhs);
308 } else {
309 return addExpr(lhs, IrBuilder::create<Int>(rhs));
310 }
311}
312
313Val* SimplifyingIrBuilder::subExpr(Val* lhs, Val* rhs) {
314 return addExpr(lhs, negExpr(rhs));
315}
316
317Val* SimplifyingIrBuilder::mulExpr(Int* lhs, Int::ScalarType rhs) {
318 if (rhs == 0) {
319 return lhs->container()->zeroVal();
320 } else if (rhs == 1) {
321 return lhs;
322 } else if (lhs == nullptr) {
323 return IrBuilder::create<Int>(rhs);
324 } else if (lhs->isConst()) {
325 return IrBuilder::create<Int>(lhs->value().value() * rhs);
326 } else {
327 return IrBuilder::mulExpr(lhs, IrBuilder::create<Int>(rhs));
328 }
329}
330
331Val* SimplifyingIrBuilder::mulExpr(Val* lhs, Int::ScalarType rhs) {
332 auto lhs_int = dynamic_cast<Int*>(lhs);
333 if (lhs_int != nullptr) {
334 return mulExpr(lhs_int, rhs);
335 } else {
336 return IrBuilder::mulExpr(lhs, IrBuilder::create<Int>(rhs));
337 }
338}
339
340Val* SimplifyingIrBuilder::mulExpr(Int* lhs, Int* rhs) {
341 if (rhs == nullptr) {
342 return lhs;
343 } else if (lhs == nullptr) {
344 return rhs;
345 } else if (lhs->isConst()) {
346 return mulExpr(rhs, lhs->value().value());
347 } else if (rhs->isConst()) {
348 return mulExpr(lhs, rhs->value().value());
349 } else {
350 return IrBuilder::mulExpr(lhs, rhs);
351 }
352}
353
354Val* SimplifyingIrBuilder::mulExpr(Val* lhs, Val* rhs) {
355 TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr);
356 if (lhs == nullptr || lhs->isOneInt()) {
357 return rhs;
358 } else if (rhs == nullptr || rhs->isOneInt()) {
359 return lhs;
360 } else if (lhs->isZeroInt() || rhs->isZeroInt()) {
361 return lhs->container()->zeroVal();
362 }
363 auto lhs_int = dynamic_cast<Int*>(lhs);
364 auto rhs_int = dynamic_cast<Int*>(rhs);
365 if (lhs_int != nullptr && rhs_int != nullptr) {
366 return mulExpr(lhs_int, rhs_int);
367 } else {
368 return IrBuilder::mulExpr(lhs, rhs);
369 }
370}
371
372Val* SimplifyingIrBuilder::andExpr(Val* lhs, Val* rhs) {
373 TORCH_INTERNAL_ASSERT(!(lhs == nullptr && rhs == nullptr));
374
375 if (lhs == nullptr) {
376 return rhs;
377 } else if (rhs == nullptr) {
378 return lhs;
379 }
380
381 bool lhs_definitely_true = false;
382 bool lhs_definitely_false = false;
383 auto lhs_bool = dynamic_cast<Bool*>(lhs);
384 if (lhs_bool && lhs_bool->isConst()) {
385 lhs_definitely_true = lhs_bool->value().value();
386 lhs_definitely_false = !lhs_bool->value().value();
387 }
388 auto rhs_bool = dynamic_cast<Bool*>(rhs);
389 bool rhs_definitely_true = false;
390 bool rhs_definitely_false = false;
391 if (rhs_bool && rhs_bool->isConst()) {
392 rhs_definitely_true = rhs_bool->value().value();
393 rhs_definitely_false = !rhs_bool->value().value();
394 }
395
396 if (lhs_definitely_true && rhs_definitely_true) {
397 return FusionGuard::getCurFusion()->trueVal();
398 } else if (lhs_definitely_false || rhs_definitely_false) {
399 return FusionGuard::getCurFusion()->falseVal();
400 } else if (lhs_definitely_true) {
401 return rhs;
402 } else if (rhs_definitely_true) {
403 return lhs;
404 }
405
406 return IrBuilder::andExpr(lhs, rhs);
407}
408
409namespace {
410
411template <typename IrBuilderFunc, typename IntFunc>
412Val* minOrMaxExpr(
413 Int* lhs,
414 Int* rhs,
415 IrBuilderFunc ir_builder_func,
416 IntFunc int_func) {
417 if (rhs == nullptr) {
418 return lhs;
419 } else if (lhs == nullptr) {
420 return rhs;
421 } else if (lhs->isConst() && rhs->isConst()) {
422 return IrBuilder::create<Int>(
423 int_func(lhs->value().value(), rhs->value().value()));
424 } else {
425 return ir_builder_func(lhs, rhs);
426 }
427}
428
429template <typename IrBuilderFunc, typename IntFunc>
430Val* minOrMaxExpr(
431 Val* lhs,
432 Val* rhs,
433 IrBuilderFunc ir_builder_func,
434 IntFunc int_func) {
435 TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr);
436 if (lhs == nullptr) {
437 return rhs;
438 } else if (rhs == nullptr || lhs == rhs) {
439 return lhs;
440 }
441 auto lhs_int = dynamic_cast<Int*>(lhs);
442 auto rhs_int = dynamic_cast<Int*>(rhs);
443 if (lhs_int != nullptr && rhs_int != nullptr) {
444 return minOrMaxExpr(lhs_int, rhs_int, ir_builder_func, int_func);
445 } else {
446 return ir_builder_func(lhs, rhs);
447 }
448}
449
450} // namespace
451
452Val* SimplifyingIrBuilder::maxExpr(Val* lhs, Val* rhs) {
453 return minOrMaxExpr(
454 lhs,
455 rhs,
456 [](Val* lhs, Val* rhs) { return IrBuilder::maxExpr(lhs, rhs); },
457 [](int64_t lhs, int64_t rhs) { return std::max(lhs, rhs); });
458}
459
460Val* SimplifyingIrBuilder::minExpr(Val* lhs, Val* rhs) {
461 return minOrMaxExpr(
462 lhs,
463 rhs,
464 [](Val* lhs, Val* rhs) { return IrBuilder::minExpr(lhs, rhs); },
465 [](int64_t lhs, int64_t rhs) { return std::min(lhs, rhs); });
466}
467
468} // namespace cuda
469} // namespace fuser
470} // namespace jit
471} // namespace torch
472