1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/common_runtime/constant_folding.h"
19#include "tensorflow/core/common_runtime/graph_constructor.h"
20#include "tensorflow/core/common_runtime/threadpool_device.h"
21#include "tensorflow/core/graph/node_builder.h"
22#include "tensorflow/core/graph/subgraph.h"
23#include "tensorflow/core/kernels/quantization_utils.h"
24#include "tensorflow/core/platform/init_main.h"
25#include "tensorflow/core/public/session.h"
26#include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28namespace tensorflow {
29namespace graph_transforms {
30
31// Holds the information we need to translate from a float version of this op
32// into the quantized equivalent.
33struct QuantizedOpInfo {
34 // The name of the float op.
35 string float_name;
36 // Which attributes to copy directly over.
37 std::vector<string> attrs_to_copy;
38 // Extra data type attributes we need to set.
39 std::vector<std::pair<string, DataType>> dtypes_to_set;
40 // What depth of inputs the op can read in.
41 DataType input_bit_depth;
42 // The depth of the op's quantized outputs.
43 DataType output_bit_depth;
44 // Which inputs (e.g. shapes) aren't involved in the quantization process.
45 std::set<int32> unquantized_inputs;
46 // How the outputs are arranged, either
47 // [input0, input1, min0, max0, min1, max1] for contiguous, or
48 // [input0, input1, min0, min1, max0, max1] for separate.
49 // The separate order is needed because it's the only way to specify unknown
50 // numbers of inputs for ops like Concat.
51 enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order;
52};
53
54// Every op that has a quantized equivalent should be listed here, so that the
55// conversion process can transform them.
56const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
57 static const std::vector<QuantizedOpInfo> op_list = {
58 {"Add",
59 {},
60 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
61 DT_QUINT8,
62 DT_QINT32,
63 {},
64 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
65 {"AvgPool",
66 {"ksize", "strides", "padding"},
67 {{"T", DT_QUINT8}},
68 DT_QUINT8,
69 DT_QUINT8,
70 {},
71 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
72 {"BiasAdd",
73 {},
74 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}},
75 DT_QUINT8,
76 DT_QINT32,
77 {},
78 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
79 {"Concat",
80 {"N"},
81 {{"T", DT_QUINT8}},
82 DT_QUINT8,
83 DT_QUINT8,
84 {0},
85 QuantizedOpInfo::SEPARATE_MIN_MAX},
86 {"Conv2D",
87 {"strides", "padding"},
88 {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}},
89 DT_QUINT8,
90 DT_QINT32,
91 {},
92 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
93 {"MatMul",
94 {"transpose_a", "transpose_b"},
95 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
96 DT_QUINT8,
97 DT_QINT32,
98 {},
99 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
100 {"MaxPool",
101 {"ksize", "strides", "padding"},
102 {{"T", DT_QUINT8}},
103 DT_QUINT8,
104 DT_QUINT8,
105 {},
106 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
107 {"Mul",
108 {},
109 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
110 DT_QUINT8,
111 DT_QINT32,
112 {},
113 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
114 {"Relu",
115 {},
116 {{"Tinput", DT_QUINT8}},
117 DT_QUINT8,
118 DT_QUINT8,
119 {},
120 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
121 {"ResizeBilinear",
122 {"align_corners"},
123 {{"T", DT_QUINT8}},
124 DT_QUINT8,
125 DT_QUINT8,
126 {1},
127 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
128 {"Relu6",
129 {},
130 {{"Tinput", DT_QUINT8}},
131 DT_QUINT8,
132 DT_QUINT8,
133 {},
134 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
135 {"Reshape",
136 {},
137 {{"T", DT_QUINT8}},
138 DT_QUINT8,
139 DT_QUINT8,
140 {1},
141 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
142 };
143 return op_list;
144}
145
146namespace {
147// Replaces invalid characters in input names to get a unique node name.
148string UniqueNodeNameFromInput(const string& input_name) {
149 string prefix;
150 string node_name;
151 string suffix;
152 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
153 string result;
154 if (prefix == "^") {
155 result += "__hat__";
156 }
157 result += node_name;
158 if (!suffix.empty()) {
159 result += "__port__" + suffix.substr(1, suffix.size() - 1);
160 }
161 return result;
162}
163
164// Pulls two float values from the named parameters, with a lot of checking.
165Status ExtractRangeFromParams(const TransformFuncContext& context,
166 const string& min_name, const string& max_name,
167 float* min_value, float* max_value,
168 bool* has_range) {
169 // See if we've been given quantized inputs with a known range.
170 const bool has_min = (context.params.count(min_name) != 0);
171 const bool has_max = (context.params.count(max_name) != 0);
172 *has_range = (has_min || has_max);
173 if (!*has_range) {
174 return OkStatus();
175 }
176 if (!has_min || !has_max) {
177 return errors::InvalidArgument("You must pass both ", min_name, " and ",
178 max_name, " into quantize_nodes");
179 }
180 TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
181 TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
182 return OkStatus();
183}
184
185} // namespace
186
187// Analyzes all the nodes in the graph to figure out which ones are duplicates
188// apart from their names. This commonly includes identical Const nodes, but can
189// also be simple operations that are repeated on multiple outputs of a
190// particular node. The complexity is managed using a hash function that avoids
191// the need for any O(n^2) algorithms when identifying duplicates.
192Status MergeDuplicateNodes(const GraphDef& input_graph_def,
193 const TransformFuncContext& context,
194 GraphDef* output_graph_def) {
195 // Make sure we can look up inputs and outputs quickly.
196 std::set<string> input_names(context.input_names.begin(),
197 context.input_names.end());
198 std::set<string> output_names(context.output_names.begin(),
199 context.output_names.end());
200 GraphDef current_graph_def = input_graph_def;
201 // Keep running the merging until no more duplicates are found.
202 bool any_duplicates_found;
203 do {
204 any_duplicates_found = false;
205 // First arrange all of the nodes by a hash of their contents.
206 std::map<uint64, std::vector<const NodeDef*>> hashed_nodes;
207 for (const NodeDef& node : current_graph_def.node()) {
208 NodeDef nameless_node = node;
209 // The name matters if it's being used as an input or output node,
210 // otherwise ignore it when looking for duplicates.
211 if (!input_names.count(node.name()) && !output_names.count(node.name())) {
212 nameless_node.set_name("");
213 }
214 const uint64 hash = HashNodeDef(nameless_node);
215 hashed_nodes[hash].push_back(&node);
216 }
217 // If we have multiple nodes with the same hash, then we know they're
218 // duplicates and can be removed, unless they're stateful.
219 std::map<string, string> inputs_to_rename;
220 GraphDef merged_graph_def;
221 for (const std::pair<const uint64, std::vector<const NodeDef*>>&
222 hashed_node_info : hashed_nodes) {
223 const std::vector<const NodeDef*>& hash_node_list =
224 hashed_node_info.second;
225 for (int i = 0; i < hash_node_list.size(); ++i) {
226 const NodeDef* current_node = hash_node_list[i];
227 const OpDef* op_def = nullptr;
228 TF_RETURN_IF_ERROR(
229 OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def));
230 const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0));
231 if (is_duplicate) {
232 const string original_name = hash_node_list[0]->name();
233 inputs_to_rename[current_node->name() + ":*"] = original_name;
234 any_duplicates_found = true;
235 } else {
236 NodeDef* new_node = merged_graph_def.mutable_node()->Add();
237 *new_node = *current_node;
238 }
239 }
240 }
241 // Update the graph so that any nodes that referred to removed inputs now
242 // pull from the remaining duplicate.
243 TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename,
244 std::unordered_set<string>(),
245 &current_graph_def));
246 } while (any_duplicates_found);
247
248 *output_graph_def = current_graph_def;
249
250 return OkStatus();
251}
252
253// Looks for the patterns that indicate there are two eight-bit ops feeding into
254// each other, separated by a conversion up to float and back again. These occur
255// during the initial conversion of ops to their quantized forms. Because we're
256// only looking at an individual op in that phase and don't know if its inputs
257// and outputs are eight-bit-capable, we start by converting the actual op into
258// quantized form, but add float conversions before and after. This pass gets
259// rid of those conversions if it turns out we do have adjacent ops capable of
260// eight-bit processing.
261Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
262 const TransformFuncContext& context,
263 GraphDef* output_graph_def) {
264 std::set<string> graph_outputs;
265 for (const string& output_name : context.output_names) {
266 graph_outputs.insert(NodeNameFromInput(output_name));
267 }
268 std::map<string, string> inputs_to_rename;
269 GraphDef replaced_graph_def;
270 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
271 input_graph_def, // clang-format off
272 {"QuantizeV2",
273 {
274 {"Dequantize"},
275 {"Min"},
276 {"Max"},
277 }
278 }, // clang-format on
279 [&inputs_to_rename, &graph_outputs](const NodeMatch& match,
280 const std::set<string>& input_nodes,
281 const std::set<string>& output_nodes,
282 std::vector<NodeDef>* new_nodes) {
283 const NodeDef& quantize_node = match.node;
284 const NodeDef& dequantize_node = match.inputs[0].node;
285 inputs_to_rename[quantize_node.name() + ":0"] =
286 dequantize_node.input(0);
287 inputs_to_rename[quantize_node.name() + ":1"] =
288 dequantize_node.input(1);
289 inputs_to_rename[quantize_node.name() + ":2"] =
290 dequantize_node.input(2);
291
292 // Are other sub-graphs using the float intermediate result? If so,
293 // preserve it, but the input renaming still rewires the eight-bit ops
294 // so they don't go through float.
295 if (output_nodes.count(dequantize_node.name()) ||
296 graph_outputs.count(dequantize_node.name())) {
297 CopyOriginalMatch(match, new_nodes);
298 }
299
300 return OkStatus();
301 },
302 {true}, &replaced_graph_def));
303
304 return RenameNodeInputs(replaced_graph_def, inputs_to_rename,
305 std::unordered_set<string>(), output_graph_def);
306}
307
308// If the user has passed in the input_min and input_max args, then we need to
309// convert any input placeholders from float to eight bit, so quantized inputs
310// can be fed directly into the graph.
311Status QuantizePlaceholders(const GraphDef& input_graph_def,
312 const TransformFuncContext& context,
313 GraphDef* output_graph_def) {
314 float input_min;
315 float input_max;
316 bool has_input_range;
317 TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max",
318 &input_min, &input_max,
319 &has_input_range));
320 if (!has_input_range) {
321 *output_graph_def = input_graph_def;
322 return OkStatus();
323 }
324 std::map<string, string> inputs_to_rename_first_pass;
325 std::map<string, string> inputs_to_rename_second_pass;
326 GraphDef placeholder_graph_def;
327 placeholder_graph_def.Clear();
328 for (const NodeDef& node : input_graph_def.node()) {
329 if (node.op() != "Placeholder") {
330 *(placeholder_graph_def.mutable_node()->Add()) = node;
331 } else {
332 string namespace_prefix = node.name() + "_eightbit";
333
334 NodeDef quantized_placeholder;
335 quantized_placeholder = node;
336 SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder);
337 *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder;
338
339 NodeDef min_node;
340 min_node.set_op("Const");
341 min_node.set_name(namespace_prefix + "/min");
342 SetNodeAttr("dtype", DT_FLOAT, &min_node);
343 Tensor min_tensor(DT_FLOAT, {});
344 min_tensor.flat<float>()(0) = input_min;
345 SetNodeTensorAttr<float>("value", min_tensor, &min_node);
346 *(placeholder_graph_def.mutable_node()->Add()) = min_node;
347
348 NodeDef max_node;
349 max_node.set_op("Const");
350 max_node.set_name(namespace_prefix + "/max");
351 SetNodeAttr("dtype", DT_FLOAT, &max_node);
352 Tensor max_tensor(DT_FLOAT, {});
353 max_tensor.flat<float>()(0) = input_max;
354 SetNodeTensorAttr<float>("value", max_tensor, &max_node);
355 *(placeholder_graph_def.mutable_node()->Add()) = max_node;
356
357 const string rename_suffix = "__RENAMED_PLACEHOLDER__";
358 NodeDef dequantize_node;
359 dequantize_node.set_op("Dequantize");
360 dequantize_node.set_name(namespace_prefix + "/dequantize");
361 SetNodeAttr("T", DT_QUINT8, &dequantize_node);
362 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
363 AddNodeInput(node.name() + rename_suffix, &dequantize_node);
364 AddNodeInput(min_node.name(), &dequantize_node);
365 AddNodeInput(max_node.name(), &dequantize_node);
366 *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node;
367
368 // First make sure that any internal references to the old placeholder
369 // now point to the dequantize result.
370 inputs_to_rename_first_pass[node.name()] = dequantize_node.name();
371 // Then fix up the dequantize op so that it really points to the
372 // placeholder.
373 inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name();
374 }
375 }
376
377 GraphDef first_pass_graph_def;
378 TF_RETURN_IF_ERROR(
379 RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
380 std::unordered_set<string>(), &first_pass_graph_def));
381 TF_RETURN_IF_ERROR(
382 RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
383 std::unordered_set<string>(), output_graph_def));
384
385 return OkStatus();
386}
387
388// During training, FakeQuantWithMinMaxVars ops capture a good min/max range for
389// an activation layer. To use these during inference, this pass converts those
390// ops into Requantizes with the trained min/maxes as constant inputs.
391Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
392 const TransformFuncContext& context,
393 GraphDef* output_graph_def) {
394 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
395 input_graph_def, // clang-format off
396 {"FakeQuantWithMinMaxVars",
397 {
398 {"*"},
399 {"Const"},
400 {"Const"},
401 }
402 }, // clang-format on
403 [](const NodeMatch& match, const std::set<string>& input_nodes,
404 const std::set<string>& output_nodes,
405 std::vector<NodeDef>* new_nodes) {
406 const NodeDef& fake_quant_node = match.node;
407 const NodeDef& original_op_node = match.inputs[0].node;
408 const NodeDef& fake_quant_min_node = match.inputs[1].node;
409 const NodeDef& fake_quant_max_node = match.inputs[2].node;
410
411 string namespace_prefix = fake_quant_node.name() + "_eightbit";
412
413 new_nodes->push_back(original_op_node);
414 new_nodes->push_back(fake_quant_min_node);
415 new_nodes->push_back(fake_quant_max_node);
416
417 NodeDef quantize_node;
418 quantize_node.set_op("QuantizeV2");
419 quantize_node.set_name(namespace_prefix + "/quantize");
420 SetNodeAttr("T", DT_QINT32, &quantize_node);
421 SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
422 AddNodeInput(fake_quant_node.input(0), &quantize_node);
423 AddNodeInput(fake_quant_min_node.name(), &quantize_node);
424 AddNodeInput(fake_quant_max_node.name(), &quantize_node);
425 new_nodes->push_back(quantize_node);
426
427 NodeDef requantize_node;
428 requantize_node.set_op("Requantize");
429 requantize_node.set_name(namespace_prefix + "/requantize");
430 SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
431 SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
432 AddNodeInput(quantize_node.name() + ":0", &requantize_node);
433 AddNodeInput(quantize_node.name() + ":1", &requantize_node);
434 AddNodeInput(quantize_node.name() + ":2", &requantize_node);
435 AddNodeInput(fake_quant_min_node.name(), &requantize_node);
436 AddNodeInput(fake_quant_max_node.name(), &requantize_node);
437 new_nodes->push_back(requantize_node);
438
439 // Convert the 8-bit result back into float for the final output.
440 NodeDef dequantize_node;
441 dequantize_node.set_op("Dequantize");
442 dequantize_node.set_name(fake_quant_node.name());
443 SetNodeAttr("T", DT_QUINT8, &dequantize_node);
444 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
445 AddNodeInput(requantize_node.name() + ":0", &dequantize_node);
446 AddNodeInput(requantize_node.name() + ":1", &dequantize_node);
447 AddNodeInput(requantize_node.name() + ":2", &dequantize_node);
448 new_nodes->push_back(dequantize_node);
449
450 return OkStatus();
451 },
452 {}, output_graph_def));
453
454 return OkStatus();
455}
456
457// We always generate Requantize ops driven by dynamic RequantizationRange
458// calculations when we produce quantized ops like Conv2D or BiasAdd with
459// 32-bit results. If there were FakeQuant ops already for those activation
460// layers, then there will be a later Requantize op with constant min/max
461// inputs, which is preferable for fast inference. This pass looks for those
462// later Requantize ops, and replaces the dynamic version with them.
463Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
464 const TransformFuncContext& context,
465 GraphDef* output_graph_def) {
466 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
467 input_graph_def, // clang-format off
468 {"Requantize",
469 {
470 {"QuantizeV2",
471 {
472 {"Dequantize",
473 {
474 {"Requantize",
475 {
476 {"*"},
477 {"*"},
478 {"*"},
479 {"RequantizationRange"},
480 {"RequantizationRange"},
481 }
482 },
483 {"Requantize"},
484 {"Requantize"},
485 }
486 },
487 {"Const"},
488 {"Const"},
489 },
490 },
491 {"QuantizeV2"},
492 {"QuantizeV2"},
493 {"Const"},
494 {"Const"},
495 }
496 }, // clang-format on
497 [](const NodeMatch& match, const std::set<string>& input_nodes,
498 const std::set<string>& output_nodes,
499 std::vector<NodeDef>* new_nodes) {
500 const NodeDef& fake_requantize_node = match.node;
501 const NodeDef& original_op_node =
502 match.inputs[0].inputs[0].inputs[0].inputs[0].node;
503 const NodeDef& fake_requantize_min_node = match.inputs[3].node;
504 const NodeDef& fake_requantize_max_node = match.inputs[4].node;
505
506 new_nodes->push_back(original_op_node);
507 new_nodes->push_back(fake_requantize_min_node);
508 new_nodes->push_back(fake_requantize_max_node);
509
510 NodeDef requantize_node;
511 requantize_node = fake_requantize_node;
512 requantize_node.mutable_input()->Clear();
513 AddNodeInput(original_op_node.name() + ":0", &requantize_node);
514 AddNodeInput(original_op_node.name() + ":1", &requantize_node);
515 AddNodeInput(original_op_node.name() + ":2", &requantize_node);
516 AddNodeInput(fake_requantize_min_node.name(), &requantize_node);
517 AddNodeInput(fake_requantize_max_node.name(), &requantize_node);
518 new_nodes->push_back(requantize_node);
519
520 return OkStatus();
521 },
522 {}, output_graph_def));
523
524 return OkStatus();
525}
526
527// Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of
528// linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd
529// op that we want to apply the trained constant conversions to. This pass tries
530// to move FakeQuant ops up the input chain, so they're as close as possible to
531// the 32-bit conversion, and so can be easily merged into the automatic dynamic
532// Requantizes.
533Status HoistFakeQuants(const GraphDef& input_graph_def,
534 const TransformFuncContext& context,
535 GraphDef* output_graph_def) {
536 GraphDef current_graph_def = input_graph_def;
537 const int max_depth = 3;
538 for (int depth = max_depth; depth > 0; --depth) {
539 OpTypePattern pattern = {"*"};
540 for (int i = 0; i < depth; ++i) {
541 pattern = {"*", {pattern}};
542 }
543 pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}};
544 GraphDef hoisted_graph_def;
545 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
546 current_graph_def, pattern,
547 [depth](const NodeMatch& match, const std::set<string>& input_nodes,
548 const std::set<string>& output_nodes,
549 std::vector<NodeDef>* new_nodes) {
550 const NodeDef& fake_quant_node = match.node;
551 const NodeDef& fake_quant_min_node = match.inputs[1].node;
552 const NodeDef& fake_quant_max_node = match.inputs[2].node;
553 std::vector<NodeDef> linear_nodes;
554 NodeMatch current_match = match;
555 for (int i = 0; i <= depth; ++i) {
556 linear_nodes.push_back(current_match.inputs[0].node);
557 current_match = current_match.inputs[0];
558 }
559 NodeDef new_fake_quant_node;
560 new_fake_quant_node = fake_quant_node;
561 new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted");
562 new_fake_quant_node.set_input(
563 0, linear_nodes[linear_nodes.size() - 2].input(0));
564 new_nodes->push_back(new_fake_quant_node);
565
566 new_nodes->push_back(fake_quant_min_node);
567 new_nodes->push_back(fake_quant_max_node);
568
569 linear_nodes[linear_nodes.size() - 2].set_input(
570 0, new_fake_quant_node.name());
571 linear_nodes.front().set_name(fake_quant_node.name());
572 for (const NodeDef& linear_node : linear_nodes) {
573 new_nodes->push_back(linear_node);
574 }
575
576 return OkStatus();
577 },
578 {}, &hoisted_graph_def));
579 current_graph_def = hoisted_graph_def;
580 }
581 *output_graph_def = current_graph_def;
582
583 return OkStatus();
584}
585
586// Converts any float ops that have eight-bit equivalents into their quantized
587// forms, so that as much calculation as possible is done in the lower-precision
588// format.
589Status QuantizeNodes(const GraphDef& input_graph_def,
590 const TransformFuncContext& context,
591 GraphDef* output_graph_def) {
592 // Loop through all of the quantizable op types, and replace any occurrences
593 // with equivalent sub-graphs with quantized ops at their core. For example
594 // this one-input operation:
595 //
596 // Input(float)
597 // |
598 // v
599 // Operation
600 // |
601 // v
602 // (float)
603 //
604 // Will be turned into it's quantized equivalent:
605 //
606 // Input(float) ReshapeDims
607 // +------v v-------------+
608 // | Reshape
609 // | |
610 // | | ReductionDims
611 // | +-----+ |
612 // | | +---c---------+
613 // | v v v v-------+
614 // | Min Max
615 // | +----+ |
616 // v v v--------+
617 // Quantize
618 // |
619 // v
620 // QuantizedOperation
621 // | | |
622 // v v v
623 // Dequantize
624 // |
625 // v
626 // (float)
627 //
628 // This keeps the inputs and outputs visible to the rest of the graph in
629 // float
630 // and converts them down to quantized buffers internally for the
631 // computation.
632 // The result will end up with a lot of redundant dequantize/quantize pairs
633 // between adjacent quantized ops, but a later pass removes these where it
634 // can.
635
636 std::set<string> ops_to_ignore;
637 if (context.params.count("ignore_op") > 0) {
638 for (const string& name : context.params.at("ignore_op")) {
639 ops_to_ignore.insert(name);
640 }
641 }
642
643 const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList();
644 string op_pattern;
645 bool is_first = true;
646 std::map<string, QuantizedOpInfo> op_map;
647 for (const QuantizedOpInfo& op_info : op_list) {
648 if (ops_to_ignore.count(op_info.float_name) == 0) {
649 strings::StrAppend(&op_pattern, (is_first ? "" : "|"),
650 op_info.float_name);
651 op_map.insert({op_info.float_name, op_info});
652 is_first = false;
653 }
654 }
655
656 // If input_min and input max have been passed in, then we convert all float
657 // Placeholder nodes into quantized versions, with the supplied values as
658 // their range.
659 GraphDef placeholder_graph_def;
660 TF_RETURN_IF_ERROR(
661 QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def));
662 TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def));
663
664 // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear
665 // operations like Relu or MaxPool, move them up so that they're as close as
666 // possible to ops with 32-bit outputs like BiasAdd or Conv2D.
667 GraphDef hoisted_graph_def;
668 TF_RETURN_IF_ERROR(
669 HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def));
670 TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def));
671
672 // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of
673 // activation layers, into Requantize ops with those ranges instead. This
674 // makes it easier to replace the dynamic range calculations that are used
675 // by default.
676 GraphDef converted_graph_def;
677 TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context,
678 &converted_graph_def));
679 TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def));
680
681 // If fallback_min and fallback_max are set, then we'll use hardwired ranges
682 // for all the 32-bit to 8-bit requantizations.
683 float fallback_min;
684 float fallback_max;
685 bool has_fallback_range;
686 TF_RETURN_IF_ERROR(ExtractRangeFromParams(
687 context, "fallback_min", "fallback_max", &fallback_min, &fallback_max,
688 &has_fallback_range));
689
690 // Replace all occurrences of the current float op with its quantized
691 // equivalent.
692 GraphDef quantized_graph_def;
693 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
694 converted_graph_def, {op_pattern},
695 [&op_map, fallback_min, fallback_max, has_fallback_range](
696 const NodeMatch& match, const std::set<string>& input_nodes,
697 const std::set<string>& output_nodes,
698 std::vector<NodeDef>* new_nodes) {
699 const NodeDef& float_node = match.node;
700 const QuantizedOpInfo& op_info = op_map[float_node.op()];
701
702 DataTypeVector input_types;
703 DataTypeVector output_types;
704 TF_RETURN_IF_ERROR(
705 GetInOutTypes(float_node, &input_types, &output_types));
706 bool are_all_float = true;
707 for (int i = 0; i < float_node.input_size(); ++i) {
708 // Skip any known non-float inputs.
709 if (op_info.unquantized_inputs.count(i)) {
710 continue;
711 }
712 if (i >= input_types.size()) {
713 LOG(ERROR) << "input_types has incorrect size "
714 << input_types.size() << " <= " << i
715 << ". Assuming everything else is floats.";
716 }
717 if (i < input_types.size() && input_types[i] != DT_FLOAT) {
718 are_all_float = false;
719 }
720 }
721 for (const DataType& output_type : output_types) {
722 if (output_type != DT_FLOAT) {
723 are_all_float = false;
724 }
725 }
726 // This isn't a float op, so don't quantize it.
727 if (!are_all_float) {
728 CopyOriginalMatch(match, new_nodes);
729 return OkStatus();
730 }
731
732 string namespace_prefix = float_node.name() + "_eightbit";
733
734 // Quantize all of the inputs.
735 std::vector<string> quantized_input_names;
736 for (int i = 0; i < float_node.input_size(); ++i) {
737 // Skip any non-float inputs.
738 if (op_info.unquantized_inputs.count(i)) {
739 continue;
740 }
741
742 const string& input_name = float_node.input(i);
743 string unique_input_name =
744 namespace_prefix + "/" + UniqueNodeNameFromInput(input_name);
745
746 // Add some common constants we need for reshaping inputs.
747 NodeDef reshape_dims;
748 reshape_dims.set_op("Const");
749 reshape_dims.set_name(unique_input_name + "/reshape_dims");
750 AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims);
751 SetNodeAttr("dtype", DT_INT32, &reshape_dims);
752 Tensor reshape_dims_tensor(DT_INT32, {1});
753 reshape_dims_tensor.flat<int32>()(0) = -1;
754 SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims);
755 new_nodes->push_back(reshape_dims);
756
757 NodeDef reduction_dims;
758 reduction_dims.set_op("Const");
759 reduction_dims.set_name(unique_input_name + "/reduction_dims");
760 AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims);
761 SetNodeAttr("dtype", DT_INT32, &reduction_dims);
762 Tensor reduction_dims_tensor(DT_INT32, {1});
763 reduction_dims_tensor.flat<int32>()(0) = 0;
764 SetNodeTensorAttr<int32>("value", reduction_dims_tensor,
765 &reduction_dims);
766 new_nodes->push_back(reduction_dims);
767
768 NodeDef reshape_node;
769 reshape_node.set_op("Reshape");
770 reshape_node.set_name(unique_input_name + "/reshape");
771 SetNodeAttr("T", DT_FLOAT, &reshape_node);
772 AddNodeInput(input_name, &reshape_node);
773 AddNodeInput(reshape_dims.name(), &reshape_node);
774 new_nodes->push_back(reshape_node);
775
776 NodeDef min_node;
777 min_node.set_op("Min");
778 min_node.set_name(unique_input_name + "/min");
779 SetNodeAttr("T", DT_FLOAT, &min_node);
780 SetNodeAttr("keep_dims", false, &min_node);
781 AddNodeInput(reshape_node.name(), &min_node);
782 AddNodeInput(reduction_dims.name(), &min_node);
783 new_nodes->push_back(min_node);
784
785 NodeDef max_node;
786 max_node.set_op("Max");
787 max_node.set_name(unique_input_name + "/max");
788 SetNodeAttr("T", DT_FLOAT, &max_node);
789 SetNodeAttr("keep_dims", false, &max_node);
790 AddNodeInput(reshape_node.name(), &max_node);
791 AddNodeInput(reduction_dims.name(), &max_node);
792 new_nodes->push_back(max_node);
793
794 NodeDef quantize_node;
795 quantize_node.set_op("QuantizeV2");
796 quantize_node.set_name(unique_input_name + "/quantize");
797 SetNodeAttr("T", DT_QUINT8, &quantize_node);
798 SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
799 AddNodeInput(input_name, &quantize_node);
800 AddNodeInput(min_node.name(), &quantize_node);
801 AddNodeInput(max_node.name(), &quantize_node);
802 new_nodes->push_back(quantize_node);
803 quantized_input_names.push_back(quantize_node.name());
804 }
805
806 // Set up the quantized version of the current op.
807 NodeDef quantized_main_node;
808 quantized_main_node.set_op("Quantized" + float_node.op());
809 quantized_main_node.set_name(float_node.name() + "/eightbit");
810 for (const string& attr_to_copy : op_info.attrs_to_copy) {
811 CopyNodeAttr(float_node, attr_to_copy, attr_to_copy,
812 &quantized_main_node);
813 }
814 for (const std::pair<string, DataType>& dtype_to_set :
815 op_info.dtypes_to_set) {
816 SetNodeAttr(dtype_to_set.first, dtype_to_set.second,
817 &quantized_main_node);
818 }
819 int quantized_input_index = 0;
820 for (int i = 0; i < float_node.input_size(); ++i) {
821 if (op_info.unquantized_inputs.count(i)) {
822 AddNodeInput(float_node.input(i), &quantized_main_node);
823 } else {
824 const string& quantized_input_name =
825 quantized_input_names[quantized_input_index];
826 AddNodeInput(quantized_input_name + ":0", &quantized_main_node);
827 ++quantized_input_index;
828 }
829 }
830 if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) {
831 for (const string& quantized_input_name : quantized_input_names) {
832 AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
833 AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
834 }
835 } else {
836 for (const string& quantized_input_name : quantized_input_names) {
837 AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
838 }
839 for (const string& quantized_input_name : quantized_input_names) {
840 AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
841 }
842 }
843 new_nodes->push_back(quantized_main_node);
844
845 string eight_bit_node_name;
846 if (op_info.output_bit_depth == DT_QINT32) {
847 // Shrink the range of the output down from 32 bits to 8.
848 string requantize_min_input;
849 string requantize_max_input;
850 if (has_fallback_range) {
851 // Use constant values for the min/max range if they were given.
852 NodeDef fallback_min_node;
853 fallback_min_node.set_op("Const");
854 fallback_min_node.set_name(quantized_main_node.name() +
855 "/fallback_min");
856 SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node);
857 Tensor fallback_min_tensor(DT_FLOAT, {});
858 fallback_min_tensor.flat<float>()(0) = fallback_min;
859 SetNodeTensorAttr<float>("value", fallback_min_tensor,
860 &fallback_min_node);
861 new_nodes->push_back(fallback_min_node);
862
863 NodeDef fallback_max_node;
864 fallback_max_node.set_op("Const");
865 fallback_max_node.set_name(quantized_main_node.name() +
866 "/fallback_max");
867 SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node);
868 Tensor fallback_max_tensor(DT_FLOAT, {});
869 fallback_max_tensor.flat<float>()(0) = fallback_max;
870 SetNodeTensorAttr<float>("value", fallback_max_tensor,
871 &fallback_max_node);
872 new_nodes->push_back(fallback_max_node);
873
874 requantize_min_input = fallback_min_node.name();
875 requantize_max_input = fallback_max_node.name();
876 } else {
877 // Otherwise dynamically measure the range each time.
878 NodeDef requant_range_node;
879 requant_range_node.set_op("RequantizationRange");
880 requant_range_node.set_name(quantized_main_node.name() +
881 "/requant_range");
882 SetNodeAttr("Tinput", DT_QINT32, &requant_range_node);
883 AddNodeInput(quantized_main_node.name() + ":0",
884 &requant_range_node);
885 AddNodeInput(quantized_main_node.name() + ":1",
886 &requant_range_node);
887 AddNodeInput(quantized_main_node.name() + ":2",
888 &requant_range_node);
889 new_nodes->push_back(requant_range_node);
890
891 requantize_min_input = requant_range_node.name() + ":0";
892 requantize_max_input = requant_range_node.name() + ":1";
893 }
894 NodeDef requantize_node;
895 requantize_node.set_op("Requantize");
896 requantize_node.set_name(quantized_main_node.name() + "/requantize");
897 SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
898 SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
899 AddNodeInput(quantized_main_node.name() + ":0", &requantize_node);
900 AddNodeInput(quantized_main_node.name() + ":1", &requantize_node);
901 AddNodeInput(quantized_main_node.name() + ":2", &requantize_node);
902 AddNodeInput(requantize_min_input, &requantize_node);
903 AddNodeInput(requantize_max_input, &requantize_node);
904 new_nodes->push_back(requantize_node);
905 eight_bit_node_name = requantize_node.name();
906 } else {
907 eight_bit_node_name = quantized_main_node.name();
908 }
909
910 // Convert the 8-bit result back into float for the final output.
911 NodeDef dequantize_node;
912 dequantize_node.set_op("Dequantize");
913 dequantize_node.set_name(float_node.name());
914 SetNodeAttr("T", DT_QUINT8, &dequantize_node);
915 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
916 AddNodeInput(eight_bit_node_name + ":0", &dequantize_node);
917 AddNodeInput(eight_bit_node_name + ":1", &dequantize_node);
918 AddNodeInput(eight_bit_node_name + ":2", &dequantize_node);
919 new_nodes->push_back(dequantize_node);
920
921 return OkStatus();
922 },
923 {}, &quantized_graph_def));
924 TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def));
925
926 // If we've ended up with two Requantize ops in a row (for example if there
927 // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together,
928 // using the trained range from the second op.
929 GraphDef merged_graph_def;
930 TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context,
931 &merged_graph_def));
932 TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
933
934 // There can be duplicate quantize nodes if multiple ops pull from a single
935 // input, which makes it harder to remove redundant ones, so strip them out.
936 GraphDef deduped_graph_def;
937 TF_RETURN_IF_ERROR(
938 MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def));
939 TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def));
940
941 // Look for Dequantizes that immediately go into Quantizes, and remove them
942 // since the two together cancel each other out. This allows us to keep the
943 // data flow in eight bit where two adjacent ops are in eight bit, but still
944 // keep interoperability with float ops.
945 TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context,
946 output_graph_def));
947 TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def));
948
949 return OkStatus();
950}
951
952REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes);
953
954REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes);
955
956} // namespace graph_transforms
957} // namespace tensorflow
958