1 | #include <ir_utils.h> |
---|---|
2 | #include <lower2device.h> |
3 | #include <partial_split_map.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | namespace fuser { |
8 | namespace cuda { |
9 | |
10 | void PartialSplitMap::build(Fusion* fusion) { |
11 | auto used_vals = ir_utils::allTvs(fusion); |
12 | |
13 | for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) { |
14 | auto exprs = StmtSort::getExprs( |
15 | fusion, {tv->domain()->domain().begin(), tv->domain()->domain().end()}); |
16 | for (auto split : ir_utils::filterByType<Split>(exprs)) { |
17 | // Only needs to check root domains as partial split is only |
18 | // allowed with root domains |
19 | if (std::find( |
20 | tv->getRootDomain().begin(), |
21 | tv->getRootDomain().end(), |
22 | split->in()) == tv->getRootDomain().end()) { |
23 | continue; |
24 | } |
25 | auto root_domain = split->in(); |
26 | auto start_offset = split->startOffset(); |
27 | start_offset_map_.insert({root_domain, start_offset}); |
28 | auto stop_offset = split->stopOffset(); |
29 | stop_offset_map_.insert({root_domain, stop_offset}); |
30 | } |
31 | } |
32 | } |
33 | |
34 | Val* PartialSplitMap::getStartOffset(IterDomain* root_domain) const { |
35 | auto it = start_offset_map_.find(root_domain); |
36 | if (it == start_offset_map_.end()) { |
37 | return nullptr; |
38 | } else { |
39 | return it->second; |
40 | } |
41 | } |
42 | |
43 | Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const { |
44 | auto it = stop_offset_map_.find(root_domain); |
45 | if (it == stop_offset_map_.end()) { |
46 | return nullptr; |
47 | } else { |
48 | return it->second; |
49 | } |
50 | } |
51 | |
52 | } // namespace cuda |
53 | } // namespace fuser |
54 | } // namespace jit |
55 | } // namespace torch |
56 |