1#include <arith.h>
2
3#include <c10/util/BFloat16.h>
4#include <c10/util/Exception.h>
5#include <c10/util/Half.h>
6#include <c10/util/irange.h>
7#include <ir_all_nodes.h>
8#include <ir_builder.h>
9#include <ir_iostream.h>
10#include <ir_utils.h>
11#include <type.h>
12#include <type_promotion.h>
13#include <cfloat>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20namespace {
21
22TensorView* maybe_broadcast_inner_to_rank(TensorView* t, size_t rank) {
23 size_t t_rank = TensorDomain::noReductions(t->getMaybeRFactorDomain()).size();
24
25 // broadcast inner on inp to match rank with other.
26 if (t_rank < rank) {
27 const int num_bcast = static_cast<int>(rank - t_rank);
28 std::vector<bool> inner_bcast_dims(rank, false);
29 std::fill(
30 inner_bcast_dims.begin(), inner_bcast_dims.begin() + num_bcast, true);
31 t = broadcast(t, inner_bcast_dims);
32 }
33 return t;
34}
35
36Val* simplifiedInt(Val* val) {
37 TORCH_INTERNAL_ASSERT(
38 val->isConstInt(), "Expecting Const Int's only in this routine.");
39 if (val->as<Int>()->value().has_value()) {
40 return val;
41 }
42 return IrBuilder::create<Int>(val->evaluateInt());
43}
44
45// If one size is nullptr, return the other. If both symbolic just return v1. If
46// one's concrete, prefer that one (simplified). If both concrete make sure
47// they're the same size.
48Val* promoteSize(Val* v1, Val* v2) {
49 if (v1 == nullptr) {
50 TORCH_INTERNAL_ASSERT(
51 v2 == nullptr || v2->isAnInt(),
52 "Expecting Int's only in this routine.");
53 return v2;
54 }
55 if (v2 == nullptr) {
56 return v1;
57 }
58 TORCH_INTERNAL_ASSERT(
59 v1->isAnInt() && v2->isAnInt(), "Expecting Int's only in this routine.");
60
61 if (!v1->isConstInt() && !v2->isConstInt()) {
62 return v1;
63 } else if (v1->isConstInt() && v2->isConstInt()) {
64 TORCH_INTERNAL_ASSERT(
65 v1->evaluateInt() == v2->evaluateInt(),
66 "Expected sizes of, ",
67 v1->toString(),
68 " and ",
69 v2->toString(),
70 " to match but found ",
71 v1->evaluateInt(),
72 " and ",
73 v2->evaluateInt(),
74 ".");
75 return simplifiedInt(v1);
76 } else if (v1->isConstInt()) {
77 return simplifiedInt(v1);
78 }
79 return simplifiedInt(v2);
80}
81
82// Will return a new value of type val with the DataType dtype.
83Val* newScalar(ValType vtype, DataType dtype) {
84 switch (vtype) {
85 case (ValType::NamedScalar):
86 case (ValType::Scalar):
87 switch (dtype) {
88 case DataType::Bool:
89 return IrBuilder::create<Bool>();
90 case DataType::Double:
91 case DataType::Float:
92 case DataType::Half:
93 case DataType::BFloat16:
94 return IrBuilder::create<Double>();
95 case DataType::Int32:
96 case DataType::Int:
97 return IrBuilder::create<Int>();
98 case DataType::ComplexFloat:
99 case DataType::ComplexDouble:
100 return IrBuilder::create<ComplexDouble>();
101 default:
102 break;
103 }
104 default:
105 break;
106 }
107
108 TORCH_CHECK(
109 false,
110 "Cannot handle ValType: ",
111 vtype,
112 " with DataType:",
113 dtype,
114 " in newScalar.");
115}
116
117IterType promoteIterType(IterType type1, IterType type2) {
118 // Iteration: Default
119 // Reduction: Should not appear here
120 // Broadcast: Propagated only if type1 and type2 are Broadcast
121 // Gather: Converted to Iteration
122 // Stride: Shold not appear here
123 // VectorComponent: Converted to Iteration
124
125 TORCH_INTERNAL_ASSERT(
126 type1 != IterType::Reduction && type1 != IterType::Stride,
127 "Invalid IterType: ",
128 type1)
129 TORCH_INTERNAL_ASSERT(
130 type2 != IterType::Reduction && type2 != IterType::Stride,
131 "Invalid IterType: ",
132 type2);
133
134 // Do not propagate Gather and VectorComponent
135 if (type1 == IterType::Gather || type1 == IterType::VectorComponent) {
136 type1 = IterType::Iteration;
137 }
138 if (type2 == IterType::Gather || type2 == IterType::VectorComponent) {
139 type2 = IterType::Iteration;
140 }
141
142 // At this point, type1 and type2 must be either Iteration or
143 // Broadcast
144 TORCH_INTERNAL_ASSERT(
145 type1 == IterType::Iteration || type1 == IterType::Broadcast,
146 "Unexpected IterType: ",
147 type1);
148 TORCH_INTERNAL_ASSERT(
149 type2 == IterType::Iteration || type2 == IterType::Broadcast,
150 "Unexpected IterType: ",
151 type2);
152
153 if (type1 == IterType::Broadcast) {
154 return type2;
155 } else {
156 return type1;
157 }
158}
159
160TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
161 std::vector<TensorView*> tvs;
162 for (auto val : vals) {
163 if (val->getValType() == ValType::TensorView) {
164 tvs.push_back(val->as<TensorView>());
165 }
166 }
167 TORCH_CHECK(
168 !tvs.empty(),
169 "Tried to create new output TensorView but received empty list.");
170
171 std::vector<IterDomain*> out_domain(
172 TensorDomain::noReductions(tvs[0]->getMaybeRFactorDomain()).size(),
173 nullptr);
174
175 // For the start and stop offsets, take the maximum of input axes.
176 // For now, the offsets of both start and stop are always integer
177 // constant, so we can statically compute them. It is unclear
178 // whether we would need to support dynamic offsetting, e.g.,
179 // shifting by a dynamic offset.
180 std::vector<int64_t> start_offsets(out_domain.size(), 0);
181 std::vector<int64_t> stop_offsets(out_domain.size(), 0);
182 std::vector<Val*> extent_vals(out_domain.size(), nullptr);
183 std::vector<Val*> expanded_extent_vals(out_domain.size(), nullptr);
184 std::vector<c10::optional<IterType>> iter_types(
185 out_domain.size(), c10::nullopt);
186
187 for (auto tv : tvs) {
188 auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
189 TORCH_INTERNAL_ASSERT(
190 dom.size() == out_domain.size(),
191 "Invalid tensor view found while producing an output, it has ",
192 dom.size(),
193 " dimensions but expected ",
194 out_domain.size());
195 for (const auto i : c10::irange(dom.size())) {
196 if (dom[i]->isBroadcast()) {
197 if (dom[i]->hasExpandedExtent()) {
198 expanded_extent_vals[i] =
199 promoteSize(expanded_extent_vals[i], dom[i]->expandedExtent());
200 }
201 continue;
202 }
203 extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent());
204 if (iter_types[i].has_value()) {
205 iter_types[i] =
206 promoteIterType(iter_types[i].value(), dom[i]->getIterType());
207 } else {
208 iter_types[i] = dom[i]->getIterType();
209 }
210
211 auto start_offset = dom[i]->start()->as<Int>();
212 auto stop_offset = dom[i]->stopOffset()->as<Int>();
213 // Currently, start is always constant
214 TORCH_INTERNAL_ASSERT(
215 start_offset->isConstInt(),
216 "Invalid IterDomain start: ",
217 start_offset);
218 TORCH_INTERNAL_ASSERT(
219 stop_offset->isConstInt(),
220 "Invalid IterDomain stop offset: ",
221 stop_offset);
222 start_offsets[i] =
223 std::max(start_offsets[i], start_offset->evaluateInt());
224 stop_offsets[i] = std::max(stop_offsets[i], stop_offset->evaluateInt());
225 }
226 }
227 for (const auto dim_i : c10::irange(out_domain.size())) {
228 if (extent_vals[dim_i] != nullptr) {
229 TORCH_INTERNAL_ASSERT(
230 iter_types[dim_i].has_value(),
231 "Could not deduce iter type for new tensor view.");
232 out_domain[dim_i] =
233 IterDomainBuilder(
234 IrBuilder::create<Int>(start_offsets[dim_i]), extent_vals[dim_i])
235 .stop_offset(IrBuilder::create<Int>(stop_offsets[dim_i]))
236 .iter_type(iter_types[dim_i].value())
237 .build();
238 } else {
239 out_domain[dim_i] = IterDomainBuilder(
240 FusionGuard::getCurFusion()->zeroVal(),
241 FusionGuard::getCurFusion()->oneVal())
242 .expanded_extent(expanded_extent_vals[dim_i])
243 .iter_type(IterType::Broadcast)
244 .build();
245 }
246 }
247
248 return IrBuilder::create<TensorView>(
249 IrBuilder::create<TensorDomain>(
250 out_domain, std::vector<bool>(out_domain.size(), true)),
251 dtype);
252}
253
254std::vector<Val*> maybeBroadcast(const std::vector<Val*>& vals) {
255 std::vector<Val*> out_vals(vals.size(), nullptr);
256 size_t n_dims = 0;
257 for (auto val : vals) {
258 if (val->getValType().value() == ValType::TensorView) {
259 n_dims = std::max(
260 n_dims,
261 TensorDomain::noReductions(
262 val->as<TensorView>()->getMaybeRFactorDomain())
263 .size());
264 }
265 }
266
267 for (const auto i : c10::irange(vals.size())) {
268 if (vals[i]->getValType().value() == ValType::TensorView) {
269 auto tv = vals[i]->as<TensorView>();
270 out_vals[i] = maybe_broadcast_inner_to_rank(tv, n_dims);
271 } else {
272 out_vals[i] = vals[i];
273 }
274 }
275 return out_vals;
276}
277
278Val* newValLike(Val* val, DataType dtype) {
279 TORCH_CHECK(
280 dtype != DataType::Null, "Invalid datatype provided for new value.");
281
282 const ValType vtype = val->getValType().value();
283
284 if (vtype == ValType::TensorView)
285 return newOutputTV({val}, dtype);
286
287 return newScalar(vtype, dtype);
288}
289
290// returns the minimum init value for reduction:
291// -inf for floating type;
292// lowest value for integer type;
293// false for bool.
294Val* getMinimumValue(DataType v) {
295 switch (v) {
296 case (DataType::Double):
297 return IrBuilder::create<Double>(
298 -std::numeric_limits<double>::infinity());
299 break;
300 case (DataType::Float):
301 return IrBuilder::create<Double>(-std::numeric_limits<float>::infinity());
302 break;
303 case (DataType::Half):
304 return IrBuilder::create<Double>(
305 static_cast<double>(-std::numeric_limits<c10::Half>::infinity()));
306 break;
307 case DataType::BFloat16:
308 return IrBuilder::create<Double>(
309 static_cast<double>(-std::numeric_limits<c10::BFloat16>::infinity()));
310 break;
311 case (DataType::Int):
312 return IrBuilder::create<Int>(std::numeric_limits<int64_t>::lowest());
313 break;
314 case (DataType::Int32):
315 return IrBuilder::create<Int>(std::numeric_limits<int32_t>::lowest());
316 break;
317 case (DataType::Bool):
318 return IrBuilder::create<Bool>(false);
319 break;
320 default:
321 TORCH_CHECK(
322 false, "Could not generate a min op for tensor with type: ", v);
323 }
324 return nullptr;
325}
326
327// returns the maximum init value for reduction:
328// inf for floating type;
329// highest value for integer type;
330// true for bool.
331Val* getMaximumValue(DataType v) {
332 switch (v) {
333 case (DataType::Double):
334 return IrBuilder::create<Double>(std::numeric_limits<double>::infinity());
335 break;
336 case (DataType::Float):
337 return IrBuilder::create<Double>(std::numeric_limits<float>::infinity());
338 break;
339 case (DataType::Half):
340 return IrBuilder::create<Double>(
341 static_cast<double>(std::numeric_limits<c10::Half>::infinity()));
342 break;
343 case DataType::BFloat16:
344 return IrBuilder::create<Double>(
345 static_cast<double>(std::numeric_limits<c10::BFloat16>::infinity()));
346 break;
347 case (DataType::Int):
348 return IrBuilder::create<Int>(std::numeric_limits<int64_t>::max());
349 break;
350 case (DataType::Int32):
351 return IrBuilder::create<Int>(std::numeric_limits<int32_t>::max());
352 break;
353 case (DataType::Bool):
354 return IrBuilder::create<Bool>(true);
355 break;
356 default:
357 TORCH_CHECK(
358 false, "Could not generate a max op for tensor with type: ", v);
359 }
360 return nullptr;
361}
362
363} // namespace
364
365Val* castOp(DataType dtype, Val* v1) {
366 if (v1->getDataType().value() == dtype) {
367 return set(v1);
368 }
369
370 if (cast_func_str(std::make_pair(v1->getDataType().value(), dtype)) ==
371 c10::nullopt) {
372 TORCH_CHECK(
373 false,
374 "Illegal Cast value from DataType: ",
375 v1->getDataType().value(),
376 " to DataType: ",
377 dtype);
378 }
379
380 Val* out = newValLike(v1, dtype);
381 IrBuilder::create<UnaryOp>(UnaryOpType::Cast, out, v1);
382 return out;
383}
384
385TensorView* castOp(DataType dtype, TensorView* v1) {
386 return castOp(dtype, v1->as<Val>())->as<TensorView>();
387}
388
389Val* bitCastOp(DataType dtype, Val* v1) {
390 if (v1->getDataType().value() == dtype) {
391 return v1;
392 }
393
394 TORCH_CHECK(
395 dataTypeSize(v1->getDataType().value()) == dataTypeSize(dtype),
396 "BitCast only works for types of the same size");
397
398 Val* out = newValLike(v1, dtype);
399 IrBuilder::create<UnaryOp>(UnaryOpType::BitCast, out, v1);
400 return out;
401}
402
403TensorView* bitCastOp(DataType dtype, TensorView* v1) {
404 return bitCastOp(dtype, v1->as<Val>())->as<TensorView>();
405}
406
407Val* unaryOp(UnaryOpType type, Val* v1) {
408 TORCH_INTERNAL_ASSERT(
409 type != UnaryOpType::Address,
410 "The reference operator & is not accessible in the Fusion IR");
411 Val* out = newValLike(v1, v1->getDataType().value());
412 IrBuilder::create<UnaryOp>(type, out, v1);
413 return out;
414}
415
416TensorView* unaryOp(UnaryOpType type, TensorView* v1) {
417 return unaryOp(type, v1->as<Val>())->as<TensorView>();
418}
419
420Val* unaryIsOp(UnaryOpType type, Val* v) {
421 Val* out = newValLike(v, DataType::Bool);
422 IrBuilder::create<UnaryOp>(type, out, v);
423 return out;
424}
425
426TensorView* unaryIsOp(UnaryOpType type, TensorView* v) {
427 return unaryOp(type, v->asVal())->as<TensorView>();
428}
429
430Val* unaryOp(UnaryOpType type, Val* v1, const TypePromotionConfig& config) {
431 auto cast_v1 = promoteValues(config, {v1}).front();
432 return unaryOp(type, cast_v1);
433}
434
435TensorView* unaryOp(
436 UnaryOpType type,
437 TensorView* v1,
438 const TypePromotionConfig& config) {
439 auto cast_v1 = promoteValues(config, {v1}).front();
440 return unaryOp(type, cast_v1)->as<TensorView>();
441}
442
443// TENSOR FACTORIES
444TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
445 auto n = shape.size();
446 auto out = TensorViewBuilder()
447 .ndims(n)
448 .dtype(dtype)
449 .contiguity(std::vector<bool>(n, true))
450 .shape(shape)
451 .build();
452 IrBuilder::create<RNGOp>(RNGOpType::Uniform, out, dtype);
453 return out;
454}
455
456// TENSOR FACTORIES
457TensorView* uniform(
458 const std::vector<Val*>& shape,
459 Val* low,
460 Val* high,
461 DataType dtype) {
462 auto n = shape.size();
463 auto out = TensorViewBuilder()
464 .ndims(n)
465 .dtype(dtype)
466 .contiguity(std::vector<bool>(n, true))
467 .shape(shape)
468 .build();
469 IrBuilder::create<RNGOp>(
470 RNGOpType::UniformRange, out, dtype, std::vector<Val*>{low, high});
471 return out;
472}
473
474TensorView* rand_like(TensorView* tv) {
475 TORCH_CHECK(
476 isFloatingPointType(tv->dtype()),
477 "input must have floating point type, but got ",
478 tv->dtype());
479 std::vector<Val*> shape;
480 auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
481 shape.reserve(dom.size());
482 for (auto id : dom) {
483 shape.emplace_back(id->getMaybeExpandedExtent());
484 }
485 return rand(shape, tv->dtype());
486}
487
488Val* rand_like(Val* v) {
489 return rand_like(v->as<TensorView>());
490}
491
492TensorView* full(
493 const std::vector<Val*>& shape,
494 Val* fill_value,
495 DataType dtype) {
496 auto n = shape.size();
497 auto out = TensorViewBuilder()
498 .ndims(n)
499 .dtype(dtype)
500 .contiguity(std::vector<bool>(n, true))
501 .shape(shape)
502 .build();
503 IrBuilder::create<FullOp>(out, fill_value, dtype);
504 return out;
505}
506
507TensorView* full_like(TensorView* tv, Val* fill_value) {
508 std::vector<Val*> shape;
509 auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
510 shape.reserve(dom.size());
511 for (auto id : dom) {
512 shape.emplace_back(id->getMaybeExpandedExtent());
513 }
514 return full(shape, fill_value, tv->dtype());
515}
516
517Val* full_like(Val* v, Val* fill_value) {
518 return full_like(v->as<TensorView>(), fill_value);
519}
520
521TensorView* zeros(const std::vector<Val*>& shape, DataType dtype) {
522 return full(shape, FusionGuard::getCurFusion()->zeroVal(), dtype);
523}
524
525TensorView* zeros_like(TensorView* tv) {
526 return full_like(tv, FusionGuard::getCurFusion()->zeroVal());
527}
528
529Val* zeros_like(Val* v) {
530 return zeros_like(v->as<TensorView>());
531}
532
533TensorView* ones(const std::vector<Val*>& shape, DataType dtype) {
534 return full(shape, FusionGuard::getCurFusion()->oneVal(), dtype);
535}
536
537TensorView* ones_like(TensorView* tv) {
538 return full_like(tv, FusionGuard::getCurFusion()->oneVal());
539}
540
541Val* ones_like(Val* v) {
542 return ones_like(v->as<TensorView>());
543}
544
545TensorView* arange(Val* end, DataType dtype) {
546 return arange(FusionGuard::getCurFusion()->zeroVal(), end, dtype);
547}
548
549TensorView* arange(Val* start, Val* end, DataType dtype) {
550 return arange(start, end, FusionGuard::getCurFusion()->oneVal(), dtype);
551}
552
553TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
554 if (isIntegralType(dtype)) {
555 start = castOp(DataType::Int, start);
556 end = castOp(DataType::Int, end);
557 step = castOp(DataType::Int, step);
558 } else if (isFloatingPointType(dtype)) {
559 start = castOp(DataType::Double, start);
560 end = castOp(DataType::Double, end);
561 step = castOp(DataType::Double, step);
562 }
563 // Make sure no negative value is passed to ceilDiv as the device
564 // implementation of ceilDiv assumes positive inputs
565 auto size = castOp(DataType::Int, ceilDiv(abs(sub(end, start)), abs(step)));
566 auto out = TensorViewBuilder()
567 .ndims(1)
568 .dtype(dtype)
569 .contiguity({true})
570 .shape({size})
571 .build();
572 IrBuilder::create<ARangeOp>(out, start, end, step, dtype);
573 return out;
574}
575
576TensorView* eye(Val* rows, Val* cols, DataType dtype) {
577 TORCH_CHECK(rows->getDataType() == DataType::Int, "rows must have type Int");
578 TORCH_CHECK(cols->getDataType() == DataType::Int, "cols must have type Int");
579 auto out = TensorViewBuilder()
580 .ndims(2)
581 .dtype(dtype)
582 .contiguity({true, true})
583 .shape(std::vector<Val*>{rows, cols})
584 .build();
585 IrBuilder::create<EyeOp>(out, dtype);
586 return out;
587}
588
589TensorView* eye(Val* size, DataType dtype) {
590 return eye(size, size, dtype);
591}
592
593// UNARY OPERATIONS
594
595#define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \
596 Val* op_name(Val* v) { \
597 return unaryOp(UnaryOpType::op_type, v); \
598 } \
599 TensorView* op_name(TensorView* tv) { \
600 return unaryOp(UnaryOpType::op_type, tv); \
601 }
602
603NVFUSER_DEFINE_UNARY_OP(set, Set)
604NVFUSER_DEFINE_UNARY_OP(ceil, Ceil)
605NVFUSER_DEFINE_UNARY_OP(floor, Floor)
606NVFUSER_DEFINE_UNARY_OP(frac, Frac)
607NVFUSER_DEFINE_UNARY_OP(neg, Neg)
608NVFUSER_DEFINE_UNARY_OP(relu, Relu)
609NVFUSER_DEFINE_UNARY_OP(round, Round)
610NVFUSER_DEFINE_UNARY_OP(silu, Silu)
611NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
612NVFUSER_DEFINE_UNARY_OP(print, Print)
613#undef NVFUSER_DEFINE_UNARY_OP
614
615Val* bitwise_not(Val* v) {
616 TORCH_CHECK(
617 isIntegralType(v->dtype()) || v->dtype() == DataType::Bool,
618 "input must have integral or boolean type, but got ",
619 v->dtype());
620 return unaryOp(UnaryOpType::Not, v);
621}
622
623TensorView* bitwise_not(TensorView* tv) {
624 TORCH_CHECK(
625 isIntegralType(tv->dtype()) || tv->dtype() == DataType::Bool,
626 "input must have integral or boolean type, but got ",
627 tv->dtype());
628 return unaryOp(UnaryOpType::Not, tv);
629}
630
631// The output of abs(complex_tensor) are real numbers
632Val* abs(Val* v) {
633 if (v->getDataType() == DataType::ComplexDouble) {
634 Val* out = newValLike(v, DataType::Double);
635 IrBuilder::create<UnaryOp>(UnaryOpType::Abs, out, v);
636 return out;
637 }
638 if (v->getDataType() == DataType::ComplexFloat) {
639 Val* out = newValLike(v, DataType::Float);
640 IrBuilder::create<UnaryOp>(UnaryOpType::Abs, out, v);
641 return out;
642 }
643 return unaryOp(UnaryOpType::Abs, v);
644}
645
646TensorView* abs(TensorView* tv) {
647 return abs(tv->as<Val>())->as<TensorView>();
648}
649
650// The output of real(complex_tensor) are real numbers
651Val* real(Val* v) {
652 if (v->getDataType() == DataType::ComplexDouble) {
653 Val* out = newValLike(v, DataType::Double);
654 IrBuilder::create<UnaryOp>(UnaryOpType::Real, out, v);
655 return out;
656 }
657 if (v->getDataType() == DataType::ComplexFloat) {
658 Val* out = newValLike(v, DataType::Float);
659 IrBuilder::create<UnaryOp>(UnaryOpType::Real, out, v);
660 return out;
661 }
662 // We use UnaryOpType::Set instead of UnaryOpType::Real to support non-complex
663 // tensors
664 return unaryOp(UnaryOpType::Set, v);
665}
666
667TensorView* real(TensorView* tv) {
668 return real(tv->as<Val>())->as<TensorView>();
669}
670
671// The output of imag(complex_tensor) are real numbers
672Val* imag(Val* v) {
673 if (v->getDataType() == DataType::ComplexDouble) {
674 Val* out = newValLike(v, DataType::Double);
675 IrBuilder::create<UnaryOp>(UnaryOpType::Imag, out, v);
676 return out;
677 }
678 if (v->getDataType() == DataType::ComplexFloat) {
679 Val* out = newValLike(v, DataType::Float);
680 IrBuilder::create<UnaryOp>(UnaryOpType::Imag, out, v);
681 return out;
682 }
683 TORCH_CHECK(false, "imag not supported for non-complex tensors");
684}
685
686TensorView* imag(TensorView* tv) {
687 return imag(tv->as<Val>())->as<TensorView>();
688}
689
690// UNARY FLOAT CAST OPERATIONS
691
692#define NVFUSER_DEFINE_UNARY_FLOAT_OP(op_name, op_type) \
693 Val* op_name(Val* v) { \
694 return unaryOp(UnaryOpType::op_type, v, TypePromotion::float_op_config); \
695 } \
696 TensorView* op_name(TensorView* tv) { \
697 return unaryOp(UnaryOpType::op_type, tv, TypePromotion::float_op_config); \
698 }
699
700NVFUSER_DEFINE_UNARY_FLOAT_OP(acos, Acos)
701NVFUSER_DEFINE_UNARY_FLOAT_OP(asin, Asin)
702NVFUSER_DEFINE_UNARY_FLOAT_OP(atan, Atan)
703NVFUSER_DEFINE_UNARY_FLOAT_OP(atanh, Atanh)
704NVFUSER_DEFINE_UNARY_FLOAT_OP(cos, Cos)
705NVFUSER_DEFINE_UNARY_FLOAT_OP(cosh, Cosh)
706NVFUSER_DEFINE_UNARY_FLOAT_OP(exp, Exp)
707NVFUSER_DEFINE_UNARY_FLOAT_OP(expm1, Expm1)
708NVFUSER_DEFINE_UNARY_FLOAT_OP(erf, Erf)
709NVFUSER_DEFINE_UNARY_FLOAT_OP(erfc, Erfc)
710NVFUSER_DEFINE_UNARY_FLOAT_OP(lgamma, Lgamma)
711NVFUSER_DEFINE_UNARY_FLOAT_OP(log, Log)
712NVFUSER_DEFINE_UNARY_FLOAT_OP(log10, Log10)
713NVFUSER_DEFINE_UNARY_FLOAT_OP(log1p, Log1p)
714NVFUSER_DEFINE_UNARY_FLOAT_OP(log2, Log2)
715NVFUSER_DEFINE_UNARY_FLOAT_OP(reciprocal, Reciprocal)
716NVFUSER_DEFINE_UNARY_FLOAT_OP(rsqrt, Rsqrt)
717NVFUSER_DEFINE_UNARY_FLOAT_OP(sigmoid, Sigmoid)
718NVFUSER_DEFINE_UNARY_FLOAT_OP(sin, Sin)
719NVFUSER_DEFINE_UNARY_FLOAT_OP(sinh, Sinh)
720NVFUSER_DEFINE_UNARY_FLOAT_OP(sqrt, Sqrt)
721NVFUSER_DEFINE_UNARY_FLOAT_OP(tan, Tan)
722NVFUSER_DEFINE_UNARY_FLOAT_OP(tanh, Tanh)
723#undef NVFUSER_DEFINE_UNARY_FLOAT_OP
724
725#define NVFUSER_DEFINE_UNARY_IS_OP(op_name, op_type) \
726 Val* op_name(Val* v) { \
727 return unaryIsOp(UnaryOpType::op_type, v); \
728 } \
729 TensorView* op_name(TensorView* tv) { \
730 return unaryIsOp(UnaryOpType::op_type, tv); \
731 }
732
733NVFUSER_DEFINE_UNARY_IS_OP(isfinite, IsFinite)
734NVFUSER_DEFINE_UNARY_IS_OP(isinf, IsInf)
735NVFUSER_DEFINE_UNARY_IS_OP(isnan, IsNan)
736NVFUSER_DEFINE_UNARY_IS_OP(isneginf, IsNegInf)
737NVFUSER_DEFINE_UNARY_IS_OP(isposinf, IsPosInf)
738NVFUSER_DEFINE_UNARY_IS_OP(isreal, IsReal)
739#undef NVFUSER_DEFINE_UNARY_IS_OP
740
741// BINARY OPERATIONS
742
743namespace {
744// Helper function to reduce repetitive code
745template <typename T1, typename T2>
746TensorView* arithOpOverloads(Val* (*func)(Val*, Val*), T1* v1, T2* v2) {
747 Val* out = func(v1->template as<Val>(), v2->template as<Val>());
748 TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
749 return out->as<TensorView>();
750}
751
752template <typename T1, typename T2>
753TensorView* arithOpOverloads(
754 BinaryOpType type,
755 T1* v1,
756 T2* v2,
757 DataType common_dtype) {
758 Val* out = binaryOp(
759 type, v1->template as<Val>(), v2->template as<Val>(), common_dtype);
760 TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
761 return out->as<TensorView>();
762}
763
764template <typename T1, typename T2, typename T3>
765TensorView* arithOpOverloads(
766 Val* (*func)(Val*, Val*, Val*),
767 T1* v1,
768 T2* v2,
769 T3* v3) {
770 auto vals = maybeBroadcast({v1, v2, v3});
771 Val* out = func(
772 vals[0]->template as<Val>(),
773 vals[1]->template as<Val>(),
774 vals[2]->template as<Val>());
775 TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
776 return out->as<TensorView>();
777}
778
779template <typename T1, typename T2, typename T3, typename T4>
780TensorView* arithOpOverloads(
781 Val* (*func)(Val*, Val*, Val*, Val*),
782 T1* v1,
783 T2* v2,
784 T3* v3,
785 T4* v4) {
786 auto vals = maybeBroadcast({v1, v2, v3, v4});
787 Val* out = func(
788 vals[0]->template as<Val>(),
789 vals[1]->template as<Val>(),
790 vals[2]->template as<Val>(),
791 vals[3]->template as<Val>());
792 TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
793 return out->as<TensorView>();
794}
795
796// Output type promotion logic for binary operators
797DataType getOutputType(
798 BinaryOpType op_type,
799 Val* v1,
800 Val* v2,
801 DataType common_dtype) {
802 if (isLogicalOp(op_type)) {
803 return DataType::Bool;
804 } else if (common_dtype == DataType::Null) {
805 return promote_type(v1->getDataType().value(), v2->getDataType().value());
806 } else {
807 return common_dtype;
808 }
809}
810
811} // namespace
812
813Val* binaryOp(BinaryOpType type, Val* v1, Val* v2, DataType common_dtype) {
814 const auto out_dtype = getOutputType(type, v1, v2, common_dtype);
815 const auto out_vtype =
816 promote_type(v1->getValType().value(), v2->getValType().value());
817 auto vals = maybeBroadcast({v1, v2});
818 Val* out = nullptr;
819 if (out_vtype == ValType::TensorView) {
820 out = newOutputTV(vals, out_dtype);
821 } else {
822 out = newScalar(out_vtype, out_dtype);
823 }
824 IrBuilder::create<BinaryOp>(type, out, vals[0], vals[1]);
825 return out;
826}
827
828TensorView* binaryOp(
829 BinaryOpType type,
830 TensorView* v1,
831 Val* v2,
832 DataType common_dtype) {
833 return arithOpOverloads(type, v1, v2, common_dtype);
834}
835
836TensorView* binaryOp(
837 BinaryOpType type,
838 Val* v1,
839 TensorView* v2,
840 DataType common_dtype) {
841 return arithOpOverloads(type, v1, v2, common_dtype);
842}
843
844TensorView* binaryOp(
845 BinaryOpType type,
846 TensorView* v1,
847 TensorView* v2,
848 DataType common_dtype) {
849 return arithOpOverloads(type, v1, v2, common_dtype);
850}
851
852Val* binaryOp(
853 BinaryOpType type,
854 Val* v1,
855 Val* v2,
856 const TypePromotionConfig& config) {
857 std::vector<Val*> operands = {v1, v2};
858 auto common_dtype = computeTypes(config, operands);
859 auto cast_values = promoteValues(operands, common_dtype);
860 return binaryOp(type, cast_values.front(), cast_values.back(), common_dtype);
861}
862
863TensorView* binaryOp(
864 BinaryOpType type,
865 TensorView* v1,
866 Val* v2,
867 const TypePromotionConfig& config) {
868 std::vector<Val*> operands = {v1, v2};
869 auto common_dtype = computeTypes(config, operands);
870 auto cast_values = promoteValues(operands, common_dtype);
871 return binaryOp(
872 type,
873 cast_values.front()->as<TensorView>(),
874 cast_values.back(),
875 common_dtype);
876}
877
878TensorView* binaryOp(
879 BinaryOpType type,
880 Val* v1,
881 TensorView* v2,
882 const TypePromotionConfig& config) {
883 std::vector<Val*> operands = {v1, v2};
884 auto common_dtype = computeTypes(config, operands);
885 auto cast_values = promoteValues(operands, common_dtype);
886 return binaryOp(
887 type,
888 cast_values.front(),
889 cast_values.back()->as<TensorView>(),
890 common_dtype);
891}
892
893TensorView* binaryOp(
894 BinaryOpType type,
895 TensorView* v1,
896 TensorView* v2,
897 const TypePromotionConfig& config) {
898 std::vector<Val*> operands = {v1, v2};
899 auto common_dtype = computeTypes(config, operands);
900 auto cast_values = promoteValues(operands, common_dtype);
901 return binaryOp(
902 type,
903 cast_values.front()->as<TensorView>(),
904 cast_values.back()->as<TensorView>(),
905 common_dtype);
906}
907
908#define NVFUSER_DEFINE_BINARY_FLOAT_OP(op_name, op_type) \
909 Val* op_name(Val* v1, Val* v2) { \
910 return binaryOp( \
911 BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \
912 } \
913 TensorView* op_name(TensorView* v1, Val* v2) { \
914 return binaryOp( \
915 BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \
916 } \
917 TensorView* op_name(Val* v1, TensorView* v2) { \
918 return binaryOp( \
919 BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \
920 } \
921 TensorView* op_name(TensorView* v1, TensorView* v2) { \
922 return binaryOp( \
923 BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \
924 }
925
926NVFUSER_DEFINE_BINARY_FLOAT_OP(div, Div)
927NVFUSER_DEFINE_BINARY_FLOAT_OP(atan2, Atan2)
928#undef NVFUSER_DEFINE_BINARY_FLOAT_OP
929
930#define NVFUSER_DEFINE_BINARY_CAST_OP(op_name, op_type) \
931 Val* op_name(Val* v1, Val* v2) { \
932 return binaryOp( \
933 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
934 } \
935 TensorView* op_name(TensorView* v1, Val* v2) { \
936 return binaryOp( \
937 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
938 } \
939 TensorView* op_name(Val* v1, TensorView* v2) { \
940 return binaryOp( \
941 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
942 } \
943 TensorView* op_name(TensorView* v1, TensorView* v2) { \
944 return binaryOp( \
945 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
946 }
947
948// Integer binary ops
949NVFUSER_DEFINE_BINARY_CAST_OP(mod, Mod)
950NVFUSER_DEFINE_BINARY_CAST_OP(ceilDiv, CeilDiv)
951NVFUSER_DEFINE_BINARY_CAST_OP(add, Add)
952NVFUSER_DEFINE_BINARY_CAST_OP(fmod, Fmod)
953NVFUSER_DEFINE_BINARY_CAST_OP(mul, Mul)
954NVFUSER_DEFINE_BINARY_CAST_OP(pow, Pow)
955NVFUSER_DEFINE_BINARY_CAST_OP(remainder, Remainder)
956NVFUSER_DEFINE_BINARY_CAST_OP(sub, Sub)
957#undef NVFUSER_DEFINE_BINARY_CAST_OP
958
959#define NVFUSER_DEFINE_BITWISE_OP(op_name, op_type) \
960 Val* op_name(Val* v1, Val* v2) { \
961 TORCH_CHECK( \
962 (isIntegralType(v1->dtype()) || v1->dtype() == DataType::Bool) && \
963 (isIntegralType(v2->dtype()) || v2->dtype() == DataType::Bool), \
964 "input must have integral or boolean type, but got ", \
965 v1->dtype(), \
966 " and ", \
967 v2->dtype()); \
968 return binaryOp( \
969 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
970 } \
971 TensorView* op_name(TensorView* v1, Val* v2) { \
972 TORCH_CHECK( \
973 (isIntegralType(v1->dtype()) || v1->dtype() == DataType::Bool) && \
974 (isIntegralType(v2->dtype()) || v2->dtype() == DataType::Bool), \
975 "input must have integral or boolean type, but got ", \
976 v1->dtype(), \
977 " and ", \
978 v2->dtype()); \
979 return binaryOp( \
980 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
981 } \
982 TensorView* op_name(Val* v1, TensorView* v2) { \
983 TORCH_CHECK( \
984 (isIntegralType(v1->dtype()) || v1->dtype() == DataType::Bool) && \
985 (isIntegralType(v2->dtype()) || v2->dtype() == DataType::Bool), \
986 "input must have integral or boolean type, but got ", \
987 v1->dtype(), \
988 " and ", \
989 v2->dtype()); \
990 return binaryOp( \
991 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
992 } \
993 TensorView* op_name(TensorView* v1, TensorView* v2) { \
994 TORCH_CHECK( \
995 (isIntegralType(v1->dtype()) || v1->dtype() == DataType::Bool) && \
996 (isIntegralType(v2->dtype()) || v2->dtype() == DataType::Bool), \
997 "input must have integral or boolean type, but got ", \
998 v1->dtype(), \
999 " and ", \
1000 v2->dtype()); \
1001 return binaryOp( \
1002 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
1003 }
1004
1005NVFUSER_DEFINE_BITWISE_OP(bitwise_and, And)
1006NVFUSER_DEFINE_BITWISE_OP(bitwise_or, Or)
1007NVFUSER_DEFINE_BITWISE_OP(bitwise_xor, Xor)
1008#undef NVFUSER_DEFINE_BITWISE_OP
1009
1010#define NVFUSER_DEFINE_BITWISE_SHIFT_OP(op_name, op_type) \
1011 Val* op_name(Val* v1, Val* v2) { \
1012 TORCH_CHECK( \
1013 isIntegralType(v1->dtype()) && isIntegralType(v2->dtype()), \
1014 "input must have integral type, but got ", \
1015 v1->dtype(), \
1016 " and ", \
1017 v2->dtype()); \
1018 return binaryOp( \
1019 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
1020 } \
1021 TensorView* op_name(TensorView* v1, Val* v2) { \
1022 TORCH_CHECK( \
1023 isIntegralType(v1->dtype()) && isIntegralType(v2->dtype()), \
1024 "input must have integral type, but got ", \
1025 v1->dtype(), \
1026 " and ", \
1027 v2->dtype()); \
1028 return binaryOp( \
1029 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
1030 } \
1031 TensorView* op_name(Val* v1, TensorView* v2) { \
1032 TORCH_CHECK( \
1033 isIntegralType(v2->dtype()) && isIntegralType(v2->dtype()), \
1034 "input must have integral type, but got ", \
1035 v1->dtype(), \
1036 " and ", \
1037 v2->dtype()); \
1038 return binaryOp( \
1039 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
1040 } \
1041 TensorView* op_name(TensorView* v1, TensorView* v2) { \
1042 TORCH_CHECK( \
1043 isIntegralType(v1->dtype()) && isIntegralType(v2->dtype()), \
1044 "input must have integral type, but got ", \
1045 v1->dtype(), \
1046 " and ", \
1047 v2->dtype()); \
1048 return binaryOp( \
1049 BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \
1050 }
1051
1052NVFUSER_DEFINE_BITWISE_SHIFT_OP(bitwise_left_shift, Lshift)
1053NVFUSER_DEFINE_BITWISE_SHIFT_OP(bitwise_right_shift, Rshift)
1054#undef NVFUSER_DEFINE_BITWISE_SHIFT_OP
1055
1056#define NVFUSER_DEFINE_BINARY_COMPARE_OP(op_name, op_type) \
1057 Val* op_name(Val* v1, Val* v2) { \
1058 return binaryOp( \
1059 BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \
1060 } \
1061 TensorView* op_name(TensorView* v1, Val* v2) { \
1062 return binaryOp( \
1063 BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \
1064 } \
1065 TensorView* op_name(Val* v1, TensorView* v2) { \
1066 return binaryOp( \
1067 BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \
1068 } \
1069 TensorView* op_name(TensorView* v1, TensorView* v2) { \
1070 return binaryOp( \
1071 BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \
1072 }
1073
1074// Logical binary ops
1075NVFUSER_DEFINE_BINARY_COMPARE_OP(eq, Eq)
1076NVFUSER_DEFINE_BINARY_COMPARE_OP(ge, GE)
1077NVFUSER_DEFINE_BINARY_COMPARE_OP(gt, GT)
1078NVFUSER_DEFINE_BINARY_COMPARE_OP(le, LE)
1079NVFUSER_DEFINE_BINARY_COMPARE_OP(lt, LT)
1080NVFUSER_DEFINE_BINARY_COMPARE_OP(ne, NE)
1081#undef NVFUSER_DEFINE_BINARY_COMPARE_OP
1082
1083// REDUCTION OPERATIONS
1084
1085// TODO: How do we adjust this so we can reduce to a single scalar value?
1086static TensorView* newForReduction(
1087 TensorView* tv,
1088 const std::vector<unsigned int>& axes,
1089 DataType data_type = DataType::Null) {
1090 auto orig_domain = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
1091 std::set<unsigned int> axes_set(axes.begin(), axes.end());
1092
1093 std::vector<IterDomain*> new_domain;
1094
1095 TORCH_INTERNAL_ASSERT(
1096 !axes_set.empty(),
1097 "Asked for output of reduction, but no reduction axis provided.");
1098
1099 TORCH_INTERNAL_ASSERT(
1100 (*(axes_set.rbegin())) < orig_domain.size(),
1101 "Error setting up reduction, reduction axis (",
1102 *(axes_set.rbegin()),
1103 ") is outside nDims (",
1104 orig_domain.size(),
1105 "). Keep in mind reductions are relative to root domains, not modified views.");
1106
1107 auto axis_iter = axes_set.begin();
1108 for (const auto dim : c10::irange(orig_domain.size())) {
1109 bool isReduction = false;
1110 if (axis_iter != axes_set.end() && *axis_iter == dim) {
1111 isReduction = true;
1112 axis_iter++;
1113 }
1114
1115 const IterDomain* id = orig_domain[dim];
1116
1117 TORCH_CHECK(
1118 !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()),
1119 "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ",
1120 id,
1121 " of tensor ",
1122 tv);
1123
1124 new_domain.push_back(
1125 IterDomainBuilder(id)
1126 // If the domain is being reduced, but it's coming in as an expanded
1127 // extent, we need to realize the expand.
1128 .extent(
1129 isReduction && id->hasExpandedExtent() ? id->expandedExtent()
1130 : id->extent())
1131 .resetSchedulingParams()
1132 .iter_type(isReduction ? IterType::Reduction : id->getIterType())
1133 .build());
1134 }
1135
1136 TensorDomain* td = IrBuilder::create<TensorDomain>(
1137 new_domain, std::vector<bool>(new_domain.size(), true));
1138
1139 data_type =
1140 data_type == DataType::Null ? tv->getDataType().value() : data_type;
1141 return IrBuilder::create<TensorView>(td, data_type);
1142}
1143
1144namespace {
1145
1146// PyTorch accepts reductions of zero-dimensional tensors, which are
1147// just ignored.
1148TensorView* reductionOpZeroDimTensor(TensorView* inp) {
1149 TORCH_INTERNAL_ASSERT(inp->domain()->noReductions().size() == 0);
1150 return set(inp);
1151}
1152
1153} // namespace
1154
1155TensorView* reductionOp(
1156 BinaryOpType reduction_op_type,
1157 const std::vector<int>& axes,
1158 Val* init,
1159 TensorView* tv,
1160 bool keep_dim /*=false*/,
1161 DataType dtype /* DataType::Null */) {
1162 TORCH_CHECK(
1163 init->isConstScalar(),
1164 "Cannot create a reduction operation where the initial value is not a const scalar.");
1165
1166 TORCH_CHECK(
1167 TensorDomain::sameAs(tv->getMaybeRFactorDomain(), tv->domain()->domain()),
1168 "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");
1169
1170 TORCH_CHECK(axes.size() > 0, "No reduction axis specified");
1171
1172 // PyTorch allows reduction of 0-dim tensors
1173 if (tv->domain()->noReductions().size() == 0) {
1174 return reductionOpZeroDimTensor(tv);
1175 }
1176
1177 std::vector<unsigned int> uint_axes;
1178 const int ndims = tv->domain()->noReductions().size();
1179 for (int axis : axes) {
1180 if (axis < 0) {
1181 axis += ndims;
1182 }
1183
1184 TORCH_CHECK(
1185 axis >= 0 && axis < ndims,
1186 "Reduction on invalid axis, received: ",
1187 axis,
1188 " however tensor view only has ",
1189 ndims,
1190 " non-reduction dims.");
1191
1192 uint_axes.push_back((unsigned int)axis);
1193 }
1194
1195 TensorView* out = newForReduction(tv, uint_axes, dtype);
1196 const auto out_type = out->getDataType().value();
1197 const auto init_type = init->getDataType().value();
1198 TORCH_CHECK(
1199 (isFloatingPointType(out_type) && isFloatingPointType(init_type)) ||
1200 (isComplexType(out_type) && isComplexType(init_type)) ||
1201 (isIntegralType(out_type) && isIntegralType(init_type)) ||
1202 (isBooleanType(out_type) && isBooleanType(init_type)),
1203 "Types should match for reduction ops but received: ",
1204 out_type,
1205 " and ",
1206 init_type);
1207 IrBuilder::create<ReductionOp>(reduction_op_type, init, out, tv);
1208
1209 if (keep_dim) {
1210 auto tv_root = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
1211 std::vector<bool> is_broadcast(tv_root.size(), false);
1212 for (auto axis : uint_axes) {
1213 is_broadcast.at(axis) = true;
1214 }
1215 out = broadcast(out, is_broadcast);
1216 }
1217 return out;
1218}
1219
1220TensorView* sum(
1221 TensorView* v1,
1222 const std::vector<int>& axes,
1223 bool keep_dim /*=false*/,
1224 DataType dtype /* DataType::Null */) {
1225 if (dtype == DataType::Null) {
1226 auto initial_v1_dtype = v1->getDataType().value();
1227 if (isBooleanType(initial_v1_dtype) || isIntegralType(initial_v1_dtype)) {
1228 dtype = DataType::Int;
1229 }
1230 }
1231
1232 // Cast input tensor to dtype before the operation is performed
1233 if (dtype != DataType::Null) {
1234 v1 = optionalCastStrict(dtype, v1)->as<TensorView>();
1235 }
1236
1237 Val* init = nullptr;
1238 auto v1_dtype = v1->getDataType().value();
1239 if (isFloatingPointType(v1_dtype)) {
1240 init = IrBuilder::create<Double>(0.0);
1241 } else if (isComplexType(v1_dtype)) {
1242 init = IrBuilder::create<ComplexDouble>(c10::complex<double>(0.0, 0.0));
1243 } else if (isIntegralType(v1_dtype)) {
1244 init = FusionGuard::getCurFusion()->zeroVal();
1245 } else if (isBooleanType(v1_dtype)) {
1246 init = IrBuilder::create<Bool>(false);
1247 } else {
1248 TORCH_CHECK(
1249 false, "Could not generate a sum op for tensor with type: ", v1_dtype);
1250 }
1251
1252 return reductionOp(BinaryOpType::Add, axes, init, v1, keep_dim, dtype);
1253}
1254
1255TensorView* max(
1256 TensorView* v1,
1257 const std::vector<int>& axes,
1258 bool keep_dim /*=false*/,
1259 DataType dtype /* DataType::Null */) {
1260 TORCH_CHECK(
1261 dtype == DataType::Null,
1262 "A dtype other than Null is not currently supported.");
1263 Val* init = getMinimumValue(v1->getDataType().value());
1264 TORCH_CHECK(init != nullptr, "Missing initial value");
1265 return reductionOp(BinaryOpType::Max, axes, init, v1, keep_dim);
1266}
1267
1268TensorView* min(
1269 TensorView* v1,
1270 const std::vector<int>& axes,
1271 bool keep_dim /*=false*/,
1272 DataType dtype /* DataType::Null */) {
1273 TORCH_CHECK(
1274 dtype == DataType::Null,
1275 "A dtype other than Null is not currently supported.");
1276 Val* init = getMaximumValue(v1->getDataType().value());
1277 TORCH_CHECK(init != nullptr, "Missing initial value");
1278 return reductionOp(BinaryOpType::Min, axes, init, v1, keep_dim);
1279}
1280
1281TensorView* broadcast(
1282 TensorView* inp,
1283 const std::vector<bool>& is_broadcast_dim) {
1284 auto nBCastDims = is_broadcast_dim.size();
1285 // Validate is_broadcast_dim
1286 unsigned int n_broadcasts = 0;
1287 for (auto ent : is_broadcast_dim) {
1288 if (ent) {
1289 n_broadcasts++;
1290 }
1291 }
1292
1293 TORCH_CHECK(
1294 nBCastDims - n_broadcasts ==
1295 TensorDomain::noReductions(inp->getMaybeRFactorDomain()).size(),
1296 "Invalid broadcast, number of false entries in is_broadcast_dim expected to be ",
1297 TensorDomain::noReductions(inp->getMaybeRFactorDomain()).size(),
1298 " but received ",
1299 nBCastDims - n_broadcasts);
1300
1301 if (n_broadcasts == 0) {
1302 auto identity = set(inp);
1303 TORCH_INTERNAL_ASSERT(
1304 identity->getValType().value() == ValType::TensorView,
1305 "Expected identity op, but didn't get a TensorView back.");
1306 return identity->as<TensorView>();
1307 }
1308
1309 std::vector<IterDomain*> out_domain;
1310 // Don't propagate reduction IDs through arith ops.
1311 auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
1312 size_t iinp = 0, ibdim = 0;
1313 while (ibdim < is_broadcast_dim.size()) {
1314 if (is_broadcast_dim[ibdim]) {
1315 out_domain.push_back(IterDomainBuilder(
1316 FusionGuard::getCurFusion()->zeroVal(),
1317 FusionGuard::getCurFusion()->oneVal())
1318 .iter_type(IterType::Broadcast)
1319 .build());
1320 } else {
1321 out_domain.push_back(
1322 IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build());
1323 iinp++;
1324 }
1325 ibdim++;
1326 }
1327
1328 TensorView* out_tensor = IrBuilder::create<TensorView>(
1329 IrBuilder::create<TensorDomain>(
1330 out_domain, std::vector<bool>(out_domain.size(), true)),
1331 inp->getDataType().value());
1332 IrBuilder::create<BroadcastOp>(out_tensor, inp, is_broadcast_dim);
1333 return out_tensor;
1334}
1335
1336TensorView* expand(TensorView* inp, const std::vector<Val*>& expanded_sizes) {
1337 auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
1338
1339 TORCH_CHECK(
1340 expanded_sizes.size() >= inp_domain.size(),
1341 "Invalid expand, number of sizes provided is expected to be at least ",
1342 inp_domain.size(),
1343 " but received ",
1344 expanded_sizes.size());
1345
1346 inp = maybe_broadcast_inner_to_rank(inp, expanded_sizes.size());
1347 inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
1348
1349 std::vector<Val*> maybe_expanded_sizes;
1350 maybe_expanded_sizes.resize(inp_domain.size(), nullptr);
1351
1352 // Did a dimension actually get expanded
1353 bool expanded = false;
1354
1355 std::vector<IterDomain*> out_domain;
1356 for (auto i : c10::irange(inp_domain.size())) {
1357 auto inp_id = inp_domain[i];
1358 auto out_id_builder = IterDomainBuilder(inp_id);
1359 maybe_expanded_sizes[i] = inp_domain[i]->extent();
1360
1361 auto expanded_size_int = expanded_sizes[i]->getInt();
1362
1363 // If the expanded size is -1, let the input extent be propagated
1364 // as is
1365 if (expanded_size_int == -1) {
1366 // This is just done for clarity. It isn't necessary as it's
1367 // already done when constructing out_id_builder.
1368 out_id_builder.extent(inp_id->extent());
1369 } else if (inp_id->isBroadcast() && expanded_size_int != 1) {
1370 // When input id is a broadcast, expand the extent to the given
1371 // size, which can be concrete or symbolic.
1372 expanded = true;
1373 out_id_builder.expanded_extent(expanded_sizes[i]);
1374 maybe_expanded_sizes[i] = expanded_sizes[i];
1375 } else if (!inp_id->extent()->isConstInt()) {
1376 // Input id is non-broadcast and its extent is symbolic. Promote
1377 // the extent to the given expanded size.
1378 // Note that expansion to 1 just means its extent becomes 1 and
1379 // does not mean the ID becomes a broadcast.
1380 out_id_builder.extent(expanded_sizes[i]);
1381 } else {
1382 // Input id is non-expand and its extent is concrete. Nothing
1383 // to expand, but the input and expanded sizes should match if
1384 // the expanded size is also concrete.
1385 auto inp_id_size_int = inp_id->extent()->getInt();
1386 if (expanded_size_int.has_value()) {
1387 TORCH_CHECK(
1388 inp_id_size_int == expanded_size_int,
1389 "Invalid expand size, ",
1390 expanded_sizes[i]->toString(),
1391 ", for ",
1392 inp_id->toString());
1393 }
1394 }
1395 out_domain.push_back(out_id_builder.build());
1396 }
1397
1398 TensorView* out_tensor = IrBuilder::create<TensorView>(
1399 IrBuilder::create<TensorDomain>(
1400 out_domain, std::vector<bool>(out_domain.size(), true)),
1401 inp->getDataType().value());
1402 if (!expanded) {
1403 IrBuilder::create<UnaryOp>(UnaryOpType::Set, out_tensor, inp);
1404 } else {
1405 IrBuilder::create<ExpandOp>(out_tensor, inp, maybe_expanded_sizes);
1406 }
1407 return out_tensor;
1408}
1409
1410TensorView* expand_as(TensorView* inp, TensorView* other) {
1411 auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
1412 auto other_domain =
1413 TensorDomain::noReductions(other->getMaybeRFactorDomain());
1414
1415 TORCH_CHECK(
1416 inp_domain.size() <= other_domain.size(),
1417 "Invalid expand_as, dimensions of inp is higher than dimensions of other, expected other to be at least ",
1418 inp_domain.size(),
1419 " but received ",
1420 other_domain.size());
1421
1422 inp = maybe_broadcast_inner_to_rank(inp, other_domain.size());
1423 inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
1424
1425 std::vector<IterDomain*> out_domain;
1426 std::vector<Val*> maybe_expanded_sizes;
1427 bool expanded = false;
1428 for (auto i : c10::irange(inp_domain.size())) {
1429 auto inp_id = inp_domain[i];
1430 auto other_id = other_domain[i];
1431
1432 auto out_id_builder = IterDomainBuilder(inp_id);
1433 Val* maybe_expanded_size = inp_id->extent();
1434
1435 if (!inp_id->isBroadcast()) {
1436 TORCH_INTERNAL_ASSERT(
1437 !other_id->isBroadcast(),
1438 "Cannot expand as a tensor if other has broadcast dimensions that don't map to broadcast dimensions in the input.");
1439 if (!inp_id->isConstInt() && other_id->isConstInt()) {
1440 out_id_builder.extent(
1441 promoteSize(inp_id->extent(), other_id->extent()));
1442 }
1443 } else {
1444 if (!other_id->isBroadcast()) {
1445 expanded = true;
1446 out_id_builder.expanded_extent(other_id->extent());
1447 maybe_expanded_size = other_id->extent();
1448 } else if (other_id->isBroadcast() && other_id->hasExpandedExtent()) {
1449 expanded = true;
1450 out_id_builder.expanded_extent(other_id->expandedExtent());
1451 maybe_expanded_size = other_id->expandedExtent();
1452 }
1453 }
1454 out_domain.push_back(out_id_builder.build());
1455 maybe_expanded_sizes.push_back(maybe_expanded_size);
1456 }
1457
1458 TensorView* out_tensor = IrBuilder::create<TensorView>(
1459 IrBuilder::create<TensorDomain>(
1460 out_domain, std::vector<bool>(out_domain.size(), true)),
1461 inp->getDataType().value());
1462 if (!expanded) {
1463 IrBuilder::create<UnaryOp>(UnaryOpType::Set, out_tensor, inp);
1464 } else {
1465 IrBuilder::create<ExpandOp>(out_tensor, inp, maybe_expanded_sizes);
1466 }
1467 return out_tensor;
1468}
1469
1470WelfordResult Welford(
1471 TensorView* tv,
1472 const std::vector<int>& axes,
1473 TensorView* init_avg,
1474 TensorView* init_var,
1475 Int* init_N) {
1476 TORCH_CHECK(
1477 TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()),
1478 "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");
1479
1480 TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");
1481 TORCH_CHECK(axes.size() > 0, "No reduction axis specified");
1482
1483 if (init_N == nullptr) {
1484 init_N = FusionGuard::getCurFusion()->zeroVal();
1485 }
1486
1487 // Initial values for welford op are tensors, so their dims have to match the
1488 // output dim,
1489 // i.e. original_dims - dims_to_be_reduced
1490 Val* init_avg_val = nullptr;
1491 Val* init_var_val = nullptr;
1492 if (!init_N->isZeroInt()) {
1493 TORCH_CHECK(
1494 init_avg != nullptr && init_var != nullptr && init_N != nullptr,
1495 "welford op: all init values need to be provided");
1496 TORCH_CHECK(
1497 (axes.size() + init_avg->getRootDomain().size()) ==
1498 tv->getRootDomain().size(),
1499 "welford op: initial tensor mismatch");
1500 TORCH_CHECK(
1501 (axes.size() + init_var->getRootDomain().size()) ==
1502 tv->getRootDomain().size(),
1503 "welford op: initial tensor mismatch");
1504 init_avg_val = init_avg;
1505 init_var_val = init_var;
1506 } else {
1507 init_avg_val = IrBuilder::create<Double>(0);
1508 init_var_val = IrBuilder::create<Double>(0);
1509 }
1510
1511 // Check and collect reduction axes
1512 std::vector<unsigned int> uint_axes;
1513 const int ndims = tv->domain()->noReductions().size();
1514 for (int axis : axes) {
1515 if (axis < 0) {
1516 axis += ndims;
1517 }
1518
1519 TORCH_CHECK(
1520 axis >= 0 && axis < ndims,
1521 "Reduction on invalid axis, received: ",
1522 axis,
1523 " however tensor view only has ",
1524 ndims,
1525 " non-reduction dims.");
1526
1527 uint_axes.push_back((unsigned int)axis);
1528 }
1529
1530 // Create tensor outputs
1531 TensorView* out_avg = newForReduction(tv, uint_axes);
1532 TensorView* out_var = newForReduction(tv, uint_axes);
1533 TensorView* out_N = newForReduction(tv, uint_axes, DataType::Index);
1534
1535 IrBuilder::create<WelfordOp>(
1536 out_avg,
1537 out_var,
1538 out_N, /*out var/avg/count */
1539 tv, /*in var/avg/count */
1540 FusionGuard::getCurFusion()->zeroVal(),
1541 FusionGuard::getCurFusion()->oneVal(),
1542 init_avg_val,
1543 init_var_val,
1544 init_N); /*init var/avg/count */
1545
1546 return WelfordResult(out_avg, out_var, out_N);
1547}
1548
1549WelfordResult::WelfordResult(
1550 TensorView* in_avg,
1551 TensorView* in_var_sum,
1552 TensorView* in_n)
1553 : avg(in_avg), var_sum(in_var_sum), n(in_n) {
1554 TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(var_sum->definition()));
1555 TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
1556}
1557
1558// COMPOUND OPERATIONS
1559
1560// add_alpha
1561Val* add_alpha(Val* v1, Val* v2, Val* s) {
1562 TORCH_CHECK(
1563 s->getValType().value() == ValType::Scalar,
1564 "Alpha value should be a Scalar Valtype and not ",
1565 s->getValType().value());
1566
1567 std::vector<Val*> operands = {v1, v2};
1568 auto common_dtype = computeTypes(TypePromotion::default_op_config, operands);
1569 auto cast_values = promoteValues({v1, v2, s}, common_dtype);
1570 auto vals = maybeBroadcast(cast_values);
1571 Val* intrm = mul(vals[1], vals[2]);
1572 return add(vals[0], intrm);
1573}
1574TensorView* add_alpha(TensorView* v1, Val* v2, Val* v3) {
1575 return arithOpOverloads(add_alpha, v1, v2, v3);
1576}
1577TensorView* add_alpha(Val* v1, TensorView* v2, Val* v3) {
1578 return arithOpOverloads(add_alpha, v1, v2, v3);
1579}
1580TensorView* add_alpha(TensorView* v1, TensorView* v2, Val* v3) {
1581 return arithOpOverloads(add_alpha, v1, v2, v3);
1582}
1583// sub_alpha
1584Val* sub_alpha(Val* v1, Val* v2, Val* s) {
1585 TORCH_CHECK(
1586 s->getValType().value() == ValType::Scalar,
1587 "Alpha value should be a Scalar Valtype and not ",
1588 s->getValType().value());
1589
1590 std::vector<Val*> operands = {v1, v2};
1591 auto common_dtype = computeTypes(TypePromotion::default_op_config, operands);
1592 auto cast_values = promoteValues({v1, v2, s}, common_dtype);
1593 auto vals = maybeBroadcast(cast_values);
1594 Val* intrm = mul(vals[1], vals[2]);
1595 return sub(vals[0], intrm);
1596}
1597TensorView* sub_alpha(TensorView* v1, Val* v2, Val* v3) {
1598 return arithOpOverloads(sub_alpha, v1, v2, v3);
1599}
1600TensorView* sub_alpha(Val* v1, TensorView* v2, Val* v3) {
1601 return arithOpOverloads(sub_alpha, v1, v2, v3);
1602}
1603TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* v3) {
1604 return arithOpOverloads(sub_alpha, v1, v2, v3);
1605}
1606// lerp
1607Val* lerp(Val* start, Val* end, Val* weight) {
1608 auto cast_values =
1609 promoteValues(TypePromotion::default_op_config, {start, end, weight});
1610 start = cast_values[0];
1611 end = cast_values[1];
1612 weight = cast_values[2];
1613
1614 auto out_dtype =
1615 promote_type(start->getDataType().value(), end->getDataType().value());
1616 auto out_vtype =
1617 promote_type(start->getValType().value(), end->getValType().value());
1618
1619 auto vals = maybeBroadcast({start, end, weight});
1620 Val* out = nullptr;
1621 if (out_vtype == ValType::TensorView) {
1622 out = newOutputTV(vals, out_dtype);
1623 } else {
1624 out = newScalar(out_vtype, out_dtype);
1625 }
1626
1627 IrBuilder::create<TernaryOp>(
1628 TernaryOpType::Lerp, out, vals[0], vals[1], vals[2]);
1629 return out;
1630}
1631TensorView* lerp(TensorView* v1, Val* v2, Val* v3) {
1632 return arithOpOverloads(lerp, v1, v2, v3);
1633}
1634TensorView* lerp(Val* v1, TensorView* v2, Val* v3) {
1635 return arithOpOverloads(lerp, v1, v2, v3);
1636}
1637TensorView* lerp(Val* v1, Val* v2, TensorView* v3) {
1638 return arithOpOverloads(lerp, v1, v2, v3);
1639}
1640TensorView* lerp(TensorView* v1, TensorView* v2, Val* v3) {
1641 return arithOpOverloads(lerp, v1, v2, v3);
1642}
1643TensorView* lerp(TensorView* v1, Val* v2, TensorView* v3) {
1644 return arithOpOverloads(lerp, v1, v2, v3);
1645}
1646TensorView* lerp(Val* v1, TensorView* v2, TensorView* v3) {
1647 return arithOpOverloads(lerp, v1, v2, v3);
1648}
1649TensorView* lerp(TensorView* v1, TensorView* v2, TensorView* v3) {
1650 return arithOpOverloads(lerp, v1, v2, v3);
1651}
1652// addcmul
1653Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s) {
1654 TORCH_CHECK(
1655 s->getValType().value() == ValType::Scalar,
1656 "Alpha value should be a Scalar Valtype and not ",
1657 s->getValType().value());
1658
1659 std::vector<Val*> operands = {v1, v2, v3};
1660 auto common_dtype = computeTypes(TypePromotion::default_op_config, operands);
1661 auto cast_values = promoteValues({v1, v2, v3, s}, common_dtype);
1662 auto vals = maybeBroadcast(cast_values);
1663 Val* intrm1 = mul(vals[2], vals[3]);
1664 Val* intrm2 = mul(vals[1], intrm1);
1665 return add(vals[0], intrm2);
1666}
1667TensorView* addcmul(TensorView* v1, Val* v2, Val* v3, Val* v4) {
1668 return arithOpOverloads(addcmul, v1, v2, v3, v4);
1669}
1670TensorView* addcmul(Val* v1, TensorView* v2, Val* v3, Val* v4) {
1671 return arithOpOverloads(addcmul, v1, v2, v3, v4);
1672}
1673TensorView* addcmul(Val* v1, Val* v2, TensorView* v3, Val* v4) {
1674 return arithOpOverloads(addcmul, v1, v2, v3, v4);
1675}
1676TensorView* addcmul(TensorView* v1, TensorView* v2, Val* v3, Val* v4) {
1677 return arithOpOverloads(addcmul, v1, v2, v3, v4);
1678}
1679TensorView* addcmul(TensorView* v1, Val* v2, TensorView* v3, Val* v4) {
1680 return arithOpOverloads(addcmul, v1, v2, v3, v4);
1681}
1682TensorView* addcmul(Val* v1, TensorView* v2, TensorView* v3, Val* v4) {
1683 return arithOpOverloads(addcmul, v1, v2, v3, v4);
1684}
1685TensorView* addcmul(TensorView* v1, TensorView* v2, TensorView* v3, Val* v4) {
1686 return arithOpOverloads(addcmul, v1, v2, v3, v4);
1687}
1688
1689// TERNARY OPERATIONS
1690// where (c ? v1 : v2)
1691Val* where(Val* c, Val* v1, Val* v2) {
1692 TORCH_CHECK(
1693 c->getDataType().value() == DataType::Bool,
1694 "Condition should be of DataType Bool, not ",
1695 c->getDataType().value());
1696
1697 std::vector<Val*> operands = {v1, v2};
1698 auto common_dtype = computeTypes(TypePromotion::default_op_config, operands);
1699 auto cast_values = promoteValues(operands, common_dtype);
1700 v1 = cast_values[0];
1701 v2 = cast_values[1];
1702
1703 TORCH_CHECK(c->getDataType().value() == DataType::Bool);
1704 auto out_dtype = common_dtype;
1705 auto out_vtype =
1706 promote_type(v1->getValType().value(), v2->getValType().value());
1707 // Even when v1 and v2 are scalar, the output is a tensor if the
1708 // conditional input is a tensor.
1709 if (c->getValType() == ValType::TensorView) {
1710 out_vtype = ValType::TensorView;
1711 }
1712 auto vals = maybeBroadcast({c, v1, v2});
1713 Val* out = nullptr;
1714 if (out_vtype == ValType::TensorView) {
1715 out = newOutputTV(vals, out_dtype);
1716 } else {
1717 out = newScalar(out_vtype, out_dtype);
1718 }
1719 IrBuilder::create<TernaryOp>(
1720 TernaryOpType::Where, out, vals[0], vals[1], vals[2]);
1721 return out;
1722}
1723
1724TensorView* where(TensorView* v1, Val* v2, Val* v3) {
1725 return arithOpOverloads(where, v1, v2, v3);
1726}
1727TensorView* where(Val* v1, TensorView* v2, Val* v3) {
1728 return arithOpOverloads(where, v1, v2, v3);
1729}
1730TensorView* where(Val* v1, Val* v2, TensorView* v3) {
1731 return arithOpOverloads(where, v1, v2, v3);
1732}
1733TensorView* where(TensorView* v1, TensorView* v2, Val* v3) {
1734 return arithOpOverloads(where, v1, v2, v3);
1735}
1736TensorView* where(TensorView* v1, Val* v2, TensorView* v3) {
1737 return arithOpOverloads(where, v1, v2, v3);
1738}
1739TensorView* where(Val* v1, TensorView* v2, TensorView* v3) {
1740 return arithOpOverloads(where, v1, v2, v3);
1741}
1742TensorView* where(TensorView* v1, TensorView* v2, TensorView* v3) {
1743 return arithOpOverloads(where, v1, v2, v3);
1744}
1745
1746// TERNARY OPERATIONS
1747
1748Val* threshold(Val* in, Val* thresh, Val* value) {
1749 TORCH_CHECK(
1750 (thresh->getValType().value() == ValType::Scalar ||
1751 thresh->getValType().value() == ValType::NamedScalar) &&
1752 (value->getValType().value() == ValType::Scalar ||
1753 value->getValType().value() == ValType::NamedScalar),
1754 "For Threshold operation: Thresh and Value values should be Scalars.");
1755
1756 thresh = optionalCast(in->getDataType().value(), thresh);
1757 value = optionalCast(in->getDataType().value(), value);
1758 Val* out = newValLike(in, in->getDataType().value());
1759
1760 IrBuilder::create<TernaryOp>(
1761 TernaryOpType::Threshold, out, in, thresh, value);
1762 return out;
1763}
1764
1765TensorView* threshold(TensorView* in, Val* thresh, Val* value) {
1766 return threshold(in->as<Val>(), thresh, value)->as<TensorView>();
1767}
1768
1769Val* clamp(Val* in, Val* min_val, Val* max_val) {
1770 TORCH_CHECK(
1771 (min_val == nullptr || min_val->getValType().value() == ValType::Scalar ||
1772 min_val->getValType().value() == ValType::NamedScalar) &&
1773 (max_val == nullptr ||
1774 max_val->getValType().value() == ValType::Scalar ||
1775 max_val->getValType().value() == ValType::NamedScalar),
1776 "For Clamp operation: Min and Max values should be Scalars.");
1777
1778 min_val = (min_val == nullptr)
1779 ? getMinimumValue(in->getDataType().value())
1780 : optionalCast(in->getDataType().value(), min_val);
1781 TORCH_CHECK(min_val != nullptr, "Missing minimum value");
1782
1783 max_val = (max_val == nullptr)
1784 ? getMaximumValue(in->getDataType().value())
1785 : optionalCast(in->getDataType().value(), max_val);
1786 TORCH_CHECK(max_val != nullptr, "Missing maximum value");
1787
1788 Val* out = newValLike(in, in->getDataType().value());
1789 IrBuilder::create<TernaryOp>(TernaryOpType::Clamp, out, in, min_val, max_val);
1790 return out;
1791}
1792
1793TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) {
1794 return clamp(in->as<Val>(), min_val, max_val)->as<TensorView>();
1795}
1796
1797// sum_to operator
1798
1799TensorView* sum_to(TensorView* in, const std::vector<Int*>& sum_to_size) {
1800 const auto& root = TensorDomain::noReductions(in->getMaybeRFactorDomain());
1801
1802 TORCH_CHECK(
1803 root.size() >= sum_to_size.size(),
1804 "sum_to: Error trying to reduce",
1805 in,
1806 "into a shape of size",
1807 sum_to_size.size());
1808
1809 // If no reduction is needed sum_to returns the input tv
1810 TensorView* out = in;
1811
1812 const int64_t leading_dims = root.size() - sum_to_size.size();
1813
1814 // Generate reduction axes for leading dims
1815 std::vector<int> reduce_dims(leading_dims);
1816 std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
1817
1818 // Generate reduction axes for dims within sum_to_size
1819 std::vector<bool> inner_red_dims(sum_to_size.size(), false);
1820 bool reduction_within_shape = false;
1821
1822 // Reduce rest of the dims with keep_dim
1823 for (const auto i : c10::irange(leading_dims, root.size())) {
1824 if (sum_to_size[i - leading_dims]->isOneInt() &&
1825 !root[i]->extent()->isOneInt()) {
1826 inner_red_dims[i - leading_dims] = true;
1827 reduce_dims.push_back(i);
1828 reduction_within_shape = true;
1829 }
1830 }
1831
1832 // Reduction step
1833 if (!reduce_dims.empty()) {
1834 out = sum(in, reduce_dims);
1835 }
1836
1837 // Broadcast back reduced dims within shape
1838 if (reduction_within_shape) {
1839 out = broadcast(out, inner_red_dims);
1840 }
1841
1842 return out;
1843}
1844
1845TensorView* sum_to(TensorView* in, const std::vector<int64_t>& sum_to_size) {
1846 const auto& root = TensorDomain::noReductions(in->getMaybeRFactorDomain());
1847
1848 TORCH_CHECK(
1849 root.size() >= sum_to_size.size(),
1850 "sum_to: Error trying to reduce",
1851 in,
1852 "into a shape of size",
1853 sum_to_size.size());
1854
1855 // If no reduction is needed sum_to returns the input tv
1856 TensorView* out = in;
1857
1858 const int64_t leading_dims = root.size() - sum_to_size.size();
1859
1860 // Generate reduction axes for leading dims
1861 std::vector<int> reduce_dims(leading_dims);
1862 std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
1863
1864 // Generate reduction axes for dims within sum_to_size
1865 std::vector<bool> inner_red_dims(sum_to_size.size(), false);
1866 bool reduction_within_shape = false;
1867
1868 // Reduce rest of the dims with keep_dim
1869 for (const auto i : c10::irange(leading_dims, root.size())) {
1870 if (sum_to_size[i - leading_dims] == 1 && !root[i]->extent()->isOneInt()) {
1871 inner_red_dims[i - leading_dims] = true;
1872 reduce_dims.push_back(i);
1873 reduction_within_shape = true;
1874 }
1875 }
1876
1877 // Reduction step
1878 if (!reduce_dims.empty()) {
1879 out = sum(in, reduce_dims);
1880 }
1881
1882 // Broadcast back reduced dims within shape
1883 if (reduction_within_shape) {
1884 out = broadcast(out, inner_red_dims);
1885 }
1886
1887 return out;
1888}
1889
1890TensorView* shift(TensorView* inp, const std::vector<int>& offsets, bool pad) {
1891 // When pad is false, no padding is given. When it is true, padding
1892 // sizes are set so that output domains have the same extents as
1893 // input domains.
1894 std::vector<int> pad_width(offsets.size(), 0);
1895 if (pad) {
1896 for (const auto i : c10::irange(offsets.size())) {
1897 pad_width[i] = std::abs(offsets[i]);
1898 }
1899 }
1900 return shift(inp, offsets, pad_width);
1901}
1902
1903TensorView* shift(
1904 TensorView* inp,
1905 const std::vector<int>& offsets,
1906 const std::vector<int>& pad_width_param) {
1907 auto inp_dom = TensorDomain::noReductions(inp->getRootDomain());
1908 const auto ndims = inp_dom.size();
1909
1910 auto pad_width = pad_width_param;
1911 // Default padding is set so that the extent is kept unchanged
1912 if (pad_width.empty()) {
1913 pad_width = offsets;
1914 for (auto& p : pad_width) {
1915 p = std::abs(p);
1916 }
1917 }
1918
1919 TORCH_CHECK(
1920 ndims == offsets.size(),
1921 "Invalid shift offsets, number of entries in offsets expected to be ",
1922 ndims,
1923 " but received ",
1924 offsets.size());
1925
1926 TORCH_CHECK(
1927 ndims == pad_width.size(),
1928 "Invalid padding width list, number of entries in pad_width expected to be ",
1929 ndims,
1930 " but received ",
1931 pad_width.size());
1932
1933 std::for_each(pad_width.begin(), pad_width.end(), [](const auto& pad) {
1934 TORCH_CHECK(pad >= 0, "Padding width must be >= 0: ", pad);
1935 });
1936
1937 TensorView* out = nullptr;
1938
1939 std::vector<IterDomain*> out_dom;
1940 for (const auto i : c10::irange(ndims)) {
1941 const auto inp_axis = inp_dom[i];
1942 const auto offset = offsets[i];
1943 const auto pad = pad_width[i];
1944
1945 if (offset == 0) {
1946 out_dom.push_back(inp_axis->cloneWithoutRFactor());
1947 continue;
1948 }
1949
1950 Int* current_start_offset = dynamic_cast<Int*>(inp_axis->start());
1951 TORCH_INTERNAL_ASSERT(
1952 current_start_offset != nullptr && current_start_offset->isConst(),
1953 "Invalid IterDomain start value:",
1954 current_start_offset);
1955
1956 Int* current_stop_offset = dynamic_cast<Int*>(inp_axis->stopOffset());
1957 TORCH_INTERNAL_ASSERT(
1958 current_stop_offset != nullptr && current_stop_offset->isConst(),
1959 "Invalid IterDomain stop offset value:",
1960 current_stop_offset);
1961
1962 const auto cur_start_offset_value = current_start_offset->value().value();
1963 const auto cur_stop_offset_value = current_stop_offset->value().value();
1964
1965 int64_t out_start_offset = 0;
1966 int64_t out_stop_offset = 0;
1967
1968 if (offset > 0) {
1969 // shift to right; extent remains the same, start and stop
1970 // positions are moved right
1971 out_start_offset = cur_start_offset_value + offset - pad;
1972 out_stop_offset = std::max(cur_stop_offset_value - offset, int64_t(0));
1973 // If pad > offset, the extent of the output ID could be larger than the
1974 // input, and the start offset of the output domain could become
1975 // negative, which is not supported.
1976 TORCH_CHECK(
1977 out_start_offset >= 0,
1978 "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ",
1979 pad,
1980 ". Shift: ",
1981 offset,
1982 ".");
1983 } else {
1984 // shift to left; extent remains the same, start and stop
1985 // positions are moved left
1986 out_start_offset = std::max(cur_start_offset_value + offset, int64_t(0));
1987 out_stop_offset = cur_stop_offset_value - offset - pad;
1988 // Similar to the above case whwere offset is positive, if pad >
1989 // -offset (note offset is negative), the extent of the output
1990 // ID could be larger than the input, and the stop offset of the
1991 // output domain could become negative.
1992 TORCH_CHECK(
1993 out_stop_offset >= 0,
1994 "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ",
1995 pad,
1996 ". Shift: ",
1997 offset,
1998 ".");
1999 }
2000
2001 out_dom.push_back(
2002 IterDomainBuilder(
2003 IrBuilder::create<Int>(out_start_offset), inp_axis->extent())
2004 .stop_offset(IrBuilder::create<Int>(out_stop_offset))
2005 .iter_type(inp_axis->getIterType())
2006 .build());
2007 }
2008
2009 out = IrBuilder::create<TensorView>(
2010 IrBuilder::create<TensorDomain>(
2011 out_dom, std::vector<bool>(out_dom.size(), true)),
2012 inp->getDataType().value());
2013
2014 IrBuilder::create<ShiftOp>(out, inp, offsets, pad_width);
2015 return out;
2016}
2017
2018namespace {
2019
2020// Return a new TensorDomain with given root domains. Apply
2021// strides if necessary. With non-unit strides, strided domains become an
2022// rfactor domain.
2023TensorDomain* generateTensorDomainWithStrides(
2024 const std::vector<IterDomain*>& root_domains,
2025 const std::vector<int>& strides,
2026 bool skip_unit_stride) {
2027 std::vector<IterDomain*> strided_domains;
2028
2029 // If strides are just unit strides, don't apply striding
2030 if (strides.empty() ||
2031 (skip_unit_stride &&
2032 std::all_of(
2033 strides.begin(), strides.end(), [](int s) { return s == 1; }))) {
2034 return IrBuilder::create<TensorDomain>(
2035 root_domains, std::vector<bool>(root_domains.size(), true));
2036 }
2037
2038 for (const auto i : c10::irange(root_domains.size())) {
2039 auto root_dom = root_domains.at(i);
2040
2041 if (i >= strides.size() || (skip_unit_stride && strides[i] == 1)) {
2042 strided_domains.push_back(root_dom);
2043 continue;
2044 }
2045
2046 // Split the root domain by the stride
2047 auto split_out = root_dom->stridedSplit(strides[i]);
2048 strided_domains.push_back(split_out.first);
2049 strided_domains.push_back(split_out.second);
2050 }
2051
2052 auto contig_vector_size = strided_domains.size();
2053
2054 auto strided_td = IrBuilder::create<TensorDomain>(
2055 root_domains,
2056 strided_domains,
2057 strided_domains,
2058 std::vector<bool>(contig_vector_size, true));
2059
2060 return strided_td;
2061}
2062
2063} // namespace
2064
2065TensorView* gather(
2066 TensorView* inp,
2067 const std::vector<int>& window_shape,
2068 const std::vector<std::vector<int>>& pad_width,
2069 const std::vector<int>& strides,
2070 bool trim_out_of_bounds) {
2071 auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
2072 const auto ndims = inp_dom.size();
2073
2074 TORCH_CHECK(
2075 ndims == window_shape.size(),
2076 "Invalid window shape: number of entries expected to be ",
2077 ndims,
2078 " but received ",
2079 window_shape.size());
2080
2081 std::for_each(window_shape.begin(), window_shape.end(), [](const auto& w) {
2082 TORCH_CHECK(w > 0, "Window size must be > 0: ", w);
2083 });
2084
2085 TORCH_CHECK(
2086 ndims == pad_width.size(),
2087 "Invalid pad width: number of entries expected to be ",
2088 ndims,
2089 " but received ",
2090 pad_width.size());
2091
2092 std::for_each(pad_width.begin(), pad_width.end(), [](const auto& p) {
2093 TORCH_CHECK(
2094 p.size() == 2,
2095 "Each entry of pad_width must have two non-negative integers.");
2096 std::for_each(p.begin(), p.end(), [](const auto& p_left_or_right) {
2097 TORCH_CHECK(
2098 p_left_or_right >= 0, "Padding must be >= 0: ", p_left_or_right);
2099 });
2100 });
2101
2102 TORCH_CHECK(
2103 strides.empty() || ndims == strides.size(),
2104 "Invalid strides: number of entries expected to be ",
2105 ndims,
2106 " but received ",
2107 strides.size());
2108
2109 std::for_each(strides.begin(), strides.end(), [](const auto& s) {
2110 TORCH_CHECK(s > 0, "Stride must be > 0: ", s);
2111 });
2112
2113 std::vector<IterDomain*> out_root_domains;
2114 std::vector<IterDomain*> out_gather_dom;
2115
2116 for (const auto i : c10::irange(ndims)) {
2117 const auto inp_axis = inp_dom[i];
2118 const auto window_dim = window_shape[i];
2119 const auto pad_left = pad_width[i][0];
2120 const auto pad_right = pad_width[i][1];
2121 // This may be over-conservative
2122 TORCH_INTERNAL_ASSERT(inp_axis->start()->isZeroInt());
2123 TORCH_INTERNAL_ASSERT(
2124 inp_axis->stopOffset()->isConstInt(),
2125 "Dynamic stop offset not supported: ",
2126 inp_axis);
2127 const auto inp_stop_offset = inp_axis->stopOffset()->evaluateInt();
2128 const auto extent_adjustment = window_dim - 1 - pad_left - pad_right;
2129 TORCH_CHECK(
2130 extent_adjustment >= 0,
2131 "Invalid gather window and padding as output extent would be larger than input.",
2132 " Window: ",
2133 window_dim,
2134 ". Padding left: ",
2135 pad_left,
2136 ". Padding right: ",
2137 pad_right);
2138 const auto out_stop_offset = inp_stop_offset + extent_adjustment;
2139 out_root_domains.push_back(
2140 IterDomainBuilder(
2141 FusionGuard::getCurFusion()->zeroVal(), inp_axis->extent())
2142 .stop_offset(IrBuilder::create<Int>(out_stop_offset))
2143 .iter_type(inp_axis->getIterType())
2144 .build());
2145 // create a new axis for the gathered domain
2146 out_gather_dom.push_back(IterDomainBuilder(
2147 FusionGuard::getCurFusion()->zeroVal(),
2148 IrBuilder::create<Int>(window_dim))
2149 .iter_type(IterType::Gather)
2150 .build());
2151 }
2152
2153 out_root_domains.insert(
2154 out_root_domains.end(), out_gather_dom.begin(), out_gather_dom.end());
2155
2156 TensorDomain* out_td = nullptr;
2157
2158 if (trim_out_of_bounds) {
2159 // If no stride vector is given, just use stride 1. It does not do
2160 // any striding effect, but out-of-bounds values are trimmed.
2161 auto s = strides.empty() ? std::vector<int>(ndims, 1) : strides;
2162 out_td = generateTensorDomainWithStrides(out_root_domains, strides, false);
2163 } else {
2164 out_td = generateTensorDomainWithStrides(out_root_domains, strides, true);
2165 }
2166
2167 auto out_tv =
2168 IrBuilder::create<TensorView>(out_td, inp->getDataType().value());
2169
2170 IrBuilder::create<GatherOp>(out_tv, inp, window_shape, pad_width);
2171 return out_tv;
2172}
2173
2174TensorView* viewAsScalar(TensorView* inp) {
2175 auto inp_type = inp->getDataType().value();
2176 TORCH_CHECK(
2177 isVectorType(inp_type),
2178 "Invalid type to viewAsScalar. A vector type is expected but ",
2179 inp_type,
2180 " is given.");
2181 int vec_size = getVectorSizeFromType(inp_type);
2182 auto out_type = getTypeFromVectorType(inp_type);
2183
2184 std::vector<IterDomain*> out_domain;
2185 auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
2186 out_domain.reserve(inp_domain.size());
2187 for (auto d : inp_domain) {
2188 out_domain.push_back(d->cloneWithoutRFactor());
2189 }
2190
2191 IterDomain* id = IterDomainBuilder(
2192 inp_domain[0]->container()->zeroVal(),
2193 IrBuilder::create<Int>(vec_size))
2194 .iter_type(IterType::VectorComponent)
2195 .build();
2196 out_domain.push_back(id);
2197
2198 auto out = IrBuilder::create<TensorView>(
2199 inp->container(),
2200 IrBuilder::create<TensorDomain>(
2201 out_domain, std::vector<bool>(out_domain.size(), true)),
2202 out_type);
2203
2204 IrBuilder::create<ViewAsScalar>(inp->container(), out, inp, id);
2205
2206 return out;
2207}
2208
2209namespace {
2210
2211//! Create new output for mma
2212static TensorView* newForMma(
2213 TensorView* tv_a,
2214 TensorView* tv_b,
2215 const std::vector<unsigned int>& axes,
2216 DataType data_type = DataType::Float) {
2217 auto orig_domain_a =
2218 TensorDomain::noReductions(tv_a->getMaybeRFactorDomain());
2219 auto orig_domain_b =
2220 TensorDomain::noReductions(tv_b->getMaybeRFactorDomain());
2221
2222 TORCH_INTERNAL_ASSERT(
2223 orig_domain_a.size() == orig_domain_b.size(),
2224 "MMA op: need matching dim input");
2225
2226 std::set<unsigned int> axes_set(axes.begin(), axes.end());
2227 std::vector<IterDomain*> new_domain;
2228
2229 TORCH_INTERNAL_ASSERT(
2230 !axes_set.empty(),
2231 "Asked for output of reduction, but no reduction axis provided.");
2232
2233 TORCH_INTERNAL_ASSERT(
2234 (*(axes_set.rbegin())) < orig_domain_a.size(),
2235 "Error setting up reduction, reduction axis (",
2236 *(axes_set.rbegin()),
2237 ") is outside nDims (",
2238 orig_domain_a.size(),
2239 "). Keep in mind reductions are relative to root domains, not modified views.");
2240
2241 auto axis_iter = axes_set.begin();
2242 for (const auto dim : c10::irange(orig_domain_a.size())) {
2243 bool isReduction = false;
2244 if (axis_iter != axes_set.end() && *axis_iter == dim) {
2245 isReduction = true;
2246 axis_iter++;
2247 }
2248
2249 const IterDomain* id = orig_domain_a[dim]->isBroadcast()
2250 ? orig_domain_b[dim]
2251 : orig_domain_a[dim];
2252
2253 TORCH_CHECK(
2254 !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()),
2255 "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ",
2256 id,
2257 " of tensor ",
2258 tv_a,
2259 "and",
2260 tv_b);
2261
2262 new_domain.push_back(
2263 IterDomainBuilder(id->start(), id->extent())
2264 .stop_offset(id->stopOffset())
2265 .iter_type(isReduction ? IterType::Reduction : id->getIterType())
2266 .build());
2267 }
2268
2269 TensorDomain* td = IrBuilder::create<TensorDomain>(
2270 new_domain, std::vector<bool>(new_domain.size(), true));
2271
2272 return IrBuilder::create<TensorView>(td, data_type);
2273}
2274
2275} // namespace
2276
2277TensorView* fusedMultiplySum(
2278 TensorView* tv_a,
2279 TensorView* tv_b,
2280 const std::vector<int>& axes,
2281 Val* init) {
2282 if (init == nullptr) {
2283 init = IrBuilder::create<Double>(0);
2284 }
2285
2286 // TODO:
2287 // We will want to support initialize and rfactor with
2288 // mma as well, for maybe fusing bias in prolog.
2289 // TODO: check init type if given a tv,
2290 // not supported currently though.
2291 TORCH_CHECK(
2292 init->isConstScalar(),
2293 "Cannot create a reduction operation where the initial value is not a const scalar.");
2294
2295 // TODO:
2296 // Validate axis relationships between a and b
2297 TORCH_CHECK(tv_a->nDims() > 0, "Tried to reduce a 0-dim tensor");
2298
2299 // TODO:
2300 // Add tf32 and other mma data types
2301 // Add fallback path for non-mma data types.
2302 TORCH_CHECK(tv_a->getDataType().value() == DataType::Half);
2303 TORCH_CHECK(tv_b->getDataType().value() == DataType::Half);
2304
2305 TORCH_CHECK(axes.size() > 0, "No reduction axis specified");
2306
2307 // TODO:
2308 // will lift this in a follow up when we have a
2309 // more generic axes matching.
2310 TORCH_CHECK(
2311 axes.size() == 1, "Single axis reduction only for mma op instantiation.")
2312
2313 std::vector<unsigned int> uint_axes;
2314 const int ndims = tv_a->domain()->noReductions().size();
2315 for (int axis : axes) {
2316 if (axis < 0) {
2317 axis += ndims;
2318 }
2319
2320 TORCH_CHECK(
2321 axis >= 0 && axis < ndims,
2322 "Reduction on invalid axis, received: ",
2323 axis,
2324 " however tensor view only has ",
2325 ndims,
2326 " non-reduction dims.");
2327
2328 uint_axes.push_back((unsigned int)axis);
2329 }
2330
2331 TensorView* out = newForMma(tv_a, tv_b, uint_axes);
2332 IrBuilder::create<MmaOp>(out, tv_a, tv_b, init);
2333
2334 return out;
2335}
2336
2337} // namespace cuda
2338} // namespace fuser
2339} // namespace jit
2340} // namespace torch
2341