1 | #pragma once |
---|---|
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <compute_at_map.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <kernel.h> |
8 | #include <kernel_ir.h> |
9 | #include <lower_allocation.h> |
10 | #include <lower_double_buffer.h> |
11 | #include <lower_fused_reduction.h> |
12 | #include <lower_index_hoist.h> |
13 | #include <lower_predicate.h> |
14 | #include <lower_predicate_elimination.h> |
15 | #include <lower_shift.h> |
16 | #include <lower_sync_information.h> |
17 | #include <lower_thread_predicate.h> |
18 | #include <lower_trivial_broadcast.h> |
19 | #include <lower_trivial_reductions.h> |
20 | #include <lower_warp_reduce.h> |
21 | #include <non_divisible_split.h> |
22 | #include <parallel_dimension_map.h> |
23 | #include <partial_split_map.h> |
24 | #include <root_domain_map.h> |
25 | #include <vectorization_info.h> |
26 | |
27 | #include <memory> |
28 | #include <ostream> |
29 | #include <unordered_map> |
30 | #include <unordered_set> |
31 | |
32 | namespace torch { |
33 | namespace jit { |
34 | namespace fuser { |
35 | namespace cuda { |
36 | |
37 | // TODO: we frequently use pairwise root mapping from consumers to producers. |
38 | // This information is implicitly in the computeAtMaps, but there's no isolated |
39 | // container for this information that we can reuse. Would be nice to generate |
40 | // such a structure and propagate it through lowering. |
41 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
42 | class TORCH_CUDA_CU_API GpuLower : public NonCopyable { |
43 | class KernelIrMapper; |
44 | |
45 | public: |
46 | GpuLower() = delete; |
47 | |
48 | // GpuLower lowers the provided fusion into a kernel which can be translated |
49 | // into cuda code. index_type allows to compile the kernel based on int32 |
50 | // indexing instead of int64 for additional performance. |
51 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
52 | explicit GpuLower(Fusion* fusion, DataType index_type = DataType::Int) { |
53 | lower(fusion, index_type); |
54 | } |
55 | |
56 | kir::Kernel* kernel() const; |
57 | |
58 | //! Returns the currently active lowering object. |
59 | //! It's an error if no lowering is in progress. |
60 | static GpuLower* current(); |
61 | |
62 | //! Query if lowering is in progress |
63 | static bool hasCurrent(); |
64 | |
65 | std::shared_ptr<const ConcretizedBroadcastDomains> |
66 | concretizedBroadcastDomains() { |
67 | return concretized_broadcast_domains_; |
68 | } |
69 | |
70 | const ThreadPredicateMap& threadPredMap() const { |
71 | return thread_pred_map_; |
72 | } |
73 | |
74 | // Returns non-const reference. Necessary to reset a predicate flag |
75 | // when a broadcast expression is fused into a reduction. |
76 | ThreadPredicateMap& threadPredMap() { |
77 | return thread_pred_map_; |
78 | } |
79 | |
80 | std::shared_ptr<const ComputeAtMap> caMap() const { |
81 | return std::const_pointer_cast<const ComputeAtMap>(compute_at_map_); |
82 | } |
83 | |
84 | const TrivialReductionInfo& trivialReductionInfo() const { |
85 | return trivial_reduction_info_; |
86 | } |
87 | |
88 | std::shared_ptr<const HaloInfo> haloInfo() const { |
89 | return std::const_pointer_cast<const HaloInfo>(halo_info_); |
90 | } |
91 | |
92 | const ParallelDimensionMap& parallelDimensionMap() const { |
93 | return parallel_dimension_map_; |
94 | } |
95 | |
96 | ParallelDimensionMap& parallelDimensionMap() { |
97 | return parallel_dimension_map_; |
98 | } |
99 | |
100 | PredicateElimination& predicateElimination() { |
101 | return pred_elimination_; |
102 | } |
103 | |
104 | const PredicateElimination& predicateElimination() const { |
105 | return pred_elimination_; |
106 | } |
107 | |
108 | LocalAllocationInfoMap& localAllocationInfoMap() { |
109 | return local_allocation_info_map_; |
110 | } |
111 | |
112 | const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const { |
113 | return warp_pad_info_; |
114 | } |
115 | |
116 | PartialSplitMap& partialSplitMap() { |
117 | return partial_split_map_; |
118 | } |
119 | |
120 | const PartialSplitMap& partialSplitMap() const { |
121 | return partial_split_map_; |
122 | } |
123 | |
124 | auto& nonDivisibleSplitInfo() { |
125 | return non_divisible_split_info_; |
126 | } |
127 | |
128 | const auto& nonDivisibleSplitInfo() const { |
129 | return non_divisible_split_info_; |
130 | } |
131 | |
132 | const auto& divisbleSplitSet() const { |
133 | return divisible_splits_; |
134 | } |
135 | |
136 | DoubleBufferInfo& doubleBufferInfo() { |
137 | return double_buffer_info_; |
138 | } |
139 | |
140 | CommonIndexMap& commonIndexMap() { |
141 | return common_index_map_; |
142 | } |
143 | |
144 | const auto& vectorizedAccesses() const { |
145 | return vectorized_accesses_; |
146 | } |
147 | |
148 | auto& vectorizedAccesses() { |
149 | return vectorized_accesses_; |
150 | } |
151 | |
152 | const auto& vectorizedSetInfo() const { |
153 | return vectorized_set_info_; |
154 | } |
155 | |
156 | auto& vectorizedSetInfo() { |
157 | return vectorized_set_info_; |
158 | } |
159 | |
160 | FusedReductionInfo& fusedReductionInfo() { |
161 | return fused_reduction_info_; |
162 | } |
163 | |
164 | const SyncMap& syncMap() const { |
165 | return sync_map_; |
166 | } |
167 | |
168 | kir::KernelPerformanceProfile& profile() { |
169 | return profile_; |
170 | } |
171 | |
172 | // This is an interface to propagate information after expression |
173 | // replacement on the kernel IR. E.g.: |
174 | // for ... |
175 | // c = a + b (expr 0) |
176 | // after any pass that does replacement: |
177 | // for ... |
178 | // c1 = a1 + b1 (expr1) |
179 | // The previous analysis that was performed on expr0 might still |
180 | // be valid on expr1 but that info would be lost after replacement. |
181 | // This function provides an interface to manually update the info |
182 | // in any pass that performs replacement. |
183 | void propagateExprInfo(const Expr* old_expr, const Expr* new_expr); |
184 | |
185 | private: |
186 | void lower(Fusion* fusion, DataType index_type); |
187 | |
188 | // Goes through the parallelized iterdomains of the used TVs and find |
189 | // the parallel dimensions that need to be padded to a multiples of |
190 | // warp size. |
191 | void collectPaddedParallelDims(); |
192 | |
193 | private: |
194 | // Lowered Kernel IR |
195 | std::unique_ptr<kir::Kernel> kernel_; |
196 | |
197 | // Some stateful information during lowering |
198 | // TODO: A lot of this information uses a define class then call build. It |
199 | // would be safer to wrap all of these in unique pointers and remove the build |
200 | // interface and default constructor. That way they couldn't be accessed |
201 | // without being initialized. |
202 | std::shared_ptr<const ConcretizedBroadcastDomains> |
203 | concretized_broadcast_domains_; |
204 | ThreadPredicateMap thread_pred_map_; |
205 | PredicateElimination pred_elimination_; |
206 | std::shared_ptr<ComputeAtMap> compute_at_map_; |
207 | TrivialReductionInfo trivial_reduction_info_; |
208 | std::shared_ptr<HaloInfo> halo_info_; |
209 | LocalAllocationInfoMap local_allocation_info_map_; |
210 | WarpPaddedParallelInfo warp_pad_info_; |
211 | ParallelDimensionMap parallel_dimension_map_; |
212 | PartialSplitMap partial_split_map_; |
213 | NonDivisibleSplitInfo non_divisible_split_info_; |
214 | DoubleBufferInfo double_buffer_info_; |
215 | CommonIndexMap common_index_map_; |
216 | FusedReductionInfo fused_reduction_info_; |
217 | SyncMap sync_map_; |
218 | kir::KernelPerformanceProfile profile_; |
219 | std::unordered_set<Split*> divisible_splits_; |
220 | |
221 | // Track which tensor views are inputs or outputs of a vectorized operation |
222 | // and their maximum vectorized access size |
223 | // std::unordered_map<TensorView*, VectorizationInfo> vectorized_accesses_; |
224 | std::unordered_map<TensorView*, int> vectorized_accesses_; |
225 | // Info on each vectorized set op |
226 | std::vector<VectorizedSetInfo> vectorized_set_info_; |
227 | |
228 | Fusion* fusion_ = nullptr; |
229 | }; |
230 | |
231 | } // namespace cuda |
232 | } // namespace fuser |
233 | } // namespace jit |
234 | } // namespace torch |
235 |