1 | #pragma once |
---|---|
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <dispatch.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <kernel_ir.h> |
8 | |
9 | #include <vector> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | //! Collects start and stop offsets of all split root domains. Offsets |
17 | //! are zero unless partially split. |
18 | class TORCH_CUDA_CU_API PartialSplitMap { |
19 | public: |
20 | void build(Fusion* fusion); |
21 | |
22 | Val* getStartOffset(IterDomain* root_domain) const; |
23 | Val* getStopOffset(IterDomain* root_domain) const; |
24 | |
25 | private: |
26 | std::unordered_map<IterDomain*, Val*> start_offset_map_; |
27 | std::unordered_map<IterDomain*, Val*> stop_offset_map_; |
28 | }; |
29 | |
30 | } // namespace cuda |
31 | } // namespace fuser |
32 | } // namespace jit |
33 | } // namespace torch |
34 |