1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | |
5 | #include <c10/core/DeviceType.h> |
6 | #include <c10/util/Exception.h> |
7 | |
8 | #include <cuda.h> |
9 | |
10 | #include <torch/csrc/jit/ir/ir.h> |
11 | |
12 | #include <executor_kernel_arg.h> |
13 | #include <expr_evaluator.h> |
14 | #include <fusion.h> |
15 | #include <ir_all_nodes.h> |
16 | #include <kernel.h> |
17 | #include <kernel_expr_evaluator.h> |
18 | #include <lower2device.h> |
19 | |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | namespace torch { |
24 | namespace jit { |
25 | namespace fuser { |
26 | namespace cuda { |
27 | namespace executor_utils { |
28 | |
29 | // Include all the functions we might need in generated code |
30 | std::string kernelPreamble(); |
31 | |
32 | void validateKernelInputs( |
33 | Fusion* fusion, |
34 | const KernelArgumentHolder& args, |
35 | const c10::Device& device); |
36 | |
37 | void validateKernelOutputs( |
38 | Fusion* fusion, |
39 | const std::vector<at::Tensor>& outputs, |
40 | const c10::Device& device); |
41 | |
42 | //! Bind kernel input values to runtime values |
43 | kir::ExpressionEvaluator bindKernelInputs( |
44 | const KernelArgumentHolder& args, |
45 | kir::Kernel* kernel, |
46 | bool check_consistency = true); |
47 | |
48 | //! Bind fusion input values to runtime values |
49 | TORCH_CUDA_CU_API ExpressionEvaluator |
50 | bindFusionInputs(const KernelArgumentHolder& args, Fusion* fusion); |
51 | |
52 | struct NvrtcFunction { |
53 | CUmodule module = CUmodule(); |
54 | CUfunction function = CUfunction(); |
55 | }; |
56 | |
57 | // Returns executable function and the ptxas log from compilation |
58 | std::pair<NvrtcFunction, std::string> nvrtcCompile( |
59 | const std::string& code, |
60 | const std::string& func_name, |
61 | int id, |
62 | c10::optional<int> opt_block_size = c10::nullopt); |
63 | |
64 | namespace caching { |
65 | // TODO: Could consider putting some of |
66 | // the logic in the common space and re-use |
67 | |
68 | //! List of all the possible entry types in |
69 | //! `FusionExecutor` compile-time data cache. |
70 | enum class CompileTimeEntryType { |
71 | PARALLEL_BINDING_ITERDOMAINS, |
72 | PARALLEL_ITER_EXTENT_MAP, |
73 | SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP, |
74 | WARP_PADDED_PARALLEL_EXTENTS, |
75 | VECTORIZED_TENSOR_VALIDATION, |
76 | INPUT_ALIAS_INDICES, |
77 | OUTPUT_ALIAS_INDICES |
78 | }; |
79 | |
80 | //! Entry class definitions for each entry type: |
81 | //! each class defines the data type for each entry type |
82 | |
83 | //! Compile-time info to be cached in each FusionExecutor: |
84 | //! ParallelBindingIterDomains: |
85 | //! Stores all the iterdomains that are parallelized |
86 | //! on the scheduled Fusion graph. They will be used |
87 | //! in launch param iteration and their extents may |
88 | //! come from launch constraints. |
89 | class ParallelBindingIterDomains { |
90 | public: |
91 | using DataType = std::vector<IterDomain*>; |
92 | static const CompileTimeEntryType EntryType = |
93 | CompileTimeEntryType::PARALLEL_BINDING_ITERDOMAINS; |
94 | }; |
95 | |
96 | //! Compile-time info to be cached in each FusionExecutor: |
97 | //! ParallelIterExtentMap |
98 | //! Stores the symbolic extents of all the parallelized |
99 | //! iterdomains corresponding to each used parallel type. |
100 | class ParallelIterExtentMap { |
101 | public: |
102 | using DataType = |
103 | std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>; |
104 | static const CompileTimeEntryType EntryType = |
105 | CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; |
106 | }; |
107 | |
108 | //! Compile-time info to be cached in each FusionExecutor: |
109 | //! SimplifiedParallelIterExtentMap |
110 | //! This entry type is a simplified version of ParallelIterExtentMap. |
111 | //! |
112 | //! For launch parameter binding we only need the most concrete iterdomain |
113 | //! in each disjoint set stored in CaParallelMap. This entry stores the |
114 | //! remaining list of extents for binding after this simplification. |
115 | //! |
116 | //! We still need ParallelIterExtentMap since we want to bind the concrete |
117 | //! values to the extents of all parallelized iterdomains. We would be |
118 | //! able to save these bindings if the integer machine has a notion of |
119 | //! equality and could be configured compile time. But that'd be a longer |
120 | //! term target. |
121 | class SimplifiedParallelIterExtentMap { |
122 | public: |
123 | using DataType = |
124 | std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>; |
125 | static const CompileTimeEntryType EntryType = |
126 | CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP; |
127 | }; |
128 | |
129 | //! WarpPaddedExtentsInfo: |
130 | //! Auxiliary data type for entry class WarpPaddedParallelExtents |
131 | struct WarpPaddedExtentsInfo { |
132 | std::unordered_set<const Val*> warp_padded_extent_set; |
133 | std::unordered_map<const Val*, int64_t> warp_padded_constant; |
134 | }; |
135 | |
136 | //! Compile-time info to be cached in each FusionExecutor: |
137 | //! WarpPaddedParallelExtents |
138 | //! Stores the symbolic and constant extents of warp |
139 | //! padded parallel iterdomains. |
140 | class WarpPaddedParallelExtents { |
141 | public: |
142 | using DataType = WarpPaddedExtentsInfo; |
143 | static const CompileTimeEntryType EntryType = |
144 | CompileTimeEntryType::WARP_PADDED_PARALLEL_EXTENTS; |
145 | }; |
146 | |
147 | //! VectorizedTensorInfo: |
148 | //! Auxiliary data type for entry class VectorizedTensorValidation |
149 | struct VectorizedTensorInfo { |
150 | //! Aligned vectorized fusion inputs |
151 | std::vector<int> aligned_vectorized_inp_tensor_pos; |
152 | //! Aligned vectorized fusion outputs |
153 | std::vector<int> aligned_vectorized_out_tensor_pos; |
154 | //! Misaligned vectorized input tensors |
155 | std::unordered_set<TensorView*> global_inp_misaligned_tv; |
156 | //! Misaligned vectorized output tensors |
157 | std::unordered_set<TensorView*> global_out_misaligned_tv; |
158 | //! Positions of misaligned input tensors |
159 | std::vector<int> inp_misaligned_tensors_pos; |
160 | //! Positions of misaligned output tensors |
161 | std::vector<int> out_misaligned_tensors_pos; |
162 | }; |
163 | |
164 | //! Compile-time info to be cached in each FusionExecutor: |
165 | //! VectorizedTensorValidation |
166 | //! Stores position info and vector word sizes of |
167 | //! vectorized input/output tensors, to be used |
168 | //! in misaligned vectorization validation. |
169 | class VectorizedTensorValidation { |
170 | public: |
171 | using DataType = VectorizedTensorInfo; |
172 | static const CompileTimeEntryType EntryType = |
173 | CompileTimeEntryType::VECTORIZED_TENSOR_VALIDATION; |
174 | }; |
175 | |
176 | //! Compile-time info to be cached in each FusionExecutor: |
177 | //! InputAliasIndices |
178 | //! Stores position info of aliased input tensors |
179 | class InputAliasIndices { |
180 | public: |
181 | using DataType = std::vector<std::pair<int, int>>; |
182 | static const CompileTimeEntryType EntryType = |
183 | CompileTimeEntryType::INPUT_ALIAS_INDICES; |
184 | }; |
185 | |
186 | //! Compile-time info to be cached in each FusionExecutor: |
187 | //! OutputAliasIndices |
188 | //! Stores position info of aliased output tensors |
189 | class OutputAliasIndices { |
190 | public: |
191 | using DataType = std::unordered_set<int>; |
192 | static const CompileTimeEntryType EntryType = |
193 | CompileTimeEntryType::OUTPUT_ALIAS_INDICES; |
194 | }; |
195 | |
196 | //! Base abstract class for unified storage in `ExecutorCompileTimeInfoCache`, |
197 | //! each entry in `ExecutorCompileTimeInfoCache` will be a subclass. |
198 | class CompileTimeInfoBase : public PolymorphicBase { |
199 | public: |
200 | CompileTimeInfoBase(CompileTimeEntryType entry_type) |
201 | : entry_type_(entry_type) {} |
202 | CompileTimeEntryType type() { |
203 | return entry_type_; |
204 | } |
205 | |
206 | private: |
207 | CompileTimeEntryType entry_type_; |
208 | }; |
209 | |
210 | // Note: Do NOT export this class. MSVC issue with exported class that contains |
211 | // std::vector<unique_ptr<xxx>>: https://godbolt.org/z/3E4e8T1P1 |
212 | //! Compile-time information cache |
213 | class ExecutorCompileTimeInfoCache { |
214 | using Entry = CompileTimeInfoBase; |
215 | using EntryOwningPtr = std::unique_ptr<Entry>; |
216 | using EntryPtr = Entry*; |
217 | using EntryType = CompileTimeEntryType; |
218 | |
219 | public: |
220 | void insert(EntryOwningPtr new_entry); |
221 | |
222 | EntryPtr at(EntryType entry_type) { |
223 | return entry_type_map_.at(entry_type); |
224 | } |
225 | |
226 | bool has(EntryType entry_type) { |
227 | return entry_type_map_.count(entry_type); |
228 | } |
229 | |
230 | private: |
231 | std::vector<EntryOwningPtr> entries_; |
232 | std::unordered_map<EntryType, EntryPtr> entry_type_map_; |
233 | }; |
234 | |
235 | //! A utility class to facilitate accessing ExecutorCompileTimeInfoCache. |
236 | template <typename EntryClass> |
237 | class ExecutorCompileTimeEntry { |
238 | using EntryDataType = typename EntryClass::DataType; |
239 | using EntryDataTypeOwnPtr = std::unique_ptr<EntryDataType>; |
240 | using MakerFnType = std::function<EntryDataTypeOwnPtr()>; |
241 | |
242 | public: |
243 | //! Creates a data entry with type defined in EntryClass, |
244 | //! eg. EntryClass = VectorizableInputsAndOutputs; |
245 | //! |
246 | //! @param data_cache, a pointer to an instantiated compile-time |
247 | //! info cache. The info data will be |
248 | //! 1. read from data cache if data cache has the corresponding entry. |
249 | //! 2. written into data cache if data cache doesn't have the entry. |
250 | //! 3. managed by owned_data_ if data cache is nullptr |
251 | //! @param fn: |
252 | //! The factory function that needs to return a owning pointer |
253 | //! i.e. std::unique_ptr<EntryClass::DataType>. It will only |
254 | //! be called either when data cache is missing an entry or when no data |
255 | //! cache is given. |
256 | ExecutorCompileTimeEntry( |
257 | ExecutorCompileTimeInfoCache* data_cache, |
258 | MakerFnType fn); |
259 | |
260 | //! Unified interface to get actual data, either from cache |
261 | //! or from factory function. |
262 | EntryDataType& get() { |
263 | return *data_ptr_; |
264 | } |
265 | |
266 | private: |
267 | //! Internal data owing pointer that will manage the computed |
268 | //! data where there is no data cache. |
269 | EntryDataTypeOwnPtr owned_data_ = nullptr; |
270 | |
271 | //! Pointer to the valid data entry that could be accessed. |
272 | EntryDataType* data_ptr_ = nullptr; |
273 | }; |
274 | |
275 | } // namespace caching |
276 | |
277 | //! Returns the vector of tensorviews that will be used to bind parallel |
278 | //! dimensions. |
279 | std::vector<IterDomain*> getParallelBindingsIterDomains( |
280 | GpuLower* lower, |
281 | const std::vector<TensorView*>& used_tvs); |
282 | |
283 | using ParallelExtentMap = |
284 | std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>; |
285 | |
286 | //! Returns the extents of all parallel binding iterdomains corresponding |
287 | //! to each parallel type. |
288 | std::unique_ptr<ParallelExtentMap> getParallelIterExtents( |
289 | std::vector<IterDomain*>& parallel_binding_ids); |
290 | |
291 | //! Returns the simplified set of extents necessary for launch parameter |
292 | //! binding. |
293 | std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents( |
294 | GpuLower* lower, |
295 | std::vector<IterDomain*>& parallel_binding_ids); |
296 | |
297 | //! Returns the symbolic or constant extetns of warp padded parallel |
298 | //! iterdomains in the given vector. |
299 | std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo( |
300 | kir::Kernel* lower, |
301 | std::vector<IterDomain*>& parallel_binding_ids); |
302 | |
303 | void validateVectorizedTensors( |
304 | kir::Kernel* kernel, |
305 | const KernelArgumentHolder& args, |
306 | const std::vector<at::Tensor>& outputs, |
307 | caching::ExecutorCompileTimeInfoCache* data_cache, |
308 | kir::ExpressionEvaluator& expr_eval); |
309 | |
310 | } // namespace executor_utils |
311 | } // namespace cuda |
312 | } // namespace fuser |
313 | } // namespace jit |
314 | } // namespace torch |
315 | |