1#pragma once
2
3#include <ir_all_nodes.h>
4#include <kernel_ir.h>
5
6#include <deque>
7#include <unordered_map>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14//! Maps TID/BID to its dimension. It is by default blockDim/gridDim,
15//! but if use of a ParallelType is mapped to a unique constant
16//! extent, the constant value is used instead since presumably it's
17//! more efficient.
18class TORCH_CUDA_CU_API ParallelDimensionMap {
19 public:
20 void build(Fusion* fusion);
21
22 //! Returns the dimension of a ParallelType. nullptr is returned if
23 //! a ParallelType is unused.
24 Val* get(ParallelType pt) const;
25
26 //! True if the dimension of a ParallelType is known to be exact
27 bool isExact(ParallelType pt) const;
28
29 std::string toString() const;
30
31 //! Symbolically analyze if two extent vals are equal
32 static bool equalDim(Val* dim1, Val* dim2);
33
34 private:
35 //! Register the extent of an IterDomain if its constant
36 void registerConstantExtent(IterDomain* id);
37
38 void handleParallelDomain(IterDomain* id);
39
40 void populateDimensionMapWithSingleCASet(
41 ParallelType pt,
42 const std::unordered_set<IterDomain*>& dom_set);
43
44 void populateDimensionMapWithMultipleCASet(
45 ParallelType pt,
46 const std::unordered_set<IterDomain*>& dom_set);
47
48 //! TIDx may need to be marked as non-exact as it may be padded to a
49 //! multiple of the warp size.
50 void adjustMappingsForWarpPadding();
51
52 static IterDomain* getCAMappedConcreteDomain(IterDomain* id);
53
54 private:
55 //! Maps from parallel types to dimensions, which are constant if
56 //! a unique value is found.
57 std::unordered_map<ParallelType, Val*, TypeHash> dim_map_;
58 //! Set of parallel types whose dimensions are identified to be
59 //! exactly the same as extents of mapped domains.
60 std::unordered_set<ParallelType, TypeHash> exact_types_;
61
62 // Below are temporary maps to build the ParallelType-to-dimension
63 // map. Only used during build().
64
65 //! Map from a parallel type to a set of concrete domains where the
66 //! parallel type is used.
67 std::unordered_map<ParallelType, std::unordered_set<IterDomain*>, TypeHash>
68 concrete_dom_map_;
69 //! Keep track of constant extents found for a CA domain set
70 //! represented by the concrete domain.
71 std::unordered_map<IterDomain*, std::unordered_set<int64_t>>
72 constant_extent_map_;
73};
74
75} // namespace cuda
76} // namespace fuser
77} // namespace jit
78} // namespace torch
79