1#pragma once
2
3#include <ir_all_nodes.h>
4#include <type.h>
5
6#include <iterator>
7#include <unordered_map>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14namespace ir_utils {
15
16// Replace values in fusion using ValReplacementMutator
17void replaceValue(
18 Fusion*,
19 const std::unordered_map<Val*, Val*>& replacement_map);
20
21template <typename FilterType, typename Iterator>
22class FilterIterator {
23 public:
24 using iterator_category = std::forward_iterator_tag;
25 using difference_type = std::ptrdiff_t;
26 using value_type = FilterType*;
27 using pointer = value_type*;
28 using reference = value_type&;
29
30 FilterIterator(Iterator begin, Iterator end) : current_(begin), end_(end) {
31 advance();
32 }
33
34 FilterType* operator*() const {
35 return (*current_)->template as<FilterType>();
36 }
37
38 FilterType* operator->() const {
39 return (*this);
40 }
41
42 FilterIterator& operator++() {
43 ++current_;
44 advance();
45 return *this;
46 }
47
48 FilterIterator operator++(int) {
49 const auto before_increment = *this;
50 ++current_;
51 advance();
52 return before_increment;
53 }
54
55 bool operator==(const FilterIterator& other) const {
56 TORCH_INTERNAL_ASSERT(
57 end_ == other.end_,
58 "Comparing two FilteredViews that originate from different containers");
59 return current_ == other.current_;
60 }
61
62 bool operator!=(const FilterIterator& other) const {
63 return !(*this == other);
64 }
65
66 private:
67 void advance() {
68 current_ = std::find_if(current_, end_, [](const auto& val) {
69 return dynamic_cast<const FilterType*>(val) != nullptr;
70 });
71 }
72
73 private:
74 Iterator current_;
75 Iterator end_;
76};
77
78// An iterable view to a given container of Val pointers. Only returns
79// Vals of a given Val type.
80// NOTE: Add a non-const iterator if needed.
81template <typename FilterType, typename InputIt>
82class FilteredView {
83 public:
84 using value_type = FilterType*;
85 using const_iterator = FilterIterator<FilterType, InputIt>;
86
87 FilteredView(InputIt first, InputIt last) : input_it_(first), last_(last) {}
88
89 const_iterator cbegin() const {
90 return const_iterator(input_it_, last_);
91 }
92
93 const_iterator begin() const {
94 return cbegin();
95 }
96
97 const_iterator cend() const {
98 return const_iterator(last_, last_);
99 }
100
101 const_iterator end() const {
102 return cend();
103 }
104
105 bool empty() const {
106 return begin() == end();
107 }
108
109 std::vector<value_type> vector() const {
110 return std::vector<value_type>(begin(), end());
111 }
112
113 private:
114 const InputIt input_it_;
115 const InputIt last_;
116};
117
118template <typename FilterType, typename InputIt>
119auto filterByType(InputIt first, InputIt last) {
120 return FilteredView<FilterType, InputIt>(first, last);
121}
122
123template <typename FilterType, typename ContainerType>
124auto filterByType(const ContainerType&& inputs) = delete;
125
126template <typename FilterType, typename ContainerType>
127auto filterByType(const ContainerType& inputs) {
128 return filterByType<FilterType>(inputs.cbegin(), inputs.cend());
129}
130
131//! Returns a list of new-to-old mappings.
132//!
133//! This funcion canonicalizes the dimensions and validates that multiple old
134//! dimension are mapped to the same new dimension.
135std::vector<int64_t> normalizeNew2Old(
136 const std::vector<int64_t>& new2old_in,
137 size_t ndims);
138
139//! Returns a list of new-to-old mappings.
140//!
141//! The input map does not need to be complete. Missing axes are
142//! assumed not to be affected.
143//!
144//! This is used to preprocess broadcast and transpose arguments.
145//!
146//! Example: (N := ndims)
147//! {{0, 1}} -> [1, 0, ...., N-1]
148//! Transposes the first two axes with no other change.
149//!
150//! {{0, -1}} -> [N-1, ...., 0]
151//! Swaps the first and last axes.
152std::vector<int> normalizeOld2New(
153 const std::unordered_map<int, int>& old2new_in,
154 size_t ndims);
155
156// Replace all uses of reference with substitute in expr. Return the Expr.
157// Warning: Invalidates provided Expr.
158// Warning: Removes connection of reference through provided Expr.
159// Warning: Creates new Expr connecting substitue.
160// Reference is found through direct pointer comparison.
161Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute);
162
163//! Replace Vals in an index Val as specified by replacement_map while
164//! cloning the given index Val. The index val is assumed to represent
165//! a tensor index consisting of Ints and arithmetic expressions.
166//!
167//! This is similar to replaceValInExpr but is different as Vals are
168//! cloned such that no other exprs using the same leaf Vals are not
169//! modified. TODO: Consider cleaning up the multiple replacement
170//! routines.
171Val* replaceValInIndexVal(
172 Val* index,
173 const std::unordered_map<Val*, Val*>& replacement_map);
174
175// Makes rfactor generic with reduction ops and Welford
176TORCH_CUDA_CU_API TensorView* rfactorHelper(
177 TensorView* red_tv,
178 const std::vector<int>& axes);
179
180// Return immediate producers of val, this function can be used on any Val and
181// will return producers through Exprs.
182//
183// Warning: returned val's are not guaranteed to be between fusion inputs and
184// outputs. This function simply uses val->definition() or val->uses() which is
185// limited to not go through fusion inputs/outputs, but if on a path that isn't
186// strictly between fusion inputs/outputs, it could effectively return dead
187// code.
188TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(Val* val);
189
190// Return immediate consumers of val, this function can be used on any Val and
191// will return consumers through Exprs.
192//
193// Warning: returned val's are not guaranteed to be between fusion inputs and
194// outputs. This function simply uses val->definition() or val->uses() which is
195// limited to not go through fusion inputs/outputs, but if on a path that isn't
196// strictly between fusion inputs/outputs, it could effectively return dead
197// code.
198TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val);
199
200// Return immediate siblings of val, this function can be used on any Val and
201// will return siblings through Exprs.
202//
203// Warning: returned val's are not guaranteed to be between fusion inputs and
204// outputs. This function simply uses val->definition() or val->uses() which is
205// limited to not go through fusion inputs/outputs, but if on a path that isn't
206// strictly between fusion inputs/outputs, it could effectively return dead
207// code.
208TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val);
209
210// Return immediate producers of vals, this function can be used on any vals and
211// will return producers through Exprs.
212//
213// Warning: returned val's are not guaranteed to be between fusion inputs and
214// outputs. This function simply uses val->definition() or val->uses() which is
215// limited to not go through fusion inputs/outputs, but if on a path that isn't
216// strictly between fusion inputs/outputs, it could effectively return dead
217// code.
218TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(
219 const std::vector<Val*>& vals);
220
221// Return immediate consumers of vals, this function can be used on any vals and
222// will return consumers through Exprs.
223//
224// Warning: returned val's are not guaranteed to be between fusion inputs and
225// outputs. This function simply uses val->definition() or val->uses() which is
226// limited to not go through fusion inputs/outputs, but if on a path that isn't
227// strictly between fusion inputs/outputs, it could effectively return dead
228// code.
229TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(
230 const std::vector<Val*>& vals);
231
232// Return immediate producers of tv, this function will return all immediate
233// producers of tv through Exprs.
234//
235// Warning: returned tv's are not guaranteed to be between fusion inputs and
236// outputs. This function simply uses tv->definition() or tv->uses() which is
237// limited to not go through fusion inputs/outputs, but if on a path that isn't
238// strictly between fusion inputs/outputs, it could effectively return dead
239// code.
240TORCH_CUDA_CU_API std::vector<TensorView*> producerTvsOf(TensorView* tv);
241
242// Return immediate consumers of tv, this function will return all immediate
243// consumers of tv through Exprs.
244//
245// Warning: returned tv's are not guaranteed to be between fusion inputs and
246// outputs. This function simply uses tv->definition() or tv->uses() which is
247// limited to not go through fusion inputs/outputs, but if on a path that isn't
248// strictly between fusion inputs/outputs, it could effectively return dead
249// code.
250TORCH_CUDA_CU_API std::vector<TensorView*> consumerTvsOf(TensorView* tv);
251
252// Return immediate siblings of tv, this function will return all immediate
253// siblings of tv through Exprs.
254//
255// Warning: returned tv's are not guaranteed to be between fusion inputs and
256// outputs. This function simply uses tv->definition() or tv->uses() which is
257// limited to not go through fusion inputs/outputs, but if on a path that isn't
258// strictly between fusion inputs/outputs, it could effectively return dead
259// code.
260TORCH_CUDA_CU_API std::vector<TensorView*> siblingTvsOf(TensorView* tv);
261
262// Return immediate producers of tvs, this function will return all immediate
263// producers of tvs through Exprs.
264//
265// Warning: returned tv's are not guaranteed to be between fusion inputs and
266// outputs. This function simply uses tv->definition() or tv->uses() which is
267// limited to not go through fusion inputs/outputs, but if on a path that isn't
268// strictly between fusion inputs/outputs, it could effectively return dead
269// code.
270TORCH_CUDA_CU_API std::vector<TensorView*> producerTvsOf(
271 const std::vector<TensorView*>& tvs);
272
273// Return immediate consumers of tvs, this function will return all immediate
274// consumers of tvs through Exprs.
275//
276// Warning: returned tv's are not guaranteed to be between fusion inputs and
277// outputs. This function simply uses tv->definition() or tv->uses() which is
278// limited to not go through fusion inputs/outputs, but if on a path that isn't
279// strictly between fusion inputs/outputs, it could effectively return dead
280// code.
281TORCH_CUDA_CU_API std::vector<TensorView*> consumerTvsOf(
282 const std::vector<TensorView*>& tvs);
283
284// Returns producers of tv that are inputs of fusion
285TORCH_CUDA_CU_API std::vector<TensorView*> inputTvsOf(TensorView* tv);
286
287// Returns consumers of tv that are outputs of fusion
288TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf(TensorView* tv);
289
290// Returns producers of tvs that are inputs of fusion
291TORCH_CUDA_CU_API std::vector<TensorView*> inputTvsOf(
292 std::vector<TensorView*> tvs);
293
294// Returns consumers of tvs that are outputs of fusion
295TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf(
296 std::vector<TensorView*> tvs);
297
298// returns all tensor views in fusion that are used between outputs and inputs.
299TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion);
300
301// returns all tensor views in fusion that are used between outputs and inputs
302// except the specified set.
303TORCH_CUDA_CU_API std::vector<TensorView*> allTvsExcept(
304 Fusion* fusion,
305 const std::unordered_set<TensorView*>& except);
306
307TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
308 Fusion* fusion,
309 bool ignore_trivial = true);
310
311// Returns the initialization value of tv or nullptr if not initialized.
312TORCH_CUDA_CU_API Val* getReductionInitValOf(TensorView* tv);
313
314// Returns if Expr is a reduction op
315TORCH_CUDA_CU_API bool isReductionOp(const Expr*);
316
317// Returns if Expr is a reduction op with TensorView or TensorIndex
318TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);
319
320// Returns all non-trivial view operations. We shouldn't have trivial view
321// operations but this function is to simply make sure if we ever do we don't
322// pull them in.
323TORCH_CUDA_CU_API std::vector<ViewOp*> getViewOps(Fusion*);
324
325template <typename T>
326std::string toString(const T& nodes) {
327 std::stringstream ss;
328 for (const Statement* stmt : nodes) {
329 if (ss.tellp() != 0) {
330 ss << ", ";
331 }
332 ss << stmt->toString();
333 }
334 return ss.str();
335}
336
337} // namespace ir_utils
338} // namespace cuda
339} // namespace fuser
340} // namespace jit
341} // namespace torch
342