1 | #include <torch/csrc/jit/passes/batch_mm.h> |
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <ATen/core/symbol.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/irange.h> |
7 | #include <torch/csrc/jit/ir/alias_analysis.h> |
8 | #include <torch/csrc/jit/ir/constants.h> |
9 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
10 | #include <torch/csrc/jit/passes/peephole.h> |
11 | #include <torch/csrc/jit/runtime/custom_operator.h> |
12 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
13 | |
14 | #include <ATen/ATen.h> |
15 | #include <algorithm> |
16 | #include <unordered_map> |
17 | #include <utility> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | namespace { |
23 | c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() { |
24 | return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; |
25 | } |
26 | } // namespace |
27 | |
28 | // This pass looks for trees in the graph, where leaves are mm ops, and the |
29 | // inner vertices are add nodes. Once we have such a tree they can be reduced to |
30 | // two concats and a single mm (basically into a single multiply of a wide |
31 | // matrix, with a tall matrix). Such patterns show up mostly in backward of |
32 | // RNNs, since the derivative of many uses of matrix multiplies with same |
33 | // weights forms exactly such a tree (note that it's usually also highly |
34 | // imbalanced i.e. has O(n) depth). |
35 | // |
36 | // This (or any tree of adds of MMs): |
37 | // |
38 | // +------+ +------+ +------+ +------+ +------+ |
39 | // | | | | | | | | | | |
40 | // | L1 | | R1 | + | L2 | | R2 | = | O | |
41 | // | | | | | | | | | | |
42 | // +------+ +------+ +------+ +------+ +------+ |
43 | // |
44 | // can be basically transformed into a single MM which looks like this |
45 | // (we concat all lhs operands, concat rhs operands, do mm): |
46 | // |
47 | // +------+ |
48 | // | | |
49 | // | R1 | |
50 | // | | |
51 | // +------+ |
52 | // | | |
53 | // | R2 | |
54 | // | | |
55 | // +------+ |
56 | // +------+------+ +------+ |
57 | // | | | | | |
58 | // | L1 | L2 | | O | |
59 | // | | | | | |
60 | // +------+------+ +------+ |
61 | |
62 | // Note [Further optimizations] |
63 | // It would be straightforward to extend the TreeToken class to also detect if |
64 | // all MMs had the same lhs/rhs. In such case it's more efficient to expand the |
65 | // lhs and use bmm + sum instead of repeating it in memory via concat. |
66 | |
67 | // Note [Overlapping trees] |
68 | // Additionally it wouldn't be too hard to add support for partially overlapping |
69 | // trees. Right now the it's forbidden in the algorithm (only a single tree will |
70 | // be allowed), so theoretically we might miss some optimization options, |
71 | // especially that the rejected tree could be much larger. I didn't implement |
72 | // that because it's not necessary for the simple RNN cases I saw, so I decided |
73 | // to keep stuff simple. If we ever get around implementing this, the right |
74 | // solution is probably to fuse MMs for the common part, and assume it's an |
75 | // input leaf for the outer two parts (I don't think it's beneficial to |
76 | // recompute, unless the subtree is super small, but let's not get into such |
77 | // details). |
78 | |
79 | // The algorithm we're using is simple. We're iterating through the graph in the |
80 | // topological order and labeling nodes with TreeTokens. Then, we look for roots |
81 | // of the trees we formed and fuse them. |
82 | |
83 | // Tunable parameter. Set to something larger if it turns out to be better. |
84 | static constexpr size_t min_fusion_size = 4; |
85 | |
86 | bool have_same_shape(at::TensorList inputs) { |
87 | auto expected_sizes = inputs[0].sizes(); |
88 | return (std::all_of( |
89 | inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) { |
90 | return t.sizes() == expected_sizes; |
91 | })); |
92 | } |
93 | |
94 | bool should_be_transposed(at::TensorList inputs) { |
95 | return (std::all_of(inputs.begin(), inputs.end(), [](const at::Tensor& t) { |
96 | return t.stride(0) == 1 && t.stride(1) == t.size(0); |
97 | })); |
98 | } |
99 | |
100 | std::vector<at::Tensor> transpose_inputs(at::TensorList inputs) { |
101 | return fmap(inputs, [](const at::Tensor& i) { return i.t(); }); |
102 | } |
103 | |
104 | bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) { |
105 | size_t l = lhs.size(0); |
106 | size_t m = lhs.size(1); |
107 | size_t r = rhs.size(1); |
108 | // Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V |
109 | return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256)); |
110 | } |
111 | |
112 | RegisterOperators mm_tree_reduction_reg({Operator( |
113 | "prim::MMTreeReduce(...) -> Tensor" , |
114 | [](Stack& stack) { |
115 | auto num_inputs = pop(stack).toInt(); |
116 | std::vector<at::Tensor> inputs; |
117 | inputs.reserve(num_inputs); |
118 | for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) { |
119 | inputs.push_back(std::move(*it).toTensor()); |
120 | } |
121 | drop(stack, num_inputs); |
122 | |
123 | AT_ASSERT(!inputs.empty()); |
124 | AT_ASSERT(inputs.size() % 2 == 0); |
125 | size_t side_num_elems = inputs.size() / 2; |
126 | auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems); |
127 | auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems); |
128 | // TODO: checking this is not free, so we should stop if this keeps |
129 | // failing |
130 | if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) && |
131 | shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) { |
132 | // sometimes lhs_inputs or rhs_inputs are not contiguous, and that |
133 | // causes at::cat to go through slow path view them as contiguous if |
134 | // possible by transposing |
135 | bool lhs_input_transposed = should_be_transposed(lhs_inputs); |
136 | bool rhs_input_transposed = should_be_transposed(rhs_inputs); |
137 | at::Tensor lhs, rhs; |
138 | if (lhs_input_transposed) { |
139 | std::vector<at::Tensor> lhs_contig_inputs = |
140 | transpose_inputs(lhs_inputs); |
141 | lhs = at::cat(lhs_contig_inputs, /*dim*/ 0); |
142 | lhs = lhs.t(); |
143 | } else { |
144 | lhs = at::cat(lhs_inputs, /*dim=*/1); |
145 | } |
146 | if (rhs_input_transposed) { |
147 | std::vector<at::Tensor> rhs_contig_inputs = |
148 | transpose_inputs(rhs_inputs); |
149 | rhs = at::cat(rhs_contig_inputs, /*dim*/ 1); |
150 | rhs = rhs.t(); |
151 | } else { |
152 | rhs = at::cat(rhs_inputs, /*dim=*/0); |
153 | } |
154 | push(stack, at::mm(lhs, rhs)); |
155 | } else { |
156 | auto acc = at::mm(inputs[0], inputs[side_num_elems]); |
157 | for (const auto i : c10::irange(1, side_num_elems)) { |
158 | acc.add_(at::mm(inputs[i], inputs[side_num_elems + i])); |
159 | } |
160 | push(stack, std::move(acc)); |
161 | } |
162 | }, |
163 | aliasAnalysisIsSpecialCase())}); |
164 | |
165 | // TreeTokens will be used to label nodes of the graph, if the nodes will fit |
166 | // our mm/add tree pattern. Basically we do dynamic programming on DAGs, where |
167 | // when we reach node N with inputs A and B, then A and B have already been |
168 | // processed, and we can try to unify their TreeTokens (if they have them) |
169 | // and build a larger tree. |
170 | struct TreeToken { |
171 | uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops |
172 | Node* node = nullptr; |
173 | bool is_root = false; |
174 | |
175 | static TreeToken mm(Node* mm) { |
176 | TreeToken token; |
177 | token.tree_size = 1; |
178 | token.node = mm; |
179 | token.is_root = true; |
180 | return token; |
181 | } |
182 | |
183 | // NB: the returned token might be invalid, so make sure to check its boolean |
184 | // value! |
185 | static TreeToken transpose(Node* t, TreeToken& inp_token) { |
186 | TreeToken token; |
187 | if (!inp_token.node->matches( |
188 | "aten::mm(Tensor self, Tensor mat2) -> Tensor" )) { |
189 | return token; |
190 | } |
191 | token.tree_size = 1; |
192 | token.node = t; |
193 | token.is_root = true; |
194 | inp_token.is_root = false; |
195 | return token; |
196 | } |
197 | |
198 | // NB: the returned token might be invalid, so make sure to check its boolean |
199 | // value! |
200 | static TreeToken add(Node* add, TreeToken& l, TreeToken& r) { |
201 | TreeToken token; |
202 | // See Note [Overlapping trees] |
203 | if (&l == &r || !l.is_root || !r.is_root) |
204 | return token; |
205 | token.tree_size = l.tree_size + r.tree_size; |
206 | token.node = add; |
207 | token.is_root = true; |
208 | l.is_root = r.is_root = |
209 | false; // Reserve the subtrees, so they can't be used again. |
210 | return token; |
211 | } |
212 | |
213 | explicit operator bool() { |
214 | return is_root; |
215 | } |
216 | |
217 | std::vector<Node*> removeTransposesAndGatherMatmuls() { |
218 | std::vector<Node*> matmuls; |
219 | std::vector<Node*> queue{node}; |
220 | Graph* graph = node->owningGraph(); |
221 | while (!queue.empty()) { |
222 | auto n = queue.back(); |
223 | queue.pop_back(); |
224 | if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor" )) { |
225 | matmuls.push_back(n); |
226 | } else if (n->matches("aten::t(Tensor self) -> Tensor" )) { |
227 | Node* input_node = n->input()->node(); |
228 | AT_ASSERT(input_node->matches( |
229 | "aten::mm(Tensor self, Tensor mat2) -> Tensor" )); |
230 | // (AB)^T == B^TA^T |
231 | WithInsertPoint insert_guard{input_node}; |
232 | Value* A = input_node->inputs()[0]; |
233 | Value* B = input_node->inputs()[1]; |
234 | Value* AT = graph->insert(aten::t, {A}); |
235 | Value* BT = graph->insert(aten::t, {B}); |
236 | Value* BTAT = graph->insert(aten::mm, {BT, AT}); |
237 | n->output()->replaceAllUsesWith(BTAT); |
238 | matmuls.push_back(BTAT->node()); |
239 | } else if ( |
240 | n->matches( |
241 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" )) { |
242 | queue.push_back(n->inputs()[0]->node()); |
243 | queue.push_back(n->inputs()[1]->node()); |
244 | } else { |
245 | AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!" ); |
246 | } |
247 | } |
248 | return matmuls; |
249 | } |
250 | }; |
251 | |
252 | enum class Side { LHS, RHS }; |
253 | |
254 | void BatchMMTreeReduce(Block* block, AliasDb& alias_db) { |
255 | auto graph = block->owningGraph(); |
256 | |
257 | // Look for trees in the block |
258 | std::unordered_map<Node*, TreeToken> tokens; |
259 | for (auto node : block->nodes()) { |
260 | if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor" ) && |
261 | !alias_db.hasWriters(node)) { |
262 | tokens[node] = TreeToken::mm(node); |
263 | } else if ( |
264 | node->matches("aten::t(Tensor self) -> Tensor" ) && |
265 | !alias_db.hasWriters(node)) { |
266 | auto input_it = tokens.find(node->input()->node()); |
267 | if (input_it != tokens.end()) { |
268 | tokens[node] = TreeToken::transpose(node, input_it->second); |
269 | } |
270 | } else if ( |
271 | node->matches( |
272 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" ) && |
273 | !alias_db.hasWriters(node)) { |
274 | Node* lhs = node->inputs()[0]->node(); |
275 | Node* rhs = node->inputs()[1]->node(); |
276 | auto lhs_it = tokens.find(lhs); |
277 | auto rhs_it = tokens.find(rhs); |
278 | // See Note [Overlapping trees] (regarding the uses().size() == 1 check) |
279 | // We could treat a subtree with multiple uses as if it was overlapping. |
280 | // XXX: uses().size() == 1 is also something that guarantees that this |
281 | // transform is valid, because we know for sure that the none of these |
282 | // operands depend on the result of the other. If we were to remove this, |
283 | // we need to compute a transitive closure and actually check the |
284 | // dependencies. |
285 | if (lhs_it != tokens.end() && rhs_it != tokens.end() && |
286 | lhs->output()->uses().size() == 1 && |
287 | rhs->output()->uses().size() == 1) { |
288 | if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) { |
289 | tokens[node] = token; |
290 | } |
291 | } |
292 | } else { |
293 | for (auto block : node->blocks()) { |
294 | BatchMMTreeReduce(block, alias_db); |
295 | } |
296 | } |
297 | } |
298 | |
299 | // Merge trees we've found |
300 | for (auto& item : tokens) { |
301 | auto& root = item.second; |
302 | if (!root || root.tree_size < min_fusion_size) |
303 | continue; |
304 | auto matmuls = root.removeTransposesAndGatherMatmuls(); |
305 | WithInsertPoint insert_guard{root.node}; |
306 | Node* tree_reduce = |
307 | graph->insertNode(graph->create(Symbol::prim("MMTreeReduce" ))); |
308 | for (Node* matmul : matmuls) { |
309 | tree_reduce->addInput(matmul->inputs().at(0)); |
310 | } |
311 | for (Node* matmul : matmuls) { |
312 | tree_reduce->addInput(matmul->inputs().at(1)); |
313 | } |
314 | root.node->output()->replaceAllUsesWith(tree_reduce->output()); |
315 | // NB: don't bother with cleaning up after yourself. We'll use DCE for that. |
316 | } |
317 | } |
318 | |
319 | bool shape_is_fast_for_side(const at::Tensor& other_side_input) { |
320 | // Cutoff chosed by benchmarking on a TITAN V |
321 | return other_side_input.numel() <= 1024 * 2048; |
322 | } |
323 | |
324 | RegisterOperators mm_batch_side_reg({Operator( |
325 | prim::MMBatchSide, |
326 | [](const Node* node) -> Operation { |
327 | size_t num_other_side_inputs = node->inputs().size() - 1; |
328 | Side single_side = static_cast<Side>(node->i(Symbol::attr("side" ))); |
329 | return [num_other_side_inputs, single_side](Stack& stack) { |
330 | at::Tensor side_input; |
331 | std::vector<at::Tensor> other_side_inputs; |
332 | other_side_inputs.reserve(num_other_side_inputs); |
333 | for (auto it = stack.end() - num_other_side_inputs; it != stack.end(); |
334 | ++it) { |
335 | other_side_inputs.push_back(std::move(*it).toTensor()); |
336 | } |
337 | drop(stack, num_other_side_inputs); |
338 | pop(stack, side_input); |
339 | |
340 | auto any_other_input = other_side_inputs[0]; |
341 | if (have_same_shape(other_side_inputs) && |
342 | shape_is_fast_for_side(other_side_inputs[0])) { |
343 | auto other_side_input = |
344 | at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0); |
345 | auto mm_out = single_side == Side::LHS |
346 | ? side_input.mm(other_side_input) |
347 | : other_side_input.mm(side_input); |
348 | auto outputs = at::chunk( |
349 | mm_out, |
350 | num_other_side_inputs, |
351 | /*dim=*/single_side == Side::LHS ? 1 : 0); |
352 | stack.insert( |
353 | stack.end(), |
354 | std::make_move_iterator(outputs.begin()), |
355 | std::make_move_iterator(outputs.end())); |
356 | } else { |
357 | if (single_side == Side::LHS) { |
358 | for (at::Tensor& other : other_side_inputs) { |
359 | stack.emplace_back(side_input.mm(other)); |
360 | } |
361 | } else { |
362 | for (at::Tensor& other : other_side_inputs) { |
363 | stack.emplace_back(other.mm(side_input)); |
364 | } |
365 | } |
366 | } |
367 | }; |
368 | }, |
369 | aliasAnalysisIsSpecialCase())}); |
370 | |
371 | std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses( |
372 | Value* value, |
373 | AliasDb& alias_db) { |
374 | const auto postprocess = [&](std::vector<Node*> mms) { |
375 | if (mms.empty()) { |
376 | return mms; |
377 | } |
378 | std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) { |
379 | return n->isBefore(m); |
380 | }); |
381 | // Filter out dependent MMs. This algorithm might do very badly if e.g. you |
382 | // have a lot of independent MMs, that depend on the first one, but I doubt |
383 | // this will be a common scenario. |
384 | for (const auto i : c10::irange(mms.size())) { |
385 | if (mms[i] == nullptr) |
386 | continue; |
387 | for (size_t j = i + 1; j < mms.size(); ++j) { |
388 | if (mms[j] == nullptr) |
389 | continue; |
390 | if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) { |
391 | mms[j] = nullptr; |
392 | } |
393 | } |
394 | } |
395 | return c10::filter(mms, [](Node* n) { return n != nullptr; }); |
396 | }; |
397 | |
398 | Block* block = value->node()->owningBlock(); |
399 | std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs |
400 | std::vector<Node*> rhses; // Like above, but rhs |
401 | for (Use u : value->uses()) { |
402 | if (u.user->owningBlock() == block && |
403 | u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor" ) && |
404 | !alias_db.hasWriters(u.user)) { |
405 | if (u.offset == 0 && u.user->inputs()[1] != value) { |
406 | lhses.push_back(u.user); |
407 | } else if (u.offset == 1 && u.user->inputs()[0] != value) { |
408 | rhses.push_back(u.user); |
409 | } |
410 | } |
411 | } |
412 | return std::make_pair( |
413 | postprocess(std::move(lhses)), postprocess(std::move(rhses))); |
414 | } |
415 | |
416 | void BatchMMSide(Block* block, AliasDb& alias_db) { |
417 | // NB: 8 is the current loop unrolling factor |
418 | static constexpr size_t how_many_is_many = 8; |
419 | const auto batch_side = [&](std::vector<Node*>& mms, Side side) { |
420 | AT_ASSERT(!mms.empty()); |
421 | for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) { |
422 | bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]); |
423 | AT_ASSERT(move_ok); |
424 | } |
425 | WithInsertPoint insert_guard{mms[0]}; |
426 | Graph* graph = mms[0]->owningGraph(); |
427 | Node* batch_mm = graph->create( |
428 | prim::MMBatchSide, |
429 | /*inputs=*/{}, |
430 | /*num_outputs=*/mms.size()); |
431 | graph->insertNode(batch_mm); |
432 | batch_mm->i_(Symbol::attr("side" ), static_cast<int>(side)); |
433 | Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1); |
434 | batch_mm->addInput(const_side); |
435 | for (const auto i : c10::irange(mms.size())) { |
436 | batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0)); |
437 | mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i)); |
438 | } |
439 | }; |
440 | |
441 | std::unordered_set<Value*> considered_values; |
442 | for (Node* node : block->nodes()) { |
443 | if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor" ) && |
444 | !alias_db.hasWriters(node)) { |
445 | for (Value* input : node->inputs()) { |
446 | if (/*bool not_inserted = */ !considered_values.emplace(input).second) { |
447 | continue; |
448 | } |
449 | auto uses_with_many = gatherIndependentMMUses(input, alias_db); |
450 | if (uses_with_many.first.size() >= how_many_is_many) { |
451 | batch_side(uses_with_many.first, Side::LHS); |
452 | } |
453 | if (uses_with_many.second.size() >= how_many_is_many) { |
454 | batch_side(uses_with_many.second, Side::RHS); |
455 | } |
456 | } |
457 | } else { |
458 | for (Block* subblock : node->blocks()) { |
459 | BatchMMSide(subblock, alias_db); |
460 | } |
461 | } |
462 | } |
463 | } |
464 | |
465 | bool hasMutableOperators(Block* block) { |
466 | for (auto n : block->nodes()) { |
467 | if (n->kind().is_aten() && n->schema().is_mutable()) |
468 | return true; |
469 | for (auto b : n->blocks()) { |
470 | if (hasMutableOperators(b)) |
471 | return true; |
472 | } |
473 | } |
474 | return false; |
475 | } |
476 | |
477 | bool hasMMOperators(std::shared_ptr<Graph>& graph) { |
478 | DepthFirstGraphNodeIterator it(graph); |
479 | Node* n = nullptr; |
480 | while ((n = it.next()) != nullptr) { |
481 | if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor" )) { |
482 | return true; |
483 | } |
484 | } |
485 | return false; |
486 | } |
487 | |
488 | void BatchMM(std::shared_ptr<Graph>& graph) { |
489 | if (!hasMMOperators(graph)) { |
490 | return; |
491 | } |
492 | AliasDb alias_db(graph); |
493 | BatchMMTreeReduce(graph->block(), alias_db); |
494 | BatchMMSide(graph->block(), alias_db); |
495 | EliminateDeadCode(graph); |
496 | // It's possible that transpose rearrangements have created sequences of |
497 | // consecutive transposes that didn't exist before. |
498 | |
499 | // tensor type properties are not guaranteed to be correct |
500 | PeepholeOptimize(graph, /*disable_shape_peepholes*/ true); |
501 | } |
502 | |
503 | } // namespace jit |
504 | } // namespace torch |
505 | |