1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <compute_at_map.h>
6#include <ir_all_nodes.h>
7#include <kernel.h>
8#include <kernel_ir.h>
9#include <lower_allocation.h>
10#include <lower_double_buffer.h>
11#include <lower_fused_reduction.h>
12#include <lower_index_hoist.h>
13#include <lower_predicate.h>
14#include <lower_predicate_elimination.h>
15#include <lower_shift.h>
16#include <lower_sync_information.h>
17#include <lower_thread_predicate.h>
18#include <lower_trivial_broadcast.h>
19#include <lower_trivial_reductions.h>
20#include <lower_warp_reduce.h>
21#include <non_divisible_split.h>
22#include <parallel_dimension_map.h>
23#include <partial_split_map.h>
24#include <root_domain_map.h>
25#include <vectorization_info.h>
26
27#include <memory>
28#include <ostream>
29#include <unordered_map>
30#include <unordered_set>
31
32namespace torch {
33namespace jit {
34namespace fuser {
35namespace cuda {
36
37// TODO: we frequently use pairwise root mapping from consumers to producers.
38// This information is implicitly in the computeAtMaps, but there's no isolated
39// container for this information that we can reuse. Would be nice to generate
40// such a structure and propagate it through lowering.
41// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
42class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
43 class KernelIrMapper;
44
45 public:
46 GpuLower() = delete;
47
48 // GpuLower lowers the provided fusion into a kernel which can be translated
49 // into cuda code. index_type allows to compile the kernel based on int32
50 // indexing instead of int64 for additional performance.
51 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
52 explicit GpuLower(Fusion* fusion, DataType index_type = DataType::Int) {
53 lower(fusion, index_type);
54 }
55
56 kir::Kernel* kernel() const;
57
58 //! Returns the currently active lowering object.
59 //! It's an error if no lowering is in progress.
60 static GpuLower* current();
61
62 //! Query if lowering is in progress
63 static bool hasCurrent();
64
65 std::shared_ptr<const ConcretizedBroadcastDomains>
66 concretizedBroadcastDomains() {
67 return concretized_broadcast_domains_;
68 }
69
70 const ThreadPredicateMap& threadPredMap() const {
71 return thread_pred_map_;
72 }
73
74 // Returns non-const reference. Necessary to reset a predicate flag
75 // when a broadcast expression is fused into a reduction.
76 ThreadPredicateMap& threadPredMap() {
77 return thread_pred_map_;
78 }
79
80 std::shared_ptr<const ComputeAtMap> caMap() const {
81 return std::const_pointer_cast<const ComputeAtMap>(compute_at_map_);
82 }
83
84 const TrivialReductionInfo& trivialReductionInfo() const {
85 return trivial_reduction_info_;
86 }
87
88 std::shared_ptr<const HaloInfo> haloInfo() const {
89 return std::const_pointer_cast<const HaloInfo>(halo_info_);
90 }
91
92 const ParallelDimensionMap& parallelDimensionMap() const {
93 return parallel_dimension_map_;
94 }
95
96 ParallelDimensionMap& parallelDimensionMap() {
97 return parallel_dimension_map_;
98 }
99
100 PredicateElimination& predicateElimination() {
101 return pred_elimination_;
102 }
103
104 const PredicateElimination& predicateElimination() const {
105 return pred_elimination_;
106 }
107
108 LocalAllocationInfoMap& localAllocationInfoMap() {
109 return local_allocation_info_map_;
110 }
111
112 const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const {
113 return warp_pad_info_;
114 }
115
116 PartialSplitMap& partialSplitMap() {
117 return partial_split_map_;
118 }
119
120 const PartialSplitMap& partialSplitMap() const {
121 return partial_split_map_;
122 }
123
124 auto& nonDivisibleSplitInfo() {
125 return non_divisible_split_info_;
126 }
127
128 const auto& nonDivisibleSplitInfo() const {
129 return non_divisible_split_info_;
130 }
131
132 const auto& divisbleSplitSet() const {
133 return divisible_splits_;
134 }
135
136 DoubleBufferInfo& doubleBufferInfo() {
137 return double_buffer_info_;
138 }
139
140 CommonIndexMap& commonIndexMap() {
141 return common_index_map_;
142 }
143
144 const auto& vectorizedAccesses() const {
145 return vectorized_accesses_;
146 }
147
148 auto& vectorizedAccesses() {
149 return vectorized_accesses_;
150 }
151
152 const auto& vectorizedSetInfo() const {
153 return vectorized_set_info_;
154 }
155
156 auto& vectorizedSetInfo() {
157 return vectorized_set_info_;
158 }
159
160 FusedReductionInfo& fusedReductionInfo() {
161 return fused_reduction_info_;
162 }
163
164 const SyncMap& syncMap() const {
165 return sync_map_;
166 }
167
168 kir::KernelPerformanceProfile& profile() {
169 return profile_;
170 }
171
172 // This is an interface to propagate information after expression
173 // replacement on the kernel IR. E.g.:
174 // for ...
175 // c = a + b (expr 0)
176 // after any pass that does replacement:
177 // for ...
178 // c1 = a1 + b1 (expr1)
179 // The previous analysis that was performed on expr0 might still
180 // be valid on expr1 but that info would be lost after replacement.
181 // This function provides an interface to manually update the info
182 // in any pass that performs replacement.
183 void propagateExprInfo(const Expr* old_expr, const Expr* new_expr);
184
185 private:
186 void lower(Fusion* fusion, DataType index_type);
187
188 // Goes through the parallelized iterdomains of the used TVs and find
189 // the parallel dimensions that need to be padded to a multiples of
190 // warp size.
191 void collectPaddedParallelDims();
192
193 private:
194 // Lowered Kernel IR
195 std::unique_ptr<kir::Kernel> kernel_;
196
197 // Some stateful information during lowering
198 // TODO: A lot of this information uses a define class then call build. It
199 // would be safer to wrap all of these in unique pointers and remove the build
200 // interface and default constructor. That way they couldn't be accessed
201 // without being initialized.
202 std::shared_ptr<const ConcretizedBroadcastDomains>
203 concretized_broadcast_domains_;
204 ThreadPredicateMap thread_pred_map_;
205 PredicateElimination pred_elimination_;
206 std::shared_ptr<ComputeAtMap> compute_at_map_;
207 TrivialReductionInfo trivial_reduction_info_;
208 std::shared_ptr<HaloInfo> halo_info_;
209 LocalAllocationInfoMap local_allocation_info_map_;
210 WarpPaddedParallelInfo warp_pad_info_;
211 ParallelDimensionMap parallel_dimension_map_;
212 PartialSplitMap partial_split_map_;
213 NonDivisibleSplitInfo non_divisible_split_info_;
214 DoubleBufferInfo double_buffer_info_;
215 CommonIndexMap common_index_map_;
216 FusedReductionInfo fused_reduction_info_;
217 SyncMap sync_map_;
218 kir::KernelPerformanceProfile profile_;
219 std::unordered_set<Split*> divisible_splits_;
220
221 // Track which tensor views are inputs or outputs of a vectorized operation
222 // and their maximum vectorized access size
223 // std::unordered_map<TensorView*, VectorizationInfo> vectorized_accesses_;
224 std::unordered_map<TensorView*, int> vectorized_accesses_;
225 // Info on each vectorized set op
226 std::vector<VectorizedSetInfo> vectorized_set_info_;
227
228 Fusion* fusion_ = nullptr;
229};
230
231} // namespace cuda
232} // namespace fuser
233} // namespace jit
234} // namespace torch
235