1#pragma once
2
3#include <fusion.h>
4#include <index_compute.h>
5
6namespace torch {
7namespace jit {
8namespace fuser {
9namespace cuda {
10
11// Struct to hold useful information from an index pass on iterdomain graph.
12// Used to return the IndexCompute structure back to the indexing calls in
13// index_compute.cpp. Other structurs are required to resolve the actual
14// indexing math there.
15struct IndexFromIdGraph {
16 IndexCompute index;
17 IndexCompute concrete_index;
18 std::unordered_map<IterDomain*, Val*> initial_concrete_index_map;
19 std::vector<IterDomain*> resolved_loop_domains;
20
21 explicit IndexFromIdGraph(
22 IndexCompute index,
23 IndexCompute concrete_index,
24 std::unordered_map<IterDomain*, Val*> initial_concrete_index_map,
25 std::vector<IterDomain*> loop_domains);
26};
27
28//! Indexing interface, returns IndexFromIdGraph which the IndexCompute object
29//! can be queried from directly for the produced indexing. If producer_tv !=
30//! nullptr producer will be indexed, if producer_tv == nullptr consumer will be
31//! indexed. If is_global global indexing will be done, else shared memory or
32//! local indexing will be performed.
33IndexFromIdGraph getTensorIndexFromIdGraph(
34 const std::vector<kir::ForLoop*>& loops,
35 const TensorView* consumer_tv,
36 const TensorView* producer_tv = nullptr,
37 bool is_global = true,
38 std::unordered_map<IterDomain*, IterDomain*> c2p_map = {});
39
40//! Indexing interface for calculating predicate index returns IndexFromIdGraph
41//! which the IndexCompute object can be queried from directly for the produced
42//! indexing If is_start_predicate, will produce indexing math for the start
43//! predicates.
44IndexFromIdGraph getPredicateIndexingFromIdGraph(
45 const std::vector<kir::ForLoop*>& loops,
46 TensorView* consumer_tv,
47 kir::ForLoop* unswitch_or_vec_loop,
48 IterDomain* double_buffer_axis,
49 bool is_start_predicate);
50
51//! getTensorIndexFromIdGraph is the function that index_compute will call very
52//! straightforwardly. However, for implementing the new indexing logic that
53//! starts to abstract some of the indexing away from index_compute we need to
54//! move quite a bit of the intertwined indexing logic away from the
55//! index_compute file and the index_reference_replay file. This is because we
56//! want to separate out what has to be done on the fly, from what analysis we
57//! can do early on with the iter domain graph and associated properties.
58//!
59//! getTensorIndexFromIdGraph places this analysis internally in
60//! LoopIndexingAnalysis. LoopIndexingAnalysis though has to communicate to:
61//! 1) index_compute.cpp::IndexCompute to tell IndexCompute which expressions
62//! it needs to traverse to compute the indexing math.
63//! 2) lower_shift.cpp::HaloInfo::buildConcreteHaloExtentMap to build the halo
64//! extent map used in indexing.
65//!
66//! LoopIndexing is nothing but a mechanism for this communication.
67//!
68//! Holds information needed to produce indexing math. In the current version of
69//! indexing pass, the iter domains combined with the loop nests are the source
70//! of truth in terms of resolving the actual integer indexing math from the
71//! sequence of iterdomain transforms.
72//!
73//! This information is crtiical in resolving indexing associated with complex
74//! broadcast patterns. Check FusionComplexBCast* test cases as well as
75//! FusionAdvancedIndexing* for examples where resolving indices from IterDomain
76//! transformations can be challenging.
77//!
78//! The source of this challenge is due to inling patterns where the IterDomains
79//! responsible for control flow are not local to a particular TensorView.
80//! Broadcast, operations like view/reshape, and gather/shift can make indexing
81//! local buffers complex because of the complex effects inlining into other
82//! TensorViews produce.
83//!
84//! TODO:
85//! The first iteration tries to match the semantics of reference
86//! replay without any new logic. In a follow up iteration will
87//! need to revisit a few further pathological patterns.
88//!
89//! Note:
90//! The current implementation of loop indexing pass works on
91//! equivalent classes defined by ComputeAt exact map. The
92//! list of expressions stored in this class form a "reference", graph of
93//! iterdomain expressions when all of their inputs and outputs are replaced
94//! with their exact concrete mapped id's.
95//!
96//! Here an invariant in a graph of iterdomain expressions is that
97//! each iterdomain is produced exactly once and is either a leaf domain
98//! or has been consumed exactly once by another expression. This makes sure
99//! that a well defined indexing can be generated for each of the concrete ids
100//! whenever we either forward or backward traverse the graph.
101class LoopIndexing {
102 public:
103 //! Returns the original loop nest.
104 const auto& loops() const {
105 return loops_;
106 }
107
108 //! Returns the vector of Iterdomains
109 //! that match the original loop pattern.
110 const auto& loopDomains() const {
111 return loop_domains_;
112 }
113
114 //! Returns the consumer tv that the view info
115 //! was derived from.
116 auto consumerTv() const {
117 return consumer_tv_;
118 }
119
120 //! Returns the set of Iterdomain transforms that
121 //! define the correct indexing path, in forward
122 //! topological order.
123 std::vector<Expr*> getForwardExprList() const;
124
125 //! Returns the set of Iterdomain transforms that
126 //! define the correct indexing path, in backward
127 //! topological order.
128 std::vector<Expr*> getBackwardExprList() const;
129
130 //! Returns the set of out of line expressions in
131 //! reverse topological order.
132 const std::vector<Expr*>& getBackwardOutOfLineExprList() const {
133 return out_of_line_exprs_;
134 }
135
136 //! Returns all exact concrete id's that were produced
137 //! or consumed in the selected indexing expressions
138 std::unordered_set<IterDomain*> getAllExactConcreteIdSet() const;
139
140 private:
141 friend class LoopIndexingAnalysis;
142
143 //! The loop nest that this loop indexing is derived from.
144 std::vector<kir::ForLoop*> loops_;
145
146 //! Consumer tv, where the view related info was derived from.
147 const TensorView* consumer_tv_;
148
149 //! The source iterdomains that all the Iterdomain transforms
150 //! in this loop nest originated from.
151 std::vector<IterDomain*> loop_root_;
152
153 //! The leaf iterdomains that the original loop nests correspond
154 //! to. May be longer than loops_ with the dangling iterdomains
155 //! appended towards the end.
156 std::vector<IterDomain*> loop_domains_;
157
158 //! The selected sequence of expressions that should represent
159 //! the correct indexing math from the given loop nest.
160 std::vector<Expr*> index_exprs_;
161
162 //! The subset of sequence of expressions that can be resolved
163 //! with only the iterdomains on the right of consumer tv's ca
164 //! axis.
165 //! Expressions are ordered in reverse topological order.
166 std::vector<Expr*> out_of_line_exprs_;
167};
168
169// When indexing there are sometimes an option to propagate an index down
170// multiple paths. This will return the IterDomains in the history of the
171// reference domain and mark which paths should be taken (if there's a
172// preference) to reach the roots provided in preferred_roots.
173std::unordered_set<IterDomain*> buildLoopIndexingPreferredPath(
174 const TensorView* original_tv,
175 const LoopIndexing& loop_indexing,
176 bool use_replay_map = false,
177 std::unordered_map<IterDomain*, IterDomain*> p2c_map = {});
178
179// Get an rfactor IterDomain that is mapped with an IterDomain. If
180// multiple such IDs exist, select one whose input IDs are mapped with
181// the consumer IDs. This is to ensure the path from the leaf
182// IterDomains to the root matches with the consumer tensor.
183IterDomain* getRfactorIDToTraverse(
184 IterDomain* id,
185 const std::vector<Val*>& consumer_all_ids);
186
187} // namespace cuda
188} // namespace fuser
189} // namespace jit
190} // namespace torch
191