1#include <ir_utils.h>
2#include <lower2device.h>
3#include <partial_split_map.h>
4
5namespace torch {
6namespace jit {
7namespace fuser {
8namespace cuda {
9
10void PartialSplitMap::build(Fusion* fusion) {
11 auto used_vals = ir_utils::allTvs(fusion);
12
13 for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
14 auto exprs = StmtSort::getExprs(
15 fusion, {tv->domain()->domain().begin(), tv->domain()->domain().end()});
16 for (auto split : ir_utils::filterByType<Split>(exprs)) {
17 // Only needs to check root domains as partial split is only
18 // allowed with root domains
19 if (std::find(
20 tv->getRootDomain().begin(),
21 tv->getRootDomain().end(),
22 split->in()) == tv->getRootDomain().end()) {
23 continue;
24 }
25 auto root_domain = split->in();
26 auto start_offset = split->startOffset();
27 start_offset_map_.insert({root_domain, start_offset});
28 auto stop_offset = split->stopOffset();
29 stop_offset_map_.insert({root_domain, stop_offset});
30 }
31 }
32}
33
34Val* PartialSplitMap::getStartOffset(IterDomain* root_domain) const {
35 auto it = start_offset_map_.find(root_domain);
36 if (it == start_offset_map_.end()) {
37 return nullptr;
38 } else {
39 return it->second;
40 }
41}
42
43Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const {
44 auto it = stop_offset_map_.find(root_domain);
45 if (it == stop_offset_map_.end()) {
46 return nullptr;
47 } else {
48 return it->second;
49 }
50}
51
52} // namespace cuda
53} // namespace fuser
54} // namespace jit
55} // namespace torch
56