1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <compute_at_map.h> |
6 | #include <fusion.h> |
7 | #include <ir_all_nodes.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | // Looks through all transformations assocaited with view, or enforced divisible |
15 | // vectorization splits and gathers all splits that provably don't have a |
16 | // remainder, therefore the extents of the associated IterDomains do not require |
17 | // a ceilDiv expressions. |
18 | TORCH_CUDA_CU_API std::unordered_set<Split*> getAllDivisibleSplits( |
19 | Fusion* fusion); |
20 | |
21 | // Same as above but will use provided ComputeAtMap instead of building its own. |
22 | TORCH_CUDA_CU_API std::unordered_set<Split*> getAllDivisibleSplits( |
23 | Fusion* fusion, |
24 | const ComputeAtMap* ca_map); |
25 | |
26 | } // namespace cuda |
27 | } // namespace fuser |
28 | } // namespace jit |
29 | } // namespace torch |
30 | |