1#pragma once
2
3#include <ATen/core/ivalue.h>
4
5#include <c10/core/DeviceType.h>
6#include <c10/util/Exception.h>
7
8#include <cuda.h>
9
10#include <torch/csrc/jit/ir/ir.h>
11
12#include <executor_kernel_arg.h>
13#include <expr_evaluator.h>
14#include <fusion.h>
15#include <ir_all_nodes.h>
16#include <kernel.h>
17#include <kernel_expr_evaluator.h>
18#include <lower2device.h>
19
20#include <string>
21#include <vector>
22
23namespace torch {
24namespace jit {
25namespace fuser {
26namespace cuda {
27namespace executor_utils {
28
29// Include all the functions we might need in generated code
30std::string kernelPreamble();
31
32void validateKernelInputs(
33 Fusion* fusion,
34 const KernelArgumentHolder& args,
35 const c10::Device& device);
36
37void validateKernelOutputs(
38 Fusion* fusion,
39 const std::vector<at::Tensor>& outputs,
40 const c10::Device& device);
41
42//! Bind kernel input values to runtime values
43kir::ExpressionEvaluator bindKernelInputs(
44 const KernelArgumentHolder& args,
45 kir::Kernel* kernel,
46 bool check_consistency = true);
47
48//! Bind fusion input values to runtime values
49TORCH_CUDA_CU_API ExpressionEvaluator
50bindFusionInputs(const KernelArgumentHolder& args, Fusion* fusion);
51
52struct NvrtcFunction {
53 CUmodule module = CUmodule();
54 CUfunction function = CUfunction();
55};
56
57// Returns executable function and the ptxas log from compilation
58std::pair<NvrtcFunction, std::string> nvrtcCompile(
59 const std::string& code,
60 const std::string& func_name,
61 int id,
62 c10::optional<int> opt_block_size = c10::nullopt);
63
64namespace caching {
65// TODO: Could consider putting some of
66// the logic in the common space and re-use
67
68//! List of all the possible entry types in
69//! `FusionExecutor` compile-time data cache.
70enum class CompileTimeEntryType {
71 PARALLEL_BINDING_ITERDOMAINS,
72 PARALLEL_ITER_EXTENT_MAP,
73 SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP,
74 WARP_PADDED_PARALLEL_EXTENTS,
75 VECTORIZED_TENSOR_VALIDATION,
76 INPUT_ALIAS_INDICES,
77 OUTPUT_ALIAS_INDICES
78};
79
80//! Entry class definitions for each entry type:
81//! each class defines the data type for each entry type
82
83//! Compile-time info to be cached in each FusionExecutor:
84//! ParallelBindingIterDomains:
85//! Stores all the iterdomains that are parallelized
86//! on the scheduled Fusion graph. They will be used
87//! in launch param iteration and their extents may
88//! come from launch constraints.
89class ParallelBindingIterDomains {
90 public:
91 using DataType = std::vector<IterDomain*>;
92 static const CompileTimeEntryType EntryType =
93 CompileTimeEntryType::PARALLEL_BINDING_ITERDOMAINS;
94};
95
96//! Compile-time info to be cached in each FusionExecutor:
97//! ParallelIterExtentMap
98//! Stores the symbolic extents of all the parallelized
99//! iterdomains corresponding to each used parallel type.
100class ParallelIterExtentMap {
101 public:
102 using DataType =
103 std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
104 static const CompileTimeEntryType EntryType =
105 CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP;
106};
107
108//! Compile-time info to be cached in each FusionExecutor:
109//! SimplifiedParallelIterExtentMap
110//! This entry type is a simplified version of ParallelIterExtentMap.
111//!
112//! For launch parameter binding we only need the most concrete iterdomain
113//! in each disjoint set stored in CaParallelMap. This entry stores the
114//! remaining list of extents for binding after this simplification.
115//!
116//! We still need ParallelIterExtentMap since we want to bind the concrete
117//! values to the extents of all parallelized iterdomains. We would be
118//! able to save these bindings if the integer machine has a notion of
119//! equality and could be configured compile time. But that'd be a longer
120//! term target.
121class SimplifiedParallelIterExtentMap {
122 public:
123 using DataType =
124 std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
125 static const CompileTimeEntryType EntryType =
126 CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP;
127};
128
129//! WarpPaddedExtentsInfo:
130//! Auxiliary data type for entry class WarpPaddedParallelExtents
131struct WarpPaddedExtentsInfo {
132 std::unordered_set<const Val*> warp_padded_extent_set;
133 std::unordered_map<const Val*, int64_t> warp_padded_constant;
134};
135
136//! Compile-time info to be cached in each FusionExecutor:
137//! WarpPaddedParallelExtents
138//! Stores the symbolic and constant extents of warp
139//! padded parallel iterdomains.
140class WarpPaddedParallelExtents {
141 public:
142 using DataType = WarpPaddedExtentsInfo;
143 static const CompileTimeEntryType EntryType =
144 CompileTimeEntryType::WARP_PADDED_PARALLEL_EXTENTS;
145};
146
147//! VectorizedTensorInfo:
148//! Auxiliary data type for entry class VectorizedTensorValidation
149struct VectorizedTensorInfo {
150 //! Aligned vectorized fusion inputs
151 std::vector<int> aligned_vectorized_inp_tensor_pos;
152 //! Aligned vectorized fusion outputs
153 std::vector<int> aligned_vectorized_out_tensor_pos;
154 //! Misaligned vectorized input tensors
155 std::unordered_set<TensorView*> global_inp_misaligned_tv;
156 //! Misaligned vectorized output tensors
157 std::unordered_set<TensorView*> global_out_misaligned_tv;
158 //! Positions of misaligned input tensors
159 std::vector<int> inp_misaligned_tensors_pos;
160 //! Positions of misaligned output tensors
161 std::vector<int> out_misaligned_tensors_pos;
162};
163
164//! Compile-time info to be cached in each FusionExecutor:
165//! VectorizedTensorValidation
166//! Stores position info and vector word sizes of
167//! vectorized input/output tensors, to be used
168//! in misaligned vectorization validation.
169class VectorizedTensorValidation {
170 public:
171 using DataType = VectorizedTensorInfo;
172 static const CompileTimeEntryType EntryType =
173 CompileTimeEntryType::VECTORIZED_TENSOR_VALIDATION;
174};
175
176//! Compile-time info to be cached in each FusionExecutor:
177//! InputAliasIndices
178//! Stores position info of aliased input tensors
179class InputAliasIndices {
180 public:
181 using DataType = std::vector<std::pair<int, int>>;
182 static const CompileTimeEntryType EntryType =
183 CompileTimeEntryType::INPUT_ALIAS_INDICES;
184};
185
186//! Compile-time info to be cached in each FusionExecutor:
187//! OutputAliasIndices
188//! Stores position info of aliased output tensors
189class OutputAliasIndices {
190 public:
191 using DataType = std::unordered_set<int>;
192 static const CompileTimeEntryType EntryType =
193 CompileTimeEntryType::OUTPUT_ALIAS_INDICES;
194};
195
196//! Base abstract class for unified storage in `ExecutorCompileTimeInfoCache`,
197//! each entry in `ExecutorCompileTimeInfoCache` will be a subclass.
198class CompileTimeInfoBase : public PolymorphicBase {
199 public:
200 CompileTimeInfoBase(CompileTimeEntryType entry_type)
201 : entry_type_(entry_type) {}
202 CompileTimeEntryType type() {
203 return entry_type_;
204 }
205
206 private:
207 CompileTimeEntryType entry_type_;
208};
209
210// Note: Do NOT export this class. MSVC issue with exported class that contains
211// std::vector<unique_ptr<xxx>>: https://godbolt.org/z/3E4e8T1P1
212//! Compile-time information cache
213class ExecutorCompileTimeInfoCache {
214 using Entry = CompileTimeInfoBase;
215 using EntryOwningPtr = std::unique_ptr<Entry>;
216 using EntryPtr = Entry*;
217 using EntryType = CompileTimeEntryType;
218
219 public:
220 void insert(EntryOwningPtr new_entry);
221
222 EntryPtr at(EntryType entry_type) {
223 return entry_type_map_.at(entry_type);
224 }
225
226 bool has(EntryType entry_type) {
227 return entry_type_map_.count(entry_type);
228 }
229
230 private:
231 std::vector<EntryOwningPtr> entries_;
232 std::unordered_map<EntryType, EntryPtr> entry_type_map_;
233};
234
235//! A utility class to facilitate accessing ExecutorCompileTimeInfoCache.
236template <typename EntryClass>
237class ExecutorCompileTimeEntry {
238 using EntryDataType = typename EntryClass::DataType;
239 using EntryDataTypeOwnPtr = std::unique_ptr<EntryDataType>;
240 using MakerFnType = std::function<EntryDataTypeOwnPtr()>;
241
242 public:
243 //! Creates a data entry with type defined in EntryClass,
244 //! eg. EntryClass = VectorizableInputsAndOutputs;
245 //!
246 //! @param data_cache, a pointer to an instantiated compile-time
247 //! info cache. The info data will be
248 //! 1. read from data cache if data cache has the corresponding entry.
249 //! 2. written into data cache if data cache doesn't have the entry.
250 //! 3. managed by owned_data_ if data cache is nullptr
251 //! @param fn:
252 //! The factory function that needs to return a owning pointer
253 //! i.e. std::unique_ptr<EntryClass::DataType>. It will only
254 //! be called either when data cache is missing an entry or when no data
255 //! cache is given.
256 ExecutorCompileTimeEntry(
257 ExecutorCompileTimeInfoCache* data_cache,
258 MakerFnType fn);
259
260 //! Unified interface to get actual data, either from cache
261 //! or from factory function.
262 EntryDataType& get() {
263 return *data_ptr_;
264 }
265
266 private:
267 //! Internal data owing pointer that will manage the computed
268 //! data where there is no data cache.
269 EntryDataTypeOwnPtr owned_data_ = nullptr;
270
271 //! Pointer to the valid data entry that could be accessed.
272 EntryDataType* data_ptr_ = nullptr;
273};
274
275} // namespace caching
276
277//! Returns the vector of tensorviews that will be used to bind parallel
278//! dimensions.
279std::vector<IterDomain*> getParallelBindingsIterDomains(
280 GpuLower* lower,
281 const std::vector<TensorView*>& used_tvs);
282
283using ParallelExtentMap =
284 std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
285
286//! Returns the extents of all parallel binding iterdomains corresponding
287//! to each parallel type.
288std::unique_ptr<ParallelExtentMap> getParallelIterExtents(
289 std::vector<IterDomain*>& parallel_binding_ids);
290
291//! Returns the simplified set of extents necessary for launch parameter
292//! binding.
293std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents(
294 GpuLower* lower,
295 std::vector<IterDomain*>& parallel_binding_ids);
296
297//! Returns the symbolic or constant extetns of warp padded parallel
298//! iterdomains in the given vector.
299std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo(
300 kir::Kernel* lower,
301 std::vector<IterDomain*>& parallel_binding_ids);
302
303void validateVectorizedTensors(
304 kir::Kernel* kernel,
305 const KernelArgumentHolder& args,
306 const std::vector<at::Tensor>& outputs,
307 caching::ExecutorCompileTimeInfoCache* data_cache,
308 kir::ExpressionEvaluator& expr_eval);
309
310} // namespace executor_utils
311} // namespace cuda
312} // namespace fuser
313} // namespace jit
314} // namespace torch
315