1#include <instrumentation.h>
2#include <ir_builder.h>
3#include <ir_iostream.h>
4#include <ir_utils.h>
5#include <lower_utils.h>
6#include <root_domain_map.h>
7
8#include <lower_replace_size.h>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15namespace {
16// Going to generate a map of tensor view root domain extents to reduce the
17// number used during lowering. For example if we have:
18//
19// T2[i0, i1] = T1[i0, i1] + T2[i2, i3]
20//
21// We know it would be safe to use:
22//
23// T2[i0, i1] = T1[i0, i1] + T2[i0, i1]
24//
25// And that way we don't generate T2.size[0] and T2.size[1], instead we will
26// reuse T1.size[0] and T1.size[1]
27// This is important when doing CSE as T2 and T1 would otherwise look like
28// they're using different values, even though we know they're the same
29//
30// There's some duplicate logic here that's in computeAt map, but it's not so
31// concice there to pull out. May want to consider making this mapping its own
32// class especially as it may be useful during scheduling.
33std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) {
34 std::list<std::unordered_set<IterDomain*>> disjoint_root_sets;
35 std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>*>
36 id_to_disjoint_root_set;
37
38 auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set](
39 IterDomain* id0, IterDomain* id1) {
40 if (id0->isBroadcast() || id1->isBroadcast()) {
41 return;
42 }
43
44 auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0);
45 auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1);
46 bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end();
47 bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end();
48
49 if (set_0_found && set_1_found) {
50 if (disjoint_set_0_it->second == disjoint_set_1_it->second) {
51 return;
52 }
53 // merge second disjoint set into first
54 auto* set_0 = disjoint_set_0_it->second;
55 auto* set_1 = disjoint_set_1_it->second;
56 for (auto id : *set_1) {
57 set_0->emplace(id);
58 id_to_disjoint_root_set[id] = set_0;
59 }
60 // remove second set from disjoint_root_sets
61 disjoint_root_sets.erase(std::find(
62 disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1));
63 } else if (set_0_found || set_1_found) {
64 auto existing_set =
65 set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second;
66 auto to_add_id = set_0_found ? id1 : id0;
67 existing_set->emplace(to_add_id);
68 id_to_disjoint_root_set[to_add_id] = existing_set;
69 // add entry into existing set
70 } else {
71 // create new set entry
72 disjoint_root_sets.emplace_back(std::unordered_set<IterDomain*>());
73 auto* new_set = &disjoint_root_sets.back();
74 new_set->emplace(id0);
75 new_set->emplace(id1);
76 id_to_disjoint_root_set[id0] = new_set;
77 id_to_disjoint_root_set[id1] = new_set;
78 }
79 };
80
81 auto fusion_vals = fusion->usedMathVals();
82 for (auto producer_tv : ir_utils::filterByType<TensorView>(fusion_vals)) {
83 auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv);
84 for (auto consumer_tv : consumer_tvs) {
85 auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
86 auto c2p_root_map = pairwise_map.mapConsumerToProducer(
87 consumer_tv->domain(), producer_tv->domain());
88 for (auto entry : c2p_root_map) {
89 auto c_id = entry.first;
90 auto p_id = entry.second;
91 map_root_ids(p_id, c_id);
92 }
93 }
94 }
95
96 // Map each set to an input ID (if it exists) that has the smallest ->name()
97 // entry value
98 std::unordered_map<std::unordered_set<IterDomain*>*, IterDomain*>
99 set_to_input_id;
100
101 // Loop over the root domains, of the inputs to the fusion. Pick an input ID
102 // to use as the representative ID of the collected sets. Only consider inputs
103 // as those are the ones that map to values like "T0.size[1]". They are he
104 // ID's that propagated their extents into the problem. We could also check
105 // the outputs as we do have C++ examples of using output dimensions for the
106 // problem size instead of inputs. However, we don't do anything where we can
107 // translate to those kinds of kernels integrated into PyTorch.
108 for (auto input_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
109 for (auto id :
110 TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) {
111 auto id_set_it = id_to_disjoint_root_set.find(id);
112 if (id_set_it == id_to_disjoint_root_set.end()) {
113 continue;
114 }
115 auto* id_set = id_set_it->second;
116 if (set_to_input_id.find(id_set) == set_to_input_id.end()) {
117 set_to_input_id[id_set] = id;
118 } else {
119 auto input_id_of_set = set_to_input_id.at(id_set);
120 // Swap id's if new name is less than previously set
121 bool swap_ids = id->name() < input_id_of_set->name();
122 // If new id is a const scalar but previously was'nt use the const
123 // scalar
124 swap_ids = swap_ids ||
125 (id->extent()->isConstScalar() &&
126 !input_id_of_set->extent()->isConstScalar());
127 // If previous scalar was const and new isn't, don't swap
128 swap_ids = swap_ids &&
129 !(input_id_of_set->extent()->isConstScalar() &&
130 !id->extent()->isConstScalar());
131
132 if (swap_ids) {
133 set_to_input_id[id_set] = id;
134 }
135 }
136 }
137 }
138
139 // Finally make map from ID extents to the representitive ID extent.
140 std::unordered_map<Val*, Val*> extent_to_min_input_id_extent;
141 for (auto entry : set_to_input_id) {
142 auto* set = entry.first;
143 auto input_id = entry.second;
144 for (auto id : *set) {
145 extent_to_min_input_id_extent[id->extent()] = input_id->extent();
146 }
147 }
148 return extent_to_min_input_id_extent;
149}
150
151} // namespace
152
153void replaceSymbolicSizes(Fusion* fusion) {
154 FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes");
155 std::unordered_map<Val*, Val*> tensor_dim_map;
156
157 // Grab inputs and outputs
158 std::vector<TensorView*> inputs_and_outputs;
159 for (auto val : fusion->inputs()) {
160 if (ir_utils::isTV(val)) {
161 inputs_and_outputs.push_back(val->as<TensorView>());
162 }
163 }
164 // Symbolic size is necessary for outputs if there are no inputs.
165 // Otherwise infer output sizes from the inputs via expression evaluation.
166 if (fusion->inputs().empty()) {
167 for (auto val : fusion->outputs()) {
168 if (ir_utils::isTV(val)) {
169 inputs_and_outputs.push_back(val->as<TensorView>());
170 }
171 }
172 }
173
174 // Generate map for all tensorview root domain values to map them to symbolic
175 // values. i.e. T0->getRootDomain()[0] would map to a named scalar
176 // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir.
177 for (TensorView* tv : inputs_and_outputs) {
178 // Replace the domain with one based on Ti.size[j]
179 const std::vector<IterDomain*>& root_td = tv->getRootDomain();
180
181 size_t dim = 0;
182 for (auto id : root_td) {
183 Val* orig_size = id->extent();
184 // Output sizes could have reduction axes, which isn't what gets output.
185 // NOLINTNEXTLINE(bugprone-branch-clone)
186 if (id->isReduction()) {
187 continue;
188 } else if (orig_size->isConstScalar()) {
189 dim++;
190 continue;
191 }
192
193 // Currently turn off this part for inputs of segmented fusion,
194 // since FusionKernelRuntime will provide these as integer inputs
195 if (tensor_dim_map.find(orig_size) == tensor_dim_map.end() &&
196 !orig_size->isFusionInput() && !orig_size->isConstScalar()) {
197 std::stringstream ss;
198 ss << "T" << tv->name() << ".size[" << dim++ << "]";
199 tensor_dim_map[orig_size] = IrBuilder::create<NamedScalar>(
200 ss.str(), orig_size->getDataType().value());
201 } else {
202 dim++;
203 }
204 }
205 }
206
207 // Use a minimal number of sizes from provided tensors.
208 auto extent_simplification_map = getSimplificationMap(fusion);
209 for (auto extent_entry : extent_simplification_map) {
210 auto orig_extent = extent_entry.first;
211 auto simplified_extent = extent_entry.second;
212 if (tensor_dim_map.count(orig_extent)) {
213 if (tensor_dim_map.count(simplified_extent)) {
214 tensor_dim_map[orig_extent] = tensor_dim_map[simplified_extent];
215 } else {
216 tensor_dim_map[orig_extent] = simplified_extent;
217 }
218 }
219 }
220
221 // Run mutation on the fusion with the tensor_dim_map
222 ir_utils::replaceValue(fusion, tensor_dim_map);
223}
224
225} // namespace cuda
226} // namespace fuser
227} // namespace jit
228} // namespace torch
229