1#pragma once
2
3#include <c10/macros/Export.h>
4#include <dispatch.h>
5#include <ir_builder.h>
6
7#include <unordered_map>
8#include <vector>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15class IrContainer;
16
17//! Clones nodes from an exiting Fusion
18//!
19//! \warning IrCloner machinery is a specialized helper for implementing
20//! Fusion copy operations and the and limited scope of RecomputeTv below.
21//! It is not intended for any other uses.
22//!
23class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
24 friend class Statement;
25 friend class IrBuilder;
26
27 public:
28 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
29 explicit IrCloner(IrContainer* container);
30
31 Statement* clone(const Statement* statement);
32
33 template <class T>
34 T* clone(const T* node) {
35 return node ? clone(node->template as<Statement>())->template as<T>()
36 : nullptr;
37 }
38
39 template <class T>
40 std::vector<T*> clone(const std::vector<T*>& container) {
41 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
42 std::vector<T*> copy;
43 copy.reserve(container.size());
44 for (auto p : container) {
45 copy.push_back(clone(p));
46 }
47 return copy;
48 }
49
50 IrContainer* container() const {
51 return ir_container_;
52 }
53
54 protected:
55 void registerClone(const Statement* src, Statement* clone);
56
57 void handle(const Statement*) override;
58 void handle(const Val*) override;
59 void handle(const Expr*) override;
60
61 void handle(const TensorDomain*) override;
62 void handle(const TensorView*) override;
63 void handle(const IterDomain*) override;
64
65 void handle(const Bool*) override;
66 void handle(const Double*) override;
67 void handle(const Int*) override;
68 void handle(const ComplexDouble*) override;
69 void handle(const NamedScalar*) override;
70
71 void handle(const FullOp*) override;
72 void handle(const ARangeOp*) override;
73 void handle(const EyeOp*) override;
74 void handle(const UnaryOp*) override;
75 void handle(const BinaryOp*) override;
76 void handle(const TernaryOp*) override;
77 void handle(const RNGOp*) override;
78 void handle(const BroadcastOp*) override;
79 void handle(const ReductionOp*) override;
80 void handle(const GroupedReductionOp*) override;
81 void handle(const WelfordOp*) override;
82 void handle(const LoadStoreOp*) override;
83 void handle(const MmaOp*) override;
84 void handle(const TransposeOp*) override;
85 void handle(const ExpandOp*) override;
86 void handle(const ShiftOp*) override;
87 void handle(const GatherOp*) override;
88 void handle(const ViewAsScalar*) override;
89 void handle(const ViewOp*) override;
90
91 void handle(const Split*) override;
92 void handle(const Merge*) override;
93 void handle(const Swizzle2D*) override;
94
95 protected:
96 // We keep track of the original -> clone map so we don't
97 // duplicate clones of the same object if referenced multiple times
98 std::unordered_map<const Statement*, Statement*> clones_map_;
99
100 private:
101 // The destination Fusion container
102 IrContainer* ir_container_ = nullptr;
103
104 // The dispatch interface doesn't allow returning values from
105 // individual `handle()` methods, so they are storing the
106 // result here
107 Statement* clone_ = nullptr;
108
109 // Builder to make all the new nodes
110 IrBuilder builder_;
111};
112
113// Replicates all expressions used to generate the provided TensorView. Does not
114// replicate inputs. Does not replicate scalar values. In other words the value
115// provided will be recomputed from the inputs of the fusion.
116class RecomputeTv : private IrCloner {
117 public:
118 // Replicates expressions and values in provided expressions.
119 static TensorView* recompute(TensorView* tv);
120
121 private:
122 RecomputeTv(Fusion* fusion, std::vector<Expr*> exprs);
123
124 void handle(const TensorDomain*) final;
125
126 Fusion* fusion_;
127};
128
129} // namespace cuda
130} // namespace fuser
131} // namespace jit
132} // namespace torch
133