1#pragma once
2
3#include <index_compute.h>
4#include <kernel_ir.h>
5#include <lower_thread_predicate.h>
6#include <lower_utils.h>
7#include <root_domain_map.h>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14class PredicateCompute {
15 public:
16 // ignore_internal_syncthread_ops will prevent creation of predicates on
17 // block/grid broadcast/reduce as these have syncthread calls within them
18 // so all threads need to execute the function.
19 static Bool* getInlinePredicate(
20 const Expr* expr,
21 const std::vector<kir::ForLoop*>& loops,
22 Bool* thread_pred,
23 PredicateType pred_type);
24};
25
26//! Parallelized domains may need to be predicated with threading
27//! indices and IterDomain extents. For example, if a domain is
28//! parallelized by TIDx, when TIDx is not exact, i.e., it can be
29//! larger than the extents of domains parallelized by TIDx,
30//! threadIdx.x may be larger than the IterDomain extent. This can be
31//! harmless for Local tensors, however, for it would
32//! result in out-of-bounds access for Shared tensors as they are
33//! allocated based on tensor shapes rather than threading
34//! dimensions.
35class ParallelizedDomainPredicate {
36 public:
37 //! Predicate information for parallelized domains
38 class PredicateInfo {
39 public:
40 explicit PredicateInfo(ParallelType pt) : pt_(pt) {}
41
42 //! Adds a domain that is parallized by the same paralell type
43 bool addDomain(IterDomain* id);
44
45 const std::vector<IterDomain*>& ids() const {
46 return ids_;
47 }
48
49 //! Generates a predicate Val from predicate information
50 Bool* getPredicate() const;
51
52 private:
53 ParallelType pt_;
54 //! Domains parallelized by the same parallel type
55 std::vector<IterDomain*> ids_;
56 };
57
58 //! Returns a predicate Val for parallelied domains of an expression.
59 static Bool* getPredicate(
60 const Expr* expr,
61 const std::vector<kir::ForLoop*>& loops);
62
63 //! Returns predicate information for parallelied domains of an
64 //! expression.
65 static std::unordered_map<ParallelType, PredicateInfo, TypeHash>
66 getPredicateMap(
67 const Expr* expr,
68 const std::vector<kir::ForLoop*>& loops,
69 kir::ForLoop* unswitched_loop = nullptr);
70};
71
72//! Keys to identify unique unswitch predicates. Just consists of a
73//! predicated concrete domain if not parallelized. If parallelized,
74//! pick one for each different parallelization. When the same
75//! parallel type is used for different concrete domains, they are
76//! considered different predicates and are included in the unswitch
77//! condition lists.
78class UnswitchPredicateKey {
79 public:
80 UnswitchPredicateKey();
81
82 UnswitchPredicateKey(
83 IterDomain* predicated_consumer_id,
84 TensorView* consumer_tv,
85 IterDomain* predicated_concrete_id);
86
87 bool operator==(const UnswitchPredicateKey& other) const {
88 return predicated_concrete_id_ == other.predicated_concrete_id_ &&
89 parallel_concrete_ids_ == other.parallel_concrete_ids_;
90 }
91
92 const auto& predicatedId() const {
93 return predicated_concrete_id_;
94 }
95
96 const auto& parallelConcreteIds() const {
97 return parallel_concrete_ids_;
98 }
99
100 IterDomain* parallelId(ParallelType pt) const {
101 auto it = parallelConcreteIds().find(pt);
102 if (it == parallelConcreteIds().end()) {
103 return nullptr;
104 } else {
105 return it->second;
106 }
107 }
108
109 std::string toString() const;
110
111 private:
112 //! Predicated concrete domain
113 IterDomain* predicated_concrete_id_ = nullptr;
114 //! Store parallelized concrete domains
115 std::unordered_map<ParallelType, IterDomain*, TypeHash>
116 parallel_concrete_ids_;
117};
118
119struct UnswitchPredicateKeyHash {
120 std::size_t operator()(const UnswitchPredicateKey& key) const;
121};
122
123class TORCH_CUDA_CU_API UnswitchPredicate {
124 public:
125 static Bool* get(
126 const std::vector<kir::ForLoop*>& outer_loops,
127 kir::ForLoop* unrolled_loop);
128
129 private:
130 //! Predicate information for each UnswitchPredicateKey.
131 struct MergedPredicates {
132 //! Predicate information for the start and stop predicates.
133 struct Info {
134 //! Most restrictive static predicate. Nullptr if no static
135 //! predicate found.
136 Bool* static_pred = nullptr;
137 //! The offset value of static_pred
138 int64_t static_offset = 0;
139 //! List of dynamic predicates.
140 std::vector<Bool*> dynamic_preds;
141 };
142 UnswitchPredicateKey predicate_key;
143 Info start;
144 Info stop;
145 };
146
147 UnswitchPredicate(
148 std::vector<kir::ForLoop*> outer_loops,
149 kir::ForLoop* unrolled_loop);
150
151 void predicateOn(Expr*);
152
153 void openLoop(kir::ForLoop*);
154
155 void openIte(kir::IfThenElse*);
156
157 //! Generates the final predicates from the predicated_keys map
158 void finalize();
159
160 //! Merge predicates as much as possible. If a predicate offset is
161 //! static, only pick the most restrictive one, e.g., the one with the
162 //! minimum offset for the start predication.
163 void mergeUnswitchPredicateOffsets(
164 Bool* predicate,
165 Val* offset,
166 MergedPredicates::Info& merged_predicate_info,
167 bool is_start);
168
169 //! Adds new predicates for parallelized domains
170 void addParallelizedDomainPredicates(Expr*);
171
172 private:
173 //! Track which iter domains have been predicated
174 std::unordered_set<UnswitchPredicateKey, UnswitchPredicateKeyHash>
175 predicated_keys_;
176
177 //! The predicates that have been recorded but not yet finalized
178 std::vector<MergedPredicates> pending_predicates_;
179
180 //! Track which parallelized domains have been predicated
181 std::unordered_map<
182 ParallelType,
183 ParallelizedDomainPredicate::PredicateInfo,
184 TypeHash>
185 parallelized_dom_predicates_;
186
187 //! The predicates that have been generated.
188 std::vector<Bool*> predicates_;
189
190 std::vector<kir::ForLoop*> for_loops_;
191
192 kir::ForLoop* unrolled_loop_;
193};
194
195} // namespace cuda
196} // namespace fuser
197} // namespace jit
198} // namespace torch
199