1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <c10/macros/Export.h>
5#include <c10/util/Exception.h>
6
7#include <ir_base_nodes.h>
8#include <ir_container.h>
9#include <iter_visitor.h>
10
11#include <unordered_map>
12#include <unordered_set>
13#include <vector>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20//! Usage: FusionGuard and Fusion are required user interfaces for any operation
21//! underlying the code generator. In order to create values, expressions, and
22//! generate code a Fusion instance must be active. It is the responsibility of
23//! the user to create a Fusion instance and register it with the fusion guard.
24//! The simplest example of this is:
25//!
26//! Fusion fusion;
27//! FusionGuard fg(&fusion);
28//!
29//! Once a fusion is active all values and operations will be registered with
30//! it.
31//!
32//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
33//! FusionGuard is a convenient way to set what base container instance holds
34//! the defined IR. Statements that are defined are registered through the
35//! FusionGuard with a particular Fusion. FusionGuard provides convenient
36//! methods to access the active fusion so it doesn't need to be passed around
37//! constantly. Any IR node derived classes from Statement must register with
38//! Fusion to avoid memory leaks.
39//!
40//! Fusion is generally thought of as a translated fusion group from the JIT. It
41//! is likely a single kernel, although, we don't have to stick to this in the
42//! future and could in theory generate multiple kernels with an executor to run
43//! them.
44//!
45//! Fusion also allows users to set input/output values that will allow us to
46//! figure out how to hook up runtime data to and from the JIT as well as
47//! provide us mechanisms for dependency analysis and DCE including safety
48//! checks.
49
50class Fusion;
51class TensorView;
52class WelfordResult;
53
54class SegmentCandidateFinder;
55class SegmentedFusion;
56class KernelArgumentHolder;
57
58//! Fusion Guard is our "context manager". It holds the actrive fusion and
59//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
60class TORCH_CUDA_CU_API FusionGuard {
61 public:
62 Fusion* prev_fusion;
63
64 //! Set the active fusion so it can be manipulated.
65 explicit FusionGuard(Fusion* fusion);
66
67 ~FusionGuard();
68
69 static Fusion* getCurFusion();
70 static void setCurFusion(Fusion* fusion);
71};
72
73//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
74//! Fusion to another. If anything like that is desired, it would require
75//! duplicating all associated values and exprs. Fusion is considered to be SSA,
76//! though this could also change in the future if there is a good reason to do
77//! so.
78//!
79//! The Fusion owns the whole IR graph (Vals and Exprs)
80//!
81// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
82class TORCH_CUDA_CU_API Fusion : public IrContainer {
83 typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
84
85 public:
86 Fusion() = default;
87
88 Fusion(const Fusion& other);
89 Fusion(Fusion&& other) noexcept;
90
91 Fusion& operator=(const Fusion& other);
92 Fusion& operator=(Fusion&& other) noexcept;
93
94 ~Fusion();
95
96 friend void swap(Fusion& a, Fusion& b) noexcept;
97
98 void clear() noexcept;
99
100 //! Break dependency chains associated with Expr, remove references to expr
101 //! delete expr
102 void removeExpr(Expr* expr) override;
103
104 //! Completely remove val from the fusion, break all dependencies associated
105 //! with it
106 void removeVal(Val* val) override;
107
108 //! Register input as an input of the fusion
109 void addInput(Val* input);
110
111 //! Register output as an output of the fusion
112 void addOutput(Val* output);
113
114 //! Deregister input as an input of the fusion
115 void removeInput(Val* input);
116
117 //! Deregister output as an output of the fusion
118 void removeOutput(Val* output);
119
120 //! Replace output with another value
121 void replaceOutput(Val* output, Val* replacement);
122
123 //! Assert that all leaves found from outputs are registered as an input
124 void validateInputs();
125
126 //! Print this fusion to the console
127 void print();
128
129 //! Print Arith exprs
130 //! \param from_outputs_only Only print exprs reachable from outputs
131 void printMath(bool from_outputs_only = true);
132
133 //! Print transformations used in fusion (can be very verbose)
134 void printTransforms();
135
136 //! Lower the fusion and print a kernel
137 void printKernel(DataType index_type = DataType::Int);
138
139 //! Lower the fusion and evaluate bank conflict info
140 std::unordered_map<std::string, std::pair<int, int>> bankConflictInfo(
141 DataType index_type = DataType::Int);
142
143 //! Return a list of topologically sorted expressions. This only includes
144 //! exprs required to genereate registered outputs.
145 std::vector<Expr*> exprs();
146
147 //! Return a vector of fusion inputs that feed this Val
148 std::vector<Val*> inputsOf(Val* val);
149
150 //! Return all Vals in math expressions that cannot be eliminated.
151 //!
152 //! It is generally equivalent to vals that are used to generate
153 //! outputs, however, when a multi-output expression exists, and only
154 //! some of the outputs are used, the remaining unused outputs are
155 //! also included as they must show up in the final code.
156 std::vector<Val*> usedMathVals();
157
158 //! Returns all vals that are produced by used math expressions and
159 //! also do not have further consumers.
160 //!
161 //! In the case of an active multi-output expressions, the returned vector
162 //! will include the expression outputs that did not lead to an fusion
163 //! output.
164 std::vector<Val*> terminatingMathVals();
165
166 //! Return all Exprs that use val
167 std::unordered_set<Expr*> unordered_uses(const Val* val) const;
168
169 //! Return the Expr that produces val
170 Expr* definition(const Val* val) const;
171
172 //! Indicate to kernel to set itself up to generate random numbers
173 bool isStochastic();
174
175 //! Run fusion segmentation algorithm to create a segmented fusion
176 std::unique_ptr<SegmentedFusion> segment(const KernelArgumentHolder& args);
177
178 const auto& inputs() const {
179 return inputs_;
180 }
181
182 std::vector<Val*> inputsAndCreated();
183
184 const auto& outputs() const {
185 return outputs_;
186 }
187
188 std::vector<Val*> getTerminatingOutputs() const;
189
190 // Aliasing output to input value, this is a WAR to allow inplace update on
191 // input tensor.
192 // Note: this is not always safe and should be used with extra caution.
193 // Currently the only place it's used is in the running stats update for batch
194 // normalization.
195 // TODO: alias should be made aware to segmentation, so we'll always include
196 // the input tensor to the section where output is produced.
197 void aliasOutputToInput(Val* output, Val* input);
198 Val* getOutputAlias(Val* output);
199 std::unordered_set<int> getOutputAliasIndices() const;
200 std::vector<std::pair<int, int>> getInputAliasIndices() const;
201
202 // mark input at index to be permuted by permutation
203 void setPermutationOnInput(int index, std::vector<int64_t> permutation) {
204 permuted_input_map_.insert({index, permutation});
205 }
206
207 // mark output at index to be restored by permutation
208 void setPermutationOnOutput(int index, std::vector<int64_t> permutation) {
209 permuted_output_map_.insert({index, permutation});
210 }
211
212 // return a map of indices to permutation, which indicates all input tensors
213 // that needs to be permuted
214 const PermutationMap& getPermutationInputMap() const {
215 return permuted_input_map_;
216 }
217
218 // return a map of indices to permutation, which indicates all output tensors
219 // that needs to be permuted
220 const PermutationMap& getPermutationOutputMap() const {
221 return permuted_output_map_;
222 }
223
224 bool isTVUseInfoValid() {
225 return all_tv_uses_valid_;
226 }
227
228 bool isUpdatingTVUseInfo() {
229 return is_during_update_uses_;
230 }
231
232 const auto& ioAlias() const {
233 return io_alias_;
234 }
235
236 protected:
237 friend SegmentCandidateFinder;
238 friend SegmentedFusion;
239 friend class TranslateApplicableWelford;
240 friend Val;
241
242 static IrCloner copy(const Fusion* from, Fusion* to);
243
244 //! Register the Val with this fusion
245 virtual void registerVal(Val* val) override;
246
247 //! Register expr with this fusion.
248 //! When we register an expression, we want to update the dependency tracking
249 //! of Vals. If this container is a not a Kernel, it will remove previous
250 //! definitions of outputs and register this Expr as the definition. Otherwise
251 //! will update definition if not previously set, but will not remove old
252 //! definitions.
253 virtual void registerExpr(Expr* expr) override;
254
255 //! Clear Expr's from TV uses that are not required to produce outputs from
256 //! inputs. Only other place this is used (other than Fusion) is in
257 //! Val::uses()
258 void resetTvUses();
259
260 private:
261 // Determine if the two values are compatible for aliasing
262 // Same DataType, ValType, and number of dimensions
263 bool isAliasCompatible(Val* left, Val* right);
264
265 private:
266 // Fusion inputs and outputs
267 std::vector<Val*> inputs_;
268 std::vector<Val*> outputs_;
269
270 // io alias pointing from output to input
271 std::unordered_map<Val*, Val*> io_alias_;
272
273 // See Note [ Permutation support in nvfuser ]
274 // map from indices of input tensor to permutation
275 PermutationMap permuted_input_map_;
276 // map from indices of output tensor to permutation
277 PermutationMap permuted_output_map_;
278
279 // Records if the current use data in the IR nodes are valid
280 // the states are either all valid or all invalid
281 bool all_tv_uses_valid_ = false;
282 bool is_during_update_uses_ = false;
283};
284
285} // namespace cuda
286} // namespace fuser
287} // namespace jit
288} // namespace torch
289