1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <dispatch.h>
6#include <ir_all_nodes.h>
7#include <kernel_ir.h>
8
9#include <vector>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16class LoopIndexing;
17
18//! Auxiliary class to represent information about halo of an axis
19class AxisHaloInfo {
20 public:
21 //! Width of halo.
22 //!
23 //! pos is either 0 or 1. The width of halo at offset zero is set
24 //! when pos is 0.
25 int width(int pos) const;
26
27 //! Sum of the widths of both widths
28 int width() const;
29
30 const auto& widths() const {
31 return widths_;
32 }
33
34 //! Set the halo width of either side.
35 //! pos is either 0 or 1. The width of halo at offset zero is set
36 //! when pos is 0.
37 void setWidth(int pos, int width);
38
39 //! Extend the halo width to account for another axis.
40 void merge(int pos, int other);
41
42 //! Extend the halo width to account for another axis.
43 void merge(const AxisHaloInfo& other);
44
45 //! True when halo may be attached
46 bool hasHalo() const;
47
48 std::string toString() const;
49
50 private:
51 //! Sizes of the halo regions of two sides. Both values are zero for
52 //! axes with no halo. When an axis has halo at offset zero,
53 //! widths_[0] is non-zero and designates the size of the
54 //! halo. Similarly, non-zero widths_[1] means the axis has halo at
55 //! the other end of the axis.
56 std::array<int, 2> widths_ = {0, 0};
57};
58
59//! Helper class for lowering tensors with halo. Only valid at the
60//! lowering time.
61class TORCH_CUDA_CU_API HaloInfo {
62 public:
63 //! Scan a fusion and collect all information for lowering
64 HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map);
65
66 //! Almost exact duplicate of build(TensorDomain* td), except that
67 //! the traversal was done on loop indexing expressions.
68 std::unordered_map<IterDomain*, Val*> buildConcreteHaloExtentMap(
69 const LoopIndexing& loop_indexing) const;
70
71 //! Returns true if id has the root halo information set by
72 //! setRootAxisInfo.
73 bool hasRootAxisInfo(IterDomain* id) const;
74
75 //! Returns the registed AxisHaloInfo of a root axis.
76 //!
77 //! This is only for root axes. It is an error to query with
78 //! non-root axes.
79 const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const;
80
81 //! Query if an axis has a halo width.
82 //!
83 //! See the comment at halo_width_map_.
84 bool hasHaloWidth(IterDomain* id) const;
85
86 //! Return the halo width of an axis.
87 //!
88 //! It's an error if queried for an axis with no halo width
89 //! information.
90 int getHaloWidth(IterDomain* id) const;
91
92 //! Returns an extent if id is extended for halo. Nullptr is
93 //! returned otherwise.
94 Val* getExtent(IterDomain* id) const;
95
96 //! Returns all child domains of a root domain that inherits the
97 //! halo of the root domain.
98 //!
99 //! If a root domain is split, only the inner domain inherits the
100 //! halo, so the inner domain is included but not the outer domain.
101 const std::unordered_set<IterDomain*>& getChildDomains(
102 IterDomain* root_id) const;
103
104 //! Returns all root domains from which the halo of a domain
105 //! originates.
106 std::unordered_set<IterDomain*> getRootDomains(IterDomain* id) const;
107
108 //! Returns true if a domain inherits halo associated with a root
109 //! domain.
110 bool isHaloInherited(IterDomain* root_id, IterDomain* id) const;
111
112 // True when the extent of id1 is guaranteed to be lesser than or
113 // equal to id2. False when it *may* not.
114 bool extentLessEqual(IterDomain* id1, IterDomain* id2) const;
115 // True when the extent of id1 is guaranteed to be equal to
116 // id2. False when it *may* not.
117 bool extentEqual(IterDomain* id1, IterDomain* id2) const;
118
119 //! Check if expr must be predicated based on boundary conditions
120 //! directly or indirectly induced by shift expressions.
121 //!
122 //! When yes, the expression needs two predications: one for
123 //! interior and another for padding. Predicate insertion is done in
124 //! the ShiftPredicateInserter class below.
125 bool needsShiftPredicate(Expr* expr) const;
126
127 std::string toString() const;
128
129 private:
130 //! Build mappings of extent information of a TensorDomain
131 void build(TensorDomain* td);
132
133 //! Propagate root axis information from outputs to inputs of an
134 //! expression
135 void propagateRootAxisInfo(Expr* expr);
136
137 //! Set initial AxisHaloInfo of a root axis
138 //!
139 //! The axis does not need to be a root domain in the case of
140 //! reference tensors. Reference tensors get halo information from
141 //! consumer root domains, which may correspond to rfactor domains
142 //! of tensors from which reference tensors are derived.
143 void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info);
144
145 //! Adds a domain to the halo inheritance map.
146 //!
147 //! A domain, child, is added to the same set as domain parent. Both
148 //! domains must be part of TensorDomain td.
149 void insertToInheritanceMap(
150 TensorDomain* td,
151 IterDomain* parent,
152 IterDomain* child);
153
154 //! Propagate root axis information from consumer to producer
155 void propagateRootAxisInfo(
156 TensorView* producer,
157 TensorView* consumer,
158 Expr* expr);
159
160 //! Initialize mappings for a given root domain. The given domain
161 //! must be previously given to setRootAxisInfo.
162 void initializeFromRootAxisInfo(IterDomain* id);
163
164 //! Validate shift usage
165 void validate(TensorView* td, std::shared_ptr<const ComputeAtMap> ca_map)
166 const;
167
168 void setHaloWidth(IterDomain* id, int halo_width);
169
170 private:
171 // Copy the permissive map from the passed in compute at map
172 const DisjointSets<IterDomain*> permissive_map_;
173
174 //! Halo information of root axes
175 std::unordered_map<IterDomain*, AxisHaloInfo> root_axis_map_;
176
177 //! Halo-extended extents. No mapping for axes without halo extension
178 std::unordered_map<IterDomain*, Val*> extent_map_;
179
180 //! The halo width of an axis.
181 //!
182 //! The mapped value is a sum of two widths of both sizes of an
183 //! axis. For root axes, it is equivalent to AxisHaloInfo.widths_[0]
184 //! + AxisHaloInfo.widths_[1] (or AxisHaloInfo.width()). For
185 //! example, when a root axis is extended by 1 for both sides, it'd
186 //! be mapped to 2. For axes with no halo, they are mapped to zero.
187 //!
188 //! When an axis is split, its halo is only propagated to the inner
189 //! output axis, so the value of this map for the inner output is
190 //! the same as the input of split, while the outer output is mapped
191 //! to zero.
192 //!
193 //! When an axis is merged, no mapping is created for its
194 //! output at this point primarly because it isn't clear what the
195 //! "halo width" for a merged axis should mean. Perhaps, a merged
196 //! axis of (N+a)*(M+b), where N and M correspond to the original
197 //! extens of two axes, and a and b correspond to their halo widths,
198 //! it might make sense to set the halo width of this merged axis as
199 //! (N+a)*(M+b)-N*M. Currently, however, this isn't necessary, so no
200 //! particular mapping is created for merged axes.
201 //!
202 //! This is currently used only for conservatively comparing the
203 //! overall extents of axes. See HaloInfo::extentLessEqual and
204 //! HaloInfo::extentEqual.
205 //!
206 //! Example: Suppose a root axis has {0, 1} of
207 //! AxisHaloInfo.widths_. The root axis is mapped to 1. When it is
208 //! split, say, by 4, the output axes, [N / 4] and [4], where N is
209 //! the extent of the root axis, the outer axis is mapped to 0,
210 //! whereas the inner axis is mapped to 1. Further, suppose the
211 //! inner axis is merged with another axis of extent M, we know that
212 //! the extent of the resulting output axis is 5*M, but we don't
213 //! create its mapping.
214 std::unordered_map<IterDomain*, int> halo_width_map_;
215
216 //! Mappings from root domains to child domains that inherit halo
217 std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
218 inheritance_map_;
219};
220
221class ShiftPredicateInserter {
222 public:
223 //! Works mostly the same way as
224 //! PredicateCompute::getInlinePredicate but does the insertion of
225 //! the generated predicate. The branch structure is different from
226 //! the usual predicated expression, so the insertion is also done
227 //! here.
228 static Expr* insert(
229 Expr* expr,
230 const std::vector<kir::ForLoop*>& loops,
231 Bool* thread_pred,
232 bool within_unswitch);
233};
234
235} // namespace cuda
236} // namespace fuser
237} // namespace jit
238} // namespace torch
239