1 | #pragma once |
2 | |
3 | #include <ir_interface_nodes.h> |
4 | #include <ir_utils.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | namespace fuser { |
9 | namespace cuda { |
10 | |
11 | /* |
12 | * MaxInfoSpanningTree is class that generates a path to visit TensorViews in a |
13 | * DAG. The generated path is a maximum spanning tree of the DAG with the root |
14 | * at the reference tensor and the DAG traversal path that preserves the maximum |
15 | * amount of the given information (evaluated by root domain mapping). The |
16 | * spanning tree is generated using the Prim's algorithm. |
17 | * |
18 | * This class only generates ordered paths, it does not have any knowledge about |
19 | * what how or what information to propagate along these paths. In order to do a |
20 | * propagation along the generated path, you need to subclass |
21 | * MaxInfoSpanningTree::Propagator and do path.traverse(propagator); |
22 | * |
23 | * This class allows specifying the section of a TV graph to generate the |
24 | * maximum spanning on. To do this, subclass MaxInfoSpanningTree::Selector and |
25 | * pass it as an argument to the constructor of this class. |
26 | * |
27 | * MaxInfoSpanningTree is an abstract class that has no idea about what |
28 | * "information" means. In order to use this class, you needs to subclass |
29 | * MaxInfoSpanningTree::Information and implement `operator<` which is used to |
30 | * tell which path contains more information, and `operator bool` which is used |
31 | * to tell if there is any information stored. You also need to implement |
32 | * computeInfoPasC, computeInfoCasP, and computeInfoSibling, which are the |
33 | * functions that compute information of the `to` tensor from the information of |
34 | * the `from` tensor. |
35 | */ |
36 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
37 | class TORCH_CUDA_CU_API MaxInfoSpanningTree { |
38 | public: |
39 | // Class to subclass in order to stop traversal, by which limits the nodes in |
40 | // the spanning tree. |
41 | struct Selector { |
42 | virtual bool allowC2P(TensorView* from, TensorView* to) = 0; |
43 | virtual bool allowP2C(TensorView* from, TensorView* to) = 0; |
44 | virtual bool allowSibling(TensorView* from, TensorView* to) = 0; |
45 | virtual ~Selector() {} |
46 | }; |
47 | |
48 | // This is the interface to implement the actual propagation |
49 | struct Propagator { |
50 | virtual void setUp() {} |
51 | virtual void tearDown() {} |
52 | virtual void propagateC2P(TensorView* from, TensorView* to) = 0; |
53 | virtual void propagateP2C(TensorView* from, TensorView* to) = 0; |
54 | virtual void propagateSibling(TensorView* from, TensorView* to) = 0; |
55 | virtual ~Propagator() {} |
56 | }; |
57 | |
58 | // This is the interface that specifies the structure of information used to |
59 | // determine if the maximum information is preserved. |
60 | protected: |
61 | struct Information { |
62 | // returns true if there is any info about the root domain of the reference |
63 | // tensor, returns false if there is no info about the root domain of the |
64 | // reference tensor. |
65 | virtual operator bool() const = 0; |
66 | // l < r means l contains a smaller amount of information about the starting |
67 | // tensor than r. |
68 | virtual bool operator<(const Information& r) const = 0; |
69 | // l > r means l contains a bigger amount of information about the starting |
70 | // tensor than r. |
71 | bool operator>(const Information& r) const; |
72 | // l == r means it is hard to tell which one of then contains more |
73 | // information |
74 | bool operator==(const Information& r) const; |
75 | // just to stop compiler warning |
76 | virtual ~Information() {} |
77 | }; |
78 | |
79 | private: |
80 | enum class NextHopType { |
81 | SIBLING, |
82 | C_AS_P, |
83 | P_AS_C, |
84 | }; |
85 | |
86 | // This is a helper struct that contains all the information about the next |
87 | // step in the Prim's algorithm |
88 | struct NextHop { |
89 | NextHopType type; |
90 | TensorView* from = nullptr; |
91 | TensorView* to; |
92 | |
93 | NextHop() = default; |
94 | NextHop(NextHopType type_, TensorView* from_, TensorView* to_) |
95 | : type(type_), from(from_), to(to_) {} |
96 | }; |
97 | |
98 | struct NextHopWithInfo { |
99 | NextHop next_hop; |
100 | std::shared_ptr<Information> info_from; |
101 | std::shared_ptr<Information> info_to; |
102 | |
103 | NextHopWithInfo() = default; |
104 | NextHopWithInfo( |
105 | NextHop n_h, |
106 | std::shared_ptr<Information> info_f, |
107 | std::shared_ptr<Information> info_t) |
108 | : next_hop(n_h), info_from(info_f), info_to(info_t) {} |
109 | |
110 | bool operator<(const NextHopWithInfo& r) const { |
111 | return *info_to < *(r.info_to); |
112 | } |
113 | }; |
114 | |
115 | std::vector<NextHop> path_; |
116 | TensorView* reference_; |
117 | std::shared_ptr<Information> reference_info_; |
118 | Selector* selector_; |
119 | |
120 | void compute_spanning_tree(); |
121 | |
122 | protected: |
123 | virtual std::shared_ptr<Information> computeInfoPasC( |
124 | TensorView* from, |
125 | TensorView* to, |
126 | std::shared_ptr<Information> from_info) const = 0; |
127 | virtual std::shared_ptr<Information> computeInfoCasP( |
128 | TensorView* from, |
129 | TensorView* to, |
130 | std::shared_ptr<Information> from_info) const = 0; |
131 | virtual std::shared_ptr<Information> computeInfoSibling( |
132 | TensorView* from, |
133 | TensorView* to, |
134 | std::shared_ptr<Information> from_info) const = 0; |
135 | |
136 | public: |
137 | MaxInfoSpanningTree( |
138 | TensorView* reference, |
139 | std::shared_ptr<Information> reference_info, |
140 | Selector* selector = nullptr); |
141 | void traverse(Propagator* propagator); |
142 | virtual ~MaxInfoSpanningTree() {} |
143 | }; |
144 | |
145 | // MaxRootDomainInfoSpanningTree is a subclass of MaxInfoSpanningTree which |
146 | // generates the maximum spanning tree that perserves the most amount of root |
147 | // domain information from the reference tensor. |
148 | //* |
149 | // During the path-finding, we explicitly keep track of the information about |
150 | // which reference tensor's root ID's information is preserved, and to which |
151 | // level. This information is stored as a vector of `RootIDInfo`, where each |
152 | // item in the vector corresponds to one ID in the reference tensor's root |
153 | // domain. |
154 | class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree |
155 | : public MaxInfoSpanningTree { |
156 | protected: |
157 | // This is a struct storing how the information about a root ID in the |
158 | // starting tensor is preserved during path-finding. If during path-finding, |
159 | // we reached a tensor called the "current" tensor, we are interested in the |
160 | // following information: |
161 | // - Which reference tensor's root ID's information does the current tensor |
162 | // contains? Each RootIDInfo object should correspond to one reference |
163 | // tensor's root ID, but we don't need to store this ID explicitly. |
164 | // - For this reference tensor's root ID, what are its corresponding IDs in |
165 | // the current tensor's root/rfactor domain? |
166 | // - Is the current tensor's information about this reference tensor's root ID |
167 | // complete? |
168 | struct RootIDInfo { |
169 | // Each object of this class correspond to one root ID in the reference |
170 | // tensor, but we do not need to explicitly store this ID. |
171 | |
172 | // The IDs in the current tensor's root or rfactor domain that contains |
173 | // information of the corresponding reference tensor's root ID. Whether we |
174 | // are using root domain or rfactor domain depends on how we reached the |
175 | // current tensor during path-finding. `is_rfactor` tells us whether the IDs |
176 | // contained in `mapped_ids` are from the root domain or the rfactor domain. |
177 | std::unordered_set<IterDomain*> mapped_ids; |
178 | |
179 | // Does `mapped_ids` contain all the IDs required to recompute the |
180 | // corresponding reference tensor's root ID? For example, if we have |
181 | // t1 = input tensor of shape (20,) |
182 | // t2 = view(t1, {4, 5}) |
183 | // t3 = sum(t2, {1}) |
184 | // t4 = set(t3) |
185 | // and we start the path-finding from t1, then t2 and t3's information about |
186 | // t1 is complete, but t4 is not because one axis is missing. |
187 | bool is_complete; |
188 | |
189 | // Is `mapped_ids` from the root domain or rfactor domain of the current |
190 | // tensor? We only store IDs from one of them, depending on how we reach the |
191 | // current tensor during path-finding. If we reached the current tensor from |
192 | // a consumer, then `mapped_ids` containes IDs in the current tensor's |
193 | // rfactor domain because the rfactor domain contains raw information. If we |
194 | // reached the current tensor from a producer, then `mapped_ids` containes |
195 | // IDs in the current tensor's root domain because the root domain contains |
196 | // raw information. |
197 | bool is_rfactor; |
198 | }; |
199 | |
200 | struct RootDomainInfo : public Information { |
201 | std::vector<RootIDInfo> info; |
202 | operator bool() const override; |
203 | bool operator<(const Information& r) const override; |
204 | }; |
205 | |
206 | virtual std::shared_ptr<Information> computeInfoPasC( |
207 | TensorView* from, |
208 | TensorView* to, |
209 | std::shared_ptr<Information> from_info) const override; |
210 | virtual std::shared_ptr<Information> computeInfoCasP( |
211 | TensorView* from, |
212 | TensorView* to, |
213 | std::shared_ptr<Information> from_info) const override; |
214 | virtual std::shared_ptr<Information> computeInfoSibling( |
215 | TensorView* from, |
216 | TensorView* to, |
217 | std::shared_ptr<Information> from_info) const override; |
218 | |
219 | private: |
220 | static std::shared_ptr<RootDomainInfo> getReferenceRootIDInfo(TensorView* tv); |
221 | static std::shared_ptr<RootDomainInfo> getReferenceRootIDInfo( |
222 | TensorView* tv, |
223 | int64_t leaf_pos); |
224 | |
225 | public: |
226 | MaxRootDomainInfoSpanningTree( |
227 | TensorView* reference, |
228 | std::shared_ptr<Information> reference_info, |
229 | Selector* selector = nullptr) |
230 | : MaxInfoSpanningTree(reference, reference_info, selector) {} |
231 | MaxRootDomainInfoSpanningTree( |
232 | TensorView* reference, |
233 | Selector* selector = nullptr) |
234 | : MaxRootDomainInfoSpanningTree( |
235 | reference, |
236 | getReferenceRootIDInfo(reference), |
237 | selector) {} |
238 | MaxRootDomainInfoSpanningTree( |
239 | TensorView* reference, |
240 | int64_t leaf_pos, |
241 | Selector* selector = nullptr) |
242 | : MaxRootDomainInfoSpanningTree( |
243 | reference, |
244 | getReferenceRootIDInfo(reference, leaf_pos), |
245 | selector) {} |
246 | }; |
247 | |
248 | class TORCH_CUDA_CU_API SpanningTreePrinter |
249 | : public MaxInfoSpanningTree::Propagator { |
250 | std::ostream& stream_; |
251 | |
252 | public: |
253 | virtual void propagateC2P(TensorView* from, TensorView* to) override; |
254 | virtual void propagateP2C(TensorView* from, TensorView* to) override; |
255 | virtual void propagateSibling(TensorView* from, TensorView* to) override; |
256 | SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {} |
257 | }; |
258 | |
259 | // Simple selector for selecting subgraphs to build spanning trees. The selector |
260 | // allows propagation only to the given set of selected tensorviews, except for |
261 | // sibiling propagation, which we should never block. |
262 | class TORCH_CUDA_CU_API SetSelector : public MaxInfoSpanningTree::Selector { |
263 | std::unordered_set<TensorView*> selected_; |
264 | |
265 | public: |
266 | virtual bool allowC2P(TensorView* from, TensorView* to) override; |
267 | virtual bool allowP2C(TensorView* from, TensorView* to) override; |
268 | virtual bool allowSibling(TensorView* from, TensorView* to) override; |
269 | |
270 | SetSelector(std::unordered_set<TensorView*> selected) |
271 | : selected_(std::move(selected)) {} |
272 | |
273 | const std::unordered_set<TensorView*>& selected() const { |
274 | return selected_; |
275 | } |
276 | }; |
277 | |
278 | } // namespace cuda |
279 | } // namespace fuser |
280 | } // namespace jit |
281 | } // namespace torch |
282 | |