1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/common_runtime/constant_folding.h" |
19 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
20 | #include "tensorflow/core/common_runtime/threadpool_device.h" |
21 | #include "tensorflow/core/graph/node_builder.h" |
22 | #include "tensorflow/core/graph/subgraph.h" |
23 | #include "tensorflow/core/kernels/quantization_utils.h" |
24 | #include "tensorflow/core/platform/init_main.h" |
25 | #include "tensorflow/core/public/session.h" |
26 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace graph_transforms { |
30 | |
31 | // Holds the information we need to translate from a float version of this op |
32 | // into the quantized equivalent. |
33 | struct QuantizedOpInfo { |
34 | // The name of the float op. |
35 | string float_name; |
36 | // Which attributes to copy directly over. |
37 | std::vector<string> attrs_to_copy; |
38 | // Extra data type attributes we need to set. |
39 | std::vector<std::pair<string, DataType>> dtypes_to_set; |
40 | // What depth of inputs the op can read in. |
41 | DataType input_bit_depth; |
42 | // The depth of the op's quantized outputs. |
43 | DataType output_bit_depth; |
44 | // Which inputs (e.g. shapes) aren't involved in the quantization process. |
45 | std::set<int32> unquantized_inputs; |
46 | // How the outputs are arranged, either |
47 | // [input0, input1, min0, max0, min1, max1] for contiguous, or |
48 | // [input0, input1, min0, min1, max0, max1] for separate. |
49 | // The separate order is needed because it's the only way to specify unknown |
50 | // numbers of inputs for ops like Concat. |
51 | enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order; |
52 | }; |
53 | |
54 | // Every op that has a quantized equivalent should be listed here, so that the |
55 | // conversion process can transform them. |
56 | const std::vector<QuantizedOpInfo>& GetQuantizedOpList() { |
57 | static const std::vector<QuantizedOpInfo> op_list = { |
58 | {"Add" , |
59 | {}, |
60 | {{"T1" , DT_QUINT8}, {"T2" , DT_QUINT8}, {"Toutput" , DT_QINT32}}, |
61 | DT_QUINT8, |
62 | DT_QINT32, |
63 | {}, |
64 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
65 | {"AvgPool" , |
66 | {"ksize" , "strides" , "padding" }, |
67 | {{"T" , DT_QUINT8}}, |
68 | DT_QUINT8, |
69 | DT_QUINT8, |
70 | {}, |
71 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
72 | {"BiasAdd" , |
73 | {}, |
74 | {{"T1" , DT_QUINT8}, {"T2" , DT_QUINT8}, {"out_type" , DT_QINT32}}, |
75 | DT_QUINT8, |
76 | DT_QINT32, |
77 | {}, |
78 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
79 | {"Concat" , |
80 | {"N" }, |
81 | {{"T" , DT_QUINT8}}, |
82 | DT_QUINT8, |
83 | DT_QUINT8, |
84 | {0}, |
85 | QuantizedOpInfo::SEPARATE_MIN_MAX}, |
86 | {"Conv2D" , |
87 | {"strides" , "padding" }, |
88 | {{"Tinput" , DT_QUINT8}, {"Tfilter" , DT_QUINT8}, {"out_type" , DT_QINT32}}, |
89 | DT_QUINT8, |
90 | DT_QINT32, |
91 | {}, |
92 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
93 | {"MatMul" , |
94 | {"transpose_a" , "transpose_b" }, |
95 | {{"T1" , DT_QUINT8}, {"T2" , DT_QUINT8}, {"Toutput" , DT_QINT32}}, |
96 | DT_QUINT8, |
97 | DT_QINT32, |
98 | {}, |
99 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
100 | {"MaxPool" , |
101 | {"ksize" , "strides" , "padding" }, |
102 | {{"T" , DT_QUINT8}}, |
103 | DT_QUINT8, |
104 | DT_QUINT8, |
105 | {}, |
106 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
107 | {"Mul" , |
108 | {}, |
109 | {{"T1" , DT_QUINT8}, {"T2" , DT_QUINT8}, {"Toutput" , DT_QINT32}}, |
110 | DT_QUINT8, |
111 | DT_QINT32, |
112 | {}, |
113 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
114 | {"Relu" , |
115 | {}, |
116 | {{"Tinput" , DT_QUINT8}}, |
117 | DT_QUINT8, |
118 | DT_QUINT8, |
119 | {}, |
120 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
121 | {"ResizeBilinear" , |
122 | {"align_corners" }, |
123 | {{"T" , DT_QUINT8}}, |
124 | DT_QUINT8, |
125 | DT_QUINT8, |
126 | {1}, |
127 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
128 | {"Relu6" , |
129 | {}, |
130 | {{"Tinput" , DT_QUINT8}}, |
131 | DT_QUINT8, |
132 | DT_QUINT8, |
133 | {}, |
134 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
135 | {"Reshape" , |
136 | {}, |
137 | {{"T" , DT_QUINT8}}, |
138 | DT_QUINT8, |
139 | DT_QUINT8, |
140 | {1}, |
141 | QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, |
142 | }; |
143 | return op_list; |
144 | } |
145 | |
146 | namespace { |
147 | // Replaces invalid characters in input names to get a unique node name. |
148 | string UniqueNodeNameFromInput(const string& input_name) { |
149 | string prefix; |
150 | string node_name; |
151 | string suffix; |
152 | NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); |
153 | string result; |
154 | if (prefix == "^" ) { |
155 | result += "__hat__" ; |
156 | } |
157 | result += node_name; |
158 | if (!suffix.empty()) { |
159 | result += "__port__" + suffix.substr(1, suffix.size() - 1); |
160 | } |
161 | return result; |
162 | } |
163 | |
164 | // Pulls two float values from the named parameters, with a lot of checking. |
165 | Status (const TransformFuncContext& context, |
166 | const string& min_name, const string& max_name, |
167 | float* min_value, float* max_value, |
168 | bool* has_range) { |
169 | // See if we've been given quantized inputs with a known range. |
170 | const bool has_min = (context.params.count(min_name) != 0); |
171 | const bool has_max = (context.params.count(max_name) != 0); |
172 | *has_range = (has_min || has_max); |
173 | if (!*has_range) { |
174 | return OkStatus(); |
175 | } |
176 | if (!has_min || !has_max) { |
177 | return errors::InvalidArgument("You must pass both " , min_name, " and " , |
178 | max_name, " into quantize_nodes" ); |
179 | } |
180 | TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value)); |
181 | TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value)); |
182 | return OkStatus(); |
183 | } |
184 | |
185 | } // namespace |
186 | |
187 | // Analyzes all the nodes in the graph to figure out which ones are duplicates |
188 | // apart from their names. This commonly includes identical Const nodes, but can |
189 | // also be simple operations that are repeated on multiple outputs of a |
190 | // particular node. The complexity is managed using a hash function that avoids |
191 | // the need for any O(n^2) algorithms when identifying duplicates. |
192 | Status MergeDuplicateNodes(const GraphDef& input_graph_def, |
193 | const TransformFuncContext& context, |
194 | GraphDef* output_graph_def) { |
195 | // Make sure we can look up inputs and outputs quickly. |
196 | std::set<string> input_names(context.input_names.begin(), |
197 | context.input_names.end()); |
198 | std::set<string> output_names(context.output_names.begin(), |
199 | context.output_names.end()); |
200 | GraphDef current_graph_def = input_graph_def; |
201 | // Keep running the merging until no more duplicates are found. |
202 | bool any_duplicates_found; |
203 | do { |
204 | any_duplicates_found = false; |
205 | // First arrange all of the nodes by a hash of their contents. |
206 | std::map<uint64, std::vector<const NodeDef*>> hashed_nodes; |
207 | for (const NodeDef& node : current_graph_def.node()) { |
208 | NodeDef nameless_node = node; |
209 | // The name matters if it's being used as an input or output node, |
210 | // otherwise ignore it when looking for duplicates. |
211 | if (!input_names.count(node.name()) && !output_names.count(node.name())) { |
212 | nameless_node.set_name("" ); |
213 | } |
214 | const uint64 hash = HashNodeDef(nameless_node); |
215 | hashed_nodes[hash].push_back(&node); |
216 | } |
217 | // If we have multiple nodes with the same hash, then we know they're |
218 | // duplicates and can be removed, unless they're stateful. |
219 | std::map<string, string> inputs_to_rename; |
220 | GraphDef merged_graph_def; |
221 | for (const std::pair<const uint64, std::vector<const NodeDef*>>& |
222 | hashed_node_info : hashed_nodes) { |
223 | const std::vector<const NodeDef*>& hash_node_list = |
224 | hashed_node_info.second; |
225 | for (int i = 0; i < hash_node_list.size(); ++i) { |
226 | const NodeDef* current_node = hash_node_list[i]; |
227 | const OpDef* op_def = nullptr; |
228 | TF_RETURN_IF_ERROR( |
229 | OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def)); |
230 | const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0)); |
231 | if (is_duplicate) { |
232 | const string original_name = hash_node_list[0]->name(); |
233 | inputs_to_rename[current_node->name() + ":*" ] = original_name; |
234 | any_duplicates_found = true; |
235 | } else { |
236 | NodeDef* new_node = merged_graph_def.mutable_node()->Add(); |
237 | *new_node = *current_node; |
238 | } |
239 | } |
240 | } |
241 | // Update the graph so that any nodes that referred to removed inputs now |
242 | // pull from the remaining duplicate. |
243 | TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename, |
244 | std::unordered_set<string>(), |
245 | ¤t_graph_def)); |
246 | } while (any_duplicates_found); |
247 | |
248 | *output_graph_def = current_graph_def; |
249 | |
250 | return OkStatus(); |
251 | } |
252 | |
253 | // Looks for the patterns that indicate there are two eight-bit ops feeding into |
254 | // each other, separated by a conversion up to float and back again. These occur |
255 | // during the initial conversion of ops to their quantized forms. Because we're |
256 | // only looking at an individual op in that phase and don't know if its inputs |
257 | // and outputs are eight-bit-capable, we start by converting the actual op into |
258 | // quantized form, but add float conversions before and after. This pass gets |
259 | // rid of those conversions if it turns out we do have adjacent ops capable of |
260 | // eight-bit processing. |
261 | Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, |
262 | const TransformFuncContext& context, |
263 | GraphDef* output_graph_def) { |
264 | std::set<string> graph_outputs; |
265 | for (const string& output_name : context.output_names) { |
266 | graph_outputs.insert(NodeNameFromInput(output_name)); |
267 | } |
268 | std::map<string, string> inputs_to_rename; |
269 | GraphDef replaced_graph_def; |
270 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
271 | input_graph_def, // clang-format off |
272 | {"QuantizeV2" , |
273 | { |
274 | {"Dequantize" }, |
275 | {"Min" }, |
276 | {"Max" }, |
277 | } |
278 | }, // clang-format on |
279 | [&inputs_to_rename, &graph_outputs](const NodeMatch& match, |
280 | const std::set<string>& input_nodes, |
281 | const std::set<string>& output_nodes, |
282 | std::vector<NodeDef>* new_nodes) { |
283 | const NodeDef& quantize_node = match.node; |
284 | const NodeDef& dequantize_node = match.inputs[0].node; |
285 | inputs_to_rename[quantize_node.name() + ":0" ] = |
286 | dequantize_node.input(0); |
287 | inputs_to_rename[quantize_node.name() + ":1" ] = |
288 | dequantize_node.input(1); |
289 | inputs_to_rename[quantize_node.name() + ":2" ] = |
290 | dequantize_node.input(2); |
291 | |
292 | // Are other sub-graphs using the float intermediate result? If so, |
293 | // preserve it, but the input renaming still rewires the eight-bit ops |
294 | // so they don't go through float. |
295 | if (output_nodes.count(dequantize_node.name()) || |
296 | graph_outputs.count(dequantize_node.name())) { |
297 | CopyOriginalMatch(match, new_nodes); |
298 | } |
299 | |
300 | return OkStatus(); |
301 | }, |
302 | {true}, &replaced_graph_def)); |
303 | |
304 | return RenameNodeInputs(replaced_graph_def, inputs_to_rename, |
305 | std::unordered_set<string>(), output_graph_def); |
306 | } |
307 | |
308 | // If the user has passed in the input_min and input_max args, then we need to |
309 | // convert any input placeholders from float to eight bit, so quantized inputs |
310 | // can be fed directly into the graph. |
311 | Status QuantizePlaceholders(const GraphDef& input_graph_def, |
312 | const TransformFuncContext& context, |
313 | GraphDef* output_graph_def) { |
314 | float input_min; |
315 | float input_max; |
316 | bool has_input_range; |
317 | TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min" , "input_max" , |
318 | &input_min, &input_max, |
319 | &has_input_range)); |
320 | if (!has_input_range) { |
321 | *output_graph_def = input_graph_def; |
322 | return OkStatus(); |
323 | } |
324 | std::map<string, string> inputs_to_rename_first_pass; |
325 | std::map<string, string> inputs_to_rename_second_pass; |
326 | GraphDef placeholder_graph_def; |
327 | placeholder_graph_def.Clear(); |
328 | for (const NodeDef& node : input_graph_def.node()) { |
329 | if (node.op() != "Placeholder" ) { |
330 | *(placeholder_graph_def.mutable_node()->Add()) = node; |
331 | } else { |
332 | string namespace_prefix = node.name() + "_eightbit" ; |
333 | |
334 | NodeDef quantized_placeholder; |
335 | quantized_placeholder = node; |
336 | SetNodeAttr("dtype" , DT_QUINT8, &quantized_placeholder); |
337 | *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder; |
338 | |
339 | NodeDef min_node; |
340 | min_node.set_op("Const" ); |
341 | min_node.set_name(namespace_prefix + "/min" ); |
342 | SetNodeAttr("dtype" , DT_FLOAT, &min_node); |
343 | Tensor min_tensor(DT_FLOAT, {}); |
344 | min_tensor.flat<float>()(0) = input_min; |
345 | SetNodeTensorAttr<float>("value" , min_tensor, &min_node); |
346 | *(placeholder_graph_def.mutable_node()->Add()) = min_node; |
347 | |
348 | NodeDef max_node; |
349 | max_node.set_op("Const" ); |
350 | max_node.set_name(namespace_prefix + "/max" ); |
351 | SetNodeAttr("dtype" , DT_FLOAT, &max_node); |
352 | Tensor max_tensor(DT_FLOAT, {}); |
353 | max_tensor.flat<float>()(0) = input_max; |
354 | SetNodeTensorAttr<float>("value" , max_tensor, &max_node); |
355 | *(placeholder_graph_def.mutable_node()->Add()) = max_node; |
356 | |
357 | const string rename_suffix = "__RENAMED_PLACEHOLDER__" ; |
358 | NodeDef dequantize_node; |
359 | dequantize_node.set_op("Dequantize" ); |
360 | dequantize_node.set_name(namespace_prefix + "/dequantize" ); |
361 | SetNodeAttr("T" , DT_QUINT8, &dequantize_node); |
362 | SetNodeAttr("mode" , "MIN_FIRST" , &dequantize_node); |
363 | AddNodeInput(node.name() + rename_suffix, &dequantize_node); |
364 | AddNodeInput(min_node.name(), &dequantize_node); |
365 | AddNodeInput(max_node.name(), &dequantize_node); |
366 | *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node; |
367 | |
368 | // First make sure that any internal references to the old placeholder |
369 | // now point to the dequantize result. |
370 | inputs_to_rename_first_pass[node.name()] = dequantize_node.name(); |
371 | // Then fix up the dequantize op so that it really points to the |
372 | // placeholder. |
373 | inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name(); |
374 | } |
375 | } |
376 | |
377 | GraphDef first_pass_graph_def; |
378 | TF_RETURN_IF_ERROR( |
379 | RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass, |
380 | std::unordered_set<string>(), &first_pass_graph_def)); |
381 | TF_RETURN_IF_ERROR( |
382 | RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass, |
383 | std::unordered_set<string>(), output_graph_def)); |
384 | |
385 | return OkStatus(); |
386 | } |
387 | |
388 | // During training, FakeQuantWithMinMaxVars ops capture a good min/max range for |
389 | // an activation layer. To use these during inference, this pass converts those |
390 | // ops into Requantizes with the trained min/maxes as constant inputs. |
391 | Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def, |
392 | const TransformFuncContext& context, |
393 | GraphDef* output_graph_def) { |
394 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
395 | input_graph_def, // clang-format off |
396 | {"FakeQuantWithMinMaxVars" , |
397 | { |
398 | {"*" }, |
399 | {"Const" }, |
400 | {"Const" }, |
401 | } |
402 | }, // clang-format on |
403 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
404 | const std::set<string>& output_nodes, |
405 | std::vector<NodeDef>* new_nodes) { |
406 | const NodeDef& fake_quant_node = match.node; |
407 | const NodeDef& original_op_node = match.inputs[0].node; |
408 | const NodeDef& fake_quant_min_node = match.inputs[1].node; |
409 | const NodeDef& fake_quant_max_node = match.inputs[2].node; |
410 | |
411 | string namespace_prefix = fake_quant_node.name() + "_eightbit" ; |
412 | |
413 | new_nodes->push_back(original_op_node); |
414 | new_nodes->push_back(fake_quant_min_node); |
415 | new_nodes->push_back(fake_quant_max_node); |
416 | |
417 | NodeDef quantize_node; |
418 | quantize_node.set_op("QuantizeV2" ); |
419 | quantize_node.set_name(namespace_prefix + "/quantize" ); |
420 | SetNodeAttr("T" , DT_QINT32, &quantize_node); |
421 | SetNodeAttr("mode" , "MIN_FIRST" , &quantize_node); |
422 | AddNodeInput(fake_quant_node.input(0), &quantize_node); |
423 | AddNodeInput(fake_quant_min_node.name(), &quantize_node); |
424 | AddNodeInput(fake_quant_max_node.name(), &quantize_node); |
425 | new_nodes->push_back(quantize_node); |
426 | |
427 | NodeDef requantize_node; |
428 | requantize_node.set_op("Requantize" ); |
429 | requantize_node.set_name(namespace_prefix + "/requantize" ); |
430 | SetNodeAttr("Tinput" , DT_QINT32, &requantize_node); |
431 | SetNodeAttr("out_type" , DT_QUINT8, &requantize_node); |
432 | AddNodeInput(quantize_node.name() + ":0" , &requantize_node); |
433 | AddNodeInput(quantize_node.name() + ":1" , &requantize_node); |
434 | AddNodeInput(quantize_node.name() + ":2" , &requantize_node); |
435 | AddNodeInput(fake_quant_min_node.name(), &requantize_node); |
436 | AddNodeInput(fake_quant_max_node.name(), &requantize_node); |
437 | new_nodes->push_back(requantize_node); |
438 | |
439 | // Convert the 8-bit result back into float for the final output. |
440 | NodeDef dequantize_node; |
441 | dequantize_node.set_op("Dequantize" ); |
442 | dequantize_node.set_name(fake_quant_node.name()); |
443 | SetNodeAttr("T" , DT_QUINT8, &dequantize_node); |
444 | SetNodeAttr("mode" , "MIN_FIRST" , &dequantize_node); |
445 | AddNodeInput(requantize_node.name() + ":0" , &dequantize_node); |
446 | AddNodeInput(requantize_node.name() + ":1" , &dequantize_node); |
447 | AddNodeInput(requantize_node.name() + ":2" , &dequantize_node); |
448 | new_nodes->push_back(dequantize_node); |
449 | |
450 | return OkStatus(); |
451 | }, |
452 | {}, output_graph_def)); |
453 | |
454 | return OkStatus(); |
455 | } |
456 | |
457 | // We always generate Requantize ops driven by dynamic RequantizationRange |
458 | // calculations when we produce quantized ops like Conv2D or BiasAdd with |
459 | // 32-bit results. If there were FakeQuant ops already for those activation |
460 | // layers, then there will be a later Requantize op with constant min/max |
461 | // inputs, which is preferable for fast inference. This pass looks for those |
462 | // later Requantize ops, and replaces the dynamic version with them. |
463 | Status MergeAdjacentRequantizes(const GraphDef& input_graph_def, |
464 | const TransformFuncContext& context, |
465 | GraphDef* output_graph_def) { |
466 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
467 | input_graph_def, // clang-format off |
468 | {"Requantize" , |
469 | { |
470 | {"QuantizeV2" , |
471 | { |
472 | {"Dequantize" , |
473 | { |
474 | {"Requantize" , |
475 | { |
476 | {"*" }, |
477 | {"*" }, |
478 | {"*" }, |
479 | {"RequantizationRange" }, |
480 | {"RequantizationRange" }, |
481 | } |
482 | }, |
483 | {"Requantize" }, |
484 | {"Requantize" }, |
485 | } |
486 | }, |
487 | {"Const" }, |
488 | {"Const" }, |
489 | }, |
490 | }, |
491 | {"QuantizeV2" }, |
492 | {"QuantizeV2" }, |
493 | {"Const" }, |
494 | {"Const" }, |
495 | } |
496 | }, // clang-format on |
497 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
498 | const std::set<string>& output_nodes, |
499 | std::vector<NodeDef>* new_nodes) { |
500 | const NodeDef& fake_requantize_node = match.node; |
501 | const NodeDef& original_op_node = |
502 | match.inputs[0].inputs[0].inputs[0].inputs[0].node; |
503 | const NodeDef& fake_requantize_min_node = match.inputs[3].node; |
504 | const NodeDef& fake_requantize_max_node = match.inputs[4].node; |
505 | |
506 | new_nodes->push_back(original_op_node); |
507 | new_nodes->push_back(fake_requantize_min_node); |
508 | new_nodes->push_back(fake_requantize_max_node); |
509 | |
510 | NodeDef requantize_node; |
511 | requantize_node = fake_requantize_node; |
512 | requantize_node.mutable_input()->Clear(); |
513 | AddNodeInput(original_op_node.name() + ":0" , &requantize_node); |
514 | AddNodeInput(original_op_node.name() + ":1" , &requantize_node); |
515 | AddNodeInput(original_op_node.name() + ":2" , &requantize_node); |
516 | AddNodeInput(fake_requantize_min_node.name(), &requantize_node); |
517 | AddNodeInput(fake_requantize_max_node.name(), &requantize_node); |
518 | new_nodes->push_back(requantize_node); |
519 | |
520 | return OkStatus(); |
521 | }, |
522 | {}, output_graph_def)); |
523 | |
524 | return OkStatus(); |
525 | } |
526 | |
527 | // Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of |
528 | // linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd |
529 | // op that we want to apply the trained constant conversions to. This pass tries |
530 | // to move FakeQuant ops up the input chain, so they're as close as possible to |
531 | // the 32-bit conversion, and so can be easily merged into the automatic dynamic |
532 | // Requantizes. |
533 | Status HoistFakeQuants(const GraphDef& input_graph_def, |
534 | const TransformFuncContext& context, |
535 | GraphDef* output_graph_def) { |
536 | GraphDef current_graph_def = input_graph_def; |
537 | const int max_depth = 3; |
538 | for (int depth = max_depth; depth > 0; --depth) { |
539 | OpTypePattern pattern = {"*" }; |
540 | for (int i = 0; i < depth; ++i) { |
541 | pattern = {"*" , {pattern}}; |
542 | } |
543 | pattern = {"FakeQuantWithMinMaxVars" , {pattern, {"Const" }, {"Const" }}}; |
544 | GraphDef hoisted_graph_def; |
545 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
546 | current_graph_def, pattern, |
547 | [depth](const NodeMatch& match, const std::set<string>& input_nodes, |
548 | const std::set<string>& output_nodes, |
549 | std::vector<NodeDef>* new_nodes) { |
550 | const NodeDef& fake_quant_node = match.node; |
551 | const NodeDef& fake_quant_min_node = match.inputs[1].node; |
552 | const NodeDef& fake_quant_max_node = match.inputs[2].node; |
553 | std::vector<NodeDef> linear_nodes; |
554 | NodeMatch current_match = match; |
555 | for (int i = 0; i <= depth; ++i) { |
556 | linear_nodes.push_back(current_match.inputs[0].node); |
557 | current_match = current_match.inputs[0]; |
558 | } |
559 | NodeDef new_fake_quant_node; |
560 | new_fake_quant_node = fake_quant_node; |
561 | new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted" ); |
562 | new_fake_quant_node.set_input( |
563 | 0, linear_nodes[linear_nodes.size() - 2].input(0)); |
564 | new_nodes->push_back(new_fake_quant_node); |
565 | |
566 | new_nodes->push_back(fake_quant_min_node); |
567 | new_nodes->push_back(fake_quant_max_node); |
568 | |
569 | linear_nodes[linear_nodes.size() - 2].set_input( |
570 | 0, new_fake_quant_node.name()); |
571 | linear_nodes.front().set_name(fake_quant_node.name()); |
572 | for (const NodeDef& linear_node : linear_nodes) { |
573 | new_nodes->push_back(linear_node); |
574 | } |
575 | |
576 | return OkStatus(); |
577 | }, |
578 | {}, &hoisted_graph_def)); |
579 | current_graph_def = hoisted_graph_def; |
580 | } |
581 | *output_graph_def = current_graph_def; |
582 | |
583 | return OkStatus(); |
584 | } |
585 | |
586 | // Converts any float ops that have eight-bit equivalents into their quantized |
587 | // forms, so that as much calculation as possible is done in the lower-precision |
588 | // format. |
589 | Status QuantizeNodes(const GraphDef& input_graph_def, |
590 | const TransformFuncContext& context, |
591 | GraphDef* output_graph_def) { |
592 | // Loop through all of the quantizable op types, and replace any occurrences |
593 | // with equivalent sub-graphs with quantized ops at their core. For example |
594 | // this one-input operation: |
595 | // |
596 | // Input(float) |
597 | // | |
598 | // v |
599 | // Operation |
600 | // | |
601 | // v |
602 | // (float) |
603 | // |
604 | // Will be turned into it's quantized equivalent: |
605 | // |
606 | // Input(float) ReshapeDims |
607 | // +------v v-------------+ |
608 | // | Reshape |
609 | // | | |
610 | // | | ReductionDims |
611 | // | +-----+ | |
612 | // | | +---c---------+ |
613 | // | v v v v-------+ |
614 | // | Min Max |
615 | // | +----+ | |
616 | // v v v--------+ |
617 | // Quantize |
618 | // | |
619 | // v |
620 | // QuantizedOperation |
621 | // | | | |
622 | // v v v |
623 | // Dequantize |
624 | // | |
625 | // v |
626 | // (float) |
627 | // |
628 | // This keeps the inputs and outputs visible to the rest of the graph in |
629 | // float |
630 | // and converts them down to quantized buffers internally for the |
631 | // computation. |
632 | // The result will end up with a lot of redundant dequantize/quantize pairs |
633 | // between adjacent quantized ops, but a later pass removes these where it |
634 | // can. |
635 | |
636 | std::set<string> ops_to_ignore; |
637 | if (context.params.count("ignore_op" ) > 0) { |
638 | for (const string& name : context.params.at("ignore_op" )) { |
639 | ops_to_ignore.insert(name); |
640 | } |
641 | } |
642 | |
643 | const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList(); |
644 | string op_pattern; |
645 | bool is_first = true; |
646 | std::map<string, QuantizedOpInfo> op_map; |
647 | for (const QuantizedOpInfo& op_info : op_list) { |
648 | if (ops_to_ignore.count(op_info.float_name) == 0) { |
649 | strings::StrAppend(&op_pattern, (is_first ? "" : "|" ), |
650 | op_info.float_name); |
651 | op_map.insert({op_info.float_name, op_info}); |
652 | is_first = false; |
653 | } |
654 | } |
655 | |
656 | // If input_min and input max have been passed in, then we convert all float |
657 | // Placeholder nodes into quantized versions, with the supplied values as |
658 | // their range. |
659 | GraphDef placeholder_graph_def; |
660 | TF_RETURN_IF_ERROR( |
661 | QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def)); |
662 | TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def)); |
663 | |
664 | // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear |
665 | // operations like Relu or MaxPool, move them up so that they're as close as |
666 | // possible to ops with 32-bit outputs like BiasAdd or Conv2D. |
667 | GraphDef hoisted_graph_def; |
668 | TF_RETURN_IF_ERROR( |
669 | HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def)); |
670 | TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def)); |
671 | |
672 | // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of |
673 | // activation layers, into Requantize ops with those ranges instead. This |
674 | // makes it easier to replace the dynamic range calculations that are used |
675 | // by default. |
676 | GraphDef converted_graph_def; |
677 | TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context, |
678 | &converted_graph_def)); |
679 | TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def)); |
680 | |
681 | // If fallback_min and fallback_max are set, then we'll use hardwired ranges |
682 | // for all the 32-bit to 8-bit requantizations. |
683 | float fallback_min; |
684 | float fallback_max; |
685 | bool has_fallback_range; |
686 | TF_RETURN_IF_ERROR(ExtractRangeFromParams( |
687 | context, "fallback_min" , "fallback_max" , &fallback_min, &fallback_max, |
688 | &has_fallback_range)); |
689 | |
690 | // Replace all occurrences of the current float op with its quantized |
691 | // equivalent. |
692 | GraphDef quantized_graph_def; |
693 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
694 | converted_graph_def, {op_pattern}, |
695 | [&op_map, fallback_min, fallback_max, has_fallback_range]( |
696 | const NodeMatch& match, const std::set<string>& input_nodes, |
697 | const std::set<string>& output_nodes, |
698 | std::vector<NodeDef>* new_nodes) { |
699 | const NodeDef& float_node = match.node; |
700 | const QuantizedOpInfo& op_info = op_map[float_node.op()]; |
701 | |
702 | DataTypeVector input_types; |
703 | DataTypeVector output_types; |
704 | TF_RETURN_IF_ERROR( |
705 | GetInOutTypes(float_node, &input_types, &output_types)); |
706 | bool are_all_float = true; |
707 | for (int i = 0; i < float_node.input_size(); ++i) { |
708 | // Skip any known non-float inputs. |
709 | if (op_info.unquantized_inputs.count(i)) { |
710 | continue; |
711 | } |
712 | if (i >= input_types.size()) { |
713 | LOG(ERROR) << "input_types has incorrect size " |
714 | << input_types.size() << " <= " << i |
715 | << ". Assuming everything else is floats." ; |
716 | } |
717 | if (i < input_types.size() && input_types[i] != DT_FLOAT) { |
718 | are_all_float = false; |
719 | } |
720 | } |
721 | for (const DataType& output_type : output_types) { |
722 | if (output_type != DT_FLOAT) { |
723 | are_all_float = false; |
724 | } |
725 | } |
726 | // This isn't a float op, so don't quantize it. |
727 | if (!are_all_float) { |
728 | CopyOriginalMatch(match, new_nodes); |
729 | return OkStatus(); |
730 | } |
731 | |
732 | string namespace_prefix = float_node.name() + "_eightbit" ; |
733 | |
734 | // Quantize all of the inputs. |
735 | std::vector<string> quantized_input_names; |
736 | for (int i = 0; i < float_node.input_size(); ++i) { |
737 | // Skip any non-float inputs. |
738 | if (op_info.unquantized_inputs.count(i)) { |
739 | continue; |
740 | } |
741 | |
742 | const string& input_name = float_node.input(i); |
743 | string unique_input_name = |
744 | namespace_prefix + "/" + UniqueNodeNameFromInput(input_name); |
745 | |
746 | // Add some common constants we need for reshaping inputs. |
747 | NodeDef reshape_dims; |
748 | reshape_dims.set_op("Const" ); |
749 | reshape_dims.set_name(unique_input_name + "/reshape_dims" ); |
750 | AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims); |
751 | SetNodeAttr("dtype" , DT_INT32, &reshape_dims); |
752 | Tensor reshape_dims_tensor(DT_INT32, {1}); |
753 | reshape_dims_tensor.flat<int32>()(0) = -1; |
754 | SetNodeTensorAttr<int32>("value" , reshape_dims_tensor, &reshape_dims); |
755 | new_nodes->push_back(reshape_dims); |
756 | |
757 | NodeDef reduction_dims; |
758 | reduction_dims.set_op("Const" ); |
759 | reduction_dims.set_name(unique_input_name + "/reduction_dims" ); |
760 | AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims); |
761 | SetNodeAttr("dtype" , DT_INT32, &reduction_dims); |
762 | Tensor reduction_dims_tensor(DT_INT32, {1}); |
763 | reduction_dims_tensor.flat<int32>()(0) = 0; |
764 | SetNodeTensorAttr<int32>("value" , reduction_dims_tensor, |
765 | &reduction_dims); |
766 | new_nodes->push_back(reduction_dims); |
767 | |
768 | NodeDef reshape_node; |
769 | reshape_node.set_op("Reshape" ); |
770 | reshape_node.set_name(unique_input_name + "/reshape" ); |
771 | SetNodeAttr("T" , DT_FLOAT, &reshape_node); |
772 | AddNodeInput(input_name, &reshape_node); |
773 | AddNodeInput(reshape_dims.name(), &reshape_node); |
774 | new_nodes->push_back(reshape_node); |
775 | |
776 | NodeDef min_node; |
777 | min_node.set_op("Min" ); |
778 | min_node.set_name(unique_input_name + "/min" ); |
779 | SetNodeAttr("T" , DT_FLOAT, &min_node); |
780 | SetNodeAttr("keep_dims" , false, &min_node); |
781 | AddNodeInput(reshape_node.name(), &min_node); |
782 | AddNodeInput(reduction_dims.name(), &min_node); |
783 | new_nodes->push_back(min_node); |
784 | |
785 | NodeDef max_node; |
786 | max_node.set_op("Max" ); |
787 | max_node.set_name(unique_input_name + "/max" ); |
788 | SetNodeAttr("T" , DT_FLOAT, &max_node); |
789 | SetNodeAttr("keep_dims" , false, &max_node); |
790 | AddNodeInput(reshape_node.name(), &max_node); |
791 | AddNodeInput(reduction_dims.name(), &max_node); |
792 | new_nodes->push_back(max_node); |
793 | |
794 | NodeDef quantize_node; |
795 | quantize_node.set_op("QuantizeV2" ); |
796 | quantize_node.set_name(unique_input_name + "/quantize" ); |
797 | SetNodeAttr("T" , DT_QUINT8, &quantize_node); |
798 | SetNodeAttr("mode" , "MIN_FIRST" , &quantize_node); |
799 | AddNodeInput(input_name, &quantize_node); |
800 | AddNodeInput(min_node.name(), &quantize_node); |
801 | AddNodeInput(max_node.name(), &quantize_node); |
802 | new_nodes->push_back(quantize_node); |
803 | quantized_input_names.push_back(quantize_node.name()); |
804 | } |
805 | |
806 | // Set up the quantized version of the current op. |
807 | NodeDef quantized_main_node; |
808 | quantized_main_node.set_op("Quantized" + float_node.op()); |
809 | quantized_main_node.set_name(float_node.name() + "/eightbit" ); |
810 | for (const string& attr_to_copy : op_info.attrs_to_copy) { |
811 | CopyNodeAttr(float_node, attr_to_copy, attr_to_copy, |
812 | &quantized_main_node); |
813 | } |
814 | for (const std::pair<string, DataType>& dtype_to_set : |
815 | op_info.dtypes_to_set) { |
816 | SetNodeAttr(dtype_to_set.first, dtype_to_set.second, |
817 | &quantized_main_node); |
818 | } |
819 | int quantized_input_index = 0; |
820 | for (int i = 0; i < float_node.input_size(); ++i) { |
821 | if (op_info.unquantized_inputs.count(i)) { |
822 | AddNodeInput(float_node.input(i), &quantized_main_node); |
823 | } else { |
824 | const string& quantized_input_name = |
825 | quantized_input_names[quantized_input_index]; |
826 | AddNodeInput(quantized_input_name + ":0" , &quantized_main_node); |
827 | ++quantized_input_index; |
828 | } |
829 | } |
830 | if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) { |
831 | for (const string& quantized_input_name : quantized_input_names) { |
832 | AddNodeInput(quantized_input_name + ":1" , &quantized_main_node); |
833 | AddNodeInput(quantized_input_name + ":2" , &quantized_main_node); |
834 | } |
835 | } else { |
836 | for (const string& quantized_input_name : quantized_input_names) { |
837 | AddNodeInput(quantized_input_name + ":1" , &quantized_main_node); |
838 | } |
839 | for (const string& quantized_input_name : quantized_input_names) { |
840 | AddNodeInput(quantized_input_name + ":2" , &quantized_main_node); |
841 | } |
842 | } |
843 | new_nodes->push_back(quantized_main_node); |
844 | |
845 | string eight_bit_node_name; |
846 | if (op_info.output_bit_depth == DT_QINT32) { |
847 | // Shrink the range of the output down from 32 bits to 8. |
848 | string requantize_min_input; |
849 | string requantize_max_input; |
850 | if (has_fallback_range) { |
851 | // Use constant values for the min/max range if they were given. |
852 | NodeDef fallback_min_node; |
853 | fallback_min_node.set_op("Const" ); |
854 | fallback_min_node.set_name(quantized_main_node.name() + |
855 | "/fallback_min" ); |
856 | SetNodeAttr("dtype" , DT_FLOAT, &fallback_min_node); |
857 | Tensor fallback_min_tensor(DT_FLOAT, {}); |
858 | fallback_min_tensor.flat<float>()(0) = fallback_min; |
859 | SetNodeTensorAttr<float>("value" , fallback_min_tensor, |
860 | &fallback_min_node); |
861 | new_nodes->push_back(fallback_min_node); |
862 | |
863 | NodeDef fallback_max_node; |
864 | fallback_max_node.set_op("Const" ); |
865 | fallback_max_node.set_name(quantized_main_node.name() + |
866 | "/fallback_max" ); |
867 | SetNodeAttr("dtype" , DT_FLOAT, &fallback_max_node); |
868 | Tensor fallback_max_tensor(DT_FLOAT, {}); |
869 | fallback_max_tensor.flat<float>()(0) = fallback_max; |
870 | SetNodeTensorAttr<float>("value" , fallback_max_tensor, |
871 | &fallback_max_node); |
872 | new_nodes->push_back(fallback_max_node); |
873 | |
874 | requantize_min_input = fallback_min_node.name(); |
875 | requantize_max_input = fallback_max_node.name(); |
876 | } else { |
877 | // Otherwise dynamically measure the range each time. |
878 | NodeDef requant_range_node; |
879 | requant_range_node.set_op("RequantizationRange" ); |
880 | requant_range_node.set_name(quantized_main_node.name() + |
881 | "/requant_range" ); |
882 | SetNodeAttr("Tinput" , DT_QINT32, &requant_range_node); |
883 | AddNodeInput(quantized_main_node.name() + ":0" , |
884 | &requant_range_node); |
885 | AddNodeInput(quantized_main_node.name() + ":1" , |
886 | &requant_range_node); |
887 | AddNodeInput(quantized_main_node.name() + ":2" , |
888 | &requant_range_node); |
889 | new_nodes->push_back(requant_range_node); |
890 | |
891 | requantize_min_input = requant_range_node.name() + ":0" ; |
892 | requantize_max_input = requant_range_node.name() + ":1" ; |
893 | } |
894 | NodeDef requantize_node; |
895 | requantize_node.set_op("Requantize" ); |
896 | requantize_node.set_name(quantized_main_node.name() + "/requantize" ); |
897 | SetNodeAttr("Tinput" , DT_QINT32, &requantize_node); |
898 | SetNodeAttr("out_type" , DT_QUINT8, &requantize_node); |
899 | AddNodeInput(quantized_main_node.name() + ":0" , &requantize_node); |
900 | AddNodeInput(quantized_main_node.name() + ":1" , &requantize_node); |
901 | AddNodeInput(quantized_main_node.name() + ":2" , &requantize_node); |
902 | AddNodeInput(requantize_min_input, &requantize_node); |
903 | AddNodeInput(requantize_max_input, &requantize_node); |
904 | new_nodes->push_back(requantize_node); |
905 | eight_bit_node_name = requantize_node.name(); |
906 | } else { |
907 | eight_bit_node_name = quantized_main_node.name(); |
908 | } |
909 | |
910 | // Convert the 8-bit result back into float for the final output. |
911 | NodeDef dequantize_node; |
912 | dequantize_node.set_op("Dequantize" ); |
913 | dequantize_node.set_name(float_node.name()); |
914 | SetNodeAttr("T" , DT_QUINT8, &dequantize_node); |
915 | SetNodeAttr("mode" , "MIN_FIRST" , &dequantize_node); |
916 | AddNodeInput(eight_bit_node_name + ":0" , &dequantize_node); |
917 | AddNodeInput(eight_bit_node_name + ":1" , &dequantize_node); |
918 | AddNodeInput(eight_bit_node_name + ":2" , &dequantize_node); |
919 | new_nodes->push_back(dequantize_node); |
920 | |
921 | return OkStatus(); |
922 | }, |
923 | {}, &quantized_graph_def)); |
924 | TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def)); |
925 | |
926 | // If we've ended up with two Requantize ops in a row (for example if there |
927 | // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together, |
928 | // using the trained range from the second op. |
929 | GraphDef merged_graph_def; |
930 | TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context, |
931 | &merged_graph_def)); |
932 | TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def)); |
933 | |
934 | // There can be duplicate quantize nodes if multiple ops pull from a single |
935 | // input, which makes it harder to remove redundant ones, so strip them out. |
936 | GraphDef deduped_graph_def; |
937 | TF_RETURN_IF_ERROR( |
938 | MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def)); |
939 | TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def)); |
940 | |
941 | // Look for Dequantizes that immediately go into Quantizes, and remove them |
942 | // since the two together cancel each other out. This allows us to keep the |
943 | // data flow in eight bit where two adjacent ops are in eight bit, but still |
944 | // keep interoperability with float ops. |
945 | TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context, |
946 | output_graph_def)); |
947 | TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def)); |
948 | |
949 | return OkStatus(); |
950 | } |
951 | |
952 | REGISTER_GRAPH_TRANSFORM("quantize_nodes" , QuantizeNodes); |
953 | |
954 | REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes" , MergeDuplicateNodes); |
955 | |
956 | } // namespace graph_transforms |
957 | } // namespace tensorflow |
958 | |