1 | #include <instrumentation.h> |
2 | #include <ir_builder.h> |
3 | #include <ir_iostream.h> |
4 | #include <ir_utils.h> |
5 | #include <lower_utils.h> |
6 | #include <root_domain_map.h> |
7 | |
8 | #include <lower_replace_size.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | namespace { |
16 | // Going to generate a map of tensor view root domain extents to reduce the |
17 | // number used during lowering. For example if we have: |
18 | // |
19 | // T2[i0, i1] = T1[i0, i1] + T2[i2, i3] |
20 | // |
21 | // We know it would be safe to use: |
22 | // |
23 | // T2[i0, i1] = T1[i0, i1] + T2[i0, i1] |
24 | // |
25 | // And that way we don't generate T2.size[0] and T2.size[1], instead we will |
26 | // reuse T1.size[0] and T1.size[1] |
27 | // This is important when doing CSE as T2 and T1 would otherwise look like |
28 | // they're using different values, even though we know they're the same |
29 | // |
30 | // There's some duplicate logic here that's in computeAt map, but it's not so |
31 | // concice there to pull out. May want to consider making this mapping its own |
32 | // class especially as it may be useful during scheduling. |
33 | std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) { |
34 | std::list<std::unordered_set<IterDomain*>> disjoint_root_sets; |
35 | std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>*> |
36 | id_to_disjoint_root_set; |
37 | |
38 | auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set]( |
39 | IterDomain* id0, IterDomain* id1) { |
40 | if (id0->isBroadcast() || id1->isBroadcast()) { |
41 | return; |
42 | } |
43 | |
44 | auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0); |
45 | auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1); |
46 | bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end(); |
47 | bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end(); |
48 | |
49 | if (set_0_found && set_1_found) { |
50 | if (disjoint_set_0_it->second == disjoint_set_1_it->second) { |
51 | return; |
52 | } |
53 | // merge second disjoint set into first |
54 | auto* set_0 = disjoint_set_0_it->second; |
55 | auto* set_1 = disjoint_set_1_it->second; |
56 | for (auto id : *set_1) { |
57 | set_0->emplace(id); |
58 | id_to_disjoint_root_set[id] = set_0; |
59 | } |
60 | // remove second set from disjoint_root_sets |
61 | disjoint_root_sets.erase(std::find( |
62 | disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1)); |
63 | } else if (set_0_found || set_1_found) { |
64 | auto existing_set = |
65 | set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second; |
66 | auto to_add_id = set_0_found ? id1 : id0; |
67 | existing_set->emplace(to_add_id); |
68 | id_to_disjoint_root_set[to_add_id] = existing_set; |
69 | // add entry into existing set |
70 | } else { |
71 | // create new set entry |
72 | disjoint_root_sets.emplace_back(std::unordered_set<IterDomain*>()); |
73 | auto* new_set = &disjoint_root_sets.back(); |
74 | new_set->emplace(id0); |
75 | new_set->emplace(id1); |
76 | id_to_disjoint_root_set[id0] = new_set; |
77 | id_to_disjoint_root_set[id1] = new_set; |
78 | } |
79 | }; |
80 | |
81 | auto fusion_vals = fusion->usedMathVals(); |
82 | for (auto producer_tv : ir_utils::filterByType<TensorView>(fusion_vals)) { |
83 | auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv); |
84 | for (auto consumer_tv : consumer_tvs) { |
85 | auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); |
86 | auto c2p_root_map = pairwise_map.mapConsumerToProducer( |
87 | consumer_tv->domain(), producer_tv->domain()); |
88 | for (auto entry : c2p_root_map) { |
89 | auto c_id = entry.first; |
90 | auto p_id = entry.second; |
91 | map_root_ids(p_id, c_id); |
92 | } |
93 | } |
94 | } |
95 | |
96 | // Map each set to an input ID (if it exists) that has the smallest ->name() |
97 | // entry value |
98 | std::unordered_map<std::unordered_set<IterDomain*>*, IterDomain*> |
99 | set_to_input_id; |
100 | |
101 | // Loop over the root domains, of the inputs to the fusion. Pick an input ID |
102 | // to use as the representative ID of the collected sets. Only consider inputs |
103 | // as those are the ones that map to values like "T0.size[1]". They are he |
104 | // ID's that propagated their extents into the problem. We could also check |
105 | // the outputs as we do have C++ examples of using output dimensions for the |
106 | // problem size instead of inputs. However, we don't do anything where we can |
107 | // translate to those kinds of kernels integrated into PyTorch. |
108 | for (auto input_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) { |
109 | for (auto id : |
110 | TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) { |
111 | auto id_set_it = id_to_disjoint_root_set.find(id); |
112 | if (id_set_it == id_to_disjoint_root_set.end()) { |
113 | continue; |
114 | } |
115 | auto* id_set = id_set_it->second; |
116 | if (set_to_input_id.find(id_set) == set_to_input_id.end()) { |
117 | set_to_input_id[id_set] = id; |
118 | } else { |
119 | auto input_id_of_set = set_to_input_id.at(id_set); |
120 | // Swap id's if new name is less than previously set |
121 | bool swap_ids = id->name() < input_id_of_set->name(); |
122 | // If new id is a const scalar but previously was'nt use the const |
123 | // scalar |
124 | swap_ids = swap_ids || |
125 | (id->extent()->isConstScalar() && |
126 | !input_id_of_set->extent()->isConstScalar()); |
127 | // If previous scalar was const and new isn't, don't swap |
128 | swap_ids = swap_ids && |
129 | !(input_id_of_set->extent()->isConstScalar() && |
130 | !id->extent()->isConstScalar()); |
131 | |
132 | if (swap_ids) { |
133 | set_to_input_id[id_set] = id; |
134 | } |
135 | } |
136 | } |
137 | } |
138 | |
139 | // Finally make map from ID extents to the representitive ID extent. |
140 | std::unordered_map<Val*, Val*> extent_to_min_input_id_extent; |
141 | for (auto entry : set_to_input_id) { |
142 | auto* set = entry.first; |
143 | auto input_id = entry.second; |
144 | for (auto id : *set) { |
145 | extent_to_min_input_id_extent[id->extent()] = input_id->extent(); |
146 | } |
147 | } |
148 | return extent_to_min_input_id_extent; |
149 | } |
150 | |
151 | } // namespace |
152 | |
153 | void replaceSymbolicSizes(Fusion* fusion) { |
154 | FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes" ); |
155 | std::unordered_map<Val*, Val*> tensor_dim_map; |
156 | |
157 | // Grab inputs and outputs |
158 | std::vector<TensorView*> inputs_and_outputs; |
159 | for (auto val : fusion->inputs()) { |
160 | if (ir_utils::isTV(val)) { |
161 | inputs_and_outputs.push_back(val->as<TensorView>()); |
162 | } |
163 | } |
164 | // Symbolic size is necessary for outputs if there are no inputs. |
165 | // Otherwise infer output sizes from the inputs via expression evaluation. |
166 | if (fusion->inputs().empty()) { |
167 | for (auto val : fusion->outputs()) { |
168 | if (ir_utils::isTV(val)) { |
169 | inputs_and_outputs.push_back(val->as<TensorView>()); |
170 | } |
171 | } |
172 | } |
173 | |
174 | // Generate map for all tensorview root domain values to map them to symbolic |
175 | // values. i.e. T0->getRootDomain()[0] would map to a named scalar |
176 | // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir. |
177 | for (TensorView* tv : inputs_and_outputs) { |
178 | // Replace the domain with one based on Ti.size[j] |
179 | const std::vector<IterDomain*>& root_td = tv->getRootDomain(); |
180 | |
181 | size_t dim = 0; |
182 | for (auto id : root_td) { |
183 | Val* orig_size = id->extent(); |
184 | // Output sizes could have reduction axes, which isn't what gets output. |
185 | // NOLINTNEXTLINE(bugprone-branch-clone) |
186 | if (id->isReduction()) { |
187 | continue; |
188 | } else if (orig_size->isConstScalar()) { |
189 | dim++; |
190 | continue; |
191 | } |
192 | |
193 | // Currently turn off this part for inputs of segmented fusion, |
194 | // since FusionKernelRuntime will provide these as integer inputs |
195 | if (tensor_dim_map.find(orig_size) == tensor_dim_map.end() && |
196 | !orig_size->isFusionInput() && !orig_size->isConstScalar()) { |
197 | std::stringstream ss; |
198 | ss << "T" << tv->name() << ".size[" << dim++ << "]" ; |
199 | tensor_dim_map[orig_size] = IrBuilder::create<NamedScalar>( |
200 | ss.str(), orig_size->getDataType().value()); |
201 | } else { |
202 | dim++; |
203 | } |
204 | } |
205 | } |
206 | |
207 | // Use a minimal number of sizes from provided tensors. |
208 | auto extent_simplification_map = getSimplificationMap(fusion); |
209 | for (auto extent_entry : extent_simplification_map) { |
210 | auto orig_extent = extent_entry.first; |
211 | auto simplified_extent = extent_entry.second; |
212 | if (tensor_dim_map.count(orig_extent)) { |
213 | if (tensor_dim_map.count(simplified_extent)) { |
214 | tensor_dim_map[orig_extent] = tensor_dim_map[simplified_extent]; |
215 | } else { |
216 | tensor_dim_map[orig_extent] = simplified_extent; |
217 | } |
218 | } |
219 | } |
220 | |
221 | // Run mutation on the fusion with the tensor_dim_map |
222 | ir_utils::replaceValue(fusion, tensor_dim_map); |
223 | } |
224 | |
225 | } // namespace cuda |
226 | } // namespace fuser |
227 | } // namespace jit |
228 | } // namespace torch |
229 | |