1#include <torch/csrc/jit/ir/ir.h>
2#include <torch/csrc/jit/ir/ir_views.h>
3#include <torch/csrc/jit/jit_log.h>
4#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
5#include <torch/csrc/jit/passes/utils/optimization_utils.h>
6#include <torch/csrc/jit/runtime/graph_executor.h>
7#include <torch/csrc/jit/runtime/graph_iterator.h>
8
9#ifndef AT_PER_OPERATOR_HEADERS
10#include <ATen/Functions.h>
11#else
12#include <ATen/ops/transpose.h>
13#endif
14
15#include <iostream>
16#include <utility>
17
18namespace torch {
19namespace jit {
20namespace {
21
22using Tensor = at::Tensor;
23
24class TransposeFrozenLinear {
25 public:
26 TransposeFrozenLinear(std::shared_ptr<Graph> graph)
27 : graph_(std::move(graph)) {}
28
29 bool run() {
30 // Can't delete nodes while also iterating over it
31 DepthFirstGraphNodeIterator graph_it(graph_);
32
33 for (auto next_node = graph_it.next(); next_node != nullptr;) {
34 Node* node = next_node;
35 next_node = graph_it.next();
36
37 if (is_constant_linear_op(node)) {
38 replace_linear_with_matmul(node);
39 }
40 }
41 return graph_modified_;
42 }
43
44 bool is_constant_linear_op(Node* node) {
45 if (node->kind() != aten::linear) {
46 return false;
47 }
48
49 // This also filters out out-variants of the linear op.
50 return !nonConstantParameters(node);
51 }
52
53 void replace_linear_with_matmul(Node* node) {
54 graph_modified_ = true;
55 Node* matmul = nullptr;
56
57 {
58 WithInsertPoint insert_guard(node);
59 auto weight = node->namedInput("weight");
60
61 Tensor weight_tensor = constant_as<Tensor>(weight).value();
62 Tensor weight_t_tensor = at::transpose(weight_tensor, 1, 0)
63 .clone(at::MemoryFormat::Contiguous);
64 Value* weight_t = graph_->insertConstant(std::move(weight_t_tensor));
65 matmul = graph_->create(aten::matmul, {node->inputs()[0], weight_t});
66 matmul->insertAfter(node);
67 }
68
69 // Handle a bias if there is any
70 WithInsertPoint insert_guard(matmul);
71 auto bias = node->namedInput("bias");
72 if (bias->type() == NoneType::get()) {
73 node->replaceAllUsesWith(matmul);
74 } else {
75 Value* bias_scale = graph_->insertConstant(1);
76 Node* bias_result =
77 graph_->create(aten::add, {matmul->output(), bias, bias_scale});
78 bias_result->insertAfter(matmul);
79 node->replaceAllUsesWith(bias_result);
80 }
81 node->destroy();
82 };
83
84 void handleBlockAndSubblocks(Block* block) {}
85
86 private:
87 std::shared_ptr<Graph> graph_;
88 bool graph_modified_ = false;
89};
90} // namespace
91
92TORCH_API bool FrozenLinearTranspose(std::shared_ptr<Graph>& graph) {
93 TransposeFrozenLinear transposeWeight(graph);
94 GRAPH_DUMP("Before FrozenLinearTranspose", graph);
95 bool changed = transposeWeight.run();
96 if (changed) {
97 GRAPH_DUMP("After FrozenLinearTranspose", graph);
98 }
99 return changed;
100}
101
102} // namespace jit
103} // namespace torch
104