1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/common_runtime/constant_folding.h"
19#include "tensorflow/core/common_runtime/graph_constructor.h"
20#include "tensorflow/core/common_runtime/threadpool_device.h"
21#include "tensorflow/core/graph/node_builder.h"
22#include "tensorflow/core/graph/subgraph.h"
23#include "tensorflow/core/kernels/quantization_utils.h"
24#include "tensorflow/core/platform/init_main.h"
25#include "tensorflow/core/public/session.h"
26#include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28namespace tensorflow {
29namespace graph_transforms {
30
31// Converts any large float constants into eight-bit equivalents, with a
32// Dequantize op so that subsequent nodes can still access the results in a
33// float form.
34Status QuantizeWeights(const GraphDef& input_graph_def,
35 const TransformFuncContext& context,
36 GraphDef* output_graph_def) {
37 int32_t minimum_size;
38 TF_RETURN_IF_ERROR(
39 context.GetOneInt32Parameter("minimum_size", 1024, &minimum_size));
40 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
41 input_graph_def, {"Const"},
42 [minimum_size](const NodeMatch& match,
43 const std::set<string>& input_nodes,
44 const std::set<string>& output_nodes,
45 std::vector<NodeDef>* new_nodes) {
46 const NodeDef& old_const_node = match.node;
47 if (!old_const_node.attr().count("dtype")) {
48 return errors::InvalidArgument("No 'dtype' attribute for Const node ",
49 old_const_node.name());
50 }
51 if (!old_const_node.attr().count("value")) {
52 return errors::InvalidArgument("No 'value' attribute for Const node ",
53 old_const_node.name());
54 }
55 const DataType old_dtype = old_const_node.attr().at("dtype").type();
56 Tensor old_tensor;
57 if (!old_tensor.FromProto(old_const_node.attr().at("value").tensor())) {
58 return errors::InvalidArgument("Decoding Tensor failed for node",
59 old_const_node.name());
60 }
61 const size_t num_elements = old_tensor.NumElements();
62 // If this isn't a float constant, or it's too small, then reuse the
63 // same node with no changes.
64 if ((old_dtype != DT_FLOAT) || (num_elements < minimum_size)) {
65 new_nodes->push_back(old_const_node);
66 return OkStatus();
67 }
68 const float* old_values = old_tensor.flat<float>().data();
69 float min = std::numeric_limits<float>::max();
70 float max = std::numeric_limits<float>::min();
71 for (int i = 0; i < num_elements; ++i) {
72 const float value = old_values[i];
73 min = std::min(min, value);
74 max = std::max(max, value);
75 }
76 // Make sure the quantization range includes 0.0f. Not all quantized
77 // Ops behave properly if 0.0f is not in the range.
78 min = std::min(min, 0.0f);
79 max = std::max(0.0f, max);
80 // min_value == max_value is a tricky case. It can occur for general
81 // tensors, and of course for scalars. The quantized ops cannot deal
82 // with this case, so we set max_value to something else.
83 // It's a tricky question what is the numerically best solution to
84 // deal with this degeneracy.
85 // TODO(petewarden): Better use a tolerance than a hard comparison?
86 if (min == max) {
87 if (std::abs(min) < 0.000001f) {
88 max = min + 1.0f;
89 } else if (min > 0) {
90 max = 2.0f * min;
91 } else {
92 max = min / 2.0f;
93 }
94 }
95 Tensor quantized_tensor(DT_QUINT8, old_tensor.shape());
96 FloatTensorToQuantizedInPlace<quint8>(old_tensor, min, max,
97 &quantized_tensor);
98
99 NodeDef quantized_const_node;
100 quantized_const_node.set_op("Const");
101 quantized_const_node.set_name(old_const_node.name() +
102 "_quantized_const");
103 SetNodeAttr("dtype", DT_QUINT8, &quantized_const_node);
104 SetNodeTensorAttr<float>("value", quantized_tensor,
105 &quantized_const_node);
106 new_nodes->push_back(quantized_const_node);
107
108 NodeDef min_node;
109 min_node.set_op("Const");
110 min_node.set_name(old_const_node.name() + "_quantized_min");
111 SetNodeAttr("dtype", DT_FLOAT, &min_node);
112 Tensor min_tensor(DT_FLOAT, {});
113 min_tensor.scalar<float>()() = min;
114 SetNodeTensorAttr<float>("value", min_tensor, &min_node);
115 new_nodes->push_back(min_node);
116
117 NodeDef max_node;
118 max_node.set_op("Const");
119 max_node.set_name(old_const_node.name() + "_quantized_max");
120 SetNodeAttr("dtype", DT_FLOAT, &max_node);
121 Tensor max_tensor(DT_FLOAT, {});
122 max_tensor.scalar<float>()() = max;
123 SetNodeTensorAttr<float>("value", max_tensor, &max_node);
124 new_nodes->push_back(max_node);
125
126 NodeDef dequantize_node;
127 dequantize_node.set_op("Dequantize");
128 dequantize_node.set_name(old_const_node.name());
129 SetNodeAttr("T", DT_QUINT8, &dequantize_node);
130 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
131 AddNodeInput(quantized_const_node.name(), &dequantize_node);
132 AddNodeInput(min_node.name(), &dequantize_node);
133 AddNodeInput(max_node.name(), &dequantize_node);
134 new_nodes->push_back(dequantize_node);
135
136 return OkStatus();
137 },
138 {}, output_graph_def));
139
140 return OkStatus();
141}
142
143REGISTER_GRAPH_TRANSFORM("quantize_weights", QuantizeWeights);
144
145} // namespace graph_transforms
146} // namespace tensorflow
147