1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <algorithm> |
16 | #include <memory> |
17 | #include <string> |
18 | #include <unordered_map> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" |
22 | #include "tensorflow/lite/toco/model.h" |
23 | #include "tensorflow/lite/toco/tooling_util.h" |
24 | #include "tensorflow/core/platform/logging.h" |
25 | |
26 | namespace toco { |
27 | |
28 | ::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model, |
29 | std::size_t op_index, |
30 | bool* modified) { |
31 | *modified = false; |
32 | const auto fakequant_it = model->operators.begin() + op_index; |
33 | auto* fakequant_base_op = fakequant_it->get(); |
34 | if (fakequant_base_op->type != OperatorType::kFakeQuant) { |
35 | return ::tensorflow::OkStatus(); |
36 | } |
37 | auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op); |
38 | |
39 | if (fakequant_op->minmax) { |
40 | // Already resolved. |
41 | return ::tensorflow::OkStatus(); |
42 | } |
43 | |
44 | CHECK_EQ(fakequant_op->inputs.size(), 3); |
45 | // We need to yield until the min and max parameters have been |
46 | // resolved to constant arrays. |
47 | for (int i = 1; i <= 2; i++) { |
48 | if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) { |
49 | return ::tensorflow::OkStatus(); |
50 | } |
51 | } |
52 | |
53 | // Obtain the final min/max values |
54 | const auto& min_array = model->GetArray(fakequant_op->inputs[1]); |
55 | const auto& max_array = model->GetArray(fakequant_op->inputs[2]); |
56 | CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); |
57 | CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); |
58 | fakequant_op->minmax = std::make_unique<MinMax>(); |
59 | MinMax& minmax = *fakequant_op->minmax; |
60 | minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0]; |
61 | minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0]; |
62 | // We always want [min, max] to contain 0. |
63 | if (minmax.min > 0 || minmax.max < 0) { |
64 | LOG(WARNING) << "For " << LogName(*fakequant_op) << " the MinMax range " |
65 | << "[" << minmax.min << ", " << minmax.max |
66 | << "] does not contain 0. " |
67 | << "Proceeding by tweaking it to contain 0, which will result " |
68 | "in poor accuracy." ; |
69 | } |
70 | minmax.min = std::min(minmax.min, 0.); |
71 | minmax.max = std::max(minmax.max, 0.); |
72 | |
73 | // We won't use the input arrays that provided these min and max |
74 | // values, anymore. Delete them unless they are used by something |
75 | // else. |
76 | for (int i = 1; i <= 2; i++) { |
77 | DeleteArrayIfUnusedOutsideOfOp(fakequant_op->inputs[i], fakequant_op, |
78 | model); |
79 | } |
80 | fakequant_op->inputs.resize(1); |
81 | *modified = true; |
82 | return ::tensorflow::OkStatus(); |
83 | } |
84 | |
85 | } // namespace toco |
86 | |