1#include <manager.h>
2#include <parser.h>
3#include <partition.h>
4#include <register_interface.h>
5
6#include <ATen/core/dispatch/OperatorOptions.h>
7#include <ATen/native/NonSymbolicBC.h>
8#include <ATen/native/TensorShape.h>
9#include <c10/util/CallOnce.h>
10#include <c10/util/irange.h>
11#include <torch/csrc/jit/runtime/custom_operator.h>
12#include <torch/csrc/jit/runtime/profiling_record.h>
13#include <torch/csrc/jit/runtime/register_ops_utils.h>
14
15/*
16 * Registers function pointers in interface.h
17 */
18
19namespace torch {
20namespace jit {
21namespace fuser {
22namespace cuda {
23
24namespace {
25class RegisterInterface {
26 public:
27 RegisterInterface() {
28 auto ptr = getFuserInterface();
29 ptr->fn_compile_n = &compileCudaFusionGroup;
30 ptr->fn_run_n_s = &runCudaFusionGroup;
31 ptr->fn_fuse_graph = &CudaFuseGraph;
32 ptr->fn_can_fuse_n = &isFusibleCudaFusionGroup;
33 ptr->fn_insert_profile_inodes = &InsertProfileNodes;
34 ptr->fn_profile_n = &shouldProfileNode;
35 ptr->fn_skip_n = &skipNodeKind;
36 }
37};
38
39static RegisterInterface register_interface_;
40
41class RegisterNVFuserPass {
42 public:
43 RegisterNVFuserPass() {
44 NVFuserPassManager::registerPass(true);
45 }
46};
47
48static RegisterNVFuserPass register_nvfuser_pass_;
49
50} // namespace
51
52//! [ Note -- type guard logic in CudaFusionGuard ]
53//!
54//! CudaFusionGuard is used to Guard input tensor to `CudaFusionGroup` so that
55//! we would not feed inputs that violates the graph defined in `GraphCache`.
56//!
57//! see [ Note -- 2 level cache implementation ] for definition of unique
58//! computational graph.
59//! see [ Note -- CudaFusionGuard implementation] for details on how guard works
60//! in profiling executor
61//!
62//! Type guard logic is used to query whether a runtime input `tensor` compiles
63//! with profiled `guard_tensor_type`. `guard_tensor_type` is the observed
64//! tensor type during profiling runs.
65//!
66//! At this moment, we only do single profiling run, so `guard_tensor_type` has
67//! static shape / stride / scalarType. *This might be a little confusing as our
68//! implementation is actually more relaxed.
69//!
70//! Things that we check:
71//! a. identical rank & scalar type
72//! b. stride check:
73//! b.1. identical stride order
74//! b.2. identical contiguity
75//! note that contiguity here is used for tensor collapsing. So
76//! extra attention should be paid to contiguity across size-1
77//! dimensions.
78//! c. size check:
79//! c.1 broadcast check:
80//! making sure that broadcast semantics are identical. So we want to
81//! make sure a given dimension either are both size-1 for `tensor` &
82//! `guard_tensor_type`, or are both non-size-1.
83//! This is due to the fact that we specialize size-1 dimension as
84//! broadcasted dimension while translating PyTorch tensor to Fusion IR.
85//! c.1 size-0 check:
86//! we don't specialize this on codegen, but we do specialize fusion
87//! logic for size-0 on reductoins, hence the check
88//!
89bool complyWith(
90 const at::Tensor& tensor,
91 const c10::TensorTypePtr& guard_tensor_type) {
92 // guard broadcast semantics, contiguity & stride order;
93 TORCH_INTERNAL_ASSERT(
94 guard_tensor_type && guard_tensor_type->dim().has_value());
95
96 // check a. if num_dimension check fails or scalar type check fails
97 if (*guard_tensor_type->dim() != static_cast<size_t>(tensor.ndimension()) ||
98 (guard_tensor_type->scalarType().has_value() &&
99 (guard_tensor_type->scalarType().value() != tensor.scalar_type())) ||
100 (guard_tensor_type->device().has_value() &&
101 (guard_tensor_type->device().value() != tensor.device())) ||
102 (guard_tensor_type->requiresGrad().has_value() &&
103 guard_tensor_type->requiresGrad().value() !=
104 (tensor.requires_grad() && at::GradMode::is_enabled()))) {
105 return false;
106 }
107
108 // TODO: should we get symbolic_size instead and check for size
109 // consistency across tensors as well?
110 const auto& sizes = guard_tensor_type->sizes();
111 // see [ Note -- stirde_properties in tensor type ]
112 const auto& stride_properties = guard_tensor_type->stride_properties();
113
114 const auto& t_sizes = tensor.sizes();
115 const auto& t_strides = tensor.strides();
116 int inner_dim = -1;
117 for (const auto j : c10::irange(*guard_tensor_type->dim())) {
118 // check b. for stride check, we go along dimensions from fastest stride to
119 // slowest stride
120 int sorted_index = stride_properties[j]->stride_index_
121 ? static_cast<int>(*stride_properties[j]->stride_index_)
122 : -1;
123
124 // only apply stride check when we have stride_properties
125 if (sorted_index != -1) {
126 // check b.1. stride order [current dimension has stride larger
127 // than its inner dimension(s)], check only applies when both:
128 // i. already encountered an inner dimension
129 // ii. not at the fastest dimension
130 if (j != 0 && inner_dim != -1) {
131 // we are not looking at dim-j, but dim-sorted_index, which
132 // is the j-th fastest dim;
133 // Note: we ignore 0-stride dimension, since eager logic on stride
134 // indices is ambiguous
135 if (t_strides[sorted_index] != 0 && t_strides[inner_dim] != 0 &&
136 t_strides[sorted_index] < t_strides[inner_dim]) {
137 return false;
138 }
139 }
140
141 // check b.2. contiguity, we only check when it's marked as
142 // contiguous.
143 if (stride_properties[j]->contiguous_ &&
144 *stride_properties[j]->contiguous_) {
145 if (j != 0) {
146 // we use contiguity to collapse dimension, if size == 1, it is
147 // always collapsible
148 // computeStrideProps also default to contiguous when stride == 1
149 if (t_sizes[sorted_index] != 1 && t_strides[sorted_index] != 1) {
150 TORCH_INTERNAL_ASSERT(
151 stride_properties[j - 1]->stride_index_.has_value(),
152 "Counknown index is meaningless");
153 // TODO: merge this check up
154 if (t_strides[sorted_index] !=
155 t_strides[inner_dim] * t_sizes[inner_dim]) {
156 return false;
157 }
158 }
159 } else {
160 // TODO: merge this check up
161 if (t_strides[sorted_index] != 1) {
162 return false;
163 }
164 }
165 }
166
167 // update inner_dim to be current dim. Note that we try to skip update
168 // when current `t_size[sorted_index] == 1`, because:
169 // 1. stride comparison on a size-1 dimension is meaningless
170 // [check b.1]
171 // 2. contiguity on a size-1 dimension is misleading. For collapsing,
172 // we should actually look at the next non-size-1 dimension
173 // [check b.2]
174 if (inner_dim == -1 || t_sizes[sorted_index] != 1) {
175 inner_dim = sorted_index;
176 }
177 }
178
179 // check c.1, we go along semantic ordered dimensions
180 // check broadcast / size-1:
181 bool guard_bcast = sizes[j].has_value() && sizes[j].value() == 1;
182 if (guard_bcast != (t_sizes[j] == 1)) {
183 return false;
184 }
185
186 // check c.2, check for size-0
187 bool guard_size_0 = sizes[j].has_value() && sizes[j].value() == 0;
188 if (guard_size_0 != (t_sizes[j] == 0)) {
189 return false;
190 }
191 }
192
193 return true;
194}
195
196} // namespace cuda
197} // namespace fuser
198
199namespace {
200
201// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
202RegisterOperators size_eq_guard({
203 Operator(
204 //"prim::CudaFusionSizeEq(int[] size, int[] ref) -> bool",
205 "prim::CudaFusionSizeEq(...) -> bool",
206 // prim::CudaFusionGuard returns a fresh Boolean type without aliasing.
207 // if we would ever return refined tensor, which would change aliasing
208 // analysis, we should update aliasdb pass.
209 [](const Node* node) -> Operation {
210 return [](Stack& stack) {
211 at::ArrayRef<IValue> inputs = last(stack, 2);
212 drop(stack, 2);
213
214 if (!fuser::cuda::getCudaFusionGuardMode()) {
215 push(stack, IValue(true));
216 return;
217 }
218
219 // auto inp = inputs[0].toIntList();
220 TORCH_INTERNAL_ASSERT(
221 inputs[1].isIntList(), "reference needs to be of int list");
222 auto ref = inputs[1].toIntList();
223
224 auto ret = true;
225 if (ref.empty()) {
226 ret = inputs[0].isNone();
227 } else {
228 if (inputs[0].isIntList()) {
229 auto inp = inputs[0].toIntList();
230 if (inp.size() != ref.size()) {
231 push(stack, IValue(false));
232 return;
233 }
234
235 for (const auto i : c10::irange(inp.size())) {
236 if (((inp[i] == 1) != (ref[i] == 1))) {
237 ret = false;
238 break;
239 }
240 }
241 } else {
242 ret = false;
243 }
244 }
245
246 push(stack, IValue(ret));
247 return;
248 };
249 },
250 aliasAnalysisFromSchema()),
251});
252
253// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
254RegisterOperators reg_fusion({
255 Operator(
256 prim::CudaFusionGroup,
257 [](const Node* node) -> Operation {
258 return [node](Stack& stack) {
259 fuser::cuda::runFusionGroup(node, stack);
260 };
261 },
262 aliasAnalysisSpecialCase()),
263});
264
265RegisterOperators reg_guard({
266 Operator(
267 "prim::CudaFusionGuard(...) -> bool",
268 // prim::CudaFusionGuard returns a fresh Boolean type without aliasing.
269 // if we would ever return refined tensor, which would change aliasing
270 // analysis, we should update aliasdb pass.
271 [](const Node* node) -> Operation {
272 return [node](Stack& stack) {
273 // TODO: check latency here!!!!
274 std::vector<TypePtr> types = node->tys(attr::types);
275 const auto num_inputs = types.size();
276 at::ArrayRef<IValue> inputs = last(stack, num_inputs);
277 drop(stack, num_inputs);
278
279 if (!fuser::cuda::getCudaFusionGuardMode()) {
280 push(stack, IValue(true));
281 return;
282 }
283
284 for (const auto i : c10::irange(num_inputs)) {
285 const c10::TensorTypePtr& guard_tensor_type =
286 types[i]->cast<TensorType>();
287
288 // TODO: maybe we should just push false and fallback
289 TORCH_INTERNAL_ASSERT(inputs[i].isTensor());
290 const at::Tensor& tensor = inputs[i].toTensor();
291
292 if (!fuser::cuda::complyWith(tensor, guard_tensor_type)) {
293 push(stack, IValue(false));
294 return;
295 }
296 }
297
298 // TODO: check type and return the right flag
299 // naively return true;
300 push(stack, IValue(true));
301 return;
302 };
303 },
304 aliasAnalysisFromSchema()),
305});
306
307// Infer dynamic axis (-1) in view_sizes given tensor_sizes
308bool inferViewShape(
309 c10::List<int64_t> tensor_sizes,
310 c10::List<int64_t> view_sizes) {
311 int64_t dynamic_index = -1;
312 size_t view_size_num_elements = 1;
313 for (size_t idx = 0; idx < view_sizes.size(); ++idx) {
314 if (view_sizes[idx] == -1) {
315 TORCH_INTERNAL_ASSERT(
316 dynamic_index == -1, "Only one dimension can by inferred.")
317 dynamic_index = idx;
318 } else {
319 TORCH_INTERNAL_ASSERT(view_sizes[idx] > 0);
320 view_size_num_elements *= view_sizes[idx];
321 }
322 }
323 const size_t kNumElements = std::accumulate(
324 tensor_sizes.begin(), tensor_sizes.end(), 1, std::multiplies<>());
325
326 if (kNumElements % view_size_num_elements != 0) {
327 return false;
328 }
329
330 if (dynamic_index != -1) {
331 view_sizes[dynamic_index] = kNumElements / view_size_num_elements;
332 }
333
334 return true;
335}
336
337//!
338//! CudaFusionViewGuard Example Graph:
339//!
340//! graph(%self : __torch__.BiasViewRelu,
341//! %inputs.1 : Tensor):
342//! %2 : int = prim::Constant[value=-1]() # dynamic_bvg.py:50:40
343//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25
344//! %4 : NoneType = prim::Constant()
345//! %5 : int[] = prim::Constant[value=[2, 3]]()
346//! %6 : int[] = aten::size(%inputs.1) # dynamic_bvg.py:50:25
347//! %7 : int[] = aten::slice(%6, %4, %2, %3) # dynamic_bvg.py:50:25
348//! %view_shape.1 : int[] = aten::add(%7, %5) # dynamic_bvg.py:50:25
349//! %bias : Tensor = prim::GetAttr[name="bias"](%self)
350//! %10 : int[] = aten::size(%bias)
351//! %11 : int[] = prim::BroadcastSizes(%6, %10)
352//! %12 : bool = prim::CudaFusionGuard[types=[...]](%inputs.1, %bias)
353//! %13 : int[] = prim::Constant[value=[-1, -1, -1, 6]]()
354//! %14 : int[] = prim::Constant[value=[-1, -1, -1, 2, 3]]()
355//! %15 : bool = prim::CudaFusionViewGuard(%11, %view_shape.1, %13, %14)
356//! %16 : bool[] = prim::ListConstruct(%15, %12)
357//! %17 : bool = aten::all(%16)
358//! %18 : Tensor = prim::If(%17)
359//! block0():
360//! %19 : Tensor = prim::CudaFusionGroup_0[cache_id=0](%inputs.1, %bias)
361//! -> (%19)
362//! block1():
363//! %20 : Function = prim::Constant[name="fallback_fn", fallback=1]()
364//! %21 : (...) = prim::CallFunction(%20, %inputs.1, %bias, %view_shape.1)
365//! %22 : Float(...) = prim::TupleUnpack(%21)
366//! -> (%22)
367//! return (%18)
368//! with prim::CudaFusionGroup_0 = graph(%0 : Float(...),
369//! %1 : Float(...)):
370//! %2 : int[] = prim::Constant[value=[2, 3, 4, 2, 3]]()
371//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25
372//! %o.1 : Float(...) = aten::add(%0, %1, %3) # dynamic_bvg.py:51:16
373//! %5 : Float(...) = prim::view_copy(%o.1, %2)
374//! %6 : Float(...) = aten::relu(%5) # dynamic_bvg.py:53:19
375//! return (%6)
376//!
377RegisterOperators view_guard({
378 Operator(
379 "prim::CudaFusionViewGuard(...) -> bool",
380 // prim::CudaFusionViewGuard returns a fresh Boolean type without
381 // aliasing. if we would ever return refined tensor, which would change
382 // aliasing analysis, we should update aliasdb pass.
383 [](const Node* node) -> Operation {
384 return [](Stack& stack) {
385 // view_sizes_constraint - Constant List[Int]
386 at::ArrayRef<IValue> inputs = last(stack, 3);
387
388 // tensor_sizes is the runtime size for the self tensor
389 // tensor_sizes - dynamic size List[Int]
390 TORCH_INTERNAL_ASSERT(
391 inputs[0].isIntList(), "tensor_sizes needs to be Int List");
392 auto tensor_sizes = inputs[0].toIntList();
393
394 // profiled_view_sizes is the runtime view size
395 // profiled_view_sizes - profile_ivalue List[Int]
396 TORCH_INTERNAL_ASSERT(
397 inputs[1].isIntList(),
398 "profiled_view_sizes needs to be Int list");
399 auto profiled_view_sizes = inputs[1].toIntList();
400
401 // tensor_constraints is a constant List[Int]
402 // used to guard tensor_sizes
403 TORCH_INTERNAL_ASSERT(
404 inputs[2].isIntList(),
405 "tensor constraint needs to be Int List");
406 auto tensor_constraints = inputs[2].toIntList();
407
408 // Drop after gather all input arguments
409 // If an argument is moved, it is destroyed when dropped from stack
410 drop(stack, 3);
411
412 auto status = inferViewShape(tensor_sizes, profiled_view_sizes);
413 if (!status) {
414 push(stack, IValue(false));
415 return;
416 }
417
418 if (!fuser::cuda::getCudaFusionGuardMode()) {
419 push(stack, IValue(true));
420 return;
421 }
422 std::vector<int64_t> tensor_sizes_int_vec = tensor_sizes.vec();
423 std::vector<int64_t> view_sizes_int_vec = tensor_sizes.vec();
424 std::vector<int64_t> previous_constraints =
425 tensor_constraints.vec();
426 auto new_constraints =
427 torch::jit::fuser::cuda::analyzeViewConstraint(
428 tensor_sizes_int_vec, view_sizes_int_vec);
429 bool guard_status =
430 (new_constraints.conglomerateString() == previous_constraints);
431 push(stack, IValue(guard_status));
432 return;
433 };
434 },
435 aliasAnalysisFromSchema()),
436});
437
438RegisterOperators ivalue_guard({
439 Operator(
440 "prim::CudaFusionIvalGuard(...) -> bool",
441 [](const Node* node) -> Operation {
442 return [](Stack& stack) {
443 at::ArrayRef<IValue> inputs = last(stack, 2);
444 drop(stack, 2);
445 if (!fuser::cuda::getCudaFusionGuardMode()) {
446 push(stack, IValue(true));
447 return;
448 }
449 push(stack, inputs[0].equals(inputs[1]));
450 return;
451 };
452 },
453 aliasAnalysisFromSchema()),
454});
455
456// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
457RegisterOperators reg_add_optional({
458 Operator(
459 "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)",
460 [](const Node* node) -> Operation {
461 return [](Stack& stack) {
462 IValue input, bias;
463 pop(stack, input, bias);
464 if (bias.isNone()) {
465 push(stack, std::move(input));
466 } else {
467 push(stack, at::add(input.toTensor(), bias.toTensor(), 1.0));
468 }
469 };
470 },
471 aliasAnalysisFromSchema()),
472});
473
474// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
475RegisterOperators reg_permute_copy({
476 Operator(
477 "prim::permute_copy(Tensor(a) self, int[] dims) -> Tensor",
478 [](const Node* node) -> Operation {
479 return [node](Stack& stack) {
480 TORCH_CHECK(
481 node->s(attr::name) == "CudaFusionGroup",
482 "permute_copy is only used by nvfuser to identify non-mutating ",
483 "alias ops, should be restored after fusion pass!");
484 IValue self, dims;
485 pop(stack, self, dims);
486 push(stack, at::native::view(self.toTensor(), dims.toIntVector()));
487 };
488 },
489 aliasAnalysisFromSchema()),
490});
491
492// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
493RegisterOperators reg_transpose_copy({
494 Operator(
495 "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor",
496 [](const Node* node) -> Operation {
497 return [node](Stack& stack) {
498 TORCH_CHECK(
499 node->s(attr::name) == "CudaFusionGroup",
500 "transpose_copy is only used by nvfuser to identify non-mutating ",
501 "alias ops, should be restored after fusion pass!");
502 IValue self, dim0, dim1;
503 pop(stack, self, dim0, dim1);
504 push(
505 stack,
506 at::transpose(self.toTensor(), dim0.toInt(), dim1.toInt()));
507 };
508 },
509 aliasAnalysisFromSchema()),
510});
511
512// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
513RegisterOperators reg_t_copy({
514 Operator(
515 "prim::t_copy(Tensor(a) self) -> Tensor",
516 [](const Node* node) -> Operation {
517 return [node](Stack& stack) {
518 TORCH_CHECK(
519 node->s(attr::name) == "CudaFusionGroup",
520 "t_copy is only used by nvfuser to identify non-mutating ",
521 "alias ops, should be restored after fusion pass!");
522 IValue self;
523 pop(stack, self);
524 push(stack, at::t(self.toTensor()));
525 };
526 },
527 aliasAnalysisFromSchema()),
528});
529
530// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
531RegisterOperators reg_view_copy({
532 Operator(
533 "prim::view_copy(Tensor self, int[] size) -> Tensor",
534 [](const Node* node) -> Operation {
535 return [node](Stack& stack) {
536 TORCH_CHECK(
537 node->s(attr::name) == "CudaFusionGroup",
538 "view_copy is only used by nvfuser to identify non-mutating ",
539 "alias ops, should be restored after fusion pass!");
540 IValue self, size;
541 pop(stack, self, size);
542 push(stack, at::native::view(self.toTensor(), size.toIntVector()));
543 };
544 },
545 aliasAnalysisFromSchema()),
546});
547
548// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
549RegisterOperators reg_flatten_copy({
550 Operator(
551 "prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor",
552 [](const Node* node) -> Operation {
553 return [node](Stack& stack) {
554 TORCH_CHECK(
555 node->s(attr::name) == "CudaFusionGroup",
556 "flatten_copy is only used by nvfuser to identify non-mutating ",
557 "alias ops, should be restored after fusion pass!");
558 IValue self, start_dim, end_dim;
559 pop(stack, self, start_dim, end_dim);
560 push(
561 stack,
562 at::native::flatten(
563 self.toTensor(), start_dim.toInt(), end_dim.toInt()));
564 };
565 },
566 aliasAnalysisFromSchema()),
567});
568
569// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
570RegisterOperators reg_reshape_copy({
571 Operator(
572 "prim::reshape_copy(Tensor self, int[] shape) -> Tensor",
573 [](const Node* node) -> Operation {
574 return [node](Stack& stack) {
575 TORCH_CHECK(
576 node->s(attr::name) == "CudaFusionGroup",
577 "reshape_copy is only used by nvfuser to identify non-mutating ",
578 "alias ops, should be restored after fusion pass!");
579 IValue self, shape;
580 pop(stack, self, shape);
581 push(
582 stack,
583 at::native::reshape(self.toTensor(), shape.toIntVector()));
584 };
585 },
586 aliasAnalysisFromSchema()),
587});
588
589// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
590RegisterOperators reg_squeeze_copy({
591 Operator(
592 "prim::squeeze_copy(Tensor self) -> Tensor",
593 [](const Node* node) -> Operation {
594 return [node](Stack& stack) {
595 TORCH_CHECK(
596 node->s(attr::name) == "CudaFusionGroup",
597 "squeeze_copy is only used by nvfuser to identify non-mutating ",
598 "alias ops, should be restored after fusion pass!");
599 IValue self;
600 pop(stack, self);
601 push(stack, at::squeeze(self.toTensor()));
602 };
603 },
604 aliasAnalysisFromSchema()),
605});
606
607// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
608RegisterOperators reg_squeeze_dim_copy({
609 Operator(
610 "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor",
611 [](const Node* node) -> Operation {
612 return [node](Stack& stack) {
613 TORCH_CHECK(
614 node->s(attr::name) == "CudaFusionGroup",
615 "squeeze_dim_copy is only used by nvfuser to identify non-mutating ",
616 "alias ops, should be restored after fusion pass!");
617 IValue self, dim;
618 pop(stack, self, dim);
619 push(stack, at::squeeze(self.toTensor(), dim.toInt()));
620 };
621 },
622 aliasAnalysisFromSchema()),
623});
624
625// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
626RegisterOperators reg_unsqueeze_copy({
627 Operator(
628 "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor",
629 [](const Node* node) -> Operation {
630 return [node](Stack& stack) {
631 TORCH_CHECK(
632 node->s(attr::name) == "CudaFusionGroup",
633 "unsqueeze_copy is only used by nvfuser to identify non-mutating ",
634 "alias ops, should be restored after fusion pass!");
635 IValue self, dim;
636 pop(stack, self, dim);
637 push(stack, at::unsqueeze(self.toTensor(), dim.toInt()));
638 };
639 },
640 aliasAnalysisFromSchema()),
641});
642
643// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
644RegisterOperators reg_infer_unsqueeze_size({
645 Operator(
646 "prim::infer_unsqueeze_size(int[] a, int dim) -> int[]",
647 [](const Node* node) -> Operation {
648 return [](Stack& stack) {
649 auto dim = pop(stack).toInt();
650 auto size = pop(stack).toIntVector();
651 if (dim < 0) {
652 dim = dim + 1 + size.size();
653 }
654 auto it = size.begin() + dim;
655 size.insert(it, 1);
656 push(stack, IValue(size));
657 };
658 },
659 aliasAnalysisFromSchema()),
660});
661
662// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
663RegisterOperators reg_infer_squeeze_dim_size({
664 Operator(
665 "prim::infer_squeeze_size.dim(int[] a, int dim) -> int[]",
666 [](const Node* node) -> Operation {
667 return [](Stack& stack) {
668 auto dim = pop(stack).toInt();
669 auto size = pop(stack).toIntVector();
670 if (dim < 0) {
671 dim = dim + size.size();
672 }
673 auto it = size.begin() + dim;
674 if (*it == 1) {
675 size.erase(it);
676 }
677 push(stack, IValue(size));
678 };
679 },
680 aliasAnalysisFromSchema()),
681});
682
683// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
684RegisterOperators reg_infer_squeeze_size({
685 Operator(
686 "prim::infer_squeeze_size(int[] a) -> int[]",
687 [](const Node* node) -> Operation {
688 return [](Stack& stack) {
689 auto size = pop(stack).toIntVector();
690
691 for (auto it = size.begin(); it != size.end(); it++) {
692 if (*it == 1) {
693 auto pre = it - 1;
694 size.erase(it);
695 it = pre;
696 }
697 }
698 push(stack, IValue(size));
699 };
700 },
701 aliasAnalysisFromSchema()),
702});
703
704// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
705RegisterOperators reg_expand_copy({
706 Operator(
707 "prim::expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor",
708 [](const Node* node) -> Operation {
709 return [node](Stack& stack) {
710 TORCH_CHECK(
711 node->s(attr::name) == "CudaFusionGroup",
712 "expand_copy is only used by nvfuser to identify non-mutating ",
713 "alias ops, should be restored after fusion pass!");
714 IValue self, size, implicit;
715 pop(stack, self, size, implicit);
716 push(stack, self.toTensor().expand(size.toIntVector()));
717 };
718 },
719 aliasAnalysisFromSchema()),
720});
721
722// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
723RegisterOperators reg_expand_as_copy({
724 Operator(
725 "prim::expand_as_copy(Tensor self, Tensor other) -> Tensor",
726 [](const Node* node) -> Operation {
727 return [node](Stack& stack) {
728 TORCH_CHECK(
729 node->s(attr::name) == "CudaFusionGroup",
730 "expand_as_copy is only used by nvfuser to identify non-mutating ",
731 "alias ops, should be restored after fusion pass!");
732 IValue self, other;
733 pop(stack, self, other);
734 push(
735 stack,
736 at::native::expand_as(self.toTensor(), other.toTensor()));
737 };
738 },
739 aliasAnalysisFromSchema()),
740});
741
742} // namespace
743
744} // namespace jit
745} // namespace torch
746