1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def.pb.h"
17#include "tensorflow/core/lib/strings/str_util.h"
18#include "tensorflow/core/platform/env.h"
19#include "tensorflow/tools/graph_transforms/transform_utils.h"
20
21namespace tensorflow {
22namespace graph_transforms {
23
24struct MinMaxRecord {
25 string name;
26 float min;
27 float max;
28};
29
30// Try to parse a log file containing loosely-structured lines, some of which
31// are the min/max logs we want.
32Status ExtractMinMaxRecords(const string& log_file_name,
33 std::vector<MinMaxRecord>* records) {
34 string file_data;
35 TF_RETURN_IF_ERROR(
36 ReadFileToString(Env::Default(), log_file_name, &file_data));
37 const string print_suffix("__print__");
38 const string requant_prefix("__requant_min_max:");
39 std::vector<string> file_lines = str_util::Split(file_data, '\n');
40 for (const string& file_line : file_lines) {
41 // We expect to find a line with components separated by semicolons, so to
42 // start make sure that the basic structure is in place/
43 if (!absl::StrContains(file_line, print_suffix + ";" + requant_prefix)) {
44 continue;
45 }
46 std::vector<string> line_parts = str_util::Split(file_line, ';');
47 if (line_parts.size() < 2) {
48 continue;
49 }
50 // Now we want to figure out which components have the name and min max
51 // values by scanning for the prefix we expect.
52 bool min_max_found = false;
53 int min_max_index;
54 for (int i = 1; i < line_parts.size(); ++i) {
55 if (absl::StartsWith(line_parts[i], requant_prefix)) {
56 min_max_found = true;
57 min_max_index = i;
58 }
59 }
60 if (!min_max_found) {
61 continue;
62 }
63 // Finally we need to break out the values from the strings, and parse them
64 // into a form we can use.
65 string min_max_string = line_parts[min_max_index];
66 std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
67 if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
68 continue;
69 }
70 string min_string = min_max_parts[1];
71 std::vector<string> min_string_parts = str_util::Split(min_string, ']');
72 if (min_string_parts.size() != 2) {
73 continue;
74 }
75 string min_number_string = min_string_parts[0];
76 float min;
77 if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
78 continue;
79 }
80 string max_string = min_max_parts[2];
81 std::vector<string> max_string_parts = str_util::Split(max_string, ']');
82 if (max_string_parts.size() != 2) {
83 continue;
84 }
85 string max_number_string = max_string_parts[0];
86 float max;
87 if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
88 continue;
89 }
90 StringPiece name_string = line_parts[min_max_index - 1];
91 if (!str_util::EndsWith(name_string, print_suffix)) {
92 continue;
93 }
94 string name(
95 name_string.substr(0, name_string.size() - print_suffix.size()));
96 records->push_back({name, min, max});
97 }
98 return OkStatus();
99}
100
101// Uses the observed min/max values for requantization captured in a log file to
102// replace costly RequantizationRange ops with simple Consts.
103Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
104 const TransformFuncContext& context,
105 GraphDef* output_graph_def) {
106 string min_max_log_file;
107 TF_RETURN_IF_ERROR(
108 context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
109 if (min_max_log_file.empty()) {
110 return errors::InvalidArgument(
111 "You must pass a file name to min_max_log_file");
112 }
113 float min_percentile;
114 TF_RETURN_IF_ERROR(
115 context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
116 float max_percentile;
117 TF_RETURN_IF_ERROR(
118 context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
119
120 std::vector<MinMaxRecord> records;
121 TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
122 if (records.empty()) {
123 return errors::InvalidArgument(
124 "No min/max range logs were found in the log file");
125 }
126
127 std::map<string, const NodeDef*> node_map;
128 MapNamesToNodes(input_graph_def, &node_map);
129 bool any_missing_nodes = false;
130 std::map<string, std::vector<MinMaxRecord>> records_by_node;
131 for (const MinMaxRecord& record : records) {
132 records_by_node[record.name].push_back(record);
133 if (!node_map.count(record.name)) {
134 any_missing_nodes = true;
135 LOG(WARNING) << "Node from log not found in graph: " << record.name;
136 }
137 }
138 if (any_missing_nodes) {
139 return errors::InvalidArgument(
140 "Nodes were found in the log file that aren't present in the graph");
141 }
142
143 // Now find out the largest and smallest min/max values for the node.
144 std::map<string, std::pair<float, float>> range_for_nodes;
145 for (const auto& record_info : records_by_node) {
146 const string& name = record_info.first;
147 const std::vector<MinMaxRecord> records = record_info.second;
148 std::vector<float> mins;
149 std::vector<float> maxs;
150 for (const MinMaxRecord& record : records) {
151 mins.push_back(record.min);
152 maxs.push_back(record.max);
153 }
154 std::sort(mins.begin(), mins.end());
155 std::sort(maxs.begin(), maxs.end());
156 int min_index = std::round(mins.size() * (min_percentile / 100.0f));
157 if (min_index < 0) {
158 min_index = 0;
159 }
160 int max_index =
161 std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
162 if (max_index > (maxs.size() - 1)) {
163 max_index = maxs.size() - 1;
164 }
165 const float min = mins[min_index];
166 const float max = maxs[max_index];
167 range_for_nodes[name] = {min, max};
168 }
169 std::map<string, string> inputs_to_rename;
170 GraphDef frozen_graph_def;
171 for (const NodeDef& node : input_graph_def.node()) {
172 if (range_for_nodes.count(node.name())) {
173 if (node.op() != "RequantizationRange") {
174 return errors::InvalidArgument(
175 "Node is expected to be a RequantizationRange op: ", node.name(),
176 ", but is: ", node.op());
177 }
178 const float min_value = range_for_nodes.at(node.name()).first;
179 NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
180 min_node->set_op("Const");
181 min_node->set_name(node.name() + "/frozen_min");
182 SetNodeAttr("dtype", DT_FLOAT, min_node);
183 Tensor min_tensor(DT_FLOAT, {});
184 min_tensor.flat<float>()(0) = min_value;
185 SetNodeTensorAttr<float>("value", min_tensor, min_node);
186 inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
187
188 const float max_value = range_for_nodes.at(node.name()).second;
189 NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
190 max_node->set_op("Const");
191 max_node->set_name(node.name() + "/frozen_max");
192 SetNodeAttr("dtype", DT_FLOAT, max_node);
193 Tensor max_tensor(DT_FLOAT, {});
194 max_tensor.flat<float>()(0) = max_value;
195 SetNodeTensorAttr<float>("value", max_tensor, max_node);
196 inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
197 } else {
198 NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
199 *new_node = node;
200 }
201 }
202 return RenameNodeInputs(frozen_graph_def, inputs_to_rename,
203 std::unordered_set<string>(), output_graph_def);
204}
205
206REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
207 FreezeRequantizationRanges);
208
209} // namespace graph_transforms
210} // namespace tensorflow
211