1 | #include <expr_evaluator.h> |
2 | #include <instrumentation.h> |
3 | #include <ir_utils.h> |
4 | #include <kernel_expr_evaluator.h> |
5 | #include <lower2device.h> |
6 | |
7 | #include <evaluator_common.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | namespace { |
15 | |
16 | template <typename VALTYPE> |
17 | std::vector<VALTYPE*> getImmediateProducers(VALTYPE* val) { |
18 | if (val->definition()) { |
19 | auto expr = val->definition(); |
20 | return expr->inputs(); |
21 | } else { |
22 | return {}; |
23 | } |
24 | } |
25 | |
26 | //! IR-Generic utility, collects all the producers required for the |
27 | //! given list of IR values and returns them along with the original |
28 | //! list in topological order. |
29 | template <typename VALTYPE> |
30 | std::vector<VALTYPE*> makeSortedEvaluationList(std::vector<VALTYPE*> input) { |
31 | // Deduplicate |
32 | std::vector<VALTYPE*> to_sort; |
33 | std::unordered_set<VALTYPE*> visited; |
34 | for (auto val : input) { |
35 | if (!visited.count(val)) { |
36 | to_sort.push_back(val); |
37 | visited.insert(val); |
38 | } |
39 | } |
40 | |
41 | std::vector<VALTYPE*> sorted; |
42 | visited.clear(); |
43 | |
44 | // Topological Sort |
45 | // Note: didn't explicitly exclude producers that are not in the original |
46 | // list. This should be acceptable for the intended use. |
47 | while (!to_sort.empty()) { |
48 | auto top_val = to_sort.back(); |
49 | if (visited.count(top_val)) { |
50 | to_sort.pop_back(); |
51 | } else { |
52 | bool ready_to_pop = true; |
53 | for (auto producer : getImmediateProducers(top_val)) { |
54 | if (!visited.count(producer)) { |
55 | ready_to_pop = false; |
56 | to_sort.push_back(producer); |
57 | } |
58 | } |
59 | if (ready_to_pop) { |
60 | visited.insert(top_val); |
61 | sorted.push_back(top_val); |
62 | to_sort.pop_back(); |
63 | } |
64 | } |
65 | } |
66 | |
67 | return sorted; |
68 | } |
69 | |
70 | //! Kernel IR utility, collects all the symbolic values |
71 | //! used in allocation nodes. |
72 | void collectBufferSizes( |
73 | std::vector<Val*>& into, |
74 | const std::vector<Expr*>& exprs) { |
75 | for (auto expr : exprs) { |
76 | if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) { |
77 | into.push_back(allocate->size()); |
78 | } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
79 | collectBufferSizes(into, for_loop->body().exprs()); |
80 | } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
81 | collectBufferSizes(into, ite->thenBody().exprs()); |
82 | collectBufferSizes(into, ite->elseBody().exprs()); |
83 | } |
84 | } |
85 | } |
86 | |
87 | //! Kernel IR utility, collects all the kernel symbolic |
88 | //! values we will need at runtime, i.e. after the |
89 | //! generated cuda kernel has already been compiled. |
90 | //! The values are to be used for runtime logic, like |
91 | //! `computeLaunchparams`. |
92 | std::vector<Val*> collectRuntimeUsedValues(kir::Kernel* kernel) { |
93 | std::vector<Val*> ret; |
94 | auto all_tvs = ir_utils::allTvs(kernel); |
95 | // Collect extent and inputs |
96 | for (auto tv : all_tvs) { |
97 | for (auto id : tv->domain()->domain()) { |
98 | ret.push_back(id->extent()); |
99 | } |
100 | } |
101 | for (auto inp : kernel->inputs()) { |
102 | if (inp->isA<Int>() || inp->isA<Double>()) { |
103 | ret.push_back(inp); |
104 | } |
105 | } |
106 | // Collect allocation sizes: |
107 | collectBufferSizes(ret, kernel->topLevelExprs()); |
108 | return makeSortedEvaluationList(ret); |
109 | } |
110 | |
111 | std::vector<Val*> collectRuntimeUsedValues(Fusion* fusion) { |
112 | std::vector<Val*> ret; |
113 | auto all_tvs = ir_utils::allTvs(fusion); |
114 | // Collect extent and inputs |
115 | for (auto tv : all_tvs) { |
116 | for (auto id : tv->domain()->domain()) { |
117 | ret.push_back(id->extent()); |
118 | } |
119 | } |
120 | for (auto inp : fusion->inputs()) { |
121 | if (inp->isA<Int>() || inp->isA<Double>()) { |
122 | ret.push_back(inp); |
123 | } |
124 | } |
125 | return makeSortedEvaluationList(ret); |
126 | } |
127 | |
128 | } // namespace |
129 | |
130 | template <typename IRContext> |
131 | void PrecomputedValuesBase<IRContext>::initializeValueList( |
132 | typename IRContext::EVALUATOR_TYPE& const_evaluator, |
133 | const std::vector<Val*>& sorted_value_list) { |
134 | // Initialize workspace |
135 | num_of_values_ = sorted_value_list.size(); |
136 | defined_ = std::vector<bool>(num_of_values_, false); |
137 | is_constant_ = std::vector<bool>(num_of_values_, false); |
138 | values_ = std::vector<IntOrDouble>(num_of_values_, -1); |
139 | |
140 | // Fill in constants and assign evaluator indices |
141 | for (const auto i : c10::irange(num_of_values_)) { |
142 | // Use an expression evaluator to test if value is const |
143 | auto const_val = const_evaluator.evaluate(sorted_value_list[i]); |
144 | if (const_val.has_value()) { |
145 | is_constant_[i] = true; |
146 | values_[i] = const_val.value(); |
147 | } |
148 | sorted_value_list[i]->setEvaluatorIndex(i); |
149 | } |
150 | } |
151 | |
152 | template <typename IRContext> |
153 | c10::optional<IntOrDouble> PrecomputedValuesBase<IRContext>::getMaybeValueFor( |
154 | const Val* val) { |
155 | auto index = val->evaluatorIndex(); |
156 | if (index < 0) { |
157 | return c10::nullopt; |
158 | } |
159 | if (!defined_[index] && !is_constant_[index]) { |
160 | return c10::nullopt; |
161 | } |
162 | return values_[index]; |
163 | } |
164 | |
165 | template <typename IRContext> |
166 | void PrecomputedValuesBase<IRContext>::print() const { |
167 | std::cout << "Precomputed Values:\n" ; |
168 | for (auto i : c10::irange(symbols_.size())) { |
169 | if (defined_[i]) { |
170 | std::cout << symbols_[i]->toInlineString() << " = " << values_[i] |
171 | << std::endl; |
172 | } |
173 | } |
174 | } |
175 | |
176 | template <typename IRContext> |
177 | void PrecomputedValuesBase<IRContext>::evaluate() { |
178 | FUSER_PERF_SCOPE("PrecomputedValues::Evaluate" ); |
179 | value_machine_->run(); |
180 | validate(); |
181 | } |
182 | |
183 | template <typename IRContext> |
184 | void PrecomputedValuesBase<IRContext>::invalidate() { |
185 | // clear binding values |
186 | binding_log_.clear(); |
187 | |
188 | // invalidate value entries |
189 | std::fill(defined_.begin(), defined_.end(), false); |
190 | |
191 | // invalidate flag |
192 | has_valid_values_ = false; |
193 | } |
194 | |
195 | template <typename IRContext> |
196 | void PrecomputedValuesBase<IRContext>::validate() { |
197 | FUSER_PERF_SCOPE("PrecomputedValuess::Validate" ); |
198 | for (auto it : binding_log_) { |
199 | TORCH_INTERNAL_ASSERT( |
200 | values_[it.first] == it.second, |
201 | "Precomputed values failed to validate." , |
202 | "\nSomething unexpected changed between the compilation and execution.\n" , |
203 | values_[it.first], |
204 | " != " , |
205 | it.second); |
206 | } |
207 | has_valid_values_ = true; |
208 | } |
209 | |
210 | template <typename IRContext> |
211 | NaiveValueMachine<IRContext>::NaiveValueMachine( |
212 | PrecomputedValuesBase<IRContext>& precomputed_values) |
213 | : precomputed_values_(precomputed_values) { |
214 | num_of_instructions_ = 0; |
215 | for (auto val : precomputed_values_.symbols_) { |
216 | auto def = val->definition(); |
217 | if (def) { |
218 | if (auto uop = dynamic_cast<UnaryOp*>(def)) { |
219 | makeUnaryOp(uop); |
220 | } else if (auto bop = dynamic_cast<BinaryOp*>(def)) { |
221 | makeBinaryOp(bop); |
222 | } else { |
223 | TORCH_INTERNAL_ASSERT(false, "Unsupported expr" ); |
224 | } |
225 | } |
226 | } |
227 | } |
228 | |
229 | template <typename IRContext> |
230 | void NaiveValueMachine<IRContext>::run() { |
231 | for (const auto i : c10::irange(num_of_instructions_)) { |
232 | // Skip this instruction if the dest location |
233 | // has already been computed or is constant. |
234 | if (precomputed_values_.defined_[dest_[i]] || |
235 | precomputed_values_.is_constant_[dest_[i]]) { |
236 | continue; |
237 | } |
238 | runInstruction(i); |
239 | } |
240 | } |
241 | |
242 | template <typename IRContext> |
243 | void NaiveValueMachine<IRContext>::makeUnaryOp(UnaryOp* uop) { |
244 | int in = uop->inputs()[0]->evaluatorIndex(); |
245 | int out = uop->outputs()[0]->evaluatorIndex(); |
246 | TORCH_INTERNAL_ASSERT(in >= 0, "Integer Machine: unknown input: " , uop); |
247 | TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: " , uop); |
248 | |
249 | int index = makeInstructionEntry(); |
250 | inst_type_[index] = InstructionType::UNARY_OP; |
251 | uop_type_[index] = IRContext::getOpType(uop); |
252 | if (uop_type_[index] == UnaryOpType::Cast) { |
253 | data_type_[index] = uop->out()->getDataType().value(); |
254 | } |
255 | src0_[index] = in; |
256 | dest_[index] = out; |
257 | } |
258 | |
259 | template <typename IRContext> |
260 | void NaiveValueMachine<IRContext>::makeBinaryOp(BinaryOp* bop) { |
261 | int in0 = bop->inputs()[0]->evaluatorIndex(); |
262 | int in1 = bop->inputs()[1]->evaluatorIndex(); |
263 | int out = bop->outputs()[0]->evaluatorIndex(); |
264 | |
265 | TORCH_INTERNAL_ASSERT(in0 >= 0, "Integer Machine: unknown lhs: " , bop); |
266 | TORCH_INTERNAL_ASSERT(in1 >= 0, "Integer Machine: unknown rhs: " , bop); |
267 | TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: " , bop); |
268 | |
269 | int index = makeInstructionEntry(); |
270 | inst_type_[index] = InstructionType::BINARY_OP; |
271 | bop_type_[index] = IRContext::getOpType(bop); |
272 | src0_[index] = in0; |
273 | src1_[index] = in1; |
274 | dest_[index] = out; |
275 | } |
276 | |
277 | template <typename IRContext> |
278 | int NaiveValueMachine<IRContext>::makeInstructionEntry() { |
279 | int index = num_of_instructions_++; |
280 | inst_type_.push_back(InstructionType::UNARY_OP); |
281 | uop_type_.push_back(UnaryOpType::Abs); |
282 | bop_type_.push_back(BinaryOpType::Add); |
283 | data_type_.push_back(DataType::Null); |
284 | src0_.push_back(-1); |
285 | src1_.push_back(-1); |
286 | dest_.push_back(-1); |
287 | return index; |
288 | } |
289 | |
290 | template <typename IRContext> |
291 | void NaiveValueMachine<IRContext>::runInstruction(int index) { |
292 | switch (inst_type_[index]) { |
293 | case InstructionType::UNARY_OP: |
294 | runUnaryOp(index); |
295 | break; |
296 | case InstructionType::BINARY_OP: |
297 | runBinaryOp(index); |
298 | break; |
299 | } |
300 | } |
301 | |
302 | template <typename IRContext> |
303 | void NaiveValueMachine<IRContext>::runUnaryOp(int index) { |
304 | using namespace IntOrDouble_functions; |
305 | int src_index = src0_[index]; |
306 | bool src_defined = precomputed_values_.defined_[src_index]; |
307 | bool src_is_const = precomputed_values_.is_constant_[src_index]; |
308 | if (!src_defined && !src_is_const) { |
309 | return; |
310 | } |
311 | |
312 | int dest_index = dest_[index]; |
313 | |
314 | auto& src = precomputed_values_.values_[src_index]; |
315 | auto& dest = precomputed_values_.values_[dest_index]; |
316 | |
317 | switch (uop_type_[index]) { |
318 | case UnaryOpType::Neg: |
319 | dest = -src; |
320 | break; |
321 | case UnaryOpType::Set: |
322 | dest = src; |
323 | break; |
324 | case UnaryOpType::Cast: |
325 | if (data_type_[index] == DataType::Double) { |
326 | dest = src.template cast<double>(); |
327 | } else if (data_type_[index] == DataType::Int) { |
328 | dest = src.template cast<int64_t>(); |
329 | } else { |
330 | TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator" ); |
331 | } |
332 | break; |
333 | case UnaryOpType::Abs: |
334 | dest = abs(src); |
335 | break; |
336 | default: |
337 | TORCH_CHECK(!"Unexpected operator type " , uop_type_[index]); |
338 | } |
339 | |
340 | precomputed_values_.defined_[dest_index] = true; |
341 | } |
342 | |
343 | template <typename IRContext> |
344 | void NaiveValueMachine<IRContext>::runBinaryOp(int index) { |
345 | using namespace IntOrDouble_functions; |
346 | int src0_index = src0_[index]; |
347 | int src1_index = src1_[index]; |
348 | bool src0_is_const = precomputed_values_.is_constant_[src0_index]; |
349 | bool src1_is_const = precomputed_values_.is_constant_[src1_index]; |
350 | |
351 | bool src_defined = |
352 | (precomputed_values_.defined_[src0_index] || src0_is_const) && |
353 | (precomputed_values_.defined_[src1_index] || src1_is_const); |
354 | |
355 | if (!src_defined) { |
356 | return; |
357 | } |
358 | int dest_index = dest_[index]; |
359 | |
360 | auto& lhs = precomputed_values_.values_[src0_index]; |
361 | auto& rhs = precomputed_values_.values_[src1_index]; |
362 | auto& dest = precomputed_values_.values_[dest_index]; |
363 | |
364 | switch (bop_type_[index]) { |
365 | case BinaryOpType::Add: |
366 | dest = lhs + rhs; |
367 | break; |
368 | case BinaryOpType::Sub: |
369 | dest = lhs - rhs; |
370 | break; |
371 | case BinaryOpType::Mul: |
372 | dest = lhs * rhs; |
373 | break; |
374 | case BinaryOpType::Div: |
375 | TORCH_CHECK(rhs != 0); |
376 | dest = lhs / rhs; |
377 | break; |
378 | case BinaryOpType::Mod: |
379 | TORCH_CHECK(rhs != 0); |
380 | dest = lhs % rhs; |
381 | break; |
382 | case BinaryOpType::CeilDiv: |
383 | TORCH_CHECK(rhs != 0); |
384 | dest = ceildiv(lhs, rhs); |
385 | break; |
386 | case BinaryOpType::And: |
387 | dest = Int::ScalarType(lhs && rhs); |
388 | break; |
389 | case BinaryOpType::Max: |
390 | dest = lhs > rhs ? lhs : rhs; |
391 | break; |
392 | case BinaryOpType::Min: |
393 | dest = lhs < rhs ? lhs : rhs; |
394 | break; |
395 | default: |
396 | TORCH_CHECK(!"Unexpected operator type" ); |
397 | } |
398 | |
399 | precomputed_values_.defined_[dest_index] = true; |
400 | } |
401 | |
402 | KernelPrecomputedValues::KernelPrecomputedValues(kir::Kernel* kernel) { |
403 | loadSymbols(collectRuntimeUsedValues(kernel)); |
404 | kir::ExpressionEvaluator evaluator; |
405 | initializeValueList(evaluator, symbols()); |
406 | initializeNamedScalars(); |
407 | initializeIntegerMachine(); |
408 | } |
409 | |
410 | // TODO: put this to base class |
411 | void KernelPrecomputedValues::bindTensorMetaData( |
412 | TensorView* tv, |
413 | const TensorArgAbstract* tensor_arg_abstract) { |
414 | const auto root_domain = |
415 | TensorDomain::noReductions(tv->domain()->getMaybeRFactorDomain()); |
416 | TORCH_INTERNAL_ASSERT( |
417 | tensor_arg_abstract->getRank() == static_cast<int>(root_domain.size()), |
418 | "Something went wrong configuring launch. Inputs do not match." ); |
419 | |
420 | for (const auto dim : c10::irange(root_domain.size())) { |
421 | auto extent = root_domain[dim]->extent(); |
422 | auto value = tensor_arg_abstract->getSize(dim); |
423 | bindValue(extent->evaluatorIndex(), value); |
424 | } |
425 | } |
426 | |
427 | namespace { |
428 | |
429 | //! Compares the name of given scalar with thread size strings |
430 | //! and returns the corresponding parallel type if a match |
431 | //! is found. |
432 | c10::optional<ParallelType> getMaybeThreadSizeParallelType( |
433 | NamedScalar* named_scalar) { |
434 | auto& var_name = named_scalar->name(); |
435 | for (auto ptype : kParallelTypeThreads) { |
436 | if (var_name == stringifyThreadSize(ptype)) { |
437 | return ptype; |
438 | } |
439 | } |
440 | return c10::nullopt; |
441 | } |
442 | |
443 | } // namespace |
444 | |
445 | void KernelPrecomputedValues::initializeNamedScalars() { |
446 | for (auto val : symbols()) { |
447 | if (auto named_scalar = dynamic_cast<NamedScalar*>(val)) { |
448 | auto maybe_parallel_type = getMaybeThreadSizeParallelType(named_scalar); |
449 | if (maybe_parallel_type.has_value()) { |
450 | auto& index_list = |
451 | thread_dim_value_indices_[maybe_parallel_type.value()]; |
452 | if (!index_list) { |
453 | index_list = std::make_unique<std::vector<int>>(); |
454 | } |
455 | index_list->push_back(val->evaluatorIndex()); |
456 | } |
457 | } |
458 | } |
459 | } |
460 | |
461 | // TODO: merge this one with above. |
462 | void KernelPrecomputedValues::bindKernelInputs( |
463 | kir::Kernel* kernel, |
464 | const KernelArgumentHolder& args) { |
465 | if (hasValidValues()) { |
466 | invalidate(); |
467 | } |
468 | |
469 | const auto& inputs = kernel->inputs(); |
470 | TORCH_INTERNAL_ASSERT( |
471 | args.size() == inputs.size(), "kernel inputs size does not match args" ); |
472 | |
473 | for (const auto i : c10::irange(inputs.size())) { |
474 | auto arg = args[i]; |
475 | const auto input = inputs[i]; |
476 | if (auto tensor_input = dynamic_cast<TensorView*>(input)) { |
477 | if (const auto& tensor_arg_abstract = |
478 | dynamic_cast<const TensorArgAbstract*>(arg)) { |
479 | bindTensorMetaData(tensor_input, tensor_arg_abstract); |
480 | } else { |
481 | // TODO: cpu scalar of int type should be bound as scalar int as well |
482 | TORCH_CHECK( |
483 | arg->isType(ArgType::CpuScalarTensor), |
484 | "binding input to TensorView expects input arg to be of tensor type" ); |
485 | } |
486 | } else if (input->isScalar()) { |
487 | if (input->dtype() == DataType::Int) { |
488 | TORCH_CHECK( |
489 | arg->isType(ArgType::Long), |
490 | "binding input to integer type expects input arg to be a scalar of Long type" ); |
491 | precomputedValuesBaseType::bindValue( |
492 | input->evaluatorIndex(), *static_cast<const int64_t*>(arg->arg())); |
493 | } else if (input->dtype() == DataType::Double) { |
494 | TORCH_CHECK( |
495 | arg->isType(ArgType::Double), |
496 | "binding input to double type expects input arg to be a scalar of Double type" ); |
497 | precomputedValuesBaseType::bindValue( |
498 | input->evaluatorIndex(), *static_cast<const double*>(arg->arg())); |
499 | } |
500 | } |
501 | } |
502 | } |
503 | |
504 | void KernelPrecomputedValues::bindParallelExtents( |
505 | const ParallelExtentMap& parallel_extents, |
506 | const LaunchParams& launch_constraint) { |
507 | // Bind values of extents of parallelized |
508 | // iterdomains from launch_constraint when applicable. |
509 | // Consistency will be checked at validate(). |
510 | for (const auto& it : parallel_extents) { |
511 | auto raw_val = launch_constraint.getRawVal(it.first); |
512 | if (raw_val > 0) { |
513 | for (auto extent : it.second) { |
514 | bindValue(extent->evaluatorIndex(), raw_val); |
515 | } |
516 | } |
517 | } |
518 | } |
519 | |
520 | void KernelPrecomputedValues::bindConcreteParallelTypeValue( |
521 | ParallelType pt, |
522 | int64_t value) { |
523 | auto index_list_it = thread_dim_value_indices_.find(pt); |
524 | if (index_list_it != thread_dim_value_indices_.end()) { |
525 | for (auto index : *(index_list_it->second)) { |
526 | bindValue(index, value); |
527 | } |
528 | } |
529 | } |
530 | |
531 | FusionPrecomputedValues::FusionPrecomputedValues(Fusion* fusion) |
532 | : fusion_(fusion) { |
533 | loadSymbols(collectRuntimeUsedValues(fusion)); |
534 | ExpressionEvaluator evaluator(fusion); |
535 | initializeValueList(evaluator, symbols()); |
536 | initializeIntegerMachine(); |
537 | } |
538 | |
539 | // TODO: put this to base class |
540 | void FusionPrecomputedValues::bindTensorMetaData( |
541 | TensorView* tv, |
542 | const TensorArgAbstract* tensor_arg_abstract) { |
543 | const auto root_domain = |
544 | TensorDomain::noReductions(tv->getMaybeRFactorDomain()); |
545 | TORCH_INTERNAL_ASSERT( |
546 | tensor_arg_abstract->getRank() == static_cast<int>(root_domain.size()), |
547 | "Something went wrong configuring launch. Inputs do not match." ); |
548 | |
549 | for (const auto dim : c10::irange(root_domain.size())) { |
550 | auto extent = root_domain[dim]->extent(); |
551 | auto value = tensor_arg_abstract->getSize(dim); |
552 | precomputedValuesBaseType::bindValue(extent->evaluatorIndex(), value); |
553 | } |
554 | } |
555 | |
556 | void FusionPrecomputedValues::bindFusionInputs( |
557 | const KernelArgumentHolder& args) { |
558 | if (hasValidValues()) { |
559 | precomputedValuesBaseType::invalidate(); |
560 | } |
561 | |
562 | const auto& inputs = fusion_->inputs(); |
563 | TORCH_INTERNAL_ASSERT( |
564 | args.size() == inputs.size(), "kernel inputs size does not match args" ); |
565 | |
566 | for (const auto i : c10::irange(inputs.size())) { |
567 | const auto input = inputs[i]; |
568 | const ArgAbstract* arg = args[i]; |
569 | if (auto tensor_input = dynamic_cast<TensorView*>(input)) { |
570 | if (const auto& tensor_arg_abstract = |
571 | dynamic_cast<const TensorArgAbstract*>(arg)) { |
572 | bindTensorMetaData(tensor_input, tensor_arg_abstract); |
573 | } else { |
574 | TORCH_CHECK( |
575 | arg->isType(ArgType::CpuScalarTensor), |
576 | "binding input to TensorView expects input arg to be of tensor type" ); |
577 | } |
578 | } else if (input->isScalar()) { |
579 | if (input->getDataType() == DataType::Int) { |
580 | TORCH_CHECK( |
581 | arg->isType(ArgType::Long), |
582 | "binding input to integer type expects input arg to be a scalar of Long type" ); |
583 | precomputedValuesBaseType::bindValue( |
584 | input->evaluatorIndex(), *static_cast<const int64_t*>(arg->arg())); |
585 | } else if (input->getDataType() == DataType::Double) { |
586 | TORCH_CHECK( |
587 | arg->isType(ArgType::Double), |
588 | "binding input to double type expects input arg to be a scalar of Double type" ); |
589 | precomputedValuesBaseType::bindValue( |
590 | input->evaluatorIndex(), *static_cast<const double*>(arg->arg())); |
591 | } |
592 | } |
593 | } |
594 | } |
595 | |
596 | template class PrecomputedValuesBase<FusionIRContext>; |
597 | template class PrecomputedValuesBase<KernelIRContext>; |
598 | |
599 | } // namespace cuda |
600 | } // namespace fuser |
601 | } // namespace jit |
602 | } // namespace torch |
603 | |