1#include <torch/csrc/jit/passes/batch_mm.h>
2
3#include <ATen/core/functional.h>
4#include <ATen/core/symbol.h>
5#include <c10/util/Exception.h>
6#include <c10/util/irange.h>
7#include <torch/csrc/jit/ir/alias_analysis.h>
8#include <torch/csrc/jit/ir/constants.h>
9#include <torch/csrc/jit/passes/dead_code_elimination.h>
10#include <torch/csrc/jit/passes/peephole.h>
11#include <torch/csrc/jit/runtime/custom_operator.h>
12#include <torch/csrc/jit/runtime/graph_iterator.h>
13
14#include <ATen/ATen.h>
15#include <algorithm>
16#include <unordered_map>
17#include <utility>
18
19namespace torch {
20namespace jit {
21
22namespace {
23c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
24 return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
25}
26} // namespace
27
28// This pass looks for trees in the graph, where leaves are mm ops, and the
29// inner vertices are add nodes. Once we have such a tree they can be reduced to
30// two concats and a single mm (basically into a single multiply of a wide
31// matrix, with a tall matrix). Such patterns show up mostly in backward of
32// RNNs, since the derivative of many uses of matrix multiplies with same
33// weights forms exactly such a tree (note that it's usually also highly
34// imbalanced i.e. has O(n) depth).
35//
36// This (or any tree of adds of MMs):
37//
38// +------+ +------+ +------+ +------+ +------+
39// | | | | | | | | | |
40// | L1 | | R1 | + | L2 | | R2 | = | O |
41// | | | | | | | | | |
42// +------+ +------+ +------+ +------+ +------+
43//
44// can be basically transformed into a single MM which looks like this
45// (we concat all lhs operands, concat rhs operands, do mm):
46//
47// +------+
48// | |
49// | R1 |
50// | |
51// +------+
52// | |
53// | R2 |
54// | |
55// +------+
56// +------+------+ +------+
57// | | | | |
58// | L1 | L2 | | O |
59// | | | | |
60// +------+------+ +------+
61
62// Note [Further optimizations]
63// It would be straightforward to extend the TreeToken class to also detect if
64// all MMs had the same lhs/rhs. In such case it's more efficient to expand the
65// lhs and use bmm + sum instead of repeating it in memory via concat.
66
67// Note [Overlapping trees]
68// Additionally it wouldn't be too hard to add support for partially overlapping
69// trees. Right now the it's forbidden in the algorithm (only a single tree will
70// be allowed), so theoretically we might miss some optimization options,
71// especially that the rejected tree could be much larger. I didn't implement
72// that because it's not necessary for the simple RNN cases I saw, so I decided
73// to keep stuff simple. If we ever get around implementing this, the right
74// solution is probably to fuse MMs for the common part, and assume it's an
75// input leaf for the outer two parts (I don't think it's beneficial to
76// recompute, unless the subtree is super small, but let's not get into such
77// details).
78
79// The algorithm we're using is simple. We're iterating through the graph in the
80// topological order and labeling nodes with TreeTokens. Then, we look for roots
81// of the trees we formed and fuse them.
82
83// Tunable parameter. Set to something larger if it turns out to be better.
84static constexpr size_t min_fusion_size = 4;
85
86bool have_same_shape(at::TensorList inputs) {
87 auto expected_sizes = inputs[0].sizes();
88 return (std::all_of(
89 inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
90 return t.sizes() == expected_sizes;
91 }));
92}
93
94bool should_be_transposed(at::TensorList inputs) {
95 return (std::all_of(inputs.begin(), inputs.end(), [](const at::Tensor& t) {
96 return t.stride(0) == 1 && t.stride(1) == t.size(0);
97 }));
98}
99
100std::vector<at::Tensor> transpose_inputs(at::TensorList inputs) {
101 return fmap(inputs, [](const at::Tensor& i) { return i.t(); });
102}
103
104bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
105 size_t l = lhs.size(0);
106 size_t m = lhs.size(1);
107 size_t r = rhs.size(1);
108 // Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
109 return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
110}
111
112RegisterOperators mm_tree_reduction_reg({Operator(
113 "prim::MMTreeReduce(...) -> Tensor",
114 [](Stack& stack) {
115 auto num_inputs = pop(stack).toInt();
116 std::vector<at::Tensor> inputs;
117 inputs.reserve(num_inputs);
118 for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
119 inputs.push_back(std::move(*it).toTensor());
120 }
121 drop(stack, num_inputs);
122
123 AT_ASSERT(!inputs.empty());
124 AT_ASSERT(inputs.size() % 2 == 0);
125 size_t side_num_elems = inputs.size() / 2;
126 auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
127 auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
128 // TODO: checking this is not free, so we should stop if this keeps
129 // failing
130 if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
131 shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
132 // sometimes lhs_inputs or rhs_inputs are not contiguous, and that
133 // causes at::cat to go through slow path view them as contiguous if
134 // possible by transposing
135 bool lhs_input_transposed = should_be_transposed(lhs_inputs);
136 bool rhs_input_transposed = should_be_transposed(rhs_inputs);
137 at::Tensor lhs, rhs;
138 if (lhs_input_transposed) {
139 std::vector<at::Tensor> lhs_contig_inputs =
140 transpose_inputs(lhs_inputs);
141 lhs = at::cat(lhs_contig_inputs, /*dim*/ 0);
142 lhs = lhs.t();
143 } else {
144 lhs = at::cat(lhs_inputs, /*dim=*/1);
145 }
146 if (rhs_input_transposed) {
147 std::vector<at::Tensor> rhs_contig_inputs =
148 transpose_inputs(rhs_inputs);
149 rhs = at::cat(rhs_contig_inputs, /*dim*/ 1);
150 rhs = rhs.t();
151 } else {
152 rhs = at::cat(rhs_inputs, /*dim=*/0);
153 }
154 push(stack, at::mm(lhs, rhs));
155 } else {
156 auto acc = at::mm(inputs[0], inputs[side_num_elems]);
157 for (const auto i : c10::irange(1, side_num_elems)) {
158 acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
159 }
160 push(stack, std::move(acc));
161 }
162 },
163 aliasAnalysisIsSpecialCase())});
164
165// TreeTokens will be used to label nodes of the graph, if the nodes will fit
166// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
167// when we reach node N with inputs A and B, then A and B have already been
168// processed, and we can try to unify their TreeTokens (if they have them)
169// and build a larger tree.
170struct TreeToken {
171 uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
172 Node* node = nullptr;
173 bool is_root = false;
174
175 static TreeToken mm(Node* mm) {
176 TreeToken token;
177 token.tree_size = 1;
178 token.node = mm;
179 token.is_root = true;
180 return token;
181 }
182
183 // NB: the returned token might be invalid, so make sure to check its boolean
184 // value!
185 static TreeToken transpose(Node* t, TreeToken& inp_token) {
186 TreeToken token;
187 if (!inp_token.node->matches(
188 "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
189 return token;
190 }
191 token.tree_size = 1;
192 token.node = t;
193 token.is_root = true;
194 inp_token.is_root = false;
195 return token;
196 }
197
198 // NB: the returned token might be invalid, so make sure to check its boolean
199 // value!
200 static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
201 TreeToken token;
202 // See Note [Overlapping trees]
203 if (&l == &r || !l.is_root || !r.is_root)
204 return token;
205 token.tree_size = l.tree_size + r.tree_size;
206 token.node = add;
207 token.is_root = true;
208 l.is_root = r.is_root =
209 false; // Reserve the subtrees, so they can't be used again.
210 return token;
211 }
212
213 explicit operator bool() {
214 return is_root;
215 }
216
217 std::vector<Node*> removeTransposesAndGatherMatmuls() {
218 std::vector<Node*> matmuls;
219 std::vector<Node*> queue{node};
220 Graph* graph = node->owningGraph();
221 while (!queue.empty()) {
222 auto n = queue.back();
223 queue.pop_back();
224 if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
225 matmuls.push_back(n);
226 } else if (n->matches("aten::t(Tensor self) -> Tensor")) {
227 Node* input_node = n->input()->node();
228 AT_ASSERT(input_node->matches(
229 "aten::mm(Tensor self, Tensor mat2) -> Tensor"));
230 // (AB)^T == B^TA^T
231 WithInsertPoint insert_guard{input_node};
232 Value* A = input_node->inputs()[0];
233 Value* B = input_node->inputs()[1];
234 Value* AT = graph->insert(aten::t, {A});
235 Value* BT = graph->insert(aten::t, {B});
236 Value* BTAT = graph->insert(aten::mm, {BT, AT});
237 n->output()->replaceAllUsesWith(BTAT);
238 matmuls.push_back(BTAT->node());
239 } else if (
240 n->matches(
241 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
242 queue.push_back(n->inputs()[0]->node());
243 queue.push_back(n->inputs()[1]->node());
244 } else {
245 AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!");
246 }
247 }
248 return matmuls;
249 }
250};
251
252enum class Side { LHS, RHS };
253
254void BatchMMTreeReduce(Block* block, AliasDb& alias_db) {
255 auto graph = block->owningGraph();
256
257 // Look for trees in the block
258 std::unordered_map<Node*, TreeToken> tokens;
259 for (auto node : block->nodes()) {
260 if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
261 !alias_db.hasWriters(node)) {
262 tokens[node] = TreeToken::mm(node);
263 } else if (
264 node->matches("aten::t(Tensor self) -> Tensor") &&
265 !alias_db.hasWriters(node)) {
266 auto input_it = tokens.find(node->input()->node());
267 if (input_it != tokens.end()) {
268 tokens[node] = TreeToken::transpose(node, input_it->second);
269 }
270 } else if (
271 node->matches(
272 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") &&
273 !alias_db.hasWriters(node)) {
274 Node* lhs = node->inputs()[0]->node();
275 Node* rhs = node->inputs()[1]->node();
276 auto lhs_it = tokens.find(lhs);
277 auto rhs_it = tokens.find(rhs);
278 // See Note [Overlapping trees] (regarding the uses().size() == 1 check)
279 // We could treat a subtree with multiple uses as if it was overlapping.
280 // XXX: uses().size() == 1 is also something that guarantees that this
281 // transform is valid, because we know for sure that the none of these
282 // operands depend on the result of the other. If we were to remove this,
283 // we need to compute a transitive closure and actually check the
284 // dependencies.
285 if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
286 lhs->output()->uses().size() == 1 &&
287 rhs->output()->uses().size() == 1) {
288 if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
289 tokens[node] = token;
290 }
291 }
292 } else {
293 for (auto block : node->blocks()) {
294 BatchMMTreeReduce(block, alias_db);
295 }
296 }
297 }
298
299 // Merge trees we've found
300 for (auto& item : tokens) {
301 auto& root = item.second;
302 if (!root || root.tree_size < min_fusion_size)
303 continue;
304 auto matmuls = root.removeTransposesAndGatherMatmuls();
305 WithInsertPoint insert_guard{root.node};
306 Node* tree_reduce =
307 graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
308 for (Node* matmul : matmuls) {
309 tree_reduce->addInput(matmul->inputs().at(0));
310 }
311 for (Node* matmul : matmuls) {
312 tree_reduce->addInput(matmul->inputs().at(1));
313 }
314 root.node->output()->replaceAllUsesWith(tree_reduce->output());
315 // NB: don't bother with cleaning up after yourself. We'll use DCE for that.
316 }
317}
318
319bool shape_is_fast_for_side(const at::Tensor& other_side_input) {
320 // Cutoff chosed by benchmarking on a TITAN V
321 return other_side_input.numel() <= 1024 * 2048;
322}
323
324RegisterOperators mm_batch_side_reg({Operator(
325 prim::MMBatchSide,
326 [](const Node* node) -> Operation {
327 size_t num_other_side_inputs = node->inputs().size() - 1;
328 Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
329 return [num_other_side_inputs, single_side](Stack& stack) {
330 at::Tensor side_input;
331 std::vector<at::Tensor> other_side_inputs;
332 other_side_inputs.reserve(num_other_side_inputs);
333 for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
334 ++it) {
335 other_side_inputs.push_back(std::move(*it).toTensor());
336 }
337 drop(stack, num_other_side_inputs);
338 pop(stack, side_input);
339
340 auto any_other_input = other_side_inputs[0];
341 if (have_same_shape(other_side_inputs) &&
342 shape_is_fast_for_side(other_side_inputs[0])) {
343 auto other_side_input =
344 at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
345 auto mm_out = single_side == Side::LHS
346 ? side_input.mm(other_side_input)
347 : other_side_input.mm(side_input);
348 auto outputs = at::chunk(
349 mm_out,
350 num_other_side_inputs,
351 /*dim=*/single_side == Side::LHS ? 1 : 0);
352 stack.insert(
353 stack.end(),
354 std::make_move_iterator(outputs.begin()),
355 std::make_move_iterator(outputs.end()));
356 } else {
357 if (single_side == Side::LHS) {
358 for (at::Tensor& other : other_side_inputs) {
359 stack.emplace_back(side_input.mm(other));
360 }
361 } else {
362 for (at::Tensor& other : other_side_inputs) {
363 stack.emplace_back(other.mm(side_input));
364 }
365 }
366 }
367 };
368 },
369 aliasAnalysisIsSpecialCase())});
370
371std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
372 Value* value,
373 AliasDb& alias_db) {
374 const auto postprocess = [&](std::vector<Node*> mms) {
375 if (mms.empty()) {
376 return mms;
377 }
378 std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
379 return n->isBefore(m);
380 });
381 // Filter out dependent MMs. This algorithm might do very badly if e.g. you
382 // have a lot of independent MMs, that depend on the first one, but I doubt
383 // this will be a common scenario.
384 for (const auto i : c10::irange(mms.size())) {
385 if (mms[i] == nullptr)
386 continue;
387 for (size_t j = i + 1; j < mms.size(); ++j) {
388 if (mms[j] == nullptr)
389 continue;
390 if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
391 mms[j] = nullptr;
392 }
393 }
394 }
395 return c10::filter(mms, [](Node* n) { return n != nullptr; });
396 };
397
398 Block* block = value->node()->owningBlock();
399 std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
400 std::vector<Node*> rhses; // Like above, but rhs
401 for (Use u : value->uses()) {
402 if (u.user->owningBlock() == block &&
403 u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
404 !alias_db.hasWriters(u.user)) {
405 if (u.offset == 0 && u.user->inputs()[1] != value) {
406 lhses.push_back(u.user);
407 } else if (u.offset == 1 && u.user->inputs()[0] != value) {
408 rhses.push_back(u.user);
409 }
410 }
411 }
412 return std::make_pair(
413 postprocess(std::move(lhses)), postprocess(std::move(rhses)));
414}
415
416void BatchMMSide(Block* block, AliasDb& alias_db) {
417 // NB: 8 is the current loop unrolling factor
418 static constexpr size_t how_many_is_many = 8;
419 const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
420 AT_ASSERT(!mms.empty());
421 for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
422 bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
423 AT_ASSERT(move_ok);
424 }
425 WithInsertPoint insert_guard{mms[0]};
426 Graph* graph = mms[0]->owningGraph();
427 Node* batch_mm = graph->create(
428 prim::MMBatchSide,
429 /*inputs=*/{},
430 /*num_outputs=*/mms.size());
431 graph->insertNode(batch_mm);
432 batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
433 Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
434 batch_mm->addInput(const_side);
435 for (const auto i : c10::irange(mms.size())) {
436 batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0));
437 mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i));
438 }
439 };
440
441 std::unordered_set<Value*> considered_values;
442 for (Node* node : block->nodes()) {
443 if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
444 !alias_db.hasWriters(node)) {
445 for (Value* input : node->inputs()) {
446 if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
447 continue;
448 }
449 auto uses_with_many = gatherIndependentMMUses(input, alias_db);
450 if (uses_with_many.first.size() >= how_many_is_many) {
451 batch_side(uses_with_many.first, Side::LHS);
452 }
453 if (uses_with_many.second.size() >= how_many_is_many) {
454 batch_side(uses_with_many.second, Side::RHS);
455 }
456 }
457 } else {
458 for (Block* subblock : node->blocks()) {
459 BatchMMSide(subblock, alias_db);
460 }
461 }
462 }
463}
464
465bool hasMutableOperators(Block* block) {
466 for (auto n : block->nodes()) {
467 if (n->kind().is_aten() && n->schema().is_mutable())
468 return true;
469 for (auto b : n->blocks()) {
470 if (hasMutableOperators(b))
471 return true;
472 }
473 }
474 return false;
475}
476
477bool hasMMOperators(std::shared_ptr<Graph>& graph) {
478 DepthFirstGraphNodeIterator it(graph);
479 Node* n = nullptr;
480 while ((n = it.next()) != nullptr) {
481 if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
482 return true;
483 }
484 }
485 return false;
486}
487
488void BatchMM(std::shared_ptr<Graph>& graph) {
489 if (!hasMMOperators(graph)) {
490 return;
491 }
492 AliasDb alias_db(graph);
493 BatchMMTreeReduce(graph->block(), alias_db);
494 BatchMMSide(graph->block(), alias_db);
495 EliminateDeadCode(graph);
496 // It's possible that transpose rearrangements have created sequences of
497 // consecutive transposes that didn't exist before.
498
499 // tensor type properties are not guaranteed to be correct
500 PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
501}
502
503} // namespace jit
504} // namespace torch
505