1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/common_runtime/constant_folding.h" |
19 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
20 | #include "tensorflow/core/common_runtime/threadpool_device.h" |
21 | #include "tensorflow/core/graph/node_builder.h" |
22 | #include "tensorflow/core/graph/subgraph.h" |
23 | #include "tensorflow/core/kernels/quantization_utils.h" |
24 | #include "tensorflow/core/platform/init_main.h" |
25 | #include "tensorflow/core/public/session.h" |
26 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace graph_transforms { |
30 | |
31 | // Rounds any large float constants to the specified number of levels. |
32 | Status RoundWeights(const GraphDef& input_graph_def, |
33 | const TransformFuncContext& context, |
34 | GraphDef* output_graph_def) { |
35 | int32_t num_steps; |
36 | TF_RETURN_IF_ERROR( |
37 | context.GetOneInt32Parameter("num_steps" , 256, &num_steps)); |
38 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
39 | input_graph_def, {"Const" }, |
40 | [num_steps](const NodeMatch& match, const std::set<string>& input_nodes, |
41 | const std::set<string>& output_nodes, |
42 | std::vector<NodeDef>* new_nodes) { |
43 | const NodeDef& old_const_node = match.node; |
44 | if (!old_const_node.attr().count("dtype" )) { |
45 | return errors::InvalidArgument("No 'dtype' attribute for Const node " , |
46 | old_const_node.name()); |
47 | } |
48 | if (!old_const_node.attr().count("value" )) { |
49 | return errors::InvalidArgument("No 'value' attribute for Const node " , |
50 | old_const_node.name()); |
51 | } |
52 | const DataType old_dtype = old_const_node.attr().at("dtype" ).type(); |
53 | Tensor old_tensor; |
54 | if (!old_tensor.FromProto(old_const_node.attr().at("value" ).tensor())) { |
55 | return errors::InvalidArgument("Decoding Tensor failed for node" , |
56 | old_const_node.name()); |
57 | } |
58 | const size_t num_elements = old_tensor.NumElements(); |
59 | // If this isn't a float constant, or it's too small, then reuse the |
60 | // same node with no changes. The size is important because small |
61 | // constants tend to be used for more accuracy-sensitive calculations, |
62 | // and the benefit of shrinking them is very marginal. |
63 | if ((old_dtype != DT_FLOAT) || (num_elements < 16)) { |
64 | new_nodes->push_back(old_const_node); |
65 | return OkStatus(); |
66 | } |
67 | const float* old_values = old_tensor.flat<float>().data(); |
68 | float min = std::numeric_limits<float>::max(); |
69 | float max = std::numeric_limits<float>::min(); |
70 | for (int i = 0; i < num_elements; ++i) { |
71 | const float value = old_values[i]; |
72 | min = std::min(min, value); |
73 | max = std::max(max, value); |
74 | } |
75 | // min_value == max_value is a tricky case. It can occur for general |
76 | // tensors, and of course for scalars. The quantized ops cannot deal |
77 | // with this case, so we set max_value to something else. |
78 | // It's a tricky question what is the numerically best solution to |
79 | // deal with this degeneracy. |
80 | // TODO(petewarden): Better use a tolerance than a hard comparison? |
81 | if (min == max) { |
82 | if (std::abs(min) < 0.000001f) { |
83 | max = min + 1.0f; |
84 | } else if (min > 0) { |
85 | max = 2.0f * min; |
86 | } else { |
87 | min = 2.0f * max; |
88 | } |
89 | } |
90 | Tensor rounded_tensor(DT_FLOAT, old_tensor.shape()); |
91 | float* rounded_values = rounded_tensor.flat<float>().data(); |
92 | const float bucket_width = (max - min) / num_steps; |
93 | for (int i = 0; i < num_elements; ++i) { |
94 | const int32_t bucket = |
95 | std::floor((old_values[i] - min) / bucket_width); |
96 | rounded_values[i] = min + (bucket_width * (bucket + 0.5f)); |
97 | } |
98 | |
99 | NodeDef rounded_const_node; |
100 | rounded_const_node.set_op("Const" ); |
101 | rounded_const_node.set_name(old_const_node.name()); |
102 | SetNodeAttr("dtype" , DT_FLOAT, &rounded_const_node); |
103 | SetNodeTensorAttr<float>("value" , rounded_tensor, &rounded_const_node); |
104 | new_nodes->push_back(rounded_const_node); |
105 | |
106 | return OkStatus(); |
107 | }, |
108 | {}, output_graph_def)); |
109 | |
110 | return OkStatus(); |
111 | } |
112 | |
113 | REGISTER_GRAPH_TRANSFORM("round_weights" , RoundWeights); |
114 | |
115 | } // namespace graph_transforms |
116 | } // namespace tensorflow |
117 | |