1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/delegates/utils.h"
17
18#include <algorithm>
19#include <cstdint>
20#include <cstring>
21#include <string>
22#include <vector>
23
24#include "tensorflow/lite/builtin_ops.h"
25#include "tensorflow/lite/context_util.h"
26#include "tensorflow/lite/kernels/kernel_util.h"
27
28namespace tflite {
29namespace delegates {
30
31TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
32 const int original_tensor_index,
33 TfLiteType new_type,
34 TfLiteTensor** new_tensor,
35 int* new_tensor_index) {
36 TF_LITE_ENSURE_STATUS(context->AddTensors(context, 1, new_tensor_index));
37 const TfLiteTensor& original_tensor = context->tensors[original_tensor_index];
38 *new_tensor = &context->tensors[*new_tensor_index];
39 (*new_tensor)->type = new_type;
40 (*new_tensor)->allocation_type = kTfLiteArenaRw;
41 const auto* original_dims = original_tensor.dims;
42 TfLiteIntArray* dims = TfLiteIntArrayCreate(original_dims->size);
43 for (int i = 0; i < original_dims->size; ++i) {
44 dims->data[i] = original_dims->data[i];
45 }
46 if (context->ResizeTensor(context, *new_tensor, dims) != kTfLiteOk) {
47 TF_LITE_KERNEL_LOG(context, "Could not resize new delegate tensor");
48 return kTfLiteError;
49 }
50 return kTfLiteOk;
51}
52
53TfLiteStatus GraphPartitionHelper::Partition(
54 std::set<std::string>* unsupported_nodes_info) {
55 const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
56 if (prepare_status != kTfLiteOk) return prepare_status;
57
58 TfLiteDelegateParams* partition_params_array_ = nullptr;
59 int num_partitions_ = 0;
60 if (context_->PreviewDelegatePartitioning(context_, supported_nodes_,
61 &partition_params_array_,
62 &num_partitions_) != kTfLiteOk) {
63 TF_LITE_KERNEL_LOG(context_, "Unable to preview delegate partition.\n");
64 return kTfLiteError;
65 }
66
67 for (int i = 0; i < num_partitions_; ++i) {
68 partitions_.push_back(partition_params_array_ + i);
69 }
70
71 return kTfLiteOk;
72}
73
74std::vector<TfLiteDelegateParams*>
75GraphPartitionHelper::GetFirstNLargestPartitions(
76 int n, int min_nodes_per_partition) const {
77 // In general, the number of partitions in a delegate is never likely to be
78 // high enough to cause latency issues. Also considering this is generally a
79 // one-time work, we simply unconditionally sort partitions here according to
80 // the size.
81 std::vector<TfLiteDelegateParams*> sorted_partitions(partitions_);
82 std::sort(sorted_partitions.begin(), sorted_partitions.end(),
83 [](TfLiteDelegateParams* left, TfLiteDelegateParams* right) {
84 // Reverse sort
85 return left->nodes_to_replace->size >
86 right->nodes_to_replace->size;
87 });
88
89 std::vector<TfLiteDelegateParams*> results;
90 auto p_it = sorted_partitions.begin();
91 const int total = sorted_partitions.size();
92 for (int i = 0; i < std::min(total, n); ++i, ++p_it) {
93 auto* p = (*p_it);
94 if (p->nodes_to_replace->size < min_nodes_per_partition) {
95 break;
96 }
97 results.push_back(p);
98 }
99 return results;
100}
101
102std::vector<int> GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
103 int n, int min_nodes_per_partition) {
104 auto first_n_partitions =
105 GetFirstNLargestPartitions(n, min_nodes_per_partition);
106 std::vector<int> ops_to_replace;
107 for (const auto p : first_n_partitions) {
108 auto nodes = p->nodes_to_replace;
109 ops_to_replace.insert(ops_to_replace.end(), nodes->data,
110 nodes->data + nodes->size);
111 }
112 return ops_to_replace;
113}
114
115TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
116 std::set<std::string>* unsupported_nodes_info) {
117 if (!is_node_supported_fn_) return kTfLiteOk;
118
119 TfLiteIntArray* execution_plan = nullptr;
120 auto status = context_->GetExecutionPlan(context_, &execution_plan);
121 if (status != kTfLiteOk) {
122 TF_LITE_KERNEL_LOG(context_, "Unable to get graph execution plan.\n");
123 return status;
124 }
125 // context->GetExecutionPlan invalidates memory obtained from previous calls,
126 // which is dangerous if a delegate's IsNodeSupportedFn uses it anywhere.
127 // So we store a copy to ensure validity.
128 num_total_nodes_ = execution_plan->size;
129 original_execution_plan_ = TfLiteIntArrayCreate(execution_plan->size);
130 std::memcpy(original_execution_plan_->data, execution_plan->data,
131 num_total_nodes_ * sizeof(int32_t));
132
133 supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_);
134 supported_nodes_->size = 0;
135 for (int node_id : TfLiteIntArrayView(original_execution_plan_)) {
136 TfLiteNode* node;
137 TfLiteRegistration* registration;
138
139 status = context_->GetNodeAndRegistration(context_, node_id, &node,
140 &registration);
141 if (status != kTfLiteOk) {
142 TF_LITE_KERNEL_LOG(context_,
143 "Couldn't get node and registration info for op: %d\n",
144 node_id);
145 supported_nodes_->size = 0;
146 return status;
147 }
148
149 std::string unsupported_details;
150 if (IsNodeSupported(context_, node, registration, node_id,
151 &unsupported_details)) {
152 supported_nodes_->data[supported_nodes_->size++] = node_id;
153 } else if (unsupported_nodes_info) {
154 std::string node_info = GetOpNameByRegistration(*registration);
155 node_info.append(": ");
156 node_info.append(unsupported_details);
157 unsupported_nodes_info->insert(node_info);
158 }
159 }
160
161 num_supported_nodes_ = supported_nodes_->size;
162 return kTfLiteOk;
163}
164
165std::vector<int>
166FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
167 int n, int min_nodes_per_partition) {
168 std::vector<int> ops_to_replace;
169
170 if (num_supported_nodes() + constant_dequant_nodes_.size() ==
171 num_total_nodes()) {
172 // Scenario 1: Full Delegation.
173 // We delegate all nodes in this case to avoid unnecessary partitions due to
174 // FP16 DEQUANT nodes. This is safe to do since no non-delegated op needs
175 // the output of such a DEQUANT.
176 for (int node_id : TfLiteIntArrayView(original_execution_plan_)) {
177 ops_to_replace.push_back(node_id);
178 }
179 } else {
180 // Scenario 2: Partial Delegation.
181 // In this case, we just select the top 'n' applicable node subsets to
182 // delegate, devoid of any FP16 DEQUANT ops. Handling the latter is tricky
183 // in partial delegation cases & causes edge cases if non-delegated nodes
184 // consume their output. So we keep all of them on CPU.
185 auto first_n_partitions =
186 GetFirstNLargestPartitions(n, min_nodes_per_partition);
187 if (first_n_partitions.empty()) return ops_to_replace;
188 for (int i = 0; i < first_n_partitions.size(); ++i) {
189 auto nodes = first_n_partitions[i]->nodes_to_replace;
190 ops_to_replace.insert(ops_to_replace.end(), nodes->data,
191 nodes->data + nodes->size);
192 }
193 }
194
195 // Modify the inputs of relevant ops that support fp16 constants.
196 RemapFp16InputTensors(ops_to_replace);
197 return ops_to_replace;
198}
199
200bool FP16GraphPartitionHelper::IsNodeSupported(
201 TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
202 int node_id, std::string* unsupported_details) {
203 if (registration->builtin_code == kTfLiteBuiltinDequantize) {
204 auto& dequantize_input = context_->tensors[node->inputs->data[0]];
205 if (dequantize_input.type == kTfLiteFloat16 &&
206 IsConstantTensor(&dequantize_input)) {
207 // Update mappings if this node is a fp16 DEQUANTIZE node that
208 // works on a **constant** input tensor.
209 // If the input is not a constant, the remapping that we do here will
210 // cause bugs due to preceding ops such as DENSIFY.
211 constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
212 constant_dequant_nodes_[node->outputs->data[0]] = node_id;
213 // We do not accept these ops right now.
214 // This is done to support use-cases where a DEQUANTIZE output might be
215 // consumed by a CPU op.
216 return false;
217 }
218 }
219
220 // To check if a (possibly) FP16 node is supported, we temporarily point the
221 // node's inputs to the original fp16 tensors. This 'mutated' node is then
222 // passed to the base IsNodeSupported function for checking. After the check,
223 // we remap the original node inputs, so that the TFLite graph remains the
224 // same.
225 std::vector<int> orig_inputs;
226 if (!constant_dequant_nodes_.empty()) {
227 RemapFp16InputTensors(node, &orig_inputs);
228 }
229
230 const auto is_supported = GraphPartitionHelper::IsNodeSupported(
231 context, node, registration, node_id, unsupported_details);
232
233 if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
234 // Remapping happened. Restore original inputs.
235 for (int j = 0; j < node->inputs->size; ++j) {
236 node->inputs->data[j] = orig_inputs[j];
237 }
238 }
239 return is_supported;
240}
241
242void FP16GraphPartitionHelper::RemapFp16InputTensors(
243 const std::vector<int>& nodes) const {
244 for (int node_id : nodes) {
245 TfLiteNode* node;
246 TfLiteRegistration* registration;
247 TfLiteStatus status = context_->GetNodeAndRegistration(
248 context_, node_id, &node, &registration);
249 if (status != kTfLiteOk) {
250 TF_LITE_KERNEL_LOG(context_,
251 "Couldn't get node and registration info for op: %d\n",
252 node_id);
253 }
254 RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
255 }
256}
257
258void FP16GraphPartitionHelper::RemapFp16InputTensors(
259 TfLiteNode* node, std::vector<int>* orig_inputs) const {
260 TfLiteIntArray* inputs = node->inputs;
261 auto inputs_view = TfLiteIntArrayView(inputs);
262 // Prepopulate 'orig_inputs' first and clear it if there's no input from a
263 // dequant op.
264 if (orig_inputs) {
265 orig_inputs->clear();
266 orig_inputs->reserve(inputs->size);
267 for (auto tid : inputs_view) {
268 orig_inputs->push_back(tid);
269 }
270 }
271 // Fix this node's inputs (i.e. prune out the preceding dequantize node) in
272 // order to test if it is supported.
273 bool is_remapped = false;
274 for (int j = 0; j < inputs->size; ++j) {
275 const int input_tid = inputs->data[j];
276 const auto it = constant_dequant_map_.find(input_tid);
277 if (it != constant_dequant_map_.end()) {
278 inputs->data[j] = it->second;
279 is_remapped = true;
280 }
281 }
282 if (!is_remapped && orig_inputs) orig_inputs->clear();
283}
284
285} // namespace delegates
286} // namespace tflite
287