1
2#pragma once
3
4#include <c10/macros/Export.h>
5
6#include <compute_at_map.h>
7#include <ir_all_nodes.h>
8#include <kernel_ir.h>
9#include <parallel_type_bitmap.h>
10
11#include <bitset>
12#include <map>
13
14// Provides utilities for dealing with nested ForLoop and IfThenElse scopes
15
16namespace torch {
17namespace jit {
18namespace fuser {
19namespace cuda {
20
21class ThreadPredicateMap;
22
23using IterDomainMap = std::unordered_map<IterDomain*, IterDomain*>;
24
25namespace scope_utils {
26
27//! Create an **empty** Forloop and copy the metadata.
28kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop);
29
30//! Create an **empty** IfThenElse and copy the metadata.
31kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite);
32
33} // namespace scope_utils
34
35namespace ir_utils {
36
37// Somtimes we want to temporarily view a tensorview with another tensordomain.
38// This isn't a permanent transformation, but in indexing we want to index
39// producers with a consumer set of indices, so we need to view the producer
40// transformed like consumer while we index. This will set the tv with td for
41// the life of this context guard.
42class TORCH_CUDA_CU_API TVDomainGuard {
43 private:
44 TensorView* tv_;
45 TensorDomain* prev_domain_;
46
47 public:
48 explicit TVDomainGuard(TensorView* tv, TensorDomain* td);
49 TVDomainGuard(const TVDomainGuard&) = delete;
50 TVDomainGuard(TVDomainGuard&&);
51
52 //! An utility to access the tensordomain before the temporary
53 //! view. This is used to retrieve information, like swizzle
54 //! information that can only be reliably kept at the original domain.
55 const TensorDomain* prevDomain() const {
56 return prev_domain_;
57 }
58
59 ~TVDomainGuard();
60};
61
62// Create a TVDomainGuard that temporarily view a tensorview with specified
63// all-true or all-false contiguity.
64TORCH_CUDA_CU_API ir_utils::TVDomainGuard overrideContiguityGuard(
65 TensorView* tv,
66 bool contiguity);
67
68//! Return inputs of provided IterDomains that are IterDomains. A list
69//! of input IterDomain can be optionally given. Otherwise,
70//! IterDomains with no defining expression are returned.
71std::vector<IterDomain*> iterDomainInputsOf(
72 const std::vector<IterDomain*>& input_ids,
73 const std::vector<IterDomain*>& all_inputs = {});
74
75// Return inputs of provided IterDomains that are IterDomains, order as the
76// second provided vector.
77std::vector<IterDomain*> iterDomainInputsOfOrderedAs(
78 const std::vector<IterDomain*>& of,
79 const std::vector<IterDomain*>& order);
80
81// Returns if Val is a TensorView or TensorIndex
82TORCH_CUDA_CU_API bool isTV(const Val* const);
83
84// Returns if Expr is a TensorView or TensorIndex Expr.
85TORCH_CUDA_CU_API bool isTvOp(const Expr*);
86
87// Returns the first output of Expr that is a TensorView
88TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*);
89
90// Returns the first input of Expr that is a TensorView
91TORCH_CUDA_CU_API TensorView* getTvInput(const Expr*);
92
93//! Returns the iterdomain that maps to the thread dimension grouped
94//! to warps. Returns nullopt if the reduction is not to be lowered to
95//! a warp reduction.
96c10::optional<IterDomain*> getMaybeWarpReductionDim(
97 const Val* output,
98 const Val* input);
99
100bool isScalarOp(const Expr*);
101
102//! Get TensorView potentially via kir::TensorIndex. Returns nullptr if
103//! cast fails.
104TensorView* getTv(Val*);
105const TensorView* getTv(const Val*);
106
107//! Get only TensorView potentially via kir::TensorIndex.
108std::vector<TensorView*> getTvs(const std::vector<Val*>& vals);
109
110//! Return true if axis is derived from a root axis that is an input
111//! to a CA leaf axis.
112bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis);
113
114std::unordered_map<ParallelType, IterDomain*, TypeHash> getParallelDomains(
115 const Val* val);
116
117//! Returns true if the expression will be lowered to
118//! a ldmatrix intrinsic.
119bool isLdMatrixOp(const Expr* expr);
120
121//! Returns true if the expression will be lowered to
122//! a cp.async intrinsic.
123bool isCpAsyncOp(const Expr* expr);
124
125//! Short-cut for detecting initialization for cpAsync op.
126bool isCpAsyncInit(const Expr* expr);
127
128//! Short-cut for matching a singleton expr in a if statement,
129//! which likely becomes a predicated instruction in ptx, eg.:
130//! if(...) {expr;}
131//! Returns the expr if it is this pattern.
132//! Returns nullptr if the pattern doesn't match.
133c10::optional<Expr*> getMaybePredicatedSingleton(Expr* expr);
134
135//! Short-cut for checking if the expression loads from global memory.
136bool isGlobalLoad(const Expr* expr);
137
138//! Short-cut for checking if the given expression initializes buffers
139//! for global memory load.
140bool isGlobalLoadInit(const Expr* expr);
141
142//! Returns true if the given expression fills the output
143//! tensor with a single scalar.
144bool isTensorScalarFillOp(const Expr* expr);
145
146//! Flattens all the scoped exprs, i.e. ForLoop and IfThenElse,
147//! and returns all the exprs in all scopes in the original
148//! linear textural order.
149TORCH_CUDA_CU_API std::vector<Expr*> flattenScopedExprs(
150 const std::vector<Expr*>& loop_nests);
151
152//! Returns all swizzle ops between the set of iterdomains
153//! in `from` and `to`.
154std::vector<Expr*> getAllSwizzlesBetween(
155 std::vector<IterDomain*> from,
156 std::vector<IterDomain*> to);
157
158// Replace value pass on Kernel IR.
159// Replace each use of any Val* that apears in the given `replacement_map`
160// Keeps the predicate carried by each expr
161//
162// Warning: Blindly replaces all use based on pointer
163// Warning: May invalidate indexing if replacing uses of allocated values
164std::vector<Expr*> replaceInputsInExpr(
165 const std::vector<Expr*>& exprs,
166 const std::unordered_map<Val*, Val*>& replacement_map);
167
168// Go through all expressions and compute a local ordering of loops. operator<
169// is implemented based on the concrete_id_dependencies analysis done. If
170// there's no dependency between two IDs then order doesn't mater, otherwise we
171// can tell which is inner most by checking if there's any dependency
172// relationships.
173//
174// Dependency relationships in concrete_id_dependencies has a "global" view in
175// the fusion, so it can resolve ordering by only looking at id's and the
176// dependency map.
177//
178// For example two expressions may have domains: [I0], [I1] Yet we
179// won't know the ordering unless we see a domain with: [I0, I1]. This happened
180// in advancedIndexing9 (also see AdvancedLowering6) test when merging T5 with
181// the group containing T10 (cache of T5, which is post broadcasted output) and
182// T6(pre broadcasted output).
183// T5 had the domain [0, 1, 2, 3, 4] produce at 3
184// T6 had the domain [0, 3, 4] compute at 3
185// Merging [0, 1, 2] and [0, 3, 4] resulted in the domain [0, 3, 4, 1, 2]
186//
187// If ID's are not in filter, we don't care about their ordering and ignore
188// them. This is because we're only focused on loops we will have to merge
189// across groups. If the domain is not in a produce at position in the producer
190// edges, or a compute at position in the consumer edges, the expressions we
191// look at may not have a unique ordering.
192
193struct TORCH_CUDA_CU_API IterDomainDependencySorter {
194 IterDomainDependencySorter(
195 const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>&
196 concrete_id_dependencies,
197 std::shared_ptr<const ComputeAtMap> compute_at_map)
198 : concrete_id_dependencies_(concrete_id_dependencies),
199 compute_at_map_(compute_at_map) {}
200
201 // Return true if id0 should be before id1
202 // Orders such that if x maps to {y}, x comes before y in final ordering.
203 inline bool operator()(IterDomain* id0, IterDomain* id1) {
204 auto concrete_id_0 =
205 compute_at_map_->getConcreteMappedID(id0, IdMappingMode::LOOP);
206 auto concrete_id_1 =
207 compute_at_map_->getConcreteMappedID(id1, IdMappingMode::LOOP);
208
209 if (concrete_id_dependencies_.find(concrete_id_0) !=
210 concrete_id_dependencies_.end()) {
211 const auto& dependencies_0 = concrete_id_dependencies_.at(concrete_id_0);
212 // if id0 depends on id1 it means id1 is inside id0, so id0 < id1
213 if (dependencies_0.count(concrete_id_1)) {
214 return true;
215 }
216 }
217
218 return false;
219 }
220
221 const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>&
222 concrete_id_dependencies_;
223 const std::shared_ptr<const ComputeAtMap> compute_at_map_;
224};
225
226} // namespace ir_utils
227
228namespace lower_utils {
229
230bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map);
231
232// Allocate global buffer for a grid communication calls, i.e. grid reduce, grid
233// welford reduce, grid broadcast.
234kir::Allocate* allocGlobalBufferForGridComm(
235 Val* buffer_size,
236 DataType dtype,
237 bool zero_init);
238
239struct BasicAllocInfo {
240 // The for loop that the initialization of this allocation must be
241 // placed in, nullptr if not within a loop
242 kir::ForLoop* init_for_loop = nullptr;
243
244 // Keep track of the actual allocation loop. This can be different
245 // from init_for_loop only with unswitched shared memory allocations,
246 // which are moved outer loops to avoid duplicated allocations. This means
247 // that the alloc position may be outside what's expected. Most applications
248 // outside lower_allocation is likely looking for init_for_loop which is
249 // more directly related to how large an allocation is and how it's used.
250 // (see issue #1133).
251 kir::ForLoop* alloc_for_loop = nullptr;
252
253 // The allocation position relative to buffer IDs, it could be outside the
254 // compute at position if it's shared memory with a compute at inside an
255 // unswitch
256 size_t alloc_pos = 0;
257};
258
259// Fill the above allocation struct based on provided information. id_map is
260// used if we're looking at a producer tensor but loops on a consumer tensor.
261BasicAllocInfo getAllocInformation(
262 const TensorView* tv,
263 const std::vector<kir::ForLoop*>& loops,
264 const std::unordered_map<IterDomain*, IterDomain*>& id_map = {},
265 bool use_id_map = false);
266
267//! Returns true if the expression has a variant that takes a predicate
268//! as an inline argument.
269bool supportInlinePredicate(Expr* expr);
270
271} // namespace lower_utils
272
273} // namespace cuda
274} // namespace fuser
275} // namespace jit
276} // namespace torch
277