1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/jit/ir/alias_analysis.h> |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/ir/ir_views.h> |
5 | #include <torch/csrc/jit/jit_log.h> |
6 | #include <torch/csrc/jit/passes/frozen_concat_linear.h> |
7 | #include <torch/csrc/jit/passes/frozen_conv_folding.h> |
8 | #include <torch/csrc/jit/passes/frozen_graph_optimizations.h> |
9 | #include <torch/csrc/jit/passes/remove_dropout.h> |
10 | #include <torch/csrc/jit/passes/utils/optimization_utils.h> |
11 | #include <torch/csrc/jit/runtime/graph_executor.h> |
12 | #include <torch/csrc/utils/memory.h> |
13 | |
14 | #ifndef AT_PER_OPERATOR_HEADERS |
15 | #include <ATen/Functions.h> |
16 | #else |
17 | #include <ATen/ops/cat.h> |
18 | #endif |
19 | |
20 | #include <unordered_set> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | namespace torch { |
25 | namespace jit { |
26 | namespace { |
27 | |
28 | using Tensor = at::Tensor; |
29 | |
30 | class ConcatLinearLayers { |
31 | public: |
32 | explicit ConcatLinearLayers(std::shared_ptr<Graph> graph) |
33 | : graph_(std::move(graph)) {} |
34 | |
35 | bool run() { |
36 | handleBlockAndSubblocks(graph_->block()); |
37 | return graph_modified; |
38 | } |
39 | |
40 | AliasDb* getAliasDb() { |
41 | if (!aliasDb_) { |
42 | aliasDb_ = std::make_unique<AliasDb>(graph_); |
43 | } |
44 | return aliasDb_.get(); |
45 | } |
46 | |
47 | void collectConstantLinearLayers( |
48 | Block* b, |
49 | std::unordered_map<Value*, std::vector<Node*>>& grouped_linear_layers, |
50 | std::vector<Value*>& ordered_tensor_inputs) { |
51 | // We are using an ordered list so that we only have to |
52 | // check if moving items forward is a valid move, not |
53 | // backwards. Otherwise we need to rebuild the aliasDb when we add values. |
54 | |
55 | for (Node* n : b->nodes()) { |
56 | // Grouping together all linear layers that use the same Tensor for input |
57 | if (n->kind() != aten::linear) { |
58 | continue; |
59 | } |
60 | |
61 | auto weight = n->namedInput("weight" ); |
62 | auto bias = n->namedInput("bias" ); |
63 | if (weight->type() == NoneType::get() || |
64 | bias->type() == NoneType::get()) { |
65 | continue; |
66 | } |
67 | |
68 | if (nonConstantParameters(n)) { |
69 | continue; |
70 | } |
71 | auto weight_tensor = constant_as<Tensor>(weight).value(); |
72 | if (!weight_tensor.device().is_cuda()) { |
73 | continue; |
74 | } |
75 | |
76 | Value* linear_input = n->inputs().at(0); |
77 | if (grouped_linear_layers.find(linear_input) == |
78 | grouped_linear_layers.cend()) { |
79 | grouped_linear_layers.insert({linear_input, std::vector<Node*>()}); |
80 | ordered_tensor_inputs.push_back(linear_input); |
81 | } |
82 | grouped_linear_layers.find(linear_input)->second.push_back(n); |
83 | } |
84 | } |
85 | |
86 | void mergeLinearLayers(std::vector<Node*>& compatible_layers) { |
87 | graph_modified = true; |
88 | assert(!compatible_layers.empty()); |
89 | Node* base_node = compatible_layers[0]; |
90 | |
91 | // Scope needed to make sure we free the WithInsertPoint guard |
92 | // and reset the insert point before we delete `base_node` |
93 | Node* linear_node = nullptr; |
94 | { |
95 | WithInsertPoint guard(base_node); |
96 | auto weight_list = c10::fmap(compatible_layers, [](Node* n) { |
97 | return constant_as<Tensor>(n->namedInput("weight" )).value(); |
98 | }); |
99 | Tensor cat_weight = at::cat(weight_list, /*dim=*/0); |
100 | Value* cat_weight_value = graph_->insertConstant(std::move(cat_weight)); |
101 | |
102 | auto bias_list = c10::fmap(compatible_layers, [](Node* n) { |
103 | return constant_as<Tensor>(n->namedInput("bias" )).value(); |
104 | }); |
105 | Tensor cat_bias = at::cat(bias_list, /*dim=*/0); |
106 | Value* cat_bias_value = graph_->insertConstant(std::move(cat_bias)); |
107 | |
108 | auto tensor_input = base_node->inputs().at(0); |
109 | std::vector<Value*> linear_in = { |
110 | tensor_input, cat_weight_value, cat_bias_value}; |
111 | linear_node = graph_->create(aten::linear, linear_in); |
112 | linear_node->insertBefore(base_node); |
113 | } |
114 | |
115 | // Update the outputs of the nodes |
116 | WithInsertPoint guard2(linear_node); |
117 | Value* neg1 = graph_->insertConstant(-1); |
118 | Value* one = graph_->insertConstant(1); |
119 | |
120 | int64_t slice_start = 0; |
121 | Value* slice_start_val = graph_->insertConstant(0); |
122 | |
123 | for (Node* orig_node : compatible_layers) { |
124 | // for each node in the compatible_layers list, |
125 | // slide the output of the combined linear layer |
126 | // and use it instead of the output of the original node |
127 | |
128 | Tensor weight_tensor = |
129 | constant_as<Tensor>(orig_node->namedInput("weight" )).value(); |
130 | int64_t slice_end = slice_start + weight_tensor.size(0); |
131 | Value* slice_end_val = graph_->insertConstant(slice_end); |
132 | |
133 | Node* slice = graph_->create( |
134 | aten::slice, |
135 | {linear_node->output(), neg1, slice_start_val, slice_end_val, one}); |
136 | slice->insertAfter(linear_node); |
137 | orig_node->replaceAllUsesWith(slice); |
138 | orig_node->destroy(); |
139 | |
140 | slice_start = slice_end; |
141 | slice_start_val = slice_end_val; |
142 | } |
143 | } |
144 | |
145 | bool isNonZeroDimEqual(Tensor& tensor_a, Tensor& tensor_b) { |
146 | if (tensor_a.dim() != tensor_b.dim()) { |
147 | return false; |
148 | } |
149 | for (int64_t i = 1; i < tensor_a.dim(); i++) { |
150 | if (tensor_a.size(i) != tensor_b.size(i)) { |
151 | return false; |
152 | } |
153 | } |
154 | return true; |
155 | } |
156 | |
157 | // Check the linear_layer_group of a tensor to find ones that can be |
158 | // combined |
159 | void collectAndMergeLinearLayers(std::vector<Node*>& linear_layer_group) { |
160 | std::unordered_set<Node*> checked_nodes; |
161 | |
162 | for (size_t i = 0; i < linear_layer_group.size(); i++) { |
163 | Node* base_node = linear_layer_group[i]; |
164 | if (checked_nodes.count(base_node) != 0) { |
165 | continue; |
166 | } |
167 | |
168 | std::vector<Node*> compatible_layers; |
169 | compatible_layers.push_back(base_node); |
170 | |
171 | auto base_weight = |
172 | constant_as<Tensor>(base_node->namedInput("weight" )).value(); |
173 | auto base_bias = |
174 | constant_as<Tensor>(base_node->namedInput("bias" )).value(); |
175 | |
176 | // Now iterate over the rest of the users of the set to |
177 | // see if there is anything that we can coaleasce `base_node` with. |
178 | for (size_t j = i + 1; j < linear_layer_group.size(); j++) { |
179 | auto node = linear_layer_group[j]; |
180 | if (checked_nodes.count(node) != 0) { |
181 | continue; |
182 | } |
183 | auto weight = constant_as<Tensor>(node->namedInput("weight" )).value(); |
184 | auto bias = constant_as<Tensor>(node->namedInput("bias" )).value(); |
185 | |
186 | // For now we will just keep it simple and require matching types |
187 | // Type promotion might cause performance to actually decrease. |
188 | if (base_weight.dtype() != weight.dtype() || |
189 | base_weight.device() != weight.device() || |
190 | base_bias.dtype() != bias.dtype() || |
191 | base_bias.device() != bias.device()) { |
192 | continue; |
193 | } |
194 | |
195 | if (!isNonZeroDimEqual(base_weight, weight) || |
196 | !isNonZeroDimEqual(base_bias, bias)) { |
197 | continue; |
198 | } |
199 | |
200 | bool can_move_before_all = true; |
201 | for (auto n : compatible_layers) { |
202 | can_move_before_all &= |
203 | getAliasDb()->couldMoveBeforeTopologically(node, n); |
204 | } |
205 | if (!can_move_before_all) { |
206 | continue; |
207 | } |
208 | |
209 | // Found a node that is eligible for combination |
210 | compatible_layers.push_back(node); |
211 | checked_nodes.insert(node); |
212 | } |
213 | if (compatible_layers.size() == 1) { |
214 | continue; // No other layers to merge |
215 | } |
216 | mergeLinearLayers(compatible_layers); |
217 | } |
218 | } |
219 | |
220 | void handleBlockAndSubblocks(Block* block) { |
221 | for (auto node : block->nodes()) { |
222 | for (Block* subblock : node->blocks()) { |
223 | handleBlockAndSubblocks(subblock); |
224 | } |
225 | } |
226 | |
227 | // Processing for the block itself |
228 | std::unordered_map<Value*, std::vector<Node*>> grouped_linear_layers; |
229 | std::vector<Value*> ordered_tensor_inputs; |
230 | collectConstantLinearLayers( |
231 | block, grouped_linear_layers, ordered_tensor_inputs); |
232 | |
233 | // Reverse topological ordering is used to prevent the need to |
234 | // update the aliasDB |
235 | for (auto tensor_it = ordered_tensor_inputs.rbegin(); |
236 | tensor_it != ordered_tensor_inputs.rend(); |
237 | ++tensor_it) { |
238 | collectAndMergeLinearLayers(grouped_linear_layers.at(*tensor_it)); |
239 | } |
240 | } |
241 | |
242 | private: |
243 | std::shared_ptr<Graph> graph_; |
244 | bool graph_modified = false; |
245 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
246 | }; |
247 | } // namespace |
248 | |
249 | TORCH_API bool FrozenConcatLinear(std::shared_ptr<Graph>& graph) { |
250 | ConcatLinearLayers concatLayers(graph); |
251 | GRAPH_DUMP("Before FrozenConcatLinear" , graph); |
252 | bool changed = concatLayers.run(); |
253 | if (changed) { |
254 | GRAPH_DUMP("After FrozenConcatLinear" , graph); |
255 | } |
256 | return changed; |
257 | } |
258 | |
259 | } // namespace jit |
260 | } // namespace torch |
261 | |