1#pragma once
2#include <dynamic_type.h>
3#include <executor_kernel_arg.h>
4#include <executor_launch_params.h>
5#include <fusion.h>
6#include <ir_all_nodes.h>
7#include <lower2device.h>
8
9#include <c10/core/DeviceType.h>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16//! This is the common space for expression evaluators in
17//! fusion IR and kernel IR context. Much of the evaluator
18//! optimizations and runtimes could share the same code
19//! path and they could be collected here.
20
21class ExpressionEvaluator;
22
23namespace kir {
24
25class ExpressionEvaluator;
26
27} // namespace kir
28
29//! IR Contexts to be passed to generic evaluator optimizations
30//! and runtimes. Defines the essential interface for the
31//! generic logic to get necessary type and function info
32//! from the IR nodes. Generic optimizations will assume
33//! the same list of static definitions are provided
34//! in each of the contexts, just FusionIR and KernelIR
35//! currently.
36
37//! Context for using generic logic on FusionIR
38class FusionIRContext {
39 public:
40 using TV_TYPE = TensorView;
41 using EVALUATOR_TYPE = ExpressionEvaluator;
42
43 static BinaryOpType getOpType(BinaryOp* bop) {
44 return bop->getBinaryOpType();
45 }
46
47 static UnaryOpType getOpType(UnaryOp* uop) {
48 return uop->getUnaryOpType();
49 }
50};
51
52//! Context for using generic logic on KernelIR
53class KernelIRContext {
54 public:
55 using EVALUATOR_TYPE = kir::ExpressionEvaluator;
56
57 static BinaryOpType getOpType(BinaryOp* bop) {
58 return bop->getBinaryOpType();
59 }
60
61 static UnaryOpType getOpType(UnaryOp* uop) {
62 return uop->getUnaryOpType();
63 }
64};
65
66template <typename IRContext>
67class PrecomputedValuesBase;
68
69//! NaiveValueMachine:
70//! This is an un-optimized runtime for evaluating a
71//! set of values in one run. The runtime contains
72//! a vector of instructions inferred from IR at compile-time
73//! and it currently must be associated with an instance of
74//! PrecomputedValuesBase that will provide the workspace
75//! containing the concrete values for the values.
76template <typename IRContext>
77class NaiveValueMachine {
78 //! The generic types of instructions supported for this
79 //! machine, currently only binary and unary.
80 enum class InstructionType { UNARY_OP, BINARY_OP };
81
82 public:
83 //! Constructor lowers all the expr IR nodes stored in precomputed_values
84 //! and stores them in the private state.
85 NaiveValueMachine(PrecomputedValuesBase<IRContext>& precomputed_values);
86
87 //! Runs all the instructions and write results to the associated
88 //! precomputed_values.
89 void run();
90
91 private:
92 //! Convert an unary IR expr to an instruction
93 void makeUnaryOp(UnaryOp* uop);
94
95 //! Convert an binary IR expr to an instruction
96 void makeBinaryOp(BinaryOp* bop);
97
98 //! Create an empty instruction with all default values
99 //! and place it at the end of the instruction buffer.
100 int makeInstructionEntry();
101
102 //! Run a single instruction at the given index of
103 //! the instruction buffer. Decodes and dispatches
104 //! to the corresponding instruction handle functions.
105 void runInstruction(int index);
106
107 //! Runs a unary operation at given index of instruction buffer
108 void runUnaryOp(int index);
109
110 //! Runs a binary operation at given index of instruction buffer
111 void runBinaryOp(int index);
112
113 private:
114 friend PrecomputedValuesBase<IRContext>;
115
116 //! Reference to the PrecomputedValues workspace associated with
117 //! this runtime. All the instructions will read and write the
118 //! values in this workspace.
119 PrecomputedValuesBase<IRContext>& precomputed_values_;
120
121 //! Instruction buffer. All states are in separate vectors and
122 //! the entry of each vector at the same index correspond to
123 //! the same instruction.
124
125 //! Total number of instructions
126 int num_of_instructions_ = 0;
127
128 //! Machine instruction type for each instruction i.e.
129 //! unary or binary
130 std::vector<InstructionType> inst_type_;
131
132 //! Unary operator type if applicable, contains a default
133 //! value at each index corresponding to a binary op.
134 std::vector<UnaryOpType> uop_type_;
135
136 //! Data type for unary op of type UnaryOpType::Cast, contains a default
137 //! value at each index corresponding other ops.
138 std::vector<DataType> data_type_;
139
140 //! Unary operator type if applicable, contains a default
141 //! value at each index corresponding to a unary op.
142 std::vector<BinaryOpType> bop_type_;
143
144 //! Indexes of operands and destination of each instruction.
145 //! The indexes corresponds to positions in the workspace
146 //! where concrete values are hosted.
147
148 //! Operand 0 of each instruction.
149 std::vector<int> src0_;
150
151 //! Operand 1 of each instruction, a default value at
152 //! each index corresponding to a unary op.
153 std::vector<int> src1_;
154
155 //! Destination of each instruction.
156 std::vector<int> dest_;
157};
158
159//! PrecomputedValuesBase:
160//! A class to support optimized evaluation of values
161//! at runtime.
162//! At compile time all necessary values are collected
163//! from given IR nodes and a runtime and a workspace containing
164//! the concrete values is created and pre-allocated.
165//! At runtime the value vm is used to evaluate all the
166//! values and store them in the workspace ahead of time.
167template <typename IRContext>
168class PrecomputedValuesBase {
169 using VALUE_MACHINE = NaiveValueMachine<IRContext>;
170
171 public:
172 explicit PrecomputedValuesBase() = default;
173
174 //! Returns if the workspace contains evaluated results.
175 bool ready() {
176 return has_valid_values_;
177 }
178
179 //! Runs the internal value machine that will compute
180 //! the values allocated in the workspace.
181 void evaluate();
182
183 //! Returns value for the given IR node if it's stored
184 //! in the workspace and has been evaluated.
185 c10::optional<IntOrDouble> getMaybeValueFor(const Val* val);
186
187 //! Debugging helper, prints all the currently known values
188 void print() const;
189
190 protected:
191 //! Initialize the workspace before first use.
192 //! Assume the given value list IR nodes have
193 //! been topologically sorted.
194 void initializeValueList(
195 typename IRContext::EVALUATOR_TYPE& evaluator,
196 const std::vector<Val*>& sorted_value_list);
197
198 //! Bind concrete value to the given index
199 //! if the index is valid.
200 void bindValue(int index, IntOrDouble value) {
201 if (index < 0 || is_constant_[index]) {
202 return;
203 }
204 defined_[index] = true;
205 values_[index] = value;
206 binding_log_.emplace_back(index, value);
207 }
208
209 //! Invalidate all computed values in the workspace.
210 void invalidate();
211
212 //! Interface for subclasses to access symbols_
213 void loadSymbols(std::vector<Val*> symbols) {
214 symbols_ = std::move(symbols);
215 }
216
217 //! Interface for subclasses to access symbols_
218 std::vector<Val*>& symbols() {
219 return symbols_;
220 }
221
222 //! Initialize the value runtime that will
223 //! infer instructions from the workspace.
224 void initializeIntegerMachine() {
225 value_machine_ = std::make_unique<VALUE_MACHINE>(*this);
226 }
227
228 bool hasValidValues() {
229 return has_valid_values_;
230 }
231
232 private:
233 //! Post evaluation check, throws if any computed value
234 //! is inconsistent with its bound value
235 void validate();
236
237 //! Returns true if workspace has a computed or constant
238 //! value for given index.
239 bool hasValue(int index) {
240 TORCH_INTERNAL_ASSERT(index > 0);
241 return defined_[index] || is_constant_[index];
242 }
243
244 private:
245 friend VALUE_MACHINE;
246
247 //! Marks if an evaluation has finished
248 bool has_valid_values_ = false;
249
250 //! The size of workspace
251 int num_of_values_ = -1;
252
253 //! Marks if a value has been bound or
254 //! computed at each index.
255 std::vector<bool> defined_;
256
257 //! Marks if a value is compile-time constant
258 //! at each index.
259 std::vector<bool> is_constant_;
260
261 //! Stores the concrete values at each index.
262 std::vector<IntOrDouble> values_;
263
264 //! Stores the IR nodes corresponding to each index.
265 std::vector<Val*> symbols_;
266
267 //! An internal log to keep track of all the bindings
268 //! used in each evaluation cycle. To be used for
269 //! consistency check.
270 std::vector<std::pair<int, IntOrDouble>> binding_log_;
271
272 //! Integer runtime for realizing the values computations.
273 std::unique_ptr<VALUE_MACHINE> value_machine_;
274};
275
276//! PrecomputedValues workspace in Fusion IR context,
277//! defines the set of values to be collected in each
278//! fusion graph and the input value binding given each
279//! fusion runtime input.
280class FusionPrecomputedValues : public PrecomputedValuesBase<FusionIRContext> {
281 using precomputedValuesBaseType = PrecomputedValuesBase<FusionIRContext>;
282
283 public:
284 FusionPrecomputedValues(Fusion* fusion);
285
286 //! Bind concrete values from fusion runtime inputs
287 void bindFusionInputs(const KernelArgumentHolder& args);
288
289 private:
290 void bindTensorMetaData(
291 TensorView* tv,
292 const TensorArgAbstract* tensor_arg_abstract);
293
294 private:
295 Fusion* fusion_ = nullptr;
296};
297//! PrecomputedValues workspace in Fusion IR context,
298//! defines the set of values to be collected in each
299//! kernel IR sequence and the input value binding given each
300//! fusion runtime input and launch constraints.
301class KernelPrecomputedValues : public PrecomputedValuesBase<KernelIRContext> {
302 using precomputedValuesBaseType = PrecomputedValuesBase<KernelIRContext>;
303
304 public:
305 using ParallelExtentMap =
306 std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
307
308 KernelPrecomputedValues(kir::Kernel* kernel);
309
310 //! Bind concrete values from fusion runtime inputs
311 void bindKernelInputs(kir::Kernel* kernel, const KernelArgumentHolder& args);
312
313 //! Bind concrete values from launch constraints
314 void bindParallelExtents(
315 const ParallelExtentMap& parallel_extents,
316 const LaunchParams& launch_constraint);
317
318 //! Bind the NamedScalars corresponding to the
319 //! concrete parallel dimension sizes after the
320 //! actual value has been resolved.
321 void bindConcreteParallelTypeValue(ParallelType pt, int64_t value);
322
323 private:
324 void bindTensorMetaData(
325 TensorView* tv,
326 const TensorArgAbstract* tensor_arg_abstract);
327
328 //! Iterate through all the named scalars corresponding
329 //! to thread sizes and pre-group them by their parallel
330 //! types.
331 void initializeNamedScalars();
332
333 private:
334 //! Contains all the named scalars correspond
335 //! to thread size of each parallel type.
336 std::unordered_map<ParallelType, std::unique_ptr<std::vector<int>>, TypeHash>
337 thread_dim_value_indices_;
338};
339
340} // namespace cuda
341} // namespace fuser
342} // namespace jit
343} // namespace torch
344