1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/quantize_training.h" |
17 | |
18 | #include <algorithm> |
19 | #include <atomic> |
20 | #include <set> |
21 | #include <unordered_map> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
25 | #include "tensorflow/core/common_runtime/memory_types.h" |
26 | #include "tensorflow/core/framework/log_memory.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/graph/algorithm.h" |
29 | #include "tensorflow/core/graph/node_builder.h" |
30 | #include "tensorflow/core/graph/subgraph.h" |
31 | #include "tensorflow/core/lib/strings/strcat.h" |
32 | #include "tensorflow/core/public/session_options.h" |
33 | |
34 | namespace tensorflow { |
35 | namespace { |
36 | |
37 | // TODO(suharshs): If desired, make these values configurable. |
38 | const uint32 kAllowedInputs = 2; |
39 | const float kEMADecay = 0.999; |
40 | |
41 | // Node types to rewrite. Insert quantize_and_dequantize op for their inputs. |
42 | const auto* nodes_to_rewrite = |
43 | new std::unordered_set<string, StringPieceHasher>{"MatMul" , "Conv2D" }; |
44 | |
45 | // Contains necessary parameters to convert an edge. |
46 | struct EdgeToConvert { |
47 | // edge is not owned here. |
48 | const Edge* edge; |
49 | int32 num_bits; |
50 | bool signed_input; |
51 | bool range_given; |
52 | float input_min; |
53 | float input_max; |
54 | |
55 | EdgeToConvert(const Edge* e, int32_t bits, bool sign, bool range, float min, |
56 | float max) |
57 | : edge(e), |
58 | num_bits(bits), |
59 | signed_input(sign), |
60 | range_given(range), |
61 | input_min(min), |
62 | input_max(max) {} |
63 | }; |
64 | |
65 | // Decide if a node is in backward pass by checking if its name is led by |
66 | // "gradients". |
67 | // TODO(jmchen): Make this check more robust as it is not guaranteed that the |
68 | // forward node will not be named with a leading "gradients". |
69 | inline bool IsGradientNode(const Graph* graph, const Node* node) { |
70 | static const string tag = "gradients" ; |
71 | return (node->name().compare(0, tag.size(), tag) == 0); |
72 | } |
73 | |
74 | // Find the type of the input to set the parameters for the |
75 | // quantize_and_dequantize op. |
76 | // Returns true if the root tensor op type is known, false otherwise. |
77 | bool FindType(const Graph* graph, const Node* node, bool* signed_input, |
78 | bool* range_given, float* input_min, float* input_max) { |
79 | const string& src_op = node->type_string(); |
80 | if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2" ) { |
81 | *signed_input = true; |
82 | *range_given = false; |
83 | } else if (src_op == "Relu" ) { |
84 | // Range is not given for Relu. |
85 | *signed_input = false; |
86 | *range_given = false; |
87 | } else if (src_op == "Relu6" ) { |
88 | // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the |
89 | // actual activations are somewhere in within this range, we can quantize |
90 | // this even further. This is true for other activations like Sigmoid6 too. |
91 | *signed_input = false; |
92 | *range_given = true; |
93 | *input_min = 0; |
94 | *input_max = 6; |
95 | } else if (src_op == "Sigmoid" ) { |
96 | *signed_input = false; |
97 | *range_given = true; |
98 | *input_min = 0; |
99 | *input_max = 1; |
100 | } else if (src_op == "Tanh" ) { |
101 | *signed_input = true; |
102 | *range_given = true; |
103 | *input_min = -1; |
104 | *input_max = 1; |
105 | } else if (src_op == "Reshape" || src_op == "ConcatV2" ) { |
106 | // Reshape has 2 inputs and the first one is the tensor. |
107 | // ConcatV2 has many inputs but they should all have the same activation |
108 | // function (i.e. Inception). So we just recurse on the first input. |
109 | for (const Edge* edge : node->in_edges()) { |
110 | if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) { |
111 | FindType(graph, edge->src(), signed_input, range_given, input_min, |
112 | input_max); |
113 | } |
114 | } |
115 | } else if (src_op == "Identity" || src_op == "MaxPool" || |
116 | src_op == "AvgPool" || src_op == "MaxPool3D" || |
117 | src_op == "AvgPool3D" ) { |
118 | // All these Ops only have 1 data input. |
119 | for (const Edge* edge : node->in_edges()) { |
120 | if (edge->src_output() != Graph::kControlSlot) { |
121 | FindType(graph, edge->src(), signed_input, range_given, input_min, |
122 | input_max); |
123 | } |
124 | } |
125 | } else { |
126 | // Unknown type, could be the model input examples. |
127 | // TODO(jmchen): Set the params for input with user's hint. |
128 | *signed_input = true; |
129 | *range_given = false; |
130 | return false; |
131 | } |
132 | |
133 | return true; |
134 | } |
135 | |
136 | // Find the Save op and inputs. |
137 | Status FindSaveOp(const Graph* graph, Node** save_op, |
138 | std::vector<const Edge*>* in_edges, bool* found) { |
139 | *found = false; |
140 | for (Node* node : graph->op_nodes()) { |
141 | if (node->type_string() == "SaveV2" ) { |
142 | // We found multiple save ops. |
143 | if (*found) { |
144 | return errors::InvalidArgument("Input graph has multiple SaveV2 ops." ); |
145 | } |
146 | *save_op = node; |
147 | *found = true; |
148 | TF_RETURN_IF_ERROR(node->input_edges(in_edges)); |
149 | } |
150 | } |
151 | return OkStatus(); |
152 | } |
153 | |
154 | Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { |
155 | for (Node* node : graph->op_nodes()) { |
156 | // The restore_all op should have the same prefix of the save_op. |
157 | if (node->name() == strings::StrCat(save_prefix, "/restore_all" )) { |
158 | return node; |
159 | } |
160 | } |
161 | return nullptr; |
162 | } |
163 | |
164 | // Strips the last "/suffix" from a name. |
165 | // We use this to construct the name of restore ops in the same way they are |
166 | // constructed by the Saver. |
167 | StringPiece GetNodeNamePrefix(const Node* node) { |
168 | StringPiece name = node->name(); |
169 | return name.substr(0, name.rfind('/')); |
170 | } |
171 | |
172 | void FillStringTensor(Tensor* dst, const Tensor& src) { |
173 | auto dst_flat = dst->flat<tstring>(); |
174 | auto src_flat = src.flat<tstring>(); |
175 | for (int i = 0; i < src.NumElements(); i++) { |
176 | dst_flat(i) = src_flat(i); |
177 | } |
178 | } |
179 | |
180 | // Add the added_variables as an inputs to the Save op. |
181 | // We change the inputs of the SaveV2 op to include the names of the added |
182 | // variables. We also add the variables as inputs to the save op. |
183 | Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, |
184 | const std::vector<const Edge*>& in_edges, |
185 | const std::vector<Node*>& added_variables) { |
186 | Node* tensor_names_op = in_edges[1]->src(); |
187 | Node* shape_and_slices_op = in_edges[2]->src(); |
188 | |
189 | // Get the tensor_names and shape_and_slices tensors from the const op. |
190 | Tensor tensor_names; |
191 | Tensor shape_and_slices; |
192 | TF_RETURN_IF_ERROR( |
193 | GetNodeAttr(tensor_names_op->attrs(), "value" , &tensor_names)); |
194 | TF_RETURN_IF_ERROR( |
195 | GetNodeAttr(shape_and_slices_op->attrs(), "value" , &shape_and_slices)); |
196 | |
197 | int tn_size = tensor_names.NumElements(); |
198 | int var_size = added_variables.size(); |
199 | |
200 | // Create a new save_op that has inputs to all the new variables. |
201 | NodeBuilder save_op_builder = |
202 | NodeBuilder(save_op->name(), save_op->type_string()); |
203 | // The first three inputs are prefix, tensor_names, and shapes_and_slices. |
204 | for (int i = 0; i < 3; i++) { |
205 | save_op_builder = save_op_builder.Input(in_edges[i]->src()); |
206 | } |
207 | std::vector<NodeBuilder::NodeOut> var_nodeouts; |
208 | var_nodeouts.reserve(tn_size + var_size); |
209 | // The rest of the inputs need to be used the construct the tensor list arg. |
210 | for (int i = 3; i < in_edges.size(); i++) { |
211 | var_nodeouts.emplace_back(in_edges[i]->src()); |
212 | } |
213 | |
214 | // Add the new values to the tensors and the op input. |
215 | Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size})); |
216 | Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size})); |
217 | FillStringTensor(&new_tensor_names, tensor_names); |
218 | FillStringTensor(&new_shape_and_slices, shape_and_slices); |
219 | for (int i = 0; i < var_size; i++) { |
220 | Node* var = added_variables[i]; |
221 | new_tensor_names.flat<tstring>()(tn_size + i) = var->name(); |
222 | new_shape_and_slices.flat<tstring>()(tn_size + i) = "" ; |
223 | var_nodeouts.emplace_back(var); |
224 | } |
225 | save_op_builder = save_op_builder.Input(var_nodeouts); |
226 | |
227 | // Update the attrs. |
228 | tensor_names_op->AddAttr("value" , new_tensor_names); |
229 | shape_and_slices_op->AddAttr("value" , new_shape_and_slices); |
230 | |
231 | // Remove the old save_op and add the new one. |
232 | Node* new_save_op; |
233 | TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op)); |
234 | // Add outputs to the new_save_op, all outputs are control edges. |
235 | for (const Edge* edge : save_op->out_edges()) { |
236 | graph->AddControlEdge(new_save_op, edge->dst()); |
237 | } |
238 | graph->RemoveNode(save_op); |
239 | |
240 | return OkStatus(); |
241 | } |
242 | |
243 | // Add a restore subgraph for each variable and connect to the restore_all op. |
244 | // For each variable we add the following subgraph: |
245 | // Assign----restore_all |
246 | // | | |
247 | // RestoreV2 Variable |
248 | Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, |
249 | const std::vector<const Edge*>& in_edges, |
250 | const std::vector<Node*>& variables) { |
251 | Node* prefix_op = in_edges[0]->src(); |
252 | StringPiece name_prefix = GetNodeNamePrefix(save_op); |
253 | Node* restore_all = FindRestoreAllOp(graph, name_prefix); |
254 | if (restore_all == nullptr) { |
255 | return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp" ); |
256 | } |
257 | const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2" ); |
258 | const string assign_op_name = strings::StrCat(name_prefix, "/Assign" ); |
259 | for (Node* var : variables) { |
260 | // Add an extra prefix after calling graph->NewName because the "unique" |
261 | // name may conflict with names generated for Send nodes. |
262 | // TODO(b/77547936): fix this more generally and get rid of the extra prefix |
263 | // here. |
264 | string new_restore_op_name = |
265 | strings::StrCat(graph->NewName(restore_op_name), "_qt" ); |
266 | string new_assign_op_name = |
267 | strings::StrCat(graph->NewName(assign_op_name), "_qt" ); |
268 | string tensor_names_op_name = |
269 | strings::StrCat(new_restore_op_name, "/tensor_names" ); |
270 | string shape_and_slices_op_name = |
271 | strings::StrCat(new_restore_op_name, "/shape_and_slices" ); |
272 | |
273 | // Construct the tensor_names input with the variable name. |
274 | Node* tensor_names; |
275 | Tensor tensor_names_val(DT_STRING, TensorShape({1})); |
276 | tensor_names_val.flat<tstring>()(0) = var->name(); |
277 | TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const" ) |
278 | .Attr("dtype" , DT_STRING) |
279 | .Attr("value" , tensor_names_val) |
280 | .Finalize(graph, &tensor_names)); |
281 | |
282 | // Construct the shape_and_slices input with empty string. |
283 | Node* shape_and_slices; |
284 | Tensor shape_and_slices_val(DT_STRING, TensorShape({1})); |
285 | shape_and_slices_val.flat<tstring>()(0) = "" ; |
286 | TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const" ) |
287 | .Attr("dtype" , DT_STRING) |
288 | .Attr("value" , shape_and_slices_val) |
289 | .Finalize(graph, &shape_and_slices)); |
290 | |
291 | // Build the new Restore op for this variable. |
292 | Node* restore_op; |
293 | TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2" ) |
294 | .Input(prefix_op) |
295 | .Input(tensor_names) |
296 | .Input(shape_and_slices) |
297 | .Attr("dtypes" , {DT_FLOAT}) |
298 | .Finalize(graph, &restore_op)); |
299 | |
300 | // Create Assign op, attaching the variable and Restore op to it. |
301 | Node* assign_op; |
302 | TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign" ) |
303 | .Input(var) |
304 | .Input(restore_op) |
305 | .Finalize(graph, &assign_op)); |
306 | |
307 | // Add a control edge from the assign op to restore_all op. |
308 | graph->AddControlEdge(assign_op, restore_all); |
309 | } |
310 | return OkStatus(); |
311 | } |
312 | |
313 | // Adds new variables to save and restore ops matching the Save and Restore |
314 | // graphs created in tensorflow/python/training/saver.py. |
315 | Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) { |
316 | Node* save_op = nullptr; |
317 | std::vector<const Edge*> in_edges; |
318 | bool found = false; |
319 | TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found)); |
320 | if (found) { |
321 | TF_RETURN_IF_ERROR( |
322 | AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables)); |
323 | TF_RETURN_IF_ERROR( |
324 | ConnectVariablesToSaveOp(graph, save_op, in_edges, variables)); |
325 | } |
326 | return OkStatus(); |
327 | } |
328 | |
329 | // Sets output to the Node that computes reduction axes corresponding to all |
330 | // dimensions of input and return. |
331 | Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, |
332 | Node** output) { |
333 | name_prefix = strings::StrCat(name_prefix, "/ReductionAxes" ); |
334 | Node* start; |
335 | Tensor zero_tensor(DT_INT32, TensorShape()); |
336 | zero_tensor.flat<int32>()(0) = 0; |
337 | TF_RETURN_IF_ERROR( |
338 | NodeBuilder(strings::StrCat(name_prefix, "/RangeStart" ), "Const" ) |
339 | .Attr("dtype" , DT_INT32) |
340 | .Attr("value" , zero_tensor) |
341 | .Finalize(graph, &start)); |
342 | Node* delta; |
343 | Tensor one_tensor(DT_INT32, TensorShape()); |
344 | one_tensor.flat<int32>()(0) = 1; |
345 | TF_RETURN_IF_ERROR( |
346 | NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta" ), "Const" ) |
347 | .Attr("dtype" , DT_INT32) |
348 | .Attr("value" , one_tensor) |
349 | .Finalize(graph, &delta)); |
350 | Node* rank; |
351 | TF_RETURN_IF_ERROR( |
352 | NodeBuilder(strings::StrCat(name_prefix, "/InputRank" ), "Rank" ) |
353 | .Input(input) |
354 | .Finalize(graph, &rank)); |
355 | TF_RETURN_IF_ERROR( |
356 | NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes" ), "Range" ) |
357 | .Input(start) |
358 | .Input(rank) |
359 | .Input(delta) |
360 | .Finalize(graph, output)); |
361 | return OkStatus(); |
362 | } |
363 | |
364 | // Computes the exponential moving average of input, updated in update_variable. |
365 | Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, |
366 | const NodeBuilder::NodeOut& input, |
367 | Node* decay, Node* update_variable, |
368 | Node** assign_value) { |
369 | // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)] |
370 | name_prefix = strings::StrCat(name_prefix, "/EMA" ); |
371 | Node* one; |
372 | Tensor one_tensor(DT_FLOAT, TensorShape()); |
373 | one_tensor.flat<float>()(0) = 1.0; |
374 | TF_RETURN_IF_ERROR( |
375 | NodeBuilder(strings::StrCat(name_prefix, "/OneConst" ), "Const" ) |
376 | .Attr("dtype" , DT_FLOAT) |
377 | .Attr("value" , one_tensor) |
378 | .Finalize(graph, &one)); |
379 | Node* decay_complement; |
380 | TF_RETURN_IF_ERROR( |
381 | NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement" ), "Sub" ) |
382 | .Input(one) |
383 | .Input(decay) |
384 | .Finalize(graph, &decay_complement)); |
385 | |
386 | Node* value_diff; |
387 | TF_RETURN_IF_ERROR( |
388 | NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff" ), "Sub" ) |
389 | .Input(update_variable) |
390 | .Input(input) |
391 | .Finalize(graph, &value_diff)); |
392 | Node* update_value; |
393 | TF_RETURN_IF_ERROR( |
394 | NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue" ), "Mul" ) |
395 | .Input(value_diff) |
396 | .Input(decay_complement) |
397 | .Finalize(graph, &update_value)); |
398 | |
399 | TF_RETURN_IF_ERROR( |
400 | NodeBuilder(strings::StrCat(name_prefix, "/EMAValue" ), "Sub" ) |
401 | .Input(update_variable) |
402 | .Input(update_value) |
403 | .Finalize(graph, assign_value)); |
404 | return OkStatus(); |
405 | } |
406 | |
407 | // Creates an automatically initialized exponential moving average variable. |
408 | // This uses a switch op to assign a value to the variable on the first run, |
409 | // and update with the moving average for all other runs: |
410 | // init_val |
411 | // | |
412 | // var--is_init--switch |
413 | // | true / \ false |
414 | // | | | |
415 | // | EMA init_val |
416 | // | \ / |
417 | // +----------- assign |
418 | Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, |
419 | Node* init_val, |
420 | std::vector<Node*>* added_variables, |
421 | Node** var) { |
422 | // TODO(suharshs): Update this to use ResourceVariables when they are ready. |
423 | TF_RETURN_IF_ERROR( |
424 | NodeBuilder(strings::StrCat(name, "/Variable" ), "VariableV2" ) |
425 | .Attr("shape" , TensorShape()) |
426 | .Attr("dtype" , DT_FLOAT) |
427 | .Finalize(graph, var)); |
428 | added_variables->push_back(*var); |
429 | |
430 | Node* is_initialized; |
431 | TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized" ), |
432 | "IsVariableInitialized" ) |
433 | .Input(*var) |
434 | .Finalize(graph, &is_initialized)); |
435 | Node* switch_node; |
436 | TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch" ), "Switch" ) |
437 | .Input(init_val) |
438 | .Input(is_initialized) |
439 | .Finalize(graph, &switch_node)); |
440 | NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0); |
441 | NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1); |
442 | |
443 | Node* ema_value; |
444 | TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true, |
445 | decay, *var, &ema_value)); |
446 | |
447 | Node* assign_value; |
448 | TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge" ), "Merge" ) |
449 | .Input({output_false, ema_value}) |
450 | .Finalize(graph, &assign_value)); |
451 | |
452 | TF_RETURN_IF_ERROR( |
453 | NodeBuilder(strings::StrCat(name, "/AssignValue" ), "Assign" ) |
454 | .Input(*var) |
455 | .Input(assign_value) |
456 | .Finalize(graph, var)); |
457 | return OkStatus(); |
458 | } |
459 | |
460 | // Computes the min and max EMA of input and stores them in min_var and max_var. |
461 | Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, |
462 | std::vector<Node*>* added_variables, Node** min_var, |
463 | Node** max_var) { |
464 | // TODO(suharshs): The decay will be constant, so we could make only one for |
465 | // all quantize_and_dequantize ops to share, this would have to live outside |
466 | // this function. |
467 | Tensor decay_tensor(DT_FLOAT, TensorShape()); |
468 | decay_tensor.flat<float>()(0) = kEMADecay; |
469 | Node* decay; |
470 | TF_RETURN_IF_ERROR( |
471 | NodeBuilder(strings::StrCat(name_prefix, "/Decay" ), "Const" ) |
472 | .Attr("dtype" , DT_FLOAT) |
473 | .Attr("value" , decay_tensor) |
474 | .Finalize(graph, &decay)); |
475 | |
476 | Node* reduction_axes; |
477 | TF_RETURN_IF_ERROR( |
478 | MakeReductionAxes(graph, name_prefix, input, &reduction_axes)); |
479 | Node* min; |
480 | string min_name = strings::StrCat(name_prefix, "/Min" ); |
481 | TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min" ) |
482 | .Input(input) |
483 | .Input(reduction_axes) |
484 | .Finalize(graph, &min)); |
485 | Node* max; |
486 | string max_name = strings::StrCat(name_prefix, "/Max" ); |
487 | TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max" ) |
488 | .Input(input) |
489 | .Input(reduction_axes) |
490 | .Finalize(graph, &max)); |
491 | TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min, |
492 | added_variables, min_var)); |
493 | TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max, |
494 | added_variables, max_var)); |
495 | return OkStatus(); |
496 | } |
497 | |
498 | // Makes an input min and max constant if the range is given. Otherwise, makes |
499 | // min and max variables that are updated by an EMA. |
500 | Status MakeInputMinMax(Graph* graph, const string& name_prefix, |
501 | const EdgeToConvert& edge, |
502 | std::vector<Node*>* added_variables, Node** input_min, |
503 | Node** input_max) { |
504 | if (edge.range_given) { |
505 | // Make constant nodes for the input_min and input_max if the range is |
506 | // provided. |
507 | Tensor input_min_tensor(DT_FLOAT, TensorShape()); |
508 | input_min_tensor.flat<float>()(0) = edge.input_min; |
509 | TF_RETURN_IF_ERROR( |
510 | NodeBuilder(strings::StrCat(name_prefix, "/InputMin" ), "Const" ) |
511 | .Attr("dtype" , DT_FLOAT) |
512 | .Attr("value" , input_min_tensor) |
513 | .Finalize(graph, input_min)); |
514 | Tensor input_max_tensor(DT_FLOAT, TensorShape()); |
515 | input_max_tensor.flat<float>()(0) = edge.input_max; |
516 | TF_RETURN_IF_ERROR( |
517 | NodeBuilder(strings::StrCat(name_prefix, "/InputMax" ), "Const" ) |
518 | .Attr("dtype" , DT_FLOAT) |
519 | .Attr("value" , input_max_tensor) |
520 | .Finalize(graph, input_max)); |
521 | } else { |
522 | // If the range is not given, estimate the range with EMA variables. |
523 | TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(), |
524 | added_variables, input_min, |
525 | input_max)); |
526 | } |
527 | |
528 | return OkStatus(); |
529 | } |
530 | |
531 | // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op |
532 | // (and required input nodes) based on edge. |
533 | // The result is stored in convert_node. |
534 | Status MakeQuantizeOp(Graph* graph, const string& name_prefix, |
535 | const string& quant_op_type, const EdgeToConvert& edge, |
536 | std::vector<Node*>* added_variables, |
537 | Node** convert_node) { |
538 | Node* input_min; |
539 | Node* input_max; |
540 | TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables, |
541 | &input_min, &input_max)); |
542 | string quant_name = strings::StrCat(name_prefix, "/" , quant_op_type); |
543 | if (quant_op_type == "QuantizeAndDequantizeV2" ) { |
544 | TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type) |
545 | .Input(edge.edge->src()) |
546 | .Input(input_min) |
547 | .Input(input_max) |
548 | .Attr("signed_input" , edge.signed_input) |
549 | .Attr("num_bits" , edge.num_bits) |
550 | .Attr("range_given" , true) |
551 | .Finalize(graph, convert_node)); |
552 | } else if (quant_op_type == "FakeQuantWithMinMaxVars" ) { |
553 | TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type) |
554 | .Input(edge.edge->src()) |
555 | .Input(input_min) |
556 | .Input(input_max) |
557 | .Attr("num_bits" , edge.num_bits) |
558 | .Finalize(graph, convert_node)); |
559 | } else { |
560 | return errors::InvalidArgument("Unknown quant op type: " , quant_op_type); |
561 | } |
562 | return OkStatus(); |
563 | } |
564 | |
565 | // Insert conversion op, connect it to the graph and remove the old edge. |
566 | Status ProcessTargetEdges(Graph* graph, const string& quant_op_type, |
567 | const std::vector<EdgeToConvert>& target_edges) { |
568 | // Remember previously converted ops to avoid duplicated conversion on the |
569 | // same input. |
570 | std::unordered_map<string, Node*, StringPieceHasher> name_index; |
571 | std::vector<Node*> added_variables; |
572 | for (const EdgeToConvert edge : target_edges) { |
573 | Node* convert_node; |
574 | string name_prefix = edge.edge->src()->name(); |
575 | |
576 | auto iter = name_index.find(name_prefix); |
577 | if (iter == name_index.end()) { |
578 | TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge, |
579 | &added_variables, &convert_node)); |
580 | name_index[name_prefix] = convert_node; |
581 | } else { |
582 | convert_node = iter->second; |
583 | } |
584 | |
585 | graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input()); |
586 | graph->RemoveEdge(edge.edge); |
587 | } |
588 | |
589 | TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables)); |
590 | |
591 | return OkStatus(); |
592 | } |
593 | |
594 | } // namespace |
595 | |
596 | Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, |
597 | Graph* graph) { |
598 | if (graph == nullptr) { |
599 | return errors::InvalidArgument("Cannot accept empty graph pointer." ); |
600 | } |
601 | |
602 | if (num_bits < 1 || num_bits > 63) { |
603 | return errors::OutOfRange("num_bits should be in range [1, 63] but is: " , |
604 | num_bits); |
605 | } |
606 | int potential_input = 0; |
607 | std::vector<EdgeToConvert> target_edges; |
608 | for (Node* node : graph->nodes()) { |
609 | if (nodes_to_rewrite->find(node->type_string()) != |
610 | nodes_to_rewrite->end() && |
611 | !IsGradientNode(graph, node)) { |
612 | // Find out which types are the inputs and convert them accordingly. |
613 | // 1. Const/Variable OP: This is quantized as signed tensors with no given |
614 | // range. |
615 | // 2. Activation OP: Set the range accordingly for different types of |
616 | // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh} |
617 | // 3. Identity OP: The quantization parameters depend on its input. |
618 | // 4. Pooling OPs: various pooling ops. Also depends on its input. |
619 | // 5. Reshape OP: Also depends on the first input to this op. |
620 | // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the |
621 | // model input. However, if there are >1 unknown ops, then returns an |
622 | // error for now to avoid unexpected behavior. |
623 | // Note: The list above might not be a complete list. Please let us |
624 | // know if you see the error so we can handle your case. |
625 | for (const Edge* edge : node->in_edges()) { |
626 | if (edge->src_output() == Graph::kControlSlot) { |
627 | // Skip the control dependency input. |
628 | continue; |
629 | } else { |
630 | bool signed_input = false; |
631 | bool range_given = false; |
632 | float input_min = 0; |
633 | float input_max = 0; |
634 | bool known_op = FindType(graph, edge->src(), &signed_input, |
635 | &range_given, &input_min, &input_max); |
636 | if (!known_op) { |
637 | // Unknown op is considered as input. |
638 | potential_input++; |
639 | if (potential_input > kAllowedInputs) { |
640 | return errors::Unimplemented( |
641 | "Found an unknown op: " , edge->src()->name(), |
642 | " with type: " , edge->src()->type_string(), |
643 | "; Unknown ops are considered as model input for now and " |
644 | "only " , |
645 | kAllowedInputs, " inputs are supported currently." ); |
646 | } |
647 | } |
648 | |
649 | target_edges.emplace_back(EdgeToConvert( |
650 | edge, num_bits, signed_input, range_given, input_min, input_max)); |
651 | } |
652 | } |
653 | } |
654 | } |
655 | |
656 | TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges)); |
657 | |
658 | return OkStatus(); |
659 | } |
660 | |
661 | Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, |
662 | int32_t num_bits, |
663 | const string& quant_op_type, |
664 | GraphDef* result_graphdef) { |
665 | Graph graph(OpRegistry::Global()); |
666 | GraphConstructorOptions opts; |
667 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph)); |
668 | |
669 | // Call the rewriter on the graph. |
670 | TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph)); |
671 | |
672 | // Convert the result graph back to a GraphDef. |
673 | graph.ToGraphDef(result_graphdef); |
674 | return OkStatus(); |
675 | } |
676 | |
677 | Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string, |
678 | int32_t num_bits, |
679 | const string& quant_op_type, |
680 | string* result_graph_string) { |
681 | // First create the graph from the GraphDef. |
682 | GraphDef input_graphdef; |
683 | if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) { |
684 | return errors::InvalidArgument( |
685 | "input_graph_string is not a serialized GraphDef protocol buffer" ); |
686 | } |
687 | GraphDef output_graphdef; |
688 | TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef( |
689 | input_graphdef, num_bits, quant_op_type, &output_graphdef)); |
690 | |
691 | if (!output_graphdef.SerializeToString(result_graph_string)) { |
692 | return errors::Internal( |
693 | "quantize training transformation resulted in invalid GraphDef" ); |
694 | } |
695 | return OkStatus(); |
696 | } |
697 | |
698 | } // namespace tensorflow |
699 | |