1 | #include <torch/csrc/jit/ir/ir.h> |
---|---|
2 | #include <torch/csrc/jit/ir/ir_views.h> |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | #include <torch/csrc/jit/passes/frozen_linear_transpose.h> |
5 | #include <torch/csrc/jit/passes/utils/optimization_utils.h> |
6 | #include <torch/csrc/jit/runtime/graph_executor.h> |
7 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
8 | |
9 | #ifndef AT_PER_OPERATOR_HEADERS |
10 | #include <ATen/Functions.h> |
11 | #else |
12 | #include <ATen/ops/transpose.h> |
13 | #endif |
14 | |
15 | #include <iostream> |
16 | #include <utility> |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | namespace { |
21 | |
22 | using Tensor = at::Tensor; |
23 | |
24 | class TransposeFrozenLinear { |
25 | public: |
26 | TransposeFrozenLinear(std::shared_ptr<Graph> graph) |
27 | : graph_(std::move(graph)) {} |
28 | |
29 | bool run() { |
30 | // Can't delete nodes while also iterating over it |
31 | DepthFirstGraphNodeIterator graph_it(graph_); |
32 | |
33 | for (auto next_node = graph_it.next(); next_node != nullptr;) { |
34 | Node* node = next_node; |
35 | next_node = graph_it.next(); |
36 | |
37 | if (is_constant_linear_op(node)) { |
38 | replace_linear_with_matmul(node); |
39 | } |
40 | } |
41 | return graph_modified_; |
42 | } |
43 | |
44 | bool is_constant_linear_op(Node* node) { |
45 | if (node->kind() != aten::linear) { |
46 | return false; |
47 | } |
48 | |
49 | // This also filters out out-variants of the linear op. |
50 | return !nonConstantParameters(node); |
51 | } |
52 | |
53 | void replace_linear_with_matmul(Node* node) { |
54 | graph_modified_ = true; |
55 | Node* matmul = nullptr; |
56 | |
57 | { |
58 | WithInsertPoint insert_guard(node); |
59 | auto weight = node->namedInput("weight"); |
60 | |
61 | Tensor weight_tensor = constant_as<Tensor>(weight).value(); |
62 | Tensor weight_t_tensor = at::transpose(weight_tensor, 1, 0) |
63 | .clone(at::MemoryFormat::Contiguous); |
64 | Value* weight_t = graph_->insertConstant(std::move(weight_t_tensor)); |
65 | matmul = graph_->create(aten::matmul, {node->inputs()[0], weight_t}); |
66 | matmul->insertAfter(node); |
67 | } |
68 | |
69 | // Handle a bias if there is any |
70 | WithInsertPoint insert_guard(matmul); |
71 | auto bias = node->namedInput("bias"); |
72 | if (bias->type() == NoneType::get()) { |
73 | node->replaceAllUsesWith(matmul); |
74 | } else { |
75 | Value* bias_scale = graph_->insertConstant(1); |
76 | Node* bias_result = |
77 | graph_->create(aten::add, {matmul->output(), bias, bias_scale}); |
78 | bias_result->insertAfter(matmul); |
79 | node->replaceAllUsesWith(bias_result); |
80 | } |
81 | node->destroy(); |
82 | }; |
83 | |
84 | void handleBlockAndSubblocks(Block* block) {} |
85 | |
86 | private: |
87 | std::shared_ptr<Graph> graph_; |
88 | bool graph_modified_ = false; |
89 | }; |
90 | } // namespace |
91 | |
92 | TORCH_API bool FrozenLinearTranspose(std::shared_ptr<Graph>& graph) { |
93 | TransposeFrozenLinear transposeWeight(graph); |
94 | GRAPH_DUMP("Before FrozenLinearTranspose", graph); |
95 | bool changed = transposeWeight.run(); |
96 | if (changed) { |
97 | GRAPH_DUMP("After FrozenLinearTranspose", graph); |
98 | } |
99 | return changed; |
100 | } |
101 | |
102 | } // namespace jit |
103 | } // namespace torch |
104 |