1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <dispatch.h> |
5 | #include <ir_builder.h> |
6 | |
7 | #include <unordered_map> |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | class IrContainer; |
16 | |
17 | //! Clones nodes from an exiting Fusion |
18 | //! |
19 | //! \warning IrCloner machinery is a specialized helper for implementing |
20 | //! Fusion copy operations and the and limited scope of RecomputeTv below. |
21 | //! It is not intended for any other uses. |
22 | //! |
23 | class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { |
24 | friend class Statement; |
25 | friend class IrBuilder; |
26 | |
27 | public: |
28 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
29 | explicit IrCloner(IrContainer* container); |
30 | |
31 | Statement* clone(const Statement* statement); |
32 | |
33 | template <class T> |
34 | T* clone(const T* node) { |
35 | return node ? clone(node->template as<Statement>())->template as<T>() |
36 | : nullptr; |
37 | } |
38 | |
39 | template <class T> |
40 | std::vector<T*> clone(const std::vector<T*>& container) { |
41 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
42 | std::vector<T*> copy; |
43 | copy.reserve(container.size()); |
44 | for (auto p : container) { |
45 | copy.push_back(clone(p)); |
46 | } |
47 | return copy; |
48 | } |
49 | |
50 | IrContainer* container() const { |
51 | return ir_container_; |
52 | } |
53 | |
54 | protected: |
55 | void registerClone(const Statement* src, Statement* clone); |
56 | |
57 | void handle(const Statement*) override; |
58 | void handle(const Val*) override; |
59 | void handle(const Expr*) override; |
60 | |
61 | void handle(const TensorDomain*) override; |
62 | void handle(const TensorView*) override; |
63 | void handle(const IterDomain*) override; |
64 | |
65 | void handle(const Bool*) override; |
66 | void handle(const Double*) override; |
67 | void handle(const Int*) override; |
68 | void handle(const ComplexDouble*) override; |
69 | void handle(const NamedScalar*) override; |
70 | |
71 | void handle(const FullOp*) override; |
72 | void handle(const ARangeOp*) override; |
73 | void handle(const EyeOp*) override; |
74 | void handle(const UnaryOp*) override; |
75 | void handle(const BinaryOp*) override; |
76 | void handle(const TernaryOp*) override; |
77 | void handle(const RNGOp*) override; |
78 | void handle(const BroadcastOp*) override; |
79 | void handle(const ReductionOp*) override; |
80 | void handle(const GroupedReductionOp*) override; |
81 | void handle(const WelfordOp*) override; |
82 | void handle(const LoadStoreOp*) override; |
83 | void handle(const MmaOp*) override; |
84 | void handle(const TransposeOp*) override; |
85 | void handle(const ExpandOp*) override; |
86 | void handle(const ShiftOp*) override; |
87 | void handle(const GatherOp*) override; |
88 | void handle(const ViewAsScalar*) override; |
89 | void handle(const ViewOp*) override; |
90 | |
91 | void handle(const Split*) override; |
92 | void handle(const Merge*) override; |
93 | void handle(const Swizzle2D*) override; |
94 | |
95 | protected: |
96 | // We keep track of the original -> clone map so we don't |
97 | // duplicate clones of the same object if referenced multiple times |
98 | std::unordered_map<const Statement*, Statement*> clones_map_; |
99 | |
100 | private: |
101 | // The destination Fusion container |
102 | IrContainer* ir_container_ = nullptr; |
103 | |
104 | // The dispatch interface doesn't allow returning values from |
105 | // individual `handle()` methods, so they are storing the |
106 | // result here |
107 | Statement* clone_ = nullptr; |
108 | |
109 | // Builder to make all the new nodes |
110 | IrBuilder builder_; |
111 | }; |
112 | |
113 | // Replicates all expressions used to generate the provided TensorView. Does not |
114 | // replicate inputs. Does not replicate scalar values. In other words the value |
115 | // provided will be recomputed from the inputs of the fusion. |
116 | class RecomputeTv : private IrCloner { |
117 | public: |
118 | // Replicates expressions and values in provided expressions. |
119 | static TensorView* recompute(TensorView* tv); |
120 | |
121 | private: |
122 | RecomputeTv(Fusion* fusion, std::vector<Expr*> exprs); |
123 | |
124 | void handle(const TensorDomain*) final; |
125 | |
126 | Fusion* fusion_; |
127 | }; |
128 | |
129 | } // namespace cuda |
130 | } // namespace fuser |
131 | } // namespace jit |
132 | } // namespace torch |
133 | |