1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <c10/macros/Export.h> |
5 | #include <c10/util/Exception.h> |
6 | |
7 | #include <ir_base_nodes.h> |
8 | #include <ir_container.h> |
9 | #include <iter_visitor.h> |
10 | |
11 | #include <unordered_map> |
12 | #include <unordered_set> |
13 | #include <vector> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | //! Usage: FusionGuard and Fusion are required user interfaces for any operation |
21 | //! underlying the code generator. In order to create values, expressions, and |
22 | //! generate code a Fusion instance must be active. It is the responsibility of |
23 | //! the user to create a Fusion instance and register it with the fusion guard. |
24 | //! The simplest example of this is: |
25 | //! |
26 | //! Fusion fusion; |
27 | //! FusionGuard fg(&fusion); |
28 | //! |
29 | //! Once a fusion is active all values and operations will be registered with |
30 | //! it. |
31 | //! |
32 | //! FusionGuard and Fusion are critical to the lifetime model of the IR system. |
33 | //! FusionGuard is a convenient way to set what base container instance holds |
34 | //! the defined IR. Statements that are defined are registered through the |
35 | //! FusionGuard with a particular Fusion. FusionGuard provides convenient |
36 | //! methods to access the active fusion so it doesn't need to be passed around |
37 | //! constantly. Any IR node derived classes from Statement must register with |
38 | //! Fusion to avoid memory leaks. |
39 | //! |
40 | //! Fusion is generally thought of as a translated fusion group from the JIT. It |
41 | //! is likely a single kernel, although, we don't have to stick to this in the |
42 | //! future and could in theory generate multiple kernels with an executor to run |
43 | //! them. |
44 | //! |
45 | //! Fusion also allows users to set input/output values that will allow us to |
46 | //! figure out how to hook up runtime data to and from the JIT as well as |
47 | //! provide us mechanisms for dependency analysis and DCE including safety |
48 | //! checks. |
49 | |
50 | class Fusion; |
51 | class TensorView; |
52 | class WelfordResult; |
53 | |
54 | class SegmentCandidateFinder; |
55 | class SegmentedFusion; |
56 | class KernelArgumentHolder; |
57 | |
58 | //! Fusion Guard is our "context manager". It holds the actrive fusion and |
59 | //! allows it to be accessed anywhere through FusionGuard::getCurFusion() |
60 | class TORCH_CUDA_CU_API FusionGuard { |
61 | public: |
62 | Fusion* prev_fusion; |
63 | |
64 | //! Set the active fusion so it can be manipulated. |
65 | explicit FusionGuard(Fusion* fusion); |
66 | |
67 | ~FusionGuard(); |
68 | |
69 | static Fusion* getCurFusion(); |
70 | static void setCurFusion(Fusion* fusion); |
71 | }; |
72 | |
73 | //! Fusion is mutable but unique. Nodes cannot be copied in any way from one |
74 | //! Fusion to another. If anything like that is desired, it would require |
75 | //! duplicating all associated values and exprs. Fusion is considered to be SSA, |
76 | //! though this could also change in the future if there is a good reason to do |
77 | //! so. |
78 | //! |
79 | //! The Fusion owns the whole IR graph (Vals and Exprs) |
80 | //! |
81 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
82 | class TORCH_CUDA_CU_API Fusion : public IrContainer { |
83 | typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap; |
84 | |
85 | public: |
86 | Fusion() = default; |
87 | |
88 | Fusion(const Fusion& other); |
89 | Fusion(Fusion&& other) noexcept; |
90 | |
91 | Fusion& operator=(const Fusion& other); |
92 | Fusion& operator=(Fusion&& other) noexcept; |
93 | |
94 | ~Fusion(); |
95 | |
96 | friend void swap(Fusion& a, Fusion& b) noexcept; |
97 | |
98 | void clear() noexcept; |
99 | |
100 | //! Break dependency chains associated with Expr, remove references to expr |
101 | //! delete expr |
102 | void removeExpr(Expr* expr) override; |
103 | |
104 | //! Completely remove val from the fusion, break all dependencies associated |
105 | //! with it |
106 | void removeVal(Val* val) override; |
107 | |
108 | //! Register input as an input of the fusion |
109 | void addInput(Val* input); |
110 | |
111 | //! Register output as an output of the fusion |
112 | void addOutput(Val* output); |
113 | |
114 | //! Deregister input as an input of the fusion |
115 | void removeInput(Val* input); |
116 | |
117 | //! Deregister output as an output of the fusion |
118 | void removeOutput(Val* output); |
119 | |
120 | //! Replace output with another value |
121 | void replaceOutput(Val* output, Val* replacement); |
122 | |
123 | //! Assert that all leaves found from outputs are registered as an input |
124 | void validateInputs(); |
125 | |
126 | //! Print this fusion to the console |
127 | void print(); |
128 | |
129 | //! Print Arith exprs |
130 | //! \param from_outputs_only Only print exprs reachable from outputs |
131 | void printMath(bool from_outputs_only = true); |
132 | |
133 | //! Print transformations used in fusion (can be very verbose) |
134 | void printTransforms(); |
135 | |
136 | //! Lower the fusion and print a kernel |
137 | void printKernel(DataType index_type = DataType::Int); |
138 | |
139 | //! Lower the fusion and evaluate bank conflict info |
140 | std::unordered_map<std::string, std::pair<int, int>> bankConflictInfo( |
141 | DataType index_type = DataType::Int); |
142 | |
143 | //! Return a list of topologically sorted expressions. This only includes |
144 | //! exprs required to genereate registered outputs. |
145 | std::vector<Expr*> exprs(); |
146 | |
147 | //! Return a vector of fusion inputs that feed this Val |
148 | std::vector<Val*> inputsOf(Val* val); |
149 | |
150 | //! Return all Vals in math expressions that cannot be eliminated. |
151 | //! |
152 | //! It is generally equivalent to vals that are used to generate |
153 | //! outputs, however, when a multi-output expression exists, and only |
154 | //! some of the outputs are used, the remaining unused outputs are |
155 | //! also included as they must show up in the final code. |
156 | std::vector<Val*> usedMathVals(); |
157 | |
158 | //! Returns all vals that are produced by used math expressions and |
159 | //! also do not have further consumers. |
160 | //! |
161 | //! In the case of an active multi-output expressions, the returned vector |
162 | //! will include the expression outputs that did not lead to an fusion |
163 | //! output. |
164 | std::vector<Val*> terminatingMathVals(); |
165 | |
166 | //! Return all Exprs that use val |
167 | std::unordered_set<Expr*> unordered_uses(const Val* val) const; |
168 | |
169 | //! Return the Expr that produces val |
170 | Expr* definition(const Val* val) const; |
171 | |
172 | //! Indicate to kernel to set itself up to generate random numbers |
173 | bool isStochastic(); |
174 | |
175 | //! Run fusion segmentation algorithm to create a segmented fusion |
176 | std::unique_ptr<SegmentedFusion> segment(const KernelArgumentHolder& args); |
177 | |
178 | const auto& inputs() const { |
179 | return inputs_; |
180 | } |
181 | |
182 | std::vector<Val*> inputsAndCreated(); |
183 | |
184 | const auto& outputs() const { |
185 | return outputs_; |
186 | } |
187 | |
188 | std::vector<Val*> getTerminatingOutputs() const; |
189 | |
190 | // Aliasing output to input value, this is a WAR to allow inplace update on |
191 | // input tensor. |
192 | // Note: this is not always safe and should be used with extra caution. |
193 | // Currently the only place it's used is in the running stats update for batch |
194 | // normalization. |
195 | // TODO: alias should be made aware to segmentation, so we'll always include |
196 | // the input tensor to the section where output is produced. |
197 | void aliasOutputToInput(Val* output, Val* input); |
198 | Val* getOutputAlias(Val* output); |
199 | std::unordered_set<int> getOutputAliasIndices() const; |
200 | std::vector<std::pair<int, int>> getInputAliasIndices() const; |
201 | |
202 | // mark input at index to be permuted by permutation |
203 | void setPermutationOnInput(int index, std::vector<int64_t> permutation) { |
204 | permuted_input_map_.insert({index, permutation}); |
205 | } |
206 | |
207 | // mark output at index to be restored by permutation |
208 | void setPermutationOnOutput(int index, std::vector<int64_t> permutation) { |
209 | permuted_output_map_.insert({index, permutation}); |
210 | } |
211 | |
212 | // return a map of indices to permutation, which indicates all input tensors |
213 | // that needs to be permuted |
214 | const PermutationMap& getPermutationInputMap() const { |
215 | return permuted_input_map_; |
216 | } |
217 | |
218 | // return a map of indices to permutation, which indicates all output tensors |
219 | // that needs to be permuted |
220 | const PermutationMap& getPermutationOutputMap() const { |
221 | return permuted_output_map_; |
222 | } |
223 | |
224 | bool isTVUseInfoValid() { |
225 | return all_tv_uses_valid_; |
226 | } |
227 | |
228 | bool isUpdatingTVUseInfo() { |
229 | return is_during_update_uses_; |
230 | } |
231 | |
232 | const auto& ioAlias() const { |
233 | return io_alias_; |
234 | } |
235 | |
236 | protected: |
237 | friend SegmentCandidateFinder; |
238 | friend SegmentedFusion; |
239 | friend class TranslateApplicableWelford; |
240 | friend Val; |
241 | |
242 | static IrCloner copy(const Fusion* from, Fusion* to); |
243 | |
244 | //! Register the Val with this fusion |
245 | virtual void registerVal(Val* val) override; |
246 | |
247 | //! Register expr with this fusion. |
248 | //! When we register an expression, we want to update the dependency tracking |
249 | //! of Vals. If this container is a not a Kernel, it will remove previous |
250 | //! definitions of outputs and register this Expr as the definition. Otherwise |
251 | //! will update definition if not previously set, but will not remove old |
252 | //! definitions. |
253 | virtual void registerExpr(Expr* expr) override; |
254 | |
255 | //! Clear Expr's from TV uses that are not required to produce outputs from |
256 | //! inputs. Only other place this is used (other than Fusion) is in |
257 | //! Val::uses() |
258 | void resetTvUses(); |
259 | |
260 | private: |
261 | // Determine if the two values are compatible for aliasing |
262 | // Same DataType, ValType, and number of dimensions |
263 | bool isAliasCompatible(Val* left, Val* right); |
264 | |
265 | private: |
266 | // Fusion inputs and outputs |
267 | std::vector<Val*> inputs_; |
268 | std::vector<Val*> outputs_; |
269 | |
270 | // io alias pointing from output to input |
271 | std::unordered_map<Val*, Val*> io_alias_; |
272 | |
273 | // See Note [ Permutation support in nvfuser ] |
274 | // map from indices of input tensor to permutation |
275 | PermutationMap permuted_input_map_; |
276 | // map from indices of output tensor to permutation |
277 | PermutationMap permuted_output_map_; |
278 | |
279 | // Records if the current use data in the IR nodes are valid |
280 | // the states are either all valid or all invalid |
281 | bool all_tv_uses_valid_ = false; |
282 | bool is_during_update_uses_ = false; |
283 | }; |
284 | |
285 | } // namespace cuda |
286 | } // namespace fuser |
287 | } // namespace jit |
288 | } // namespace torch |
289 | |