1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <dispatch.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <kernel_ir.h> |
8 | |
9 | #include <vector> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | class LoopIndexing; |
17 | |
18 | //! Auxiliary class to represent information about halo of an axis |
19 | class AxisHaloInfo { |
20 | public: |
21 | //! Width of halo. |
22 | //! |
23 | //! pos is either 0 or 1. The width of halo at offset zero is set |
24 | //! when pos is 0. |
25 | int width(int pos) const; |
26 | |
27 | //! Sum of the widths of both widths |
28 | int width() const; |
29 | |
30 | const auto& widths() const { |
31 | return widths_; |
32 | } |
33 | |
34 | //! Set the halo width of either side. |
35 | //! pos is either 0 or 1. The width of halo at offset zero is set |
36 | //! when pos is 0. |
37 | void setWidth(int pos, int width); |
38 | |
39 | //! Extend the halo width to account for another axis. |
40 | void merge(int pos, int other); |
41 | |
42 | //! Extend the halo width to account for another axis. |
43 | void merge(const AxisHaloInfo& other); |
44 | |
45 | //! True when halo may be attached |
46 | bool hasHalo() const; |
47 | |
48 | std::string toString() const; |
49 | |
50 | private: |
51 | //! Sizes of the halo regions of two sides. Both values are zero for |
52 | //! axes with no halo. When an axis has halo at offset zero, |
53 | //! widths_[0] is non-zero and designates the size of the |
54 | //! halo. Similarly, non-zero widths_[1] means the axis has halo at |
55 | //! the other end of the axis. |
56 | std::array<int, 2> widths_ = {0, 0}; |
57 | }; |
58 | |
59 | //! Helper class for lowering tensors with halo. Only valid at the |
60 | //! lowering time. |
61 | class TORCH_CUDA_CU_API HaloInfo { |
62 | public: |
63 | //! Scan a fusion and collect all information for lowering |
64 | HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map); |
65 | |
66 | //! Almost exact duplicate of build(TensorDomain* td), except that |
67 | //! the traversal was done on loop indexing expressions. |
68 | std::unordered_map<IterDomain*, Val*> buildConcreteHaloExtentMap( |
69 | const LoopIndexing& loop_indexing) const; |
70 | |
71 | //! Returns true if id has the root halo information set by |
72 | //! setRootAxisInfo. |
73 | bool hasRootAxisInfo(IterDomain* id) const; |
74 | |
75 | //! Returns the registed AxisHaloInfo of a root axis. |
76 | //! |
77 | //! This is only for root axes. It is an error to query with |
78 | //! non-root axes. |
79 | const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const; |
80 | |
81 | //! Query if an axis has a halo width. |
82 | //! |
83 | //! See the comment at halo_width_map_. |
84 | bool hasHaloWidth(IterDomain* id) const; |
85 | |
86 | //! Return the halo width of an axis. |
87 | //! |
88 | //! It's an error if queried for an axis with no halo width |
89 | //! information. |
90 | int getHaloWidth(IterDomain* id) const; |
91 | |
92 | //! Returns an extent if id is extended for halo. Nullptr is |
93 | //! returned otherwise. |
94 | Val* getExtent(IterDomain* id) const; |
95 | |
96 | //! Returns all child domains of a root domain that inherits the |
97 | //! halo of the root domain. |
98 | //! |
99 | //! If a root domain is split, only the inner domain inherits the |
100 | //! halo, so the inner domain is included but not the outer domain. |
101 | const std::unordered_set<IterDomain*>& getChildDomains( |
102 | IterDomain* root_id) const; |
103 | |
104 | //! Returns all root domains from which the halo of a domain |
105 | //! originates. |
106 | std::unordered_set<IterDomain*> getRootDomains(IterDomain* id) const; |
107 | |
108 | //! Returns true if a domain inherits halo associated with a root |
109 | //! domain. |
110 | bool isHaloInherited(IterDomain* root_id, IterDomain* id) const; |
111 | |
112 | // True when the extent of id1 is guaranteed to be lesser than or |
113 | // equal to id2. False when it *may* not. |
114 | bool extentLessEqual(IterDomain* id1, IterDomain* id2) const; |
115 | // True when the extent of id1 is guaranteed to be equal to |
116 | // id2. False when it *may* not. |
117 | bool extentEqual(IterDomain* id1, IterDomain* id2) const; |
118 | |
119 | //! Check if expr must be predicated based on boundary conditions |
120 | //! directly or indirectly induced by shift expressions. |
121 | //! |
122 | //! When yes, the expression needs two predications: one for |
123 | //! interior and another for padding. Predicate insertion is done in |
124 | //! the ShiftPredicateInserter class below. |
125 | bool needsShiftPredicate(Expr* expr) const; |
126 | |
127 | std::string toString() const; |
128 | |
129 | private: |
130 | //! Build mappings of extent information of a TensorDomain |
131 | void build(TensorDomain* td); |
132 | |
133 | //! Propagate root axis information from outputs to inputs of an |
134 | //! expression |
135 | void propagateRootAxisInfo(Expr* expr); |
136 | |
137 | //! Set initial AxisHaloInfo of a root axis |
138 | //! |
139 | //! The axis does not need to be a root domain in the case of |
140 | //! reference tensors. Reference tensors get halo information from |
141 | //! consumer root domains, which may correspond to rfactor domains |
142 | //! of tensors from which reference tensors are derived. |
143 | void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info); |
144 | |
145 | //! Adds a domain to the halo inheritance map. |
146 | //! |
147 | //! A domain, child, is added to the same set as domain parent. Both |
148 | //! domains must be part of TensorDomain td. |
149 | void insertToInheritanceMap( |
150 | TensorDomain* td, |
151 | IterDomain* parent, |
152 | IterDomain* child); |
153 | |
154 | //! Propagate root axis information from consumer to producer |
155 | void propagateRootAxisInfo( |
156 | TensorView* producer, |
157 | TensorView* consumer, |
158 | Expr* expr); |
159 | |
160 | //! Initialize mappings for a given root domain. The given domain |
161 | //! must be previously given to setRootAxisInfo. |
162 | void initializeFromRootAxisInfo(IterDomain* id); |
163 | |
164 | //! Validate shift usage |
165 | void validate(TensorView* td, std::shared_ptr<const ComputeAtMap> ca_map) |
166 | const; |
167 | |
168 | void setHaloWidth(IterDomain* id, int halo_width); |
169 | |
170 | private: |
171 | // Copy the permissive map from the passed in compute at map |
172 | const DisjointSets<IterDomain*> permissive_map_; |
173 | |
174 | //! Halo information of root axes |
175 | std::unordered_map<IterDomain*, AxisHaloInfo> root_axis_map_; |
176 | |
177 | //! Halo-extended extents. No mapping for axes without halo extension |
178 | std::unordered_map<IterDomain*, Val*> extent_map_; |
179 | |
180 | //! The halo width of an axis. |
181 | //! |
182 | //! The mapped value is a sum of two widths of both sizes of an |
183 | //! axis. For root axes, it is equivalent to AxisHaloInfo.widths_[0] |
184 | //! + AxisHaloInfo.widths_[1] (or AxisHaloInfo.width()). For |
185 | //! example, when a root axis is extended by 1 for both sides, it'd |
186 | //! be mapped to 2. For axes with no halo, they are mapped to zero. |
187 | //! |
188 | //! When an axis is split, its halo is only propagated to the inner |
189 | //! output axis, so the value of this map for the inner output is |
190 | //! the same as the input of split, while the outer output is mapped |
191 | //! to zero. |
192 | //! |
193 | //! When an axis is merged, no mapping is created for its |
194 | //! output at this point primarly because it isn't clear what the |
195 | //! "halo width" for a merged axis should mean. Perhaps, a merged |
196 | //! axis of (N+a)*(M+b), where N and M correspond to the original |
197 | //! extens of two axes, and a and b correspond to their halo widths, |
198 | //! it might make sense to set the halo width of this merged axis as |
199 | //! (N+a)*(M+b)-N*M. Currently, however, this isn't necessary, so no |
200 | //! particular mapping is created for merged axes. |
201 | //! |
202 | //! This is currently used only for conservatively comparing the |
203 | //! overall extents of axes. See HaloInfo::extentLessEqual and |
204 | //! HaloInfo::extentEqual. |
205 | //! |
206 | //! Example: Suppose a root axis has {0, 1} of |
207 | //! AxisHaloInfo.widths_. The root axis is mapped to 1. When it is |
208 | //! split, say, by 4, the output axes, [N / 4] and [4], where N is |
209 | //! the extent of the root axis, the outer axis is mapped to 0, |
210 | //! whereas the inner axis is mapped to 1. Further, suppose the |
211 | //! inner axis is merged with another axis of extent M, we know that |
212 | //! the extent of the resulting output axis is 5*M, but we don't |
213 | //! create its mapping. |
214 | std::unordered_map<IterDomain*, int> halo_width_map_; |
215 | |
216 | //! Mappings from root domains to child domains that inherit halo |
217 | std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> |
218 | inheritance_map_; |
219 | }; |
220 | |
221 | class ShiftPredicateInserter { |
222 | public: |
223 | //! Works mostly the same way as |
224 | //! PredicateCompute::getInlinePredicate but does the insertion of |
225 | //! the generated predicate. The branch structure is different from |
226 | //! the usual predicated expression, so the insertion is also done |
227 | //! here. |
228 | static Expr* insert( |
229 | Expr* expr, |
230 | const std::vector<kir::ForLoop*>& loops, |
231 | Bool* thread_pred, |
232 | bool within_unswitch); |
233 | }; |
234 | |
235 | } // namespace cuda |
236 | } // namespace fuser |
237 | } // namespace jit |
238 | } // namespace torch |
239 | |