1#pragma once
2
3#include <ir_all_nodes.h>
4
5#include <functional>
6#include <unordered_map>
7#include <vector>
8
9// Hoisting common index subexpressions
10//
11// Class CommonIndexMap is updated during the lowering as new indices
12// are inserted. An index is uniquely identified with CommonIndexKey,
13// which consists of the concrete ID of the indexed/predicated domain,
14// the for-loops used in the index, and the index vals of the use
15// for-loops.
16//
17// Once all indices are inserted to CommonIndexMap, allocations of the
18// the hoisted indices are inserted by allocateCommonIndices. Note
19// that this assumes that the CUDA code generator does not inline a
20// scalar Val with allocation (PR #1434).
21
22namespace torch {
23namespace jit {
24namespace fuser {
25namespace cuda {
26
27//! Class to represent unique indexed domains for index
28//! hoisting. Uniquenesss is determined with the indexed domain
29//! itself, the for-loops and their index values.
30class TORCH_CUDA_CU_API CommonIndexKey {
31 friend struct CommonIndexKeyHash;
32
33 public:
34 //! \param consumer_indexed_id Indexed consumer domain
35 //! \param consumer_td TensorDomain of consumer_indexed_id
36 //! \param ref_td Reference domain at the time of indexing
37 //! \param ref_index_map Index map of the reference domain
38 //! \param loops Loop structure where this id is indexed
39 CommonIndexKey(
40 IterDomain* consumer_indexed_id,
41 TensorDomain* consumer_td,
42 TensorDomain* ref_td,
43 const std::unordered_map<IterDomain*, Val*>& ref_index_map,
44 const std::vector<kir::ForLoop*>& loops);
45
46 //! \param consumer_indexed_id Indexed consumer domain
47 //! \param consumer_td TensorDomain of consumer_indexed_id
48 //! \param loop_domains Resolved vector of iterdomain corresponding to loops
49 //! \param loop_index_map Index mapping generated from the loop nest.
50 //! \param loops Loop structure where this id is indexed
51 //! Duplicate of above, but without a reference domain. TODO: Remove other
52 //! implementation.
53 CommonIndexKey(
54 IterDomain* consumer_indexed_id,
55 TensorDomain* consumer_td,
56 const std::vector<IterDomain*>& loop_domains,
57 const std::unordered_map<IterDomain*, Val*>& loop_index_map,
58 const std::vector<kir::ForLoop*>& loops);
59
60 const IterDomain* concreteIndexedId() const {
61 return concrete_indexed_id_;
62 }
63
64 const std::vector<kir::ForLoop*>& usedLoops() const {
65 return used_loops_;
66 }
67
68 const std::vector<Val*>& loopIndexVals() const {
69 return loop_index_vals_;
70 }
71
72 bool operator==(const CommonIndexKey& other) const;
73
74 std::string toString() const;
75
76 private:
77 //! Concrete domain of indexed domain
78 IterDomain* concrete_indexed_id_ = nullptr;
79 //! Loops used for the index
80 std::vector<kir::ForLoop*> used_loops_;
81 //! Loop index vals for the used loops
82 std::vector<Val*> loop_index_vals_;
83};
84
85struct CommonIndexKeyHash {
86 std::size_t operator()(const CommonIndexKey& key) const {
87 auto h = std::hash<const IterDomain*>{}(key.concrete_indexed_id_);
88 // NOTE: do not use other fields as the pointers can be different
89 // even when two keys can share the same index
90 return h;
91 }
92};
93
94//! Map to hold hoisted common indices
95class TORCH_CUDA_CU_API CommonIndexMap {
96 public:
97 //! Register an indexd consumer domain to hoist
98 //!
99 //! Returns a corresponding hoisted index and a flag indicating if a
100 //! new index is inserted.
101 //!
102 //! Consumer domains are used even for producer indexing since
103 //! producer domains in producer indexing are temporary replay
104 //! domains.
105 std::pair<Val*, bool> insert(
106 IterDomain* indexed_consumer_id,
107 TensorDomain* consumer_td,
108 TensorDomain* ref_td,
109 const std::unordered_map<IterDomain*, Val*>& ref_index_map,
110 const std::vector<kir::ForLoop*>& loops,
111 Val* index);
112
113 //! Duplicate of above, but without a reference domain. TODO: Remove other
114 //! implementation.
115 std::pair<Val*, bool> insert(
116 IterDomain* indexed_consumer_id,
117 TensorDomain* consumer_td,
118 const std::vector<IterDomain*>& loop_domains,
119 const std::unordered_map<IterDomain*, Val*>& loop_index_map,
120 const std::vector<kir::ForLoop*>& loops,
121 Val* index);
122
123 const auto& commonIndexMap() const {
124 return common_index_map_;
125 }
126
127 const auto& useCounts() const {
128 return use_counts_;
129 }
130
131 private:
132 //! Utility method to insert a key into common index
133 //! map. Returns a pair of an IR node and a boolean value.
134 //! The IR node will be the previously inserted index if
135 //! the key found a match, or will be the original index
136 //! if this is new key and the key will be stored.
137 //! The boolean value will be true if the key is stored,
138 //! i.e. first time it is inserted.
139 std::pair<Val*, bool> tryInsertNewIndex(CommonIndexKey key, Val* index);
140
141 private:
142 //! Map to hold hoisted common indices
143 std::unordered_map<CommonIndexKey, Val*, CommonIndexKeyHash>
144 common_index_map_;
145 std::unordered_map<CommonIndexKey, int, CommonIndexKeyHash> use_counts_;
146};
147
148//! Insert allocations of hoisted indices. Must be called after
149//! collecting all common indices.
150std::vector<Expr*> allocateCommonIndices(const std::vector<Expr*>& exprs);
151
152} // namespace cuda
153} // namespace fuser
154} // namespace jit
155} // namespace torch
156