1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/node_def.pb.h" |
17 | #include "tensorflow/core/lib/strings/str_util.h" |
18 | #include "tensorflow/core/platform/env.h" |
19 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
20 | |
21 | namespace tensorflow { |
22 | namespace graph_transforms { |
23 | |
24 | struct MinMaxRecord { |
25 | string name; |
26 | float min; |
27 | float max; |
28 | }; |
29 | |
30 | // Try to parse a log file containing loosely-structured lines, some of which |
31 | // are the min/max logs we want. |
32 | Status (const string& log_file_name, |
33 | std::vector<MinMaxRecord>* records) { |
34 | string file_data; |
35 | TF_RETURN_IF_ERROR( |
36 | ReadFileToString(Env::Default(), log_file_name, &file_data)); |
37 | const string print_suffix("__print__" ); |
38 | const string requant_prefix("__requant_min_max:" ); |
39 | std::vector<string> file_lines = str_util::Split(file_data, '\n'); |
40 | for (const string& file_line : file_lines) { |
41 | // We expect to find a line with components separated by semicolons, so to |
42 | // start make sure that the basic structure is in place/ |
43 | if (!absl::StrContains(file_line, print_suffix + ";" + requant_prefix)) { |
44 | continue; |
45 | } |
46 | std::vector<string> line_parts = str_util::Split(file_line, ';'); |
47 | if (line_parts.size() < 2) { |
48 | continue; |
49 | } |
50 | // Now we want to figure out which components have the name and min max |
51 | // values by scanning for the prefix we expect. |
52 | bool min_max_found = false; |
53 | int min_max_index; |
54 | for (int i = 1; i < line_parts.size(); ++i) { |
55 | if (absl::StartsWith(line_parts[i], requant_prefix)) { |
56 | min_max_found = true; |
57 | min_max_index = i; |
58 | } |
59 | } |
60 | if (!min_max_found) { |
61 | continue; |
62 | } |
63 | // Finally we need to break out the values from the strings, and parse them |
64 | // into a form we can use. |
65 | string min_max_string = line_parts[min_max_index]; |
66 | std::vector<string> min_max_parts = str_util::Split(min_max_string, '['); |
67 | if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) { |
68 | continue; |
69 | } |
70 | string min_string = min_max_parts[1]; |
71 | std::vector<string> min_string_parts = str_util::Split(min_string, ']'); |
72 | if (min_string_parts.size() != 2) { |
73 | continue; |
74 | } |
75 | string min_number_string = min_string_parts[0]; |
76 | float min; |
77 | if (!strings::safe_strtof(min_number_string.c_str(), &min)) { |
78 | continue; |
79 | } |
80 | string max_string = min_max_parts[2]; |
81 | std::vector<string> max_string_parts = str_util::Split(max_string, ']'); |
82 | if (max_string_parts.size() != 2) { |
83 | continue; |
84 | } |
85 | string max_number_string = max_string_parts[0]; |
86 | float max; |
87 | if (!strings::safe_strtof(max_number_string.c_str(), &max)) { |
88 | continue; |
89 | } |
90 | StringPiece name_string = line_parts[min_max_index - 1]; |
91 | if (!str_util::EndsWith(name_string, print_suffix)) { |
92 | continue; |
93 | } |
94 | string name( |
95 | name_string.substr(0, name_string.size() - print_suffix.size())); |
96 | records->push_back({name, min, max}); |
97 | } |
98 | return OkStatus(); |
99 | } |
100 | |
101 | // Uses the observed min/max values for requantization captured in a log file to |
102 | // replace costly RequantizationRange ops with simple Consts. |
103 | Status FreezeRequantizationRanges(const GraphDef& input_graph_def, |
104 | const TransformFuncContext& context, |
105 | GraphDef* output_graph_def) { |
106 | string min_max_log_file; |
107 | TF_RETURN_IF_ERROR( |
108 | context.GetOneStringParameter("min_max_log_file" , "" , &min_max_log_file)); |
109 | if (min_max_log_file.empty()) { |
110 | return errors::InvalidArgument( |
111 | "You must pass a file name to min_max_log_file" ); |
112 | } |
113 | float min_percentile; |
114 | TF_RETURN_IF_ERROR( |
115 | context.GetOneFloatParameter("min_percentile" , 5.0f, &min_percentile)); |
116 | float max_percentile; |
117 | TF_RETURN_IF_ERROR( |
118 | context.GetOneFloatParameter("max_percentile" , 5.0f, &max_percentile)); |
119 | |
120 | std::vector<MinMaxRecord> records; |
121 | TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records)); |
122 | if (records.empty()) { |
123 | return errors::InvalidArgument( |
124 | "No min/max range logs were found in the log file" ); |
125 | } |
126 | |
127 | std::map<string, const NodeDef*> node_map; |
128 | MapNamesToNodes(input_graph_def, &node_map); |
129 | bool any_missing_nodes = false; |
130 | std::map<string, std::vector<MinMaxRecord>> records_by_node; |
131 | for (const MinMaxRecord& record : records) { |
132 | records_by_node[record.name].push_back(record); |
133 | if (!node_map.count(record.name)) { |
134 | any_missing_nodes = true; |
135 | LOG(WARNING) << "Node from log not found in graph: " << record.name; |
136 | } |
137 | } |
138 | if (any_missing_nodes) { |
139 | return errors::InvalidArgument( |
140 | "Nodes were found in the log file that aren't present in the graph" ); |
141 | } |
142 | |
143 | // Now find out the largest and smallest min/max values for the node. |
144 | std::map<string, std::pair<float, float>> range_for_nodes; |
145 | for (const auto& record_info : records_by_node) { |
146 | const string& name = record_info.first; |
147 | const std::vector<MinMaxRecord> records = record_info.second; |
148 | std::vector<float> mins; |
149 | std::vector<float> maxs; |
150 | for (const MinMaxRecord& record : records) { |
151 | mins.push_back(record.min); |
152 | maxs.push_back(record.max); |
153 | } |
154 | std::sort(mins.begin(), mins.end()); |
155 | std::sort(maxs.begin(), maxs.end()); |
156 | int min_index = std::round(mins.size() * (min_percentile / 100.0f)); |
157 | if (min_index < 0) { |
158 | min_index = 0; |
159 | } |
160 | int max_index = |
161 | std::round(maxs.size() * (1.0f - (max_percentile / 100.0f))); |
162 | if (max_index > (maxs.size() - 1)) { |
163 | max_index = maxs.size() - 1; |
164 | } |
165 | const float min = mins[min_index]; |
166 | const float max = maxs[max_index]; |
167 | range_for_nodes[name] = {min, max}; |
168 | } |
169 | std::map<string, string> inputs_to_rename; |
170 | GraphDef frozen_graph_def; |
171 | for (const NodeDef& node : input_graph_def.node()) { |
172 | if (range_for_nodes.count(node.name())) { |
173 | if (node.op() != "RequantizationRange" ) { |
174 | return errors::InvalidArgument( |
175 | "Node is expected to be a RequantizationRange op: " , node.name(), |
176 | ", but is: " , node.op()); |
177 | } |
178 | const float min_value = range_for_nodes.at(node.name()).first; |
179 | NodeDef* min_node = frozen_graph_def.mutable_node()->Add(); |
180 | min_node->set_op("Const" ); |
181 | min_node->set_name(node.name() + "/frozen_min" ); |
182 | SetNodeAttr("dtype" , DT_FLOAT, min_node); |
183 | Tensor min_tensor(DT_FLOAT, {}); |
184 | min_tensor.flat<float>()(0) = min_value; |
185 | SetNodeTensorAttr<float>("value" , min_tensor, min_node); |
186 | inputs_to_rename[node.name() + ":0" ] = min_node->name() + ":0" ; |
187 | |
188 | const float max_value = range_for_nodes.at(node.name()).second; |
189 | NodeDef* max_node = frozen_graph_def.mutable_node()->Add(); |
190 | max_node->set_op("Const" ); |
191 | max_node->set_name(node.name() + "/frozen_max" ); |
192 | SetNodeAttr("dtype" , DT_FLOAT, max_node); |
193 | Tensor max_tensor(DT_FLOAT, {}); |
194 | max_tensor.flat<float>()(0) = max_value; |
195 | SetNodeTensorAttr<float>("value" , max_tensor, max_node); |
196 | inputs_to_rename[node.name() + ":1" ] = max_node->name() + ":0" ; |
197 | } else { |
198 | NodeDef* new_node = frozen_graph_def.mutable_node()->Add(); |
199 | *new_node = node; |
200 | } |
201 | } |
202 | return RenameNodeInputs(frozen_graph_def, inputs_to_rename, |
203 | std::unordered_set<string>(), output_graph_def); |
204 | } |
205 | |
206 | REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges" , |
207 | FreezeRequantizationRanges); |
208 | |
209 | } // namespace graph_transforms |
210 | } // namespace tensorflow |
211 | |