1#include <c10/util/irange.h>
2#include <torch/csrc/jit/ir/alias_analysis.h>
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/ir/ir_views.h>
5#include <torch/csrc/jit/jit_log.h>
6#include <torch/csrc/jit/passes/frozen_concat_linear.h>
7#include <torch/csrc/jit/passes/frozen_conv_folding.h>
8#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
9#include <torch/csrc/jit/passes/remove_dropout.h>
10#include <torch/csrc/jit/passes/utils/optimization_utils.h>
11#include <torch/csrc/jit/runtime/graph_executor.h>
12#include <torch/csrc/utils/memory.h>
13
14#ifndef AT_PER_OPERATOR_HEADERS
15#include <ATen/Functions.h>
16#else
17#include <ATen/ops/cat.h>
18#endif
19
20#include <unordered_set>
21#include <utility>
22#include <vector>
23
24namespace torch {
25namespace jit {
26namespace {
27
28using Tensor = at::Tensor;
29
30class ConcatLinearLayers {
31 public:
32 explicit ConcatLinearLayers(std::shared_ptr<Graph> graph)
33 : graph_(std::move(graph)) {}
34
35 bool run() {
36 handleBlockAndSubblocks(graph_->block());
37 return graph_modified;
38 }
39
40 AliasDb* getAliasDb() {
41 if (!aliasDb_) {
42 aliasDb_ = std::make_unique<AliasDb>(graph_);
43 }
44 return aliasDb_.get();
45 }
46
47 void collectConstantLinearLayers(
48 Block* b,
49 std::unordered_map<Value*, std::vector<Node*>>& grouped_linear_layers,
50 std::vector<Value*>& ordered_tensor_inputs) {
51 // We are using an ordered list so that we only have to
52 // check if moving items forward is a valid move, not
53 // backwards. Otherwise we need to rebuild the aliasDb when we add values.
54
55 for (Node* n : b->nodes()) {
56 // Grouping together all linear layers that use the same Tensor for input
57 if (n->kind() != aten::linear) {
58 continue;
59 }
60
61 auto weight = n->namedInput("weight");
62 auto bias = n->namedInput("bias");
63 if (weight->type() == NoneType::get() ||
64 bias->type() == NoneType::get()) {
65 continue;
66 }
67
68 if (nonConstantParameters(n)) {
69 continue;
70 }
71 auto weight_tensor = constant_as<Tensor>(weight).value();
72 if (!weight_tensor.device().is_cuda()) {
73 continue;
74 }
75
76 Value* linear_input = n->inputs().at(0);
77 if (grouped_linear_layers.find(linear_input) ==
78 grouped_linear_layers.cend()) {
79 grouped_linear_layers.insert({linear_input, std::vector<Node*>()});
80 ordered_tensor_inputs.push_back(linear_input);
81 }
82 grouped_linear_layers.find(linear_input)->second.push_back(n);
83 }
84 }
85
86 void mergeLinearLayers(std::vector<Node*>& compatible_layers) {
87 graph_modified = true;
88 assert(!compatible_layers.empty());
89 Node* base_node = compatible_layers[0];
90
91 // Scope needed to make sure we free the WithInsertPoint guard
92 // and reset the insert point before we delete `base_node`
93 Node* linear_node = nullptr;
94 {
95 WithInsertPoint guard(base_node);
96 auto weight_list = c10::fmap(compatible_layers, [](Node* n) {
97 return constant_as<Tensor>(n->namedInput("weight")).value();
98 });
99 Tensor cat_weight = at::cat(weight_list, /*dim=*/0);
100 Value* cat_weight_value = graph_->insertConstant(std::move(cat_weight));
101
102 auto bias_list = c10::fmap(compatible_layers, [](Node* n) {
103 return constant_as<Tensor>(n->namedInput("bias")).value();
104 });
105 Tensor cat_bias = at::cat(bias_list, /*dim=*/0);
106 Value* cat_bias_value = graph_->insertConstant(std::move(cat_bias));
107
108 auto tensor_input = base_node->inputs().at(0);
109 std::vector<Value*> linear_in = {
110 tensor_input, cat_weight_value, cat_bias_value};
111 linear_node = graph_->create(aten::linear, linear_in);
112 linear_node->insertBefore(base_node);
113 }
114
115 // Update the outputs of the nodes
116 WithInsertPoint guard2(linear_node);
117 Value* neg1 = graph_->insertConstant(-1);
118 Value* one = graph_->insertConstant(1);
119
120 int64_t slice_start = 0;
121 Value* slice_start_val = graph_->insertConstant(0);
122
123 for (Node* orig_node : compatible_layers) {
124 // for each node in the compatible_layers list,
125 // slide the output of the combined linear layer
126 // and use it instead of the output of the original node
127
128 Tensor weight_tensor =
129 constant_as<Tensor>(orig_node->namedInput("weight")).value();
130 int64_t slice_end = slice_start + weight_tensor.size(0);
131 Value* slice_end_val = graph_->insertConstant(slice_end);
132
133 Node* slice = graph_->create(
134 aten::slice,
135 {linear_node->output(), neg1, slice_start_val, slice_end_val, one});
136 slice->insertAfter(linear_node);
137 orig_node->replaceAllUsesWith(slice);
138 orig_node->destroy();
139
140 slice_start = slice_end;
141 slice_start_val = slice_end_val;
142 }
143 }
144
145 bool isNonZeroDimEqual(Tensor& tensor_a, Tensor& tensor_b) {
146 if (tensor_a.dim() != tensor_b.dim()) {
147 return false;
148 }
149 for (int64_t i = 1; i < tensor_a.dim(); i++) {
150 if (tensor_a.size(i) != tensor_b.size(i)) {
151 return false;
152 }
153 }
154 return true;
155 }
156
157 // Check the linear_layer_group of a tensor to find ones that can be
158 // combined
159 void collectAndMergeLinearLayers(std::vector<Node*>& linear_layer_group) {
160 std::unordered_set<Node*> checked_nodes;
161
162 for (size_t i = 0; i < linear_layer_group.size(); i++) {
163 Node* base_node = linear_layer_group[i];
164 if (checked_nodes.count(base_node) != 0) {
165 continue;
166 }
167
168 std::vector<Node*> compatible_layers;
169 compatible_layers.push_back(base_node);
170
171 auto base_weight =
172 constant_as<Tensor>(base_node->namedInput("weight")).value();
173 auto base_bias =
174 constant_as<Tensor>(base_node->namedInput("bias")).value();
175
176 // Now iterate over the rest of the users of the set to
177 // see if there is anything that we can coaleasce `base_node` with.
178 for (size_t j = i + 1; j < linear_layer_group.size(); j++) {
179 auto node = linear_layer_group[j];
180 if (checked_nodes.count(node) != 0) {
181 continue;
182 }
183 auto weight = constant_as<Tensor>(node->namedInput("weight")).value();
184 auto bias = constant_as<Tensor>(node->namedInput("bias")).value();
185
186 // For now we will just keep it simple and require matching types
187 // Type promotion might cause performance to actually decrease.
188 if (base_weight.dtype() != weight.dtype() ||
189 base_weight.device() != weight.device() ||
190 base_bias.dtype() != bias.dtype() ||
191 base_bias.device() != bias.device()) {
192 continue;
193 }
194
195 if (!isNonZeroDimEqual(base_weight, weight) ||
196 !isNonZeroDimEqual(base_bias, bias)) {
197 continue;
198 }
199
200 bool can_move_before_all = true;
201 for (auto n : compatible_layers) {
202 can_move_before_all &=
203 getAliasDb()->couldMoveBeforeTopologically(node, n);
204 }
205 if (!can_move_before_all) {
206 continue;
207 }
208
209 // Found a node that is eligible for combination
210 compatible_layers.push_back(node);
211 checked_nodes.insert(node);
212 }
213 if (compatible_layers.size() == 1) {
214 continue; // No other layers to merge
215 }
216 mergeLinearLayers(compatible_layers);
217 }
218 }
219
220 void handleBlockAndSubblocks(Block* block) {
221 for (auto node : block->nodes()) {
222 for (Block* subblock : node->blocks()) {
223 handleBlockAndSubblocks(subblock);
224 }
225 }
226
227 // Processing for the block itself
228 std::unordered_map<Value*, std::vector<Node*>> grouped_linear_layers;
229 std::vector<Value*> ordered_tensor_inputs;
230 collectConstantLinearLayers(
231 block, grouped_linear_layers, ordered_tensor_inputs);
232
233 // Reverse topological ordering is used to prevent the need to
234 // update the aliasDB
235 for (auto tensor_it = ordered_tensor_inputs.rbegin();
236 tensor_it != ordered_tensor_inputs.rend();
237 ++tensor_it) {
238 collectAndMergeLinearLayers(grouped_linear_layers.at(*tensor_it));
239 }
240 }
241
242 private:
243 std::shared_ptr<Graph> graph_;
244 bool graph_modified = false;
245 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
246};
247} // namespace
248
249TORCH_API bool FrozenConcatLinear(std::shared_ptr<Graph>& graph) {
250 ConcatLinearLayers concatLayers(graph);
251 GRAPH_DUMP("Before FrozenConcatLinear", graph);
252 bool changed = concatLayers.run();
253 if (changed) {
254 GRAPH_DUMP("After FrozenConcatLinear", graph);
255 }
256 return changed;
257}
258
259} // namespace jit
260} // namespace torch
261