1#pragma once
2
3#include <ir_interface_nodes.h>
4#include <ir_utils.h>
5
6namespace torch {
7namespace jit {
8namespace fuser {
9namespace cuda {
10
11/*
12 * MaxInfoSpanningTree is class that generates a path to visit TensorViews in a
13 * DAG. The generated path is a maximum spanning tree of the DAG with the root
14 * at the reference tensor and the DAG traversal path that preserves the maximum
15 * amount of the given information (evaluated by root domain mapping). The
16 * spanning tree is generated using the Prim's algorithm.
17 *
18 * This class only generates ordered paths, it does not have any knowledge about
19 * what how or what information to propagate along these paths. In order to do a
20 * propagation along the generated path, you need to subclass
21 * MaxInfoSpanningTree::Propagator and do path.traverse(propagator);
22 *
23 * This class allows specifying the section of a TV graph to generate the
24 * maximum spanning on. To do this, subclass MaxInfoSpanningTree::Selector and
25 * pass it as an argument to the constructor of this class.
26 *
27 * MaxInfoSpanningTree is an abstract class that has no idea about what
28 * "information" means. In order to use this class, you needs to subclass
29 * MaxInfoSpanningTree::Information and implement `operator<` which is used to
30 * tell which path contains more information, and `operator bool` which is used
31 * to tell if there is any information stored. You also need to implement
32 * computeInfoPasC, computeInfoCasP, and computeInfoSibling, which are the
33 * functions that compute information of the `to` tensor from the information of
34 * the `from` tensor.
35 */
36// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
37class TORCH_CUDA_CU_API MaxInfoSpanningTree {
38 public:
39 // Class to subclass in order to stop traversal, by which limits the nodes in
40 // the spanning tree.
41 struct Selector {
42 virtual bool allowC2P(TensorView* from, TensorView* to) = 0;
43 virtual bool allowP2C(TensorView* from, TensorView* to) = 0;
44 virtual bool allowSibling(TensorView* from, TensorView* to) = 0;
45 virtual ~Selector() {}
46 };
47
48 // This is the interface to implement the actual propagation
49 struct Propagator {
50 virtual void setUp() {}
51 virtual void tearDown() {}
52 virtual void propagateC2P(TensorView* from, TensorView* to) = 0;
53 virtual void propagateP2C(TensorView* from, TensorView* to) = 0;
54 virtual void propagateSibling(TensorView* from, TensorView* to) = 0;
55 virtual ~Propagator() {}
56 };
57
58 // This is the interface that specifies the structure of information used to
59 // determine if the maximum information is preserved.
60 protected:
61 struct Information {
62 // returns true if there is any info about the root domain of the reference
63 // tensor, returns false if there is no info about the root domain of the
64 // reference tensor.
65 virtual operator bool() const = 0;
66 // l < r means l contains a smaller amount of information about the starting
67 // tensor than r.
68 virtual bool operator<(const Information& r) const = 0;
69 // l > r means l contains a bigger amount of information about the starting
70 // tensor than r.
71 bool operator>(const Information& r) const;
72 // l == r means it is hard to tell which one of then contains more
73 // information
74 bool operator==(const Information& r) const;
75 // just to stop compiler warning
76 virtual ~Information() {}
77 };
78
79 private:
80 enum class NextHopType {
81 SIBLING,
82 C_AS_P,
83 P_AS_C,
84 };
85
86 // This is a helper struct that contains all the information about the next
87 // step in the Prim's algorithm
88 struct NextHop {
89 NextHopType type;
90 TensorView* from = nullptr;
91 TensorView* to;
92
93 NextHop() = default;
94 NextHop(NextHopType type_, TensorView* from_, TensorView* to_)
95 : type(type_), from(from_), to(to_) {}
96 };
97
98 struct NextHopWithInfo {
99 NextHop next_hop;
100 std::shared_ptr<Information> info_from;
101 std::shared_ptr<Information> info_to;
102
103 NextHopWithInfo() = default;
104 NextHopWithInfo(
105 NextHop n_h,
106 std::shared_ptr<Information> info_f,
107 std::shared_ptr<Information> info_t)
108 : next_hop(n_h), info_from(info_f), info_to(info_t) {}
109
110 bool operator<(const NextHopWithInfo& r) const {
111 return *info_to < *(r.info_to);
112 }
113 };
114
115 std::vector<NextHop> path_;
116 TensorView* reference_;
117 std::shared_ptr<Information> reference_info_;
118 Selector* selector_;
119
120 void compute_spanning_tree();
121
122 protected:
123 virtual std::shared_ptr<Information> computeInfoPasC(
124 TensorView* from,
125 TensorView* to,
126 std::shared_ptr<Information> from_info) const = 0;
127 virtual std::shared_ptr<Information> computeInfoCasP(
128 TensorView* from,
129 TensorView* to,
130 std::shared_ptr<Information> from_info) const = 0;
131 virtual std::shared_ptr<Information> computeInfoSibling(
132 TensorView* from,
133 TensorView* to,
134 std::shared_ptr<Information> from_info) const = 0;
135
136 public:
137 MaxInfoSpanningTree(
138 TensorView* reference,
139 std::shared_ptr<Information> reference_info,
140 Selector* selector = nullptr);
141 void traverse(Propagator* propagator);
142 virtual ~MaxInfoSpanningTree() {}
143};
144
145// MaxRootDomainInfoSpanningTree is a subclass of MaxInfoSpanningTree which
146// generates the maximum spanning tree that perserves the most amount of root
147// domain information from the reference tensor.
148//*
149// During the path-finding, we explicitly keep track of the information about
150// which reference tensor's root ID's information is preserved, and to which
151// level. This information is stored as a vector of `RootIDInfo`, where each
152// item in the vector corresponds to one ID in the reference tensor's root
153// domain.
154class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree
155 : public MaxInfoSpanningTree {
156 protected:
157 // This is a struct storing how the information about a root ID in the
158 // starting tensor is preserved during path-finding. If during path-finding,
159 // we reached a tensor called the "current" tensor, we are interested in the
160 // following information:
161 // - Which reference tensor's root ID's information does the current tensor
162 // contains? Each RootIDInfo object should correspond to one reference
163 // tensor's root ID, but we don't need to store this ID explicitly.
164 // - For this reference tensor's root ID, what are its corresponding IDs in
165 // the current tensor's root/rfactor domain?
166 // - Is the current tensor's information about this reference tensor's root ID
167 // complete?
168 struct RootIDInfo {
169 // Each object of this class correspond to one root ID in the reference
170 // tensor, but we do not need to explicitly store this ID.
171
172 // The IDs in the current tensor's root or rfactor domain that contains
173 // information of the corresponding reference tensor's root ID. Whether we
174 // are using root domain or rfactor domain depends on how we reached the
175 // current tensor during path-finding. `is_rfactor` tells us whether the IDs
176 // contained in `mapped_ids` are from the root domain or the rfactor domain.
177 std::unordered_set<IterDomain*> mapped_ids;
178
179 // Does `mapped_ids` contain all the IDs required to recompute the
180 // corresponding reference tensor's root ID? For example, if we have
181 // t1 = input tensor of shape (20,)
182 // t2 = view(t1, {4, 5})
183 // t3 = sum(t2, {1})
184 // t4 = set(t3)
185 // and we start the path-finding from t1, then t2 and t3's information about
186 // t1 is complete, but t4 is not because one axis is missing.
187 bool is_complete;
188
189 // Is `mapped_ids` from the root domain or rfactor domain of the current
190 // tensor? We only store IDs from one of them, depending on how we reach the
191 // current tensor during path-finding. If we reached the current tensor from
192 // a consumer, then `mapped_ids` containes IDs in the current tensor's
193 // rfactor domain because the rfactor domain contains raw information. If we
194 // reached the current tensor from a producer, then `mapped_ids` containes
195 // IDs in the current tensor's root domain because the root domain contains
196 // raw information.
197 bool is_rfactor;
198 };
199
200 struct RootDomainInfo : public Information {
201 std::vector<RootIDInfo> info;
202 operator bool() const override;
203 bool operator<(const Information& r) const override;
204 };
205
206 virtual std::shared_ptr<Information> computeInfoPasC(
207 TensorView* from,
208 TensorView* to,
209 std::shared_ptr<Information> from_info) const override;
210 virtual std::shared_ptr<Information> computeInfoCasP(
211 TensorView* from,
212 TensorView* to,
213 std::shared_ptr<Information> from_info) const override;
214 virtual std::shared_ptr<Information> computeInfoSibling(
215 TensorView* from,
216 TensorView* to,
217 std::shared_ptr<Information> from_info) const override;
218
219 private:
220 static std::shared_ptr<RootDomainInfo> getReferenceRootIDInfo(TensorView* tv);
221 static std::shared_ptr<RootDomainInfo> getReferenceRootIDInfo(
222 TensorView* tv,
223 int64_t leaf_pos);
224
225 public:
226 MaxRootDomainInfoSpanningTree(
227 TensorView* reference,
228 std::shared_ptr<Information> reference_info,
229 Selector* selector = nullptr)
230 : MaxInfoSpanningTree(reference, reference_info, selector) {}
231 MaxRootDomainInfoSpanningTree(
232 TensorView* reference,
233 Selector* selector = nullptr)
234 : MaxRootDomainInfoSpanningTree(
235 reference,
236 getReferenceRootIDInfo(reference),
237 selector) {}
238 MaxRootDomainInfoSpanningTree(
239 TensorView* reference,
240 int64_t leaf_pos,
241 Selector* selector = nullptr)
242 : MaxRootDomainInfoSpanningTree(
243 reference,
244 getReferenceRootIDInfo(reference, leaf_pos),
245 selector) {}
246};
247
248class TORCH_CUDA_CU_API SpanningTreePrinter
249 : public MaxInfoSpanningTree::Propagator {
250 std::ostream& stream_;
251
252 public:
253 virtual void propagateC2P(TensorView* from, TensorView* to) override;
254 virtual void propagateP2C(TensorView* from, TensorView* to) override;
255 virtual void propagateSibling(TensorView* from, TensorView* to) override;
256 SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {}
257};
258
259// Simple selector for selecting subgraphs to build spanning trees. The selector
260// allows propagation only to the given set of selected tensorviews, except for
261// sibiling propagation, which we should never block.
262class TORCH_CUDA_CU_API SetSelector : public MaxInfoSpanningTree::Selector {
263 std::unordered_set<TensorView*> selected_;
264
265 public:
266 virtual bool allowC2P(TensorView* from, TensorView* to) override;
267 virtual bool allowP2C(TensorView* from, TensorView* to) override;
268 virtual bool allowSibling(TensorView* from, TensorView* to) override;
269
270 SetSelector(std::unordered_set<TensorView*> selected)
271 : selected_(std::move(selected)) {}
272
273 const std::unordered_set<TensorView*>& selected() const {
274 return selected_;
275 }
276};
277
278} // namespace cuda
279} // namespace fuser
280} // namespace jit
281} // namespace torch
282