1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | #include <type.h> |
5 | |
6 | #include <iterator> |
7 | #include <unordered_map> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | namespace ir_utils { |
15 | |
16 | // Replace values in fusion using ValReplacementMutator |
17 | void replaceValue( |
18 | Fusion*, |
19 | const std::unordered_map<Val*, Val*>& replacement_map); |
20 | |
21 | template <typename FilterType, typename Iterator> |
22 | class FilterIterator { |
23 | public: |
24 | using iterator_category = std::forward_iterator_tag; |
25 | using difference_type = std::ptrdiff_t; |
26 | using value_type = FilterType*; |
27 | using pointer = value_type*; |
28 | using reference = value_type&; |
29 | |
30 | FilterIterator(Iterator begin, Iterator end) : current_(begin), end_(end) { |
31 | advance(); |
32 | } |
33 | |
34 | FilterType* operator*() const { |
35 | return (*current_)->template as<FilterType>(); |
36 | } |
37 | |
38 | FilterType* operator->() const { |
39 | return (*this); |
40 | } |
41 | |
42 | FilterIterator& operator++() { |
43 | ++current_; |
44 | advance(); |
45 | return *this; |
46 | } |
47 | |
48 | FilterIterator operator++(int) { |
49 | const auto before_increment = *this; |
50 | ++current_; |
51 | advance(); |
52 | return before_increment; |
53 | } |
54 | |
55 | bool operator==(const FilterIterator& other) const { |
56 | TORCH_INTERNAL_ASSERT( |
57 | end_ == other.end_, |
58 | "Comparing two FilteredViews that originate from different containers" ); |
59 | return current_ == other.current_; |
60 | } |
61 | |
62 | bool operator!=(const FilterIterator& other) const { |
63 | return !(*this == other); |
64 | } |
65 | |
66 | private: |
67 | void advance() { |
68 | current_ = std::find_if(current_, end_, [](const auto& val) { |
69 | return dynamic_cast<const FilterType*>(val) != nullptr; |
70 | }); |
71 | } |
72 | |
73 | private: |
74 | Iterator current_; |
75 | Iterator end_; |
76 | }; |
77 | |
78 | // An iterable view to a given container of Val pointers. Only returns |
79 | // Vals of a given Val type. |
80 | // NOTE: Add a non-const iterator if needed. |
81 | template <typename FilterType, typename InputIt> |
82 | class FilteredView { |
83 | public: |
84 | using value_type = FilterType*; |
85 | using const_iterator = FilterIterator<FilterType, InputIt>; |
86 | |
87 | FilteredView(InputIt first, InputIt last) : input_it_(first), last_(last) {} |
88 | |
89 | const_iterator cbegin() const { |
90 | return const_iterator(input_it_, last_); |
91 | } |
92 | |
93 | const_iterator begin() const { |
94 | return cbegin(); |
95 | } |
96 | |
97 | const_iterator cend() const { |
98 | return const_iterator(last_, last_); |
99 | } |
100 | |
101 | const_iterator end() const { |
102 | return cend(); |
103 | } |
104 | |
105 | bool empty() const { |
106 | return begin() == end(); |
107 | } |
108 | |
109 | std::vector<value_type> vector() const { |
110 | return std::vector<value_type>(begin(), end()); |
111 | } |
112 | |
113 | private: |
114 | const InputIt input_it_; |
115 | const InputIt last_; |
116 | }; |
117 | |
118 | template <typename FilterType, typename InputIt> |
119 | auto filterByType(InputIt first, InputIt last) { |
120 | return FilteredView<FilterType, InputIt>(first, last); |
121 | } |
122 | |
123 | template <typename FilterType, typename ContainerType> |
124 | auto filterByType(const ContainerType&& inputs) = delete; |
125 | |
126 | template <typename FilterType, typename ContainerType> |
127 | auto filterByType(const ContainerType& inputs) { |
128 | return filterByType<FilterType>(inputs.cbegin(), inputs.cend()); |
129 | } |
130 | |
131 | //! Returns a list of new-to-old mappings. |
132 | //! |
133 | //! This funcion canonicalizes the dimensions and validates that multiple old |
134 | //! dimension are mapped to the same new dimension. |
135 | std::vector<int64_t> normalizeNew2Old( |
136 | const std::vector<int64_t>& new2old_in, |
137 | size_t ndims); |
138 | |
139 | //! Returns a list of new-to-old mappings. |
140 | //! |
141 | //! The input map does not need to be complete. Missing axes are |
142 | //! assumed not to be affected. |
143 | //! |
144 | //! This is used to preprocess broadcast and transpose arguments. |
145 | //! |
146 | //! Example: (N := ndims) |
147 | //! {{0, 1}} -> [1, 0, ...., N-1] |
148 | //! Transposes the first two axes with no other change. |
149 | //! |
150 | //! {{0, -1}} -> [N-1, ...., 0] |
151 | //! Swaps the first and last axes. |
152 | std::vector<int> normalizeOld2New( |
153 | const std::unordered_map<int, int>& old2new_in, |
154 | size_t ndims); |
155 | |
156 | // Replace all uses of reference with substitute in expr. Return the Expr. |
157 | // Warning: Invalidates provided Expr. |
158 | // Warning: Removes connection of reference through provided Expr. |
159 | // Warning: Creates new Expr connecting substitue. |
160 | // Reference is found through direct pointer comparison. |
161 | Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute); |
162 | |
163 | //! Replace Vals in an index Val as specified by replacement_map while |
164 | //! cloning the given index Val. The index val is assumed to represent |
165 | //! a tensor index consisting of Ints and arithmetic expressions. |
166 | //! |
167 | //! This is similar to replaceValInExpr but is different as Vals are |
168 | //! cloned such that no other exprs using the same leaf Vals are not |
169 | //! modified. TODO: Consider cleaning up the multiple replacement |
170 | //! routines. |
171 | Val* replaceValInIndexVal( |
172 | Val* index, |
173 | const std::unordered_map<Val*, Val*>& replacement_map); |
174 | |
175 | // Makes rfactor generic with reduction ops and Welford |
176 | TORCH_CUDA_CU_API TensorView* rfactorHelper( |
177 | TensorView* red_tv, |
178 | const std::vector<int>& axes); |
179 | |
180 | // Return immediate producers of val, this function can be used on any Val and |
181 | // will return producers through Exprs. |
182 | // |
183 | // Warning: returned val's are not guaranteed to be between fusion inputs and |
184 | // outputs. This function simply uses val->definition() or val->uses() which is |
185 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
186 | // strictly between fusion inputs/outputs, it could effectively return dead |
187 | // code. |
188 | TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(Val* val); |
189 | |
190 | // Return immediate consumers of val, this function can be used on any Val and |
191 | // will return consumers through Exprs. |
192 | // |
193 | // Warning: returned val's are not guaranteed to be between fusion inputs and |
194 | // outputs. This function simply uses val->definition() or val->uses() which is |
195 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
196 | // strictly between fusion inputs/outputs, it could effectively return dead |
197 | // code. |
198 | TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val); |
199 | |
200 | // Return immediate siblings of val, this function can be used on any Val and |
201 | // will return siblings through Exprs. |
202 | // |
203 | // Warning: returned val's are not guaranteed to be between fusion inputs and |
204 | // outputs. This function simply uses val->definition() or val->uses() which is |
205 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
206 | // strictly between fusion inputs/outputs, it could effectively return dead |
207 | // code. |
208 | TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val); |
209 | |
210 | // Return immediate producers of vals, this function can be used on any vals and |
211 | // will return producers through Exprs. |
212 | // |
213 | // Warning: returned val's are not guaranteed to be between fusion inputs and |
214 | // outputs. This function simply uses val->definition() or val->uses() which is |
215 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
216 | // strictly between fusion inputs/outputs, it could effectively return dead |
217 | // code. |
218 | TORCH_CUDA_CU_API std::vector<Val*> producerValsOf( |
219 | const std::vector<Val*>& vals); |
220 | |
221 | // Return immediate consumers of vals, this function can be used on any vals and |
222 | // will return consumers through Exprs. |
223 | // |
224 | // Warning: returned val's are not guaranteed to be between fusion inputs and |
225 | // outputs. This function simply uses val->definition() or val->uses() which is |
226 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
227 | // strictly between fusion inputs/outputs, it could effectively return dead |
228 | // code. |
229 | TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf( |
230 | const std::vector<Val*>& vals); |
231 | |
232 | // Return immediate producers of tv, this function will return all immediate |
233 | // producers of tv through Exprs. |
234 | // |
235 | // Warning: returned tv's are not guaranteed to be between fusion inputs and |
236 | // outputs. This function simply uses tv->definition() or tv->uses() which is |
237 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
238 | // strictly between fusion inputs/outputs, it could effectively return dead |
239 | // code. |
240 | TORCH_CUDA_CU_API std::vector<TensorView*> producerTvsOf(TensorView* tv); |
241 | |
242 | // Return immediate consumers of tv, this function will return all immediate |
243 | // consumers of tv through Exprs. |
244 | // |
245 | // Warning: returned tv's are not guaranteed to be between fusion inputs and |
246 | // outputs. This function simply uses tv->definition() or tv->uses() which is |
247 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
248 | // strictly between fusion inputs/outputs, it could effectively return dead |
249 | // code. |
250 | TORCH_CUDA_CU_API std::vector<TensorView*> consumerTvsOf(TensorView* tv); |
251 | |
252 | // Return immediate siblings of tv, this function will return all immediate |
253 | // siblings of tv through Exprs. |
254 | // |
255 | // Warning: returned tv's are not guaranteed to be between fusion inputs and |
256 | // outputs. This function simply uses tv->definition() or tv->uses() which is |
257 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
258 | // strictly between fusion inputs/outputs, it could effectively return dead |
259 | // code. |
260 | TORCH_CUDA_CU_API std::vector<TensorView*> siblingTvsOf(TensorView* tv); |
261 | |
262 | // Return immediate producers of tvs, this function will return all immediate |
263 | // producers of tvs through Exprs. |
264 | // |
265 | // Warning: returned tv's are not guaranteed to be between fusion inputs and |
266 | // outputs. This function simply uses tv->definition() or tv->uses() which is |
267 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
268 | // strictly between fusion inputs/outputs, it could effectively return dead |
269 | // code. |
270 | TORCH_CUDA_CU_API std::vector<TensorView*> producerTvsOf( |
271 | const std::vector<TensorView*>& tvs); |
272 | |
273 | // Return immediate consumers of tvs, this function will return all immediate |
274 | // consumers of tvs through Exprs. |
275 | // |
276 | // Warning: returned tv's are not guaranteed to be between fusion inputs and |
277 | // outputs. This function simply uses tv->definition() or tv->uses() which is |
278 | // limited to not go through fusion inputs/outputs, but if on a path that isn't |
279 | // strictly between fusion inputs/outputs, it could effectively return dead |
280 | // code. |
281 | TORCH_CUDA_CU_API std::vector<TensorView*> consumerTvsOf( |
282 | const std::vector<TensorView*>& tvs); |
283 | |
284 | // Returns producers of tv that are inputs of fusion |
285 | TORCH_CUDA_CU_API std::vector<TensorView*> inputTvsOf(TensorView* tv); |
286 | |
287 | // Returns consumers of tv that are outputs of fusion |
288 | TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf(TensorView* tv); |
289 | |
290 | // Returns producers of tvs that are inputs of fusion |
291 | TORCH_CUDA_CU_API std::vector<TensorView*> inputTvsOf( |
292 | std::vector<TensorView*> tvs); |
293 | |
294 | // Returns consumers of tvs that are outputs of fusion |
295 | TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf( |
296 | std::vector<TensorView*> tvs); |
297 | |
298 | // returns all tensor views in fusion that are used between outputs and inputs. |
299 | TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion); |
300 | |
301 | // returns all tensor views in fusion that are used between outputs and inputs |
302 | // except the specified set. |
303 | TORCH_CUDA_CU_API std::vector<TensorView*> allTvsExcept( |
304 | Fusion* fusion, |
305 | const std::unordered_set<TensorView*>& except); |
306 | |
307 | TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps( |
308 | Fusion* fusion, |
309 | bool ignore_trivial = true); |
310 | |
311 | // Returns the initialization value of tv or nullptr if not initialized. |
312 | TORCH_CUDA_CU_API Val* getReductionInitValOf(TensorView* tv); |
313 | |
314 | // Returns if Expr is a reduction op |
315 | TORCH_CUDA_CU_API bool isReductionOp(const Expr*); |
316 | |
317 | // Returns if Expr is a reduction op with TensorView or TensorIndex |
318 | TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*); |
319 | |
320 | // Returns all non-trivial view operations. We shouldn't have trivial view |
321 | // operations but this function is to simply make sure if we ever do we don't |
322 | // pull them in. |
323 | TORCH_CUDA_CU_API std::vector<ViewOp*> getViewOps(Fusion*); |
324 | |
325 | template <typename T> |
326 | std::string toString(const T& nodes) { |
327 | std::stringstream ss; |
328 | for (const Statement* stmt : nodes) { |
329 | if (ss.tellp() != 0) { |
330 | ss << ", " ; |
331 | } |
332 | ss << stmt->toString(); |
333 | } |
334 | return ss.str(); |
335 | } |
336 | |
337 | } // namespace ir_utils |
338 | } // namespace cuda |
339 | } // namespace fuser |
340 | } // namespace jit |
341 | } // namespace torch |
342 | |