1 | #pragma once |
2 | #include <dynamic_type.h> |
3 | #include <executor_kernel_arg.h> |
4 | #include <executor_launch_params.h> |
5 | #include <fusion.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <lower2device.h> |
8 | |
9 | #include <c10/core/DeviceType.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | //! This is the common space for expression evaluators in |
17 | //! fusion IR and kernel IR context. Much of the evaluator |
18 | //! optimizations and runtimes could share the same code |
19 | //! path and they could be collected here. |
20 | |
21 | class ExpressionEvaluator; |
22 | |
23 | namespace kir { |
24 | |
25 | class ExpressionEvaluator; |
26 | |
27 | } // namespace kir |
28 | |
29 | //! IR Contexts to be passed to generic evaluator optimizations |
30 | //! and runtimes. Defines the essential interface for the |
31 | //! generic logic to get necessary type and function info |
32 | //! from the IR nodes. Generic optimizations will assume |
33 | //! the same list of static definitions are provided |
34 | //! in each of the contexts, just FusionIR and KernelIR |
35 | //! currently. |
36 | |
37 | //! Context for using generic logic on FusionIR |
38 | class FusionIRContext { |
39 | public: |
40 | using TV_TYPE = TensorView; |
41 | using EVALUATOR_TYPE = ExpressionEvaluator; |
42 | |
43 | static BinaryOpType getOpType(BinaryOp* bop) { |
44 | return bop->getBinaryOpType(); |
45 | } |
46 | |
47 | static UnaryOpType getOpType(UnaryOp* uop) { |
48 | return uop->getUnaryOpType(); |
49 | } |
50 | }; |
51 | |
52 | //! Context for using generic logic on KernelIR |
53 | class KernelIRContext { |
54 | public: |
55 | using EVALUATOR_TYPE = kir::ExpressionEvaluator; |
56 | |
57 | static BinaryOpType getOpType(BinaryOp* bop) { |
58 | return bop->getBinaryOpType(); |
59 | } |
60 | |
61 | static UnaryOpType getOpType(UnaryOp* uop) { |
62 | return uop->getUnaryOpType(); |
63 | } |
64 | }; |
65 | |
66 | template <typename IRContext> |
67 | class PrecomputedValuesBase; |
68 | |
69 | //! NaiveValueMachine: |
70 | //! This is an un-optimized runtime for evaluating a |
71 | //! set of values in one run. The runtime contains |
72 | //! a vector of instructions inferred from IR at compile-time |
73 | //! and it currently must be associated with an instance of |
74 | //! PrecomputedValuesBase that will provide the workspace |
75 | //! containing the concrete values for the values. |
76 | template <typename IRContext> |
77 | class NaiveValueMachine { |
78 | //! The generic types of instructions supported for this |
79 | //! machine, currently only binary and unary. |
80 | enum class InstructionType { UNARY_OP, BINARY_OP }; |
81 | |
82 | public: |
83 | //! Constructor lowers all the expr IR nodes stored in precomputed_values |
84 | //! and stores them in the private state. |
85 | NaiveValueMachine(PrecomputedValuesBase<IRContext>& precomputed_values); |
86 | |
87 | //! Runs all the instructions and write results to the associated |
88 | //! precomputed_values. |
89 | void run(); |
90 | |
91 | private: |
92 | //! Convert an unary IR expr to an instruction |
93 | void makeUnaryOp(UnaryOp* uop); |
94 | |
95 | //! Convert an binary IR expr to an instruction |
96 | void makeBinaryOp(BinaryOp* bop); |
97 | |
98 | //! Create an empty instruction with all default values |
99 | //! and place it at the end of the instruction buffer. |
100 | int makeInstructionEntry(); |
101 | |
102 | //! Run a single instruction at the given index of |
103 | //! the instruction buffer. Decodes and dispatches |
104 | //! to the corresponding instruction handle functions. |
105 | void runInstruction(int index); |
106 | |
107 | //! Runs a unary operation at given index of instruction buffer |
108 | void runUnaryOp(int index); |
109 | |
110 | //! Runs a binary operation at given index of instruction buffer |
111 | void runBinaryOp(int index); |
112 | |
113 | private: |
114 | friend PrecomputedValuesBase<IRContext>; |
115 | |
116 | //! Reference to the PrecomputedValues workspace associated with |
117 | //! this runtime. All the instructions will read and write the |
118 | //! values in this workspace. |
119 | PrecomputedValuesBase<IRContext>& precomputed_values_; |
120 | |
121 | //! Instruction buffer. All states are in separate vectors and |
122 | //! the entry of each vector at the same index correspond to |
123 | //! the same instruction. |
124 | |
125 | //! Total number of instructions |
126 | int num_of_instructions_ = 0; |
127 | |
128 | //! Machine instruction type for each instruction i.e. |
129 | //! unary or binary |
130 | std::vector<InstructionType> inst_type_; |
131 | |
132 | //! Unary operator type if applicable, contains a default |
133 | //! value at each index corresponding to a binary op. |
134 | std::vector<UnaryOpType> uop_type_; |
135 | |
136 | //! Data type for unary op of type UnaryOpType::Cast, contains a default |
137 | //! value at each index corresponding other ops. |
138 | std::vector<DataType> data_type_; |
139 | |
140 | //! Unary operator type if applicable, contains a default |
141 | //! value at each index corresponding to a unary op. |
142 | std::vector<BinaryOpType> bop_type_; |
143 | |
144 | //! Indexes of operands and destination of each instruction. |
145 | //! The indexes corresponds to positions in the workspace |
146 | //! where concrete values are hosted. |
147 | |
148 | //! Operand 0 of each instruction. |
149 | std::vector<int> src0_; |
150 | |
151 | //! Operand 1 of each instruction, a default value at |
152 | //! each index corresponding to a unary op. |
153 | std::vector<int> src1_; |
154 | |
155 | //! Destination of each instruction. |
156 | std::vector<int> dest_; |
157 | }; |
158 | |
159 | //! PrecomputedValuesBase: |
160 | //! A class to support optimized evaluation of values |
161 | //! at runtime. |
162 | //! At compile time all necessary values are collected |
163 | //! from given IR nodes and a runtime and a workspace containing |
164 | //! the concrete values is created and pre-allocated. |
165 | //! At runtime the value vm is used to evaluate all the |
166 | //! values and store them in the workspace ahead of time. |
167 | template <typename IRContext> |
168 | class PrecomputedValuesBase { |
169 | using VALUE_MACHINE = NaiveValueMachine<IRContext>; |
170 | |
171 | public: |
172 | explicit PrecomputedValuesBase() = default; |
173 | |
174 | //! Returns if the workspace contains evaluated results. |
175 | bool ready() { |
176 | return has_valid_values_; |
177 | } |
178 | |
179 | //! Runs the internal value machine that will compute |
180 | //! the values allocated in the workspace. |
181 | void evaluate(); |
182 | |
183 | //! Returns value for the given IR node if it's stored |
184 | //! in the workspace and has been evaluated. |
185 | c10::optional<IntOrDouble> getMaybeValueFor(const Val* val); |
186 | |
187 | //! Debugging helper, prints all the currently known values |
188 | void print() const; |
189 | |
190 | protected: |
191 | //! Initialize the workspace before first use. |
192 | //! Assume the given value list IR nodes have |
193 | //! been topologically sorted. |
194 | void initializeValueList( |
195 | typename IRContext::EVALUATOR_TYPE& evaluator, |
196 | const std::vector<Val*>& sorted_value_list); |
197 | |
198 | //! Bind concrete value to the given index |
199 | //! if the index is valid. |
200 | void bindValue(int index, IntOrDouble value) { |
201 | if (index < 0 || is_constant_[index]) { |
202 | return; |
203 | } |
204 | defined_[index] = true; |
205 | values_[index] = value; |
206 | binding_log_.emplace_back(index, value); |
207 | } |
208 | |
209 | //! Invalidate all computed values in the workspace. |
210 | void invalidate(); |
211 | |
212 | //! Interface for subclasses to access symbols_ |
213 | void loadSymbols(std::vector<Val*> symbols) { |
214 | symbols_ = std::move(symbols); |
215 | } |
216 | |
217 | //! Interface for subclasses to access symbols_ |
218 | std::vector<Val*>& symbols() { |
219 | return symbols_; |
220 | } |
221 | |
222 | //! Initialize the value runtime that will |
223 | //! infer instructions from the workspace. |
224 | void initializeIntegerMachine() { |
225 | value_machine_ = std::make_unique<VALUE_MACHINE>(*this); |
226 | } |
227 | |
228 | bool hasValidValues() { |
229 | return has_valid_values_; |
230 | } |
231 | |
232 | private: |
233 | //! Post evaluation check, throws if any computed value |
234 | //! is inconsistent with its bound value |
235 | void validate(); |
236 | |
237 | //! Returns true if workspace has a computed or constant |
238 | //! value for given index. |
239 | bool hasValue(int index) { |
240 | TORCH_INTERNAL_ASSERT(index > 0); |
241 | return defined_[index] || is_constant_[index]; |
242 | } |
243 | |
244 | private: |
245 | friend VALUE_MACHINE; |
246 | |
247 | //! Marks if an evaluation has finished |
248 | bool has_valid_values_ = false; |
249 | |
250 | //! The size of workspace |
251 | int num_of_values_ = -1; |
252 | |
253 | //! Marks if a value has been bound or |
254 | //! computed at each index. |
255 | std::vector<bool> defined_; |
256 | |
257 | //! Marks if a value is compile-time constant |
258 | //! at each index. |
259 | std::vector<bool> is_constant_; |
260 | |
261 | //! Stores the concrete values at each index. |
262 | std::vector<IntOrDouble> values_; |
263 | |
264 | //! Stores the IR nodes corresponding to each index. |
265 | std::vector<Val*> symbols_; |
266 | |
267 | //! An internal log to keep track of all the bindings |
268 | //! used in each evaluation cycle. To be used for |
269 | //! consistency check. |
270 | std::vector<std::pair<int, IntOrDouble>> binding_log_; |
271 | |
272 | //! Integer runtime for realizing the values computations. |
273 | std::unique_ptr<VALUE_MACHINE> value_machine_; |
274 | }; |
275 | |
276 | //! PrecomputedValues workspace in Fusion IR context, |
277 | //! defines the set of values to be collected in each |
278 | //! fusion graph and the input value binding given each |
279 | //! fusion runtime input. |
280 | class FusionPrecomputedValues : public PrecomputedValuesBase<FusionIRContext> { |
281 | using precomputedValuesBaseType = PrecomputedValuesBase<FusionIRContext>; |
282 | |
283 | public: |
284 | FusionPrecomputedValues(Fusion* fusion); |
285 | |
286 | //! Bind concrete values from fusion runtime inputs |
287 | void bindFusionInputs(const KernelArgumentHolder& args); |
288 | |
289 | private: |
290 | void bindTensorMetaData( |
291 | TensorView* tv, |
292 | const TensorArgAbstract* tensor_arg_abstract); |
293 | |
294 | private: |
295 | Fusion* fusion_ = nullptr; |
296 | }; |
297 | //! PrecomputedValues workspace in Fusion IR context, |
298 | //! defines the set of values to be collected in each |
299 | //! kernel IR sequence and the input value binding given each |
300 | //! fusion runtime input and launch constraints. |
301 | class KernelPrecomputedValues : public PrecomputedValuesBase<KernelIRContext> { |
302 | using precomputedValuesBaseType = PrecomputedValuesBase<KernelIRContext>; |
303 | |
304 | public: |
305 | using ParallelExtentMap = |
306 | std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>; |
307 | |
308 | KernelPrecomputedValues(kir::Kernel* kernel); |
309 | |
310 | //! Bind concrete values from fusion runtime inputs |
311 | void bindKernelInputs(kir::Kernel* kernel, const KernelArgumentHolder& args); |
312 | |
313 | //! Bind concrete values from launch constraints |
314 | void bindParallelExtents( |
315 | const ParallelExtentMap& parallel_extents, |
316 | const LaunchParams& launch_constraint); |
317 | |
318 | //! Bind the NamedScalars corresponding to the |
319 | //! concrete parallel dimension sizes after the |
320 | //! actual value has been resolved. |
321 | void bindConcreteParallelTypeValue(ParallelType pt, int64_t value); |
322 | |
323 | private: |
324 | void bindTensorMetaData( |
325 | TensorView* tv, |
326 | const TensorArgAbstract* tensor_arg_abstract); |
327 | |
328 | //! Iterate through all the named scalars corresponding |
329 | //! to thread sizes and pre-group them by their parallel |
330 | //! types. |
331 | void initializeNamedScalars(); |
332 | |
333 | private: |
334 | //! Contains all the named scalars correspond |
335 | //! to thread size of each parallel type. |
336 | std::unordered_map<ParallelType, std::unique_ptr<std::vector<int>>, TypeHash> |
337 | thread_dim_value_indices_; |
338 | }; |
339 | |
340 | } // namespace cuda |
341 | } // namespace fuser |
342 | } // namespace jit |
343 | } // namespace torch |
344 | |