1#include <kernel_cache.h>
2
3#include <instrumentation.h>
4#include <ir_utils.h>
5#include <parser.h>
6#include <scheduler/debug_utils.h>
7#include <scheduler/registry.h>
8#include <torch/csrc/jit/jit_log.h>
9#include <torch/csrc/jit/runtime/graph_executor.h>
10
11#include <c10/core/thread_pool.h>
12#include <c10/cuda/CUDAGuard.h>
13#include <c10/util/irange.h>
14#include <torch/csrc/jit/jit_log.h>
15
16namespace torch {
17namespace jit {
18namespace fuser {
19namespace cuda {
20
21namespace {
22
23#define THREAD_POOL_SIZE 10
24
25// TODO: clean this up with some knobs
26c10::ThreadPool* getThreadPool() {
27 static c10::ThreadPool pool(THREAD_POOL_SIZE);
28 return &pool;
29}
30
31void encodeBuffer(size_t value, std::string& buffer) {
32 const char* v = reinterpret_cast<char*>(&value);
33 for (const auto i : c10::irange(sizeof(size_t))) {
34 (void)i; // Suppress unused variable warning
35 buffer.push_back(*(v++));
36 }
37}
38
39} // namespace
40
41InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId(
42 const at::ArrayRef<IValue>& inputs) {
43 IdLookupReturn ret;
44
45 // lock mutex_ because we are touching encoding_
46 std::lock_guard<std::mutex> guard(mutex_);
47 encoding_.clear();
48 for (const auto& input : inputs) {
49 if (input.isTensor()) {
50 auto& input_tensor = input.toTensor();
51
52 for (auto size : input_tensor.sizes()) {
53 encodeBuffer(size, encoding_);
54 encoding_.push_back(' ');
55 }
56 encoding_.push_back('X');
57 encoding_.push_back(' ');
58 for (auto stride : input_tensor.strides()) {
59 encodeBuffer(stride, encoding_);
60 encoding_.push_back(' ');
61 }
62 encoding_.push_back('a');
63 encodeBuffer(
64 SchedulerRuntimeInfo::computeAlignmentSize(
65 (size_t)input_tensor.data_ptr()),
66 encoding_);
67 encoding_.push_back('d');
68 encodeBuffer(input_tensor.device().index(), encoding_);
69 } else {
70 // encode s for scalar;
71 encoding_.push_back('s');
72 }
73 encoding_.push_back(';');
74 }
75
76 auto& entry = encoding_lookup_[encoding_];
77
78 if (entry.id == 0) {
79 // no entry existed for given input set, set id for given entry
80 entry.id = current_id_++;
81 if (used_entry_.size() == max_cache_size_) {
82 // pop least recently used cache;
83 const auto& remove_iter = encoding_lookup_.find(used_entry_.back());
84 used_entry_.pop_back();
85 ret.evict_id = remove_iter->second.id;
86 ret.eviction = true;
87 encoding_lookup_.erase(remove_iter);
88 }
89 } else {
90 // short-cut to leave LRU entry as is
91 if (entry.lru_iter == used_entry_.begin()) {
92 ret.id = entry.id;
93 return ret;
94 }
95
96 used_entry_.erase(entry.lru_iter);
97 }
98
99 ret.id = entry.id;
100 entry.lru_iter = used_entry_.insert(used_entry_.begin(), encoding_);
101 return ret;
102}
103
104FusionExecutorCache::FusionExecutorCache(std::unique_ptr<Fusion> fusion)
105 : fusion_(std::move(fusion)) {
106 for (const auto& indices : fusion_->getOutputAliasIndices()) {
107 aliased_output_indices_.insert(indices);
108 }
109}
110
111KernelArgumentHolder FusionExecutorCache::prepareInputs(
112 const at::ArrayRef<IValue>& inputs) {
113 FUSER_PERF_SCOPE("FusionExecutorCache::prepareInputs");
114
115 KernelArgumentHolder args =
116 KernelArgumentHolder::createKernelArgumentHolder(inputs);
117
118 // TODO: move InputsIdLookup inside KernelArgumentHolder;
119 auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs);
120 if (id_lookup_ret.eviction) {
121 evictCache(id_lookup_ret.evict_id);
122 }
123
124 args.setCacheId(id_lookup_ret.id);
125 return args;
126}
127
128bool FusionExecutorCache::isCompiled(const at::ArrayRef<IValue>& inputs) {
129 FUSER_PERF_SCOPE("FusionExecutorCache::isCompiled");
130
131 // Access kernels associated with the common device id
132 KernelArgumentHolder args = prepareInputs(inputs);
133
134 return getKernelRuntimeFor(args)->isCompiled();
135}
136
137void FusionExecutorCache::compileFusionAsync(
138 const at::ArrayRef<IValue>& inputs) {
139 FUSER_PERF_SCOPE("FusionExecutorCache::compileFusionAsync");
140
141 KernelArgumentHolder args = prepareInputs(inputs);
142 auto kernel_runtime = getKernelRuntimeFor(args);
143
144 kernel_runtime->startAsyncCompile(args);
145}
146
147// Note [ Permutation support in nvfuser ]
148//
149// Background:
150// To support permutation in nvfuser with optimal performance, we would want to
151// allow dimension collapsing in generated code on channels-last tensors, which
152// greatly simplifies indexing. Current API in codegen only allows dimensional
153// collapsing on neighboring axes. The unfortunate thing is that memory format
154// design in PyTorch is implicitly marked by strides, while the semantics
155// meaning of axes remain unchanged. i.e. A 4d tensor with axes [N, C, H, W]
156// would have the same shape in both format, while contiguous tensor carries
157// strides [C*H*W, H*W, W, 1] and channels-last tensor [H*W*C, 1, W*C, C]
158//
159// Approach:
160// Part_1. To allow axes collapsing for permuted tensor in codegen, we can
161// permute input tensor to have axes in decending order by their strides, so
162// they would be viewed as `contiguous` in codegen, hence collapsed to simple
163// indexing. Part_2. To ensure correct result, we need to ensure computation in
164// nvfuser carries same semantics as with TorchScript graph. We need to
165// Part_2_1. Maintain a bookkeeping where each codegen tensor is tagged with
166// either their permutation. Part_2_2. Parsing rule should handle and
167// propagate the tag properly, e.g. batch normalization has special rules for
168// `channels_last` input tensor and mark output in its right permutation.
169// Part_3. Codegen output tensor that has been permuted should be restored to
170// original layout before returning to TorchScript
171//
172// For details on Part_2, refer to implementation Note [ Permutation
173// Bookkeeping and Propagation in Parser ]
174std::vector<at::Tensor> FusionExecutorCache::runFusionWithInputs(
175 const at::ArrayRef<IValue>& inputs) {
176 FUSER_PERF_SCOPE("FusionExecutorCache::runFusionWithInputs");
177
178 // permute input tensor for kernel execution. See Part_1 in Note [ Channels
179 // Last support in nvfuser ]
180 at::ArrayRef<IValue> perm_inputs = inputs;
181 const auto& to_be_permuted_inputs = fusion_->getPermutationInputMap();
182 std::vector<IValue> inputs_vec;
183 if (!to_be_permuted_inputs.empty()) {
184 inputs_vec = inputs.vec();
185 for (const auto& pair : to_be_permuted_inputs) {
186 auto v = inputs_vec[pair.first];
187 TORCH_CHECK(
188 v.isTensor(), "input permutation can only be applied at tensor");
189 auto tensor = v.toTensor();
190 inputs_vec[pair.first] = tensor.permute(pair.second);
191 }
192 perm_inputs = inputs_vec;
193 }
194
195 KernelArgumentHolder args = prepareInputs(perm_inputs);
196
197 auto kernel_runtime = getKernelRuntimeFor(args);
198 most_recent_runtime_ = kernel_runtime;
199 int seq_id = 0;
200 // Record kernel input and output tensors so profiler can construct
201 // the data flow graph
202 RECORD_FUNCTION(
203 "run_fused_kernel",
204 std::vector<c10::IValue>(inputs.begin(), inputs.end()),
205 seq_id);
206 auto outputs = kernel_runtime->runWithInput(args);
207 RECORD_OUTPUTS(outputs);
208
209 // permute output tensor returned by kernel execution. See Part_3 in Note [
210 // Permutation support in nvfuser ]
211 for (const auto& pair : fusion_->getPermutationOutputMap()) {
212 if (size_t(pair.first) < outputs.size()) {
213 outputs[pair.first] = outputs[pair.first].permute(pair.second);
214 }
215 }
216
217 // removing aliased outputs, since those are only used by input tensor update
218 // by fusion. It is not semantically correct to actually return them as
219 // outputs from fusion.
220 int offset = 0;
221 for (const auto& v : aliased_output_indices_) {
222 outputs.erase(outputs.begin() + v - offset);
223 offset++;
224 }
225
226 return outputs;
227}
228
229void FusionExecutorCache::evictCache(size_t cache_id) {
230 auto it = id_to_kernel_runtime_.find(cache_id);
231 TORCH_INTERNAL_ASSERT(it != id_to_kernel_runtime_.end());
232 it->second->evictCache(cache_id);
233 id_to_kernel_runtime_.erase(it);
234}
235
236FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor(
237 const KernelArgumentHolder& args) {
238 // Check for id hit case
239 auto unique_id = *args.getCacheId();
240 auto id_it = id_to_kernel_runtime_.find(unique_id);
241 if (id_it != id_to_kernel_runtime_.end()) {
242 return id_it->second;
243 }
244
245 // Access kernels associated with the common device id
246 auto& kernel_runtimes = kernel_runtimes_[args.getDeviceIndex()];
247
248 // Check for re-use hit case
249 // a kernel runtime is re-usable if all the compiled
250 // kernels have the same heuristic parameters
251 std::unique_ptr<FusionHeuristics> new_heuristics;
252
253 auto reuse_it = std::find_if(
254 kernel_runtimes.begin(),
255 kernel_runtimes.end(),
256 [&args, &new_heuristics](auto& kernel_runtime) {
257 auto maybe_heuristics = kernel_runtime->getMaybeHeuristicsFor(args);
258 if (!maybe_heuristics.has_value()) {
259 return false;
260 }
261 new_heuristics = std::move(maybe_heuristics.value());
262 return true;
263 });
264
265 FusionKernelRuntime* kernel_runtime = nullptr;
266 if (reuse_it != kernel_runtimes.end()) {
267 kernel_runtime = reuse_it->get();
268 kernel_runtime->updateHeuristicsLaunchParams(new_heuristics.get());
269 } else {
270 // graph miss, need to re-build an optimized graph for this case
271 kernel_runtimes.emplace_back(
272 std::make_unique<FusionKernelRuntime>(fusion_.get(), args));
273 kernel_runtime = kernel_runtimes.back().get();
274 if (profiling_) {
275 kernel_runtime->profile(true);
276 }
277 }
278
279 id_to_kernel_runtime_[unique_id] = kernel_runtime;
280 return kernel_runtime;
281}
282
283FusionKernelRuntime::FusionKernelRuntime(
284 Fusion* fusion,
285 const KernelArgumentHolder& args) {
286 FUSER_PERF_SCOPE("FusionKernelRuntime::FusionKernelRuntime");
287
288 // Make a copy of fusion and do segmentation and translation
289 // on this copy
290 auto fusion_copy = std::make_unique<Fusion>(*fusion);
291
292 // Run segmentation on the copied fusion
293 SchedulerRuntimeInfo runtime_info(fusion_copy.get(), args, true);
294
295 // Initialize the evaluator simplifer
296 precomputed_values_ =
297 std::make_unique<FusionPrecomputedValues>(fusion_copy.get());
298
299 //! Try to schedule the complete fusion
300 scheduler_debug_utils::canScheduleMessage(
301 "***Runtime***: Try to schedule fusion un-segmented:\n");
302
303 const auto maybe_complete_fusion_heuristic =
304 SchedulerEntry::proposeHeuristics(fusion_copy.get(), runtime_info);
305
306 //! Decide if this fusion is segmented or not
307 const bool segmented = !maybe_complete_fusion_heuristic.has_value();
308
309 if (segmented) {
310 // Take ownership and segment transformed fusion
311 segmented_fusion_ =
312 SegmentCandidateFinder::segment(std::move(fusion_copy), args);
313 } else {
314 segmented_fusion_ = SegmentedFusion::fromCompleteFusion(
315 std::move(fusion_copy), maybe_complete_fusion_heuristic.value());
316 }
317
318 heuristics_ = segmented_fusion_->makeInitialHeuristics(args);
319 executors_ = std::vector<FusionExecutor>(segmented_fusion_->groups().size());
320 if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
321 segmented_fusion_->print();
322 }
323
324 // Even if we go through the segmented path we may still end up
325 // with a segmented fusion with one group. This case still
326 // counts as un-segmented.
327 is_segmented_ = segmented_fusion_->groups().size() > 1;
328
329 // Pre-compute the executor order so that the run time path
330 // would go directly to kernel launch.
331 prepareRuntimeOrder();
332}
333
334std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
335 KernelArgumentHolder& args,
336 SegmentedGroup* sg) {
337 FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput");
338 std::lock_guard<std::mutex> guard(mutex_);
339 // This function will be called once on un-segmented fusion,
340 // for segmented fusion, this function will be called on each segment
341 // In the case of segmented fusion, segmented group needs to be given so
342 // a kernel is compiled and run for a segmented group
343 // In the case of complete fusion, sg = nullptr, and the original fusion
344 // is complied and run
345 TORCH_INTERNAL_ASSERT(sg, "runKernelWithInput: need valid group to run");
346 auto group_id = sg->groupId();
347
348 LaunchParams launch_params;
349
350 auto scheduler_entry = schedulers()[group_id].get();
351
352 // Check that the heuristics are matched, in the case of segmented fusion
353 TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristic() == sg->heuristic());
354
355 if (!executors_[group_id].compiled()) {
356 FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::Compile");
357 std::unique_ptr<Fusion> fusion_to_run;
358
359 // Running a segment group as a single kernel,
360 // make a fusion to run from segmented fusion
361 fusion_to_run = segmented_fusion_->makeFusion(sg);
362 FusionGuard fg(fusion_to_run.get());
363 scheduler_entry->schedule(fusion_to_run.get());
364 launch_params = scheduler_entry->params()->lparams;
365 executors_[group_id].compileFusion(
366 fusion_to_run.get(), args, launch_params);
367 } else {
368 launch_params = scheduler_entry->params()->lparams;
369 }
370
371 if (profiling_) {
372 most_recent_executor_log_.fusion_executor = &executors_[group_id];
373 most_recent_executor_log_.params = scheduler_entry->params()->clone();
374 }
375
376 auto& executor = executors_[group_id];
377 if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) {
378 executor.setMeasureKernelTimeFlag(true);
379 }
380
381 auto outputs = executor.runFusion(args, launch_params);
382
383 // Print relevant information all at once for easy debuging of perf
384 if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) {
385 std::cout << "\nRun kernel:\n";
386 if (sg) {
387 segmented_fusion_->makeFusion(sg)->printMath();
388 } else {
389 segmented_fusion_->completeFusion()->printMath();
390 }
391 std::cout << "With inputs:\n";
392 for (auto i : c10::irange(args.size())) {
393 args[i]->print();
394 }
395 std::cout << "Compiler log: " << executor.compilerLog() << "\n";
396 std::cout << scheduler_entry->params()->toString() << "\n";
397 std::cout << "With arguments: " << executor.lastLaunchParams().toString();
398 std::cout << executor.kernelName() << " " << executor.bytesProcessed()
399 << " bytes/ " << std::setprecision(3) << executor.kernelTimeMs()
400 << " ms "
401 << ((double)executor.bytesProcessed() /
402 ((double)executor.kernelTimeMs() / 1000)) /
403 (double)1.0e9
404 << " GB/s" << std::endl;
405 executor.setMeasureKernelTimeFlag(false);
406 }
407
408 return outputs;
409}
410
411void FusionKernelRuntime::prepareRuntimeOrder() {
412 // Setup group run order:
413 std::unordered_set<Val*> available_input;
414
415 // setup the order tensor dimensions are bound
416 for (const size_t i : c10::irange(segmented_fusion_->inputs().size())) {
417 auto input_val = segmented_fusion_->inputs()[i];
418 available_input.insert(input_val);
419
420 if (auto input_tv = dynamic_cast<TensorView*>(input_val)) {
421 auto root_dom = TensorDomain::noReductions(input_tv->getRootDomain());
422 for (const size_t dim : c10::irange(root_dom.size())) {
423 const auto extent = root_dom[dim]->extent();
424 available_input.insert(extent);
425 runtime_workspace_.group_extent_binding_order.push_back(extent);
426 }
427 }
428 }
429
430 // Keep track of groups that has run
431 std::vector<bool> group_ran(segmented_fusion_->groups().size(), false);
432
433 while (!std::all_of(
434 group_ran.begin(), group_ran.end(), [](bool b) { return b; })) {
435 bool one_ran = false;
436
437 // Find the first segment with all inputs available to run
438 for (const size_t group_i :
439 c10::irange(segmented_fusion_->groups().size())) {
440 auto& group = segmented_fusion_->groups()[group_i];
441 if (group_ran[group_i]) {
442 continue;
443 }
444 const auto& group_inputs = group->inputs();
445 bool ready_to_run = std::all_of(
446 group_inputs.begin(),
447 group_inputs.end(),
448 [&available_input](Val* val) { return available_input.count(val); });
449
450 if (ready_to_run) {
451 runtime_workspace_.group_run_order.push_back(group);
452 const auto& group_outputs = group->outputs();
453
454 // Insert graph segment output to tensor map
455 for (const size_t group_out_i : c10::irange(group_outputs.size())) {
456 available_input.insert(group_outputs[group_out_i]);
457 }
458 group_ran[group_i] = true;
459 one_ran = true;
460 }
461 }
462 TORCH_INTERNAL_ASSERT(
463 one_ran,
464 "Couldn't run all groups, something must have gone wrong in segmentation.");
465 }
466}
467
468// passing args by value, since we will be modify this
469void FusionKernelRuntime::startAsyncCompile(KernelArgumentHolder& args_old) {
470 // only single compilation is supported at this moment.
471 std::unique_lock<std::mutex> unique_lock(mutex_, std::try_to_lock);
472 TORCH_CHECK(
473 unique_lock.owns_lock(),
474 "Calling startAsyncCompile on a FusionKernelRuntime that's already starting a compilation thread is not supported");
475 std::unique_lock<std::mutex> unique_lock2(compiling_, std::try_to_lock);
476 TORCH_CHECK(
477 unique_lock2.owns_lock(),
478 "Calling startAsyncCompile on a FusionKernelRuntime that's already starting a compilation thread is not supported 2");
479
480 // for some reason I can't seem to move unique_lock and it keeps using copy.
481 // auto compile_fusion = [args = std::move(args_old), lock =
482 // std::move(unique_lock), this] () mutable {
483 auto compile_fusion = [args = std::move(args_old), this]() mutable {
484 std::lock_guard<std::mutex> guard(compiling_);
485
486 // locking mutex_ since we are touching executors_ during compilation.
487 // c10::DeviceGuard dg(c10::Device(DeviceType::CUDA,
488 // args.getDeviceIndex())); CUDAGuard uses runtime API directly, which is
489 // thread safe.
490 c10::cuda::CUDAGuard dg(args.getDeviceIndex());
491
492 FUSER_PERF_SCOPE("FusionKernelRuntime::startAsyncCompile");
493
494 TORCH_INTERNAL_ASSERT(
495 args.size() == segmented_fusion_->inputs().size(),
496 "Inputs were not set up correctly, received ",
497 args.size(),
498 " inputs but expecting ",
499 segmented_fusion_->inputs().size());
500
501 c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex());
502 std::unordered_map<Val*, const ArgAbstract*> tensor_map;
503 mapFusionInputsToArgs(tensor_map, args);
504
505 // TODO: compilation can happen in parallel! We can have output sizes
506 // inferred on un-compiled kernel and setup all tensor_map prior to
507 // compilation.
508 for (auto group_to_run : runtime_workspace_.group_run_order) {
509 // TODO: index mode should be updated per segmented kernel
510 // Prepare input vector
511 KernelArgumentHolder group_runtime_inputs(args.getIndexMode());
512 group_runtime_inputs.setDeviceIndex(args.getDeviceIndex());
513 for (auto input : group_to_run->inputs()) {
514 group_runtime_inputs.push(tensor_map.at(input));
515 }
516
517 // Run graph segment
518 KernelArgumentHolder group_runtime_outputs =
519 compileKernel(group_runtime_inputs, group_to_run);
520
521 // map output args to tensor map
522 const auto& group_outputs = group_to_run->outputs();
523 for (const size_t group_out_i : c10::irange(group_outputs.size())) {
524 args.push(group_runtime_outputs[group_out_i]);
525 tensor_map.emplace(group_outputs[group_out_i], args.back());
526 }
527 }
528 };
529
530 getThreadPool()->run(compile_fusion);
531}
532
533// TODO: replace the boilerplate in runKernelWithInput
534KernelArgumentHolder FusionKernelRuntime::compileKernel(
535 const KernelArgumentHolder& args,
536 SegmentedGroup* sg) {
537 FUSER_PERF_SCOPE("FusionKernelRuntime::compileKernel");
538 // This function will be called once on un-segmented fusion,
539 // for segmented fusion, this function will be called on each segment
540 // In the case of segmented fusion, segmented group needs to be given so
541 // a kernel is compiled and run for a segmented group
542 // In the case of complete fusion, sg = nullptr, and the original fusion
543 // is complied and run
544 TORCH_INTERNAL_ASSERT(sg, "compileKernel: need valid group to run");
545 auto group_id = sg->groupId();
546
547 LaunchParams launch_params;
548
549 auto scheduler_entry = schedulers()[group_id].get();
550
551 // Check that the heuristics are matched, in the case of segmented fusion
552 TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristic() == sg->heuristic());
553
554 if (!executors_[group_id].compiled()) {
555 FUSER_PERF_SCOPE("FusionKernelRuntime::compileKernel::Compile");
556 std::unique_ptr<Fusion> fusion_to_run;
557
558 // Running a segment group as a single kernel,
559 // make a fusion to run from segmented fusion
560 fusion_to_run = segmented_fusion_->makeFusion(sg);
561 FusionGuard fg(fusion_to_run.get());
562 scheduler_entry->schedule(fusion_to_run.get());
563 launch_params = scheduler_entry->params()->lparams;
564
565 executors_[group_id].compileFusion(
566 fusion_to_run.get(), args, launch_params);
567 } else {
568 // TODO: this is a false negative assert, since we could be compiling
569 // something for elevated high water mark on block size.
570 TORCH_CHECK(false, "compiling an already compiled kernel");
571 }
572
573 auto& executor = executors_[group_id];
574
575 auto outputs = executor.inferOutputSizes(args, launch_params);
576 return outputs;
577}
578
579void FusionKernelRuntime::mapFusionInputsToArgs(
580 std::unordered_map<Val*, const ArgAbstract*>& tensor_map,
581 KernelArgumentHolder& args) {
582 int extent_index = 0;
583 auto original_args_size = args.size();
584 // Bind args in the tensor_map
585 for (const auto i : c10::irange(original_args_size)) {
586 tensor_map.emplace(segmented_fusion_->inputs()[i], args[i]);
587 // Bind tensorview inputs values in case some segmented group
588 // needs it down the road.
589 // TODO: we probably have done this already up to this point
590 // should consider caching the expression evaluators, both
591 // more convenient and safer than replication
592 if (auto tensor_arg_abstract =
593 dynamic_cast<const TensorArgAbstract*>(args[i])) {
594 // Note this is very ugly way. We are pushing every single extent to args,
595 // because we don't have a better place to hold them.
596 auto rank = tensor_arg_abstract->getRank();
597 for (const auto dim : c10::irange(rank)) {
598 args.push(tensor_arg_abstract->getSize(dim));
599 tensor_map.emplace(
600 runtime_workspace_.group_extent_binding_order[extent_index++],
601 args.back());
602 }
603 }
604 }
605}
606
607std::vector<at::Tensor> FusionKernelRuntime::runWithInput(
608 KernelArgumentHolder& args) {
609 FUSER_PERF_SCOPE("FusionKernelRuntime::runWithInput");
610
611 TORCH_INTERNAL_ASSERT(
612 args.size() == segmented_fusion_->inputs().size(),
613 "Inputs were not set up correctly, received ",
614 args.size(),
615 " inputs but expecting ",
616 segmented_fusion_->inputs().size());
617
618 c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex());
619
620 std::unordered_map<Val*, const ArgAbstract*> tensor_map;
621 mapFusionInputsToArgs(tensor_map, args);
622
623 // TODO: we don't need this any more, since TensorArgAbstract already holds a
624 // reference to tensor
625 std::unordered_map<Val*, at::Tensor> output_holder;
626
627 if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) {
628 std::cout << "=================RUNNING FUSION SEGMENTS================="
629 << std::endl;
630 }
631
632 // group should share cache id.
633 auto group_cache_id = args.getCacheId();
634 for (auto group_to_run : runtime_workspace_.group_run_order) {
635 // TODO: index mode should be updated per segmented kernel
636 // Prepare input vector
637 KernelArgumentHolder group_runtime_inputs(args.getIndexMode());
638 group_runtime_inputs.setDeviceIndex(args.getDeviceIndex());
639 if (group_cache_id.has_value()) {
640 group_runtime_inputs.setCacheId(group_cache_id.value());
641 }
642 for (auto input : group_to_run->inputs()) {
643 group_runtime_inputs.push(tensor_map.at(input));
644 }
645
646 // TODO: currently we are still outputing PyTorch tensors, instead of
647 // something abstract. This is quite unsatisfying. Prepare input vector
648
649 // Run graph segment
650 std::vector<at::Tensor> group_runtime_outputs =
651 runKernelWithInput(group_runtime_inputs, group_to_run);
652
653 const auto& group_outputs = group_to_run->outputs();
654
655 // Insert graph segment output to tensor map
656 TORCH_INTERNAL_ASSERT(
657 group_outputs.size() == group_runtime_outputs.size(),
658 "output size does not match");
659 for (const size_t group_out_i : c10::irange(group_outputs.size())) {
660 // trivial forwarding outputs empty tensor to save bandwidth, skip
661 // tensor_map update on those, since we want all future use of inputs on
662 // the original tensor input. See note [trivial forwarding]
663 if (!group_outputs[group_out_i]->isFusionInput()) {
664 output_holder[group_outputs[group_out_i]] =
665 group_runtime_outputs[group_out_i];
666
667 args.push(group_runtime_outputs[group_out_i]);
668 tensor_map.emplace(group_outputs[group_out_i], args.back());
669 }
670 }
671 }
672
673 if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) {
674 std::cout << "=============FINISHED RUNNING FUSION SEGMENTS============"
675 << std::endl;
676 }
677
678 // Produce final global output
679 std::vector<at::Tensor> fusion_outputs;
680 for (auto output : segmented_fusion_->outputs()) {
681 const auto iter = output_holder.find(output);
682 if (iter != output_holder.end()) {
683 fusion_outputs.push_back(iter->second);
684 } else if (output->isFusionInput()) {
685 // Note [ trivial forwarding ]
686 //
687 // Background:
688 // nvfuser codegen doesn't handle aliases at all. When we have a fusion
689 // that forwards an input to output without any operations on it, this is
690 // a no-op for codegen and the output tensor is never written to. However,
691 // the codegen cannot "forward" an input to output, since all outputs are
692 // allocated in integration. If we do not special case it, we'll ended up
693 // having a "fresh" tensor allocated for the forwarded-input.
694 //
695 // Approach:
696 // There are two aspects of the support:
697 // step 1. Codegen handles forwarding implicitly. Forwarded inputs doesn't
698 // have any producer in the IR, hence the output argument is not used in
699 // the code. But it does require to have an argument in the kernel as a
700 // place-holder so we'll map each arguments correctly.
701 // step 2. Integration handles the trivial forwarding of inputs. When we
702 // put together `fusion_outputs` for a given fusion, when outputs are just
703 // fusion inputs, we directly return the input tensor.
704 const auto iter = tensor_map.find(output);
705 TORCH_INTERNAL_ASSERT(
706 iter != tensor_map.end(), "Can not find output as aliased intput");
707 auto arg = dynamic_cast<const TensorArgAbstract*>(iter->second);
708 // See step 2 - note [ trivial forwarding ]
709 fusion_outputs.push_back(arg->getTensor());
710 } else {
711 bool empty_type_check = output->getDataType().has_value() &&
712 output->getDataType().value() == DataType::Float;
713
714 // Only support two cases of empty tensor here, since
715 // this is hot path.
716 auto out_tv = output->as<TensorView>();
717
718 // TODO: should be only one of the two once the "empty"
719 // definition has been unified throughout the ops.
720 bool empty_tensor_check = out_tv->isZeroDim() || out_tv->isEmptyTensor();
721
722 // This is the check for an empty tensor;
723 TORCH_INTERNAL_ASSERT(
724 empty_tensor_check && empty_type_check,
725 "Is empty tensor? ",
726 !empty_tensor_check,
727 " Is empty type check? ",
728 !empty_type_check,
729 " Output empty tensor check failed for tensor: ",
730 out_tv->toString(),
731 " In function: ",
732 __FUNCTION__);
733
734 // TODO: would need to clean up this part when
735 // we have a unified and consistent way to generate
736 // size-0 tensors.
737 const auto tensor_options =
738 at::TensorOptions().dtype(at::kFloat).device(device);
739 fusion_outputs.emplace_back(at::empty({0}, tensor_options));
740 }
741 }
742 return fusion_outputs;
743}
744
745const std::vector<FusionKernelRuntime::SchedulerEntryPtr>& FusionKernelRuntime::
746 schedulers() {
747 return heuristics_->heuristicsList();
748}
749
750void FusionKernelRuntime::updateHeuristicsLaunchParams(
751 FusionHeuristics* update_heuristics) {
752 FUSER_PERF_SCOPE("FusionKernelRuntime::updateHeuristicsLaunchParams");
753 auto scheduler_list_length = heuristics_->heuristicsList().size();
754 TORCH_INTERNAL_ASSERT(
755 update_heuristics->heuristicsList().size() == scheduler_list_length);
756 for (const auto i : c10::irange(scheduler_list_length)) {
757 auto& schedulerPtr = heuristics_->heuristicsList()[i];
758 schedulerPtr->updateLaunchConstraint(
759 update_heuristics->heuristicsList()[i]->params()->lparams);
760 }
761}
762
763c10::optional<FusionKernelRuntime::HeuristicsPtr> FusionKernelRuntime::
764 getMaybeHeuristicsFor(const KernelArgumentHolder& args) {
765 FUSER_PERF_SCOPE("FusionKernelRuntime::getMaybeHeuristicsFor");
766 auto complete_fusion = segmented_fusion_->completeFusion();
767 SchedulerRuntimeInfo runtime_info(complete_fusion, args);
768 precomputed_values_->bindFusionInputs(args);
769 precomputed_values_->evaluate();
770 runtime_info.expressionEvaluator().bindPrecomputedValues(
771 precomputed_values_.get());
772
773 c10::optional<FusionKernelRuntime::HeuristicsPtr> ret;
774 ret = std::make_unique<FusionHeuristics>();
775 size_t total_groups = segmented_fusion_->groups().size();
776 for (const auto group_index : c10::irange(total_groups)) {
777 auto group = segmented_fusion_->groups()[group_index];
778
779 auto maybe_scheduler_entry = group->getMaybeSchedulerEntry(runtime_info);
780 if (!maybe_scheduler_entry.has_value()) {
781 return c10::nullopt;
782 }
783 auto scheduler_entry = std::move(maybe_scheduler_entry.value());
784 if (!scheduler_entry->sameAs(
785 heuristics_->heuristicsList()[group_index].get())) {
786 return c10::nullopt;
787 }
788 ret.value()->emplaceBack(std::move(scheduler_entry));
789 }
790
791 return ret;
792}
793
794void GraphCache::createFusion(const std::shared_ptr<Graph>& graph) {
795 FUSER_PERF_SCOPE("GraphCache::createFusion");
796
797 fusion_executor_cache_ =
798 std::make_unique<FusionExecutorCache>(parseJitIR(graph));
799
800 num_of_outputs_ = graph->outputs().size();
801}
802
803// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
804GraphCache::GraphCache(const std::shared_ptr<Graph>& graph) {
805 FUSER_PERF_SCOPE("GraphCache::GraphCache");
806 TORCH_INTERNAL_ASSERT(
807 IsNewExecutorEnabled(), "legacy executor is not supported by nvfuser");
808
809 GRAPH_DEBUG("GraphCache constructor: ", this);
810 GRAPH_DUMP("GraphCache created for graph", graph);
811 createFusion(graph);
812}
813
814std::vector<at::Tensor> GraphCache::runGraphWithInputs(
815 const at::ArrayRef<IValue>& inputs) {
816 FUSER_PERF_SCOPE("GraphCache::runGraphWithInputs");
817
818 GRAPH_DEBUG("running GraphCache: ", this);
819 auto outputs = fusion_executor_cache_->runFusionWithInputs(inputs);
820 TORCH_INTERNAL_ASSERT(
821 outputs.size() == num_of_outputs_,
822 "FusionExecutorCache returned ",
823 outputs.size(),
824 " outputs, doesn't match computational graph, which requires ",
825 num_of_outputs_);
826
827 return outputs;
828}
829
830} // namespace cuda
831} // namespace fuser
832} // namespace jit
833} // namespace torch
834