1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <dispatch.h>
6#include <ir_all_nodes.h>
7#include <kernel_ir.h>
8
9#include <vector>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16//! Collects start and stop offsets of all split root domains. Offsets
17//! are zero unless partially split.
18class TORCH_CUDA_CU_API PartialSplitMap {
19 public:
20 void build(Fusion* fusion);
21
22 Val* getStartOffset(IterDomain* root_domain) const;
23 Val* getStopOffset(IterDomain* root_domain) const;
24
25 private:
26 std::unordered_map<IterDomain*, Val*> start_offset_map_;
27 std::unordered_map<IterDomain*, Val*> stop_offset_map_;
28};
29
30} // namespace cuda
31} // namespace fuser
32} // namespace jit
33} // namespace torch
34