1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/quantize_training.h"
17
18#include <algorithm>
19#include <atomic>
20#include <set>
21#include <unordered_map>
22#include <vector>
23
24#include "tensorflow/core/common_runtime/graph_constructor.h"
25#include "tensorflow/core/common_runtime/memory_types.h"
26#include "tensorflow/core/framework/log_memory.h"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/graph/algorithm.h"
29#include "tensorflow/core/graph/node_builder.h"
30#include "tensorflow/core/graph/subgraph.h"
31#include "tensorflow/core/lib/strings/strcat.h"
32#include "tensorflow/core/public/session_options.h"
33
34namespace tensorflow {
35namespace {
36
37// TODO(suharshs): If desired, make these values configurable.
38const uint32 kAllowedInputs = 2;
39const float kEMADecay = 0.999;
40
41// Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
42const auto* nodes_to_rewrite =
43 new std::unordered_set<string, StringPieceHasher>{"MatMul", "Conv2D"};
44
45// Contains necessary parameters to convert an edge.
46struct EdgeToConvert {
47 // edge is not owned here.
48 const Edge* edge;
49 int32 num_bits;
50 bool signed_input;
51 bool range_given;
52 float input_min;
53 float input_max;
54
55 EdgeToConvert(const Edge* e, int32_t bits, bool sign, bool range, float min,
56 float max)
57 : edge(e),
58 num_bits(bits),
59 signed_input(sign),
60 range_given(range),
61 input_min(min),
62 input_max(max) {}
63};
64
65// Decide if a node is in backward pass by checking if its name is led by
66// "gradients".
67// TODO(jmchen): Make this check more robust as it is not guaranteed that the
68// forward node will not be named with a leading "gradients".
69inline bool IsGradientNode(const Graph* graph, const Node* node) {
70 static const string tag = "gradients";
71 return (node->name().compare(0, tag.size(), tag) == 0);
72}
73
74// Find the type of the input to set the parameters for the
75// quantize_and_dequantize op.
76// Returns true if the root tensor op type is known, false otherwise.
77bool FindType(const Graph* graph, const Node* node, bool* signed_input,
78 bool* range_given, float* input_min, float* input_max) {
79 const string& src_op = node->type_string();
80 if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") {
81 *signed_input = true;
82 *range_given = false;
83 } else if (src_op == "Relu") {
84 // Range is not given for Relu.
85 *signed_input = false;
86 *range_given = false;
87 } else if (src_op == "Relu6") {
88 // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the
89 // actual activations are somewhere in within this range, we can quantize
90 // this even further. This is true for other activations like Sigmoid6 too.
91 *signed_input = false;
92 *range_given = true;
93 *input_min = 0;
94 *input_max = 6;
95 } else if (src_op == "Sigmoid") {
96 *signed_input = false;
97 *range_given = true;
98 *input_min = 0;
99 *input_max = 1;
100 } else if (src_op == "Tanh") {
101 *signed_input = true;
102 *range_given = true;
103 *input_min = -1;
104 *input_max = 1;
105 } else if (src_op == "Reshape" || src_op == "ConcatV2") {
106 // Reshape has 2 inputs and the first one is the tensor.
107 // ConcatV2 has many inputs but they should all have the same activation
108 // function (i.e. Inception). So we just recurse on the first input.
109 for (const Edge* edge : node->in_edges()) {
110 if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) {
111 FindType(graph, edge->src(), signed_input, range_given, input_min,
112 input_max);
113 }
114 }
115 } else if (src_op == "Identity" || src_op == "MaxPool" ||
116 src_op == "AvgPool" || src_op == "MaxPool3D" ||
117 src_op == "AvgPool3D") {
118 // All these Ops only have 1 data input.
119 for (const Edge* edge : node->in_edges()) {
120 if (edge->src_output() != Graph::kControlSlot) {
121 FindType(graph, edge->src(), signed_input, range_given, input_min,
122 input_max);
123 }
124 }
125 } else {
126 // Unknown type, could be the model input examples.
127 // TODO(jmchen): Set the params for input with user's hint.
128 *signed_input = true;
129 *range_given = false;
130 return false;
131 }
132
133 return true;
134}
135
136// Find the Save op and inputs.
137Status FindSaveOp(const Graph* graph, Node** save_op,
138 std::vector<const Edge*>* in_edges, bool* found) {
139 *found = false;
140 for (Node* node : graph->op_nodes()) {
141 if (node->type_string() == "SaveV2") {
142 // We found multiple save ops.
143 if (*found) {
144 return errors::InvalidArgument("Input graph has multiple SaveV2 ops.");
145 }
146 *save_op = node;
147 *found = true;
148 TF_RETURN_IF_ERROR(node->input_edges(in_edges));
149 }
150 }
151 return OkStatus();
152}
153
154Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) {
155 for (Node* node : graph->op_nodes()) {
156 // The restore_all op should have the same prefix of the save_op.
157 if (node->name() == strings::StrCat(save_prefix, "/restore_all")) {
158 return node;
159 }
160 }
161 return nullptr;
162}
163
164// Strips the last "/suffix" from a name.
165// We use this to construct the name of restore ops in the same way they are
166// constructed by the Saver.
167StringPiece GetNodeNamePrefix(const Node* node) {
168 StringPiece name = node->name();
169 return name.substr(0, name.rfind('/'));
170}
171
172void FillStringTensor(Tensor* dst, const Tensor& src) {
173 auto dst_flat = dst->flat<tstring>();
174 auto src_flat = src.flat<tstring>();
175 for (int i = 0; i < src.NumElements(); i++) {
176 dst_flat(i) = src_flat(i);
177 }
178}
179
180// Add the added_variables as an inputs to the Save op.
181// We change the inputs of the SaveV2 op to include the names of the added
182// variables. We also add the variables as inputs to the save op.
183Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
184 const std::vector<const Edge*>& in_edges,
185 const std::vector<Node*>& added_variables) {
186 Node* tensor_names_op = in_edges[1]->src();
187 Node* shape_and_slices_op = in_edges[2]->src();
188
189 // Get the tensor_names and shape_and_slices tensors from the const op.
190 Tensor tensor_names;
191 Tensor shape_and_slices;
192 TF_RETURN_IF_ERROR(
193 GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
194 TF_RETURN_IF_ERROR(
195 GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
196
197 int tn_size = tensor_names.NumElements();
198 int var_size = added_variables.size();
199
200 // Create a new save_op that has inputs to all the new variables.
201 NodeBuilder save_op_builder =
202 NodeBuilder(save_op->name(), save_op->type_string());
203 // The first three inputs are prefix, tensor_names, and shapes_and_slices.
204 for (int i = 0; i < 3; i++) {
205 save_op_builder = save_op_builder.Input(in_edges[i]->src());
206 }
207 std::vector<NodeBuilder::NodeOut> var_nodeouts;
208 var_nodeouts.reserve(tn_size + var_size);
209 // The rest of the inputs need to be used the construct the tensor list arg.
210 for (int i = 3; i < in_edges.size(); i++) {
211 var_nodeouts.emplace_back(in_edges[i]->src());
212 }
213
214 // Add the new values to the tensors and the op input.
215 Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size}));
216 Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size}));
217 FillStringTensor(&new_tensor_names, tensor_names);
218 FillStringTensor(&new_shape_and_slices, shape_and_slices);
219 for (int i = 0; i < var_size; i++) {
220 Node* var = added_variables[i];
221 new_tensor_names.flat<tstring>()(tn_size + i) = var->name();
222 new_shape_and_slices.flat<tstring>()(tn_size + i) = "";
223 var_nodeouts.emplace_back(var);
224 }
225 save_op_builder = save_op_builder.Input(var_nodeouts);
226
227 // Update the attrs.
228 tensor_names_op->AddAttr("value", new_tensor_names);
229 shape_and_slices_op->AddAttr("value", new_shape_and_slices);
230
231 // Remove the old save_op and add the new one.
232 Node* new_save_op;
233 TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op));
234 // Add outputs to the new_save_op, all outputs are control edges.
235 for (const Edge* edge : save_op->out_edges()) {
236 graph->AddControlEdge(new_save_op, edge->dst());
237 }
238 graph->RemoveNode(save_op);
239
240 return OkStatus();
241}
242
243// Add a restore subgraph for each variable and connect to the restore_all op.
244// For each variable we add the following subgraph:
245// Assign----restore_all
246// | |
247// RestoreV2 Variable
248Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op,
249 const std::vector<const Edge*>& in_edges,
250 const std::vector<Node*>& variables) {
251 Node* prefix_op = in_edges[0]->src();
252 StringPiece name_prefix = GetNodeNamePrefix(save_op);
253 Node* restore_all = FindRestoreAllOp(graph, name_prefix);
254 if (restore_all == nullptr) {
255 return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp");
256 }
257 const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2");
258 const string assign_op_name = strings::StrCat(name_prefix, "/Assign");
259 for (Node* var : variables) {
260 // Add an extra prefix after calling graph->NewName because the "unique"
261 // name may conflict with names generated for Send nodes.
262 // TODO(b/77547936): fix this more generally and get rid of the extra prefix
263 // here.
264 string new_restore_op_name =
265 strings::StrCat(graph->NewName(restore_op_name), "_qt");
266 string new_assign_op_name =
267 strings::StrCat(graph->NewName(assign_op_name), "_qt");
268 string tensor_names_op_name =
269 strings::StrCat(new_restore_op_name, "/tensor_names");
270 string shape_and_slices_op_name =
271 strings::StrCat(new_restore_op_name, "/shape_and_slices");
272
273 // Construct the tensor_names input with the variable name.
274 Node* tensor_names;
275 Tensor tensor_names_val(DT_STRING, TensorShape({1}));
276 tensor_names_val.flat<tstring>()(0) = var->name();
277 TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const")
278 .Attr("dtype", DT_STRING)
279 .Attr("value", tensor_names_val)
280 .Finalize(graph, &tensor_names));
281
282 // Construct the shape_and_slices input with empty string.
283 Node* shape_and_slices;
284 Tensor shape_and_slices_val(DT_STRING, TensorShape({1}));
285 shape_and_slices_val.flat<tstring>()(0) = "";
286 TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const")
287 .Attr("dtype", DT_STRING)
288 .Attr("value", shape_and_slices_val)
289 .Finalize(graph, &shape_and_slices));
290
291 // Build the new Restore op for this variable.
292 Node* restore_op;
293 TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2")
294 .Input(prefix_op)
295 .Input(tensor_names)
296 .Input(shape_and_slices)
297 .Attr("dtypes", {DT_FLOAT})
298 .Finalize(graph, &restore_op));
299
300 // Create Assign op, attaching the variable and Restore op to it.
301 Node* assign_op;
302 TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign")
303 .Input(var)
304 .Input(restore_op)
305 .Finalize(graph, &assign_op));
306
307 // Add a control edge from the assign op to restore_all op.
308 graph->AddControlEdge(assign_op, restore_all);
309 }
310 return OkStatus();
311}
312
313// Adds new variables to save and restore ops matching the Save and Restore
314// graphs created in tensorflow/python/training/saver.py.
315Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) {
316 Node* save_op = nullptr;
317 std::vector<const Edge*> in_edges;
318 bool found = false;
319 TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found));
320 if (found) {
321 TF_RETURN_IF_ERROR(
322 AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables));
323 TF_RETURN_IF_ERROR(
324 ConnectVariablesToSaveOp(graph, save_op, in_edges, variables));
325 }
326 return OkStatus();
327}
328
329// Sets output to the Node that computes reduction axes corresponding to all
330// dimensions of input and return.
331Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input,
332 Node** output) {
333 name_prefix = strings::StrCat(name_prefix, "/ReductionAxes");
334 Node* start;
335 Tensor zero_tensor(DT_INT32, TensorShape());
336 zero_tensor.flat<int32>()(0) = 0;
337 TF_RETURN_IF_ERROR(
338 NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const")
339 .Attr("dtype", DT_INT32)
340 .Attr("value", zero_tensor)
341 .Finalize(graph, &start));
342 Node* delta;
343 Tensor one_tensor(DT_INT32, TensorShape());
344 one_tensor.flat<int32>()(0) = 1;
345 TF_RETURN_IF_ERROR(
346 NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const")
347 .Attr("dtype", DT_INT32)
348 .Attr("value", one_tensor)
349 .Finalize(graph, &delta));
350 Node* rank;
351 TF_RETURN_IF_ERROR(
352 NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank")
353 .Input(input)
354 .Finalize(graph, &rank));
355 TF_RETURN_IF_ERROR(
356 NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range")
357 .Input(start)
358 .Input(rank)
359 .Input(delta)
360 .Finalize(graph, output));
361 return OkStatus();
362}
363
364// Computes the exponential moving average of input, updated in update_variable.
365Status MakeExponentialMovingAverage(Graph* graph, string name_prefix,
366 const NodeBuilder::NodeOut& input,
367 Node* decay, Node* update_variable,
368 Node** assign_value) {
369 // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)]
370 name_prefix = strings::StrCat(name_prefix, "/EMA");
371 Node* one;
372 Tensor one_tensor(DT_FLOAT, TensorShape());
373 one_tensor.flat<float>()(0) = 1.0;
374 TF_RETURN_IF_ERROR(
375 NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const")
376 .Attr("dtype", DT_FLOAT)
377 .Attr("value", one_tensor)
378 .Finalize(graph, &one));
379 Node* decay_complement;
380 TF_RETURN_IF_ERROR(
381 NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub")
382 .Input(one)
383 .Input(decay)
384 .Finalize(graph, &decay_complement));
385
386 Node* value_diff;
387 TF_RETURN_IF_ERROR(
388 NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub")
389 .Input(update_variable)
390 .Input(input)
391 .Finalize(graph, &value_diff));
392 Node* update_value;
393 TF_RETURN_IF_ERROR(
394 NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul")
395 .Input(value_diff)
396 .Input(decay_complement)
397 .Finalize(graph, &update_value));
398
399 TF_RETURN_IF_ERROR(
400 NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub")
401 .Input(update_variable)
402 .Input(update_value)
403 .Finalize(graph, assign_value));
404 return OkStatus();
405}
406
407// Creates an automatically initialized exponential moving average variable.
408// This uses a switch op to assign a value to the variable on the first run,
409// and update with the moving average for all other runs:
410// init_val
411// |
412// var--is_init--switch
413// | true / \ false
414// | | |
415// | EMA init_val
416// | \ /
417// +----------- assign
418Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
419 Node* init_val,
420 std::vector<Node*>* added_variables,
421 Node** var) {
422 // TODO(suharshs): Update this to use ResourceVariables when they are ready.
423 TF_RETURN_IF_ERROR(
424 NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2")
425 .Attr("shape", TensorShape())
426 .Attr("dtype", DT_FLOAT)
427 .Finalize(graph, var));
428 added_variables->push_back(*var);
429
430 Node* is_initialized;
431 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"),
432 "IsVariableInitialized")
433 .Input(*var)
434 .Finalize(graph, &is_initialized));
435 Node* switch_node;
436 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch")
437 .Input(init_val)
438 .Input(is_initialized)
439 .Finalize(graph, &switch_node));
440 NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0);
441 NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1);
442
443 Node* ema_value;
444 TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true,
445 decay, *var, &ema_value));
446
447 Node* assign_value;
448 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge")
449 .Input({output_false, ema_value})
450 .Finalize(graph, &assign_value));
451
452 TF_RETURN_IF_ERROR(
453 NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign")
454 .Input(*var)
455 .Input(assign_value)
456 .Finalize(graph, var));
457 return OkStatus();
458}
459
460// Computes the min and max EMA of input and stores them in min_var and max_var.
461Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
462 std::vector<Node*>* added_variables, Node** min_var,
463 Node** max_var) {
464 // TODO(suharshs): The decay will be constant, so we could make only one for
465 // all quantize_and_dequantize ops to share, this would have to live outside
466 // this function.
467 Tensor decay_tensor(DT_FLOAT, TensorShape());
468 decay_tensor.flat<float>()(0) = kEMADecay;
469 Node* decay;
470 TF_RETURN_IF_ERROR(
471 NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const")
472 .Attr("dtype", DT_FLOAT)
473 .Attr("value", decay_tensor)
474 .Finalize(graph, &decay));
475
476 Node* reduction_axes;
477 TF_RETURN_IF_ERROR(
478 MakeReductionAxes(graph, name_prefix, input, &reduction_axes));
479 Node* min;
480 string min_name = strings::StrCat(name_prefix, "/Min");
481 TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min")
482 .Input(input)
483 .Input(reduction_axes)
484 .Finalize(graph, &min));
485 Node* max;
486 string max_name = strings::StrCat(name_prefix, "/Max");
487 TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max")
488 .Input(input)
489 .Input(reduction_axes)
490 .Finalize(graph, &max));
491 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min,
492 added_variables, min_var));
493 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max,
494 added_variables, max_var));
495 return OkStatus();
496}
497
498// Makes an input min and max constant if the range is given. Otherwise, makes
499// min and max variables that are updated by an EMA.
500Status MakeInputMinMax(Graph* graph, const string& name_prefix,
501 const EdgeToConvert& edge,
502 std::vector<Node*>* added_variables, Node** input_min,
503 Node** input_max) {
504 if (edge.range_given) {
505 // Make constant nodes for the input_min and input_max if the range is
506 // provided.
507 Tensor input_min_tensor(DT_FLOAT, TensorShape());
508 input_min_tensor.flat<float>()(0) = edge.input_min;
509 TF_RETURN_IF_ERROR(
510 NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const")
511 .Attr("dtype", DT_FLOAT)
512 .Attr("value", input_min_tensor)
513 .Finalize(graph, input_min));
514 Tensor input_max_tensor(DT_FLOAT, TensorShape());
515 input_max_tensor.flat<float>()(0) = edge.input_max;
516 TF_RETURN_IF_ERROR(
517 NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const")
518 .Attr("dtype", DT_FLOAT)
519 .Attr("value", input_max_tensor)
520 .Finalize(graph, input_max));
521 } else {
522 // If the range is not given, estimate the range with EMA variables.
523 TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(),
524 added_variables, input_min,
525 input_max));
526 }
527
528 return OkStatus();
529}
530
531// Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op
532// (and required input nodes) based on edge.
533// The result is stored in convert_node.
534Status MakeQuantizeOp(Graph* graph, const string& name_prefix,
535 const string& quant_op_type, const EdgeToConvert& edge,
536 std::vector<Node*>* added_variables,
537 Node** convert_node) {
538 Node* input_min;
539 Node* input_max;
540 TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables,
541 &input_min, &input_max));
542 string quant_name = strings::StrCat(name_prefix, "/", quant_op_type);
543 if (quant_op_type == "QuantizeAndDequantizeV2") {
544 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
545 .Input(edge.edge->src())
546 .Input(input_min)
547 .Input(input_max)
548 .Attr("signed_input", edge.signed_input)
549 .Attr("num_bits", edge.num_bits)
550 .Attr("range_given", true)
551 .Finalize(graph, convert_node));
552 } else if (quant_op_type == "FakeQuantWithMinMaxVars") {
553 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
554 .Input(edge.edge->src())
555 .Input(input_min)
556 .Input(input_max)
557 .Attr("num_bits", edge.num_bits)
558 .Finalize(graph, convert_node));
559 } else {
560 return errors::InvalidArgument("Unknown quant op type: ", quant_op_type);
561 }
562 return OkStatus();
563}
564
565// Insert conversion op, connect it to the graph and remove the old edge.
566Status ProcessTargetEdges(Graph* graph, const string& quant_op_type,
567 const std::vector<EdgeToConvert>& target_edges) {
568 // Remember previously converted ops to avoid duplicated conversion on the
569 // same input.
570 std::unordered_map<string, Node*, StringPieceHasher> name_index;
571 std::vector<Node*> added_variables;
572 for (const EdgeToConvert edge : target_edges) {
573 Node* convert_node;
574 string name_prefix = edge.edge->src()->name();
575
576 auto iter = name_index.find(name_prefix);
577 if (iter == name_index.end()) {
578 TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge,
579 &added_variables, &convert_node));
580 name_index[name_prefix] = convert_node;
581 } else {
582 convert_node = iter->second;
583 }
584
585 graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input());
586 graph->RemoveEdge(edge.edge);
587 }
588
589 TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables));
590
591 return OkStatus();
592}
593
594} // namespace
595
596Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type,
597 Graph* graph) {
598 if (graph == nullptr) {
599 return errors::InvalidArgument("Cannot accept empty graph pointer.");
600 }
601
602 if (num_bits < 1 || num_bits > 63) {
603 return errors::OutOfRange("num_bits should be in range [1, 63] but is: ",
604 num_bits);
605 }
606 int potential_input = 0;
607 std::vector<EdgeToConvert> target_edges;
608 for (Node* node : graph->nodes()) {
609 if (nodes_to_rewrite->find(node->type_string()) !=
610 nodes_to_rewrite->end() &&
611 !IsGradientNode(graph, node)) {
612 // Find out which types are the inputs and convert them accordingly.
613 // 1. Const/Variable OP: This is quantized as signed tensors with no given
614 // range.
615 // 2. Activation OP: Set the range accordingly for different types of
616 // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh}
617 // 3. Identity OP: The quantization parameters depend on its input.
618 // 4. Pooling OPs: various pooling ops. Also depends on its input.
619 // 5. Reshape OP: Also depends on the first input to this op.
620 // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the
621 // model input. However, if there are >1 unknown ops, then returns an
622 // error for now to avoid unexpected behavior.
623 // Note: The list above might not be a complete list. Please let us
624 // know if you see the error so we can handle your case.
625 for (const Edge* edge : node->in_edges()) {
626 if (edge->src_output() == Graph::kControlSlot) {
627 // Skip the control dependency input.
628 continue;
629 } else {
630 bool signed_input = false;
631 bool range_given = false;
632 float input_min = 0;
633 float input_max = 0;
634 bool known_op = FindType(graph, edge->src(), &signed_input,
635 &range_given, &input_min, &input_max);
636 if (!known_op) {
637 // Unknown op is considered as input.
638 potential_input++;
639 if (potential_input > kAllowedInputs) {
640 return errors::Unimplemented(
641 "Found an unknown op: ", edge->src()->name(),
642 " with type: ", edge->src()->type_string(),
643 "; Unknown ops are considered as model input for now and "
644 "only ",
645 kAllowedInputs, " inputs are supported currently.");
646 }
647 }
648
649 target_edges.emplace_back(EdgeToConvert(
650 edge, num_bits, signed_input, range_given, input_min, input_max));
651 }
652 }
653 }
654 }
655
656 TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges));
657
658 return OkStatus();
659}
660
661Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
662 int32_t num_bits,
663 const string& quant_op_type,
664 GraphDef* result_graphdef) {
665 Graph graph(OpRegistry::Global());
666 GraphConstructorOptions opts;
667 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
668
669 // Call the rewriter on the graph.
670 TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
671
672 // Convert the result graph back to a GraphDef.
673 graph.ToGraphDef(result_graphdef);
674 return OkStatus();
675}
676
677Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
678 int32_t num_bits,
679 const string& quant_op_type,
680 string* result_graph_string) {
681 // First create the graph from the GraphDef.
682 GraphDef input_graphdef;
683 if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
684 return errors::InvalidArgument(
685 "input_graph_string is not a serialized GraphDef protocol buffer");
686 }
687 GraphDef output_graphdef;
688 TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
689 input_graphdef, num_bits, quant_op_type, &output_graphdef));
690
691 if (!output_graphdef.SerializeToString(result_graph_string)) {
692 return errors::Internal(
693 "quantize training transformation resulted in invalid GraphDef");
694 }
695 return OkStatus();
696}
697
698} // namespace tensorflow
699