1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_ |
17 | #define TENSORFLOW_LITE_DELEGATES_UTILS_H_ |
18 | |
19 | // Utility functions and classes for implementing delegates. |
20 | |
21 | #include <functional> |
22 | #include <limits> |
23 | #include <set> |
24 | #include <string> |
25 | #include <unordered_map> |
26 | #include <utility> |
27 | #include <vector> |
28 | |
29 | #include "tensorflow/lite/c/common.h" |
30 | #include "tensorflow/lite/util.h" |
31 | |
32 | namespace tflite { |
33 | namespace delegates { |
34 | |
35 | // Creates a new Read/Write tensor having the same shape as the original, but |
36 | // with a different type. Note that this might void existing references to |
37 | // tensors. |
38 | TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context, |
39 | const int original_tensor_index, |
40 | TfLiteType new_type, |
41 | TfLiteTensor** new_tensor, |
42 | int* new_tensor_index); |
43 | |
44 | using IsNodeSupportedFn = |
45 | std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*, |
46 | std::string* unsupported_details)>; |
47 | |
48 | // A utility class to help model graph parition. |
49 | // Note the class *needs* to be used in TfLiteDelegate::Prepare. |
50 | class GraphPartitionHelper { |
51 | public: |
52 | GraphPartitionHelper(TfLiteContext* context, |
53 | IsNodeSupportedFn is_node_supported_fn) |
54 | : context_(context), is_node_supported_fn_(is_node_supported_fn) {} |
55 | |
56 | GraphPartitionHelper(TfLiteContext* context, |
57 | const std::vector<int>& supported_node_indices) |
58 | : context_(context), |
59 | num_total_nodes_(supported_node_indices.size()), |
60 | supported_nodes_( |
61 | ConvertVectorToTfLiteIntArray(supported_node_indices)) {} |
62 | |
63 | virtual ~GraphPartitionHelper() { |
64 | TfLiteIntArrayFree(supported_nodes_); |
65 | TfLiteIntArrayFree(original_execution_plan_); |
66 | } |
67 | |
68 | // Partition the graph into node subsets such that each subset could be |
69 | // replaced with one delegate kernel (i.e. a kTfLiteBuiltinDelegate op). |
70 | // If 'unsupported_nodes_info' is provided, it will be populated with |
71 | // information about all different unsupported nodes. |
72 | virtual TfLiteStatus Partition(std::set<std::string>* unsupported_nodes_info); |
73 | |
74 | // Returns the first n largest partitions or all if #partitions is less than |
75 | // 'n' and each parition has at least (>=) 'min_nodes_per_partition' nodes. |
76 | // Note that partitions are ranked according to the number of nodes that |
77 | // a partition has, and the returned TfLiteDelegateParams objects are *owned* |
78 | // by the TfLite runtime. |
79 | // TODO(b/156707497): remove this and use GetNodesOfFirstNLargestPartitions |
80 | std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions( |
81 | int n = std::numeric_limits<int>::max(), |
82 | int min_nodes_per_partition = 0) const; |
83 | |
84 | // Returns a list of node indices of all nodes from the first n largest |
85 | // partitions. If there are fewer paritions than n, all nodes will be |
86 | // returned. The partition is ranked according to the number of nodes. |
87 | std::vector<int> GetNodesOfFirstNLargestPartitions( |
88 | int n = std::numeric_limits<int>::max(), |
89 | int min_nodes_per_partition = 0) { |
90 | // Separated implementation that can be overrided, to preserve default value |
91 | return GetNodesOfFirstNLargestPartitionsImpl(n, min_nodes_per_partition); |
92 | } |
93 | |
94 | int num_total_nodes() const { return num_total_nodes_; } |
95 | int num_supported_nodes() const { return num_supported_nodes_; } |
96 | int num_partitions() const { return partitions_.size(); } |
97 | |
98 | protected: |
99 | virtual bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, |
100 | TfLiteRegistration* registration, int node_id, |
101 | std::string* unsupported_details) { |
102 | return is_node_supported_fn_(context, node, registration, |
103 | unsupported_details); |
104 | } |
105 | virtual std::vector<int> GetNodesOfFirstNLargestPartitionsImpl( |
106 | int n, int min_nodes_per_partition); |
107 | |
108 | TfLiteContext* const context_ = nullptr; |
109 | |
110 | // Doesn't own the memory of each TfLiteDelegateParams object as it's |
111 | // managed by the TfLite runtime itself. See |
112 | // TfLiteContext::PreviewDelegatePartitioning for details. |
113 | std::vector<TfLiteDelegateParams*> partitions_; |
114 | |
115 | // Copy of (pre-delegation) execution plan obtained from TfLiteContext in |
116 | // PrepareSupportedNodes |
117 | TfLiteIntArray* original_execution_plan_ = nullptr; |
118 | |
119 | private: |
120 | // Generate a list of supported nodes (i.e. populating 'supported_nodes_') by |
121 | // iterating over all nodes (i,e. those listed in the execution_plan |
122 | // associated w/ 'context_'). |
123 | // If 'unsupported_nodes_info' is provided, it will be populated with |
124 | // information about all different unsupported nodes. |
125 | TfLiteStatus PrepareSupportedNodes( |
126 | std::set<std::string>* unsupported_nodes_info = nullptr); |
127 | |
128 | // The number of total nodes passed in for partitioning (i.e. the |
129 | // execution_plan size associated w/ 'context_') |
130 | int num_total_nodes_ = 0; |
131 | |
132 | int num_supported_nodes_ = 0; |
133 | |
134 | // Tells if a node is supported as it could be delegated. |
135 | const IsNodeSupportedFn is_node_supported_fn_ = nullptr; |
136 | |
137 | // Contains an array of supported node indices. |
138 | TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory |
139 | }; |
140 | |
141 | // Specialized partitioner for graphs that possibly contain fp16 tensors. |
142 | // |
143 | // From nodes that accept fp16 inputs, this delegates the following: |
144 | // 1. All nodes (except DEQUANTIZE) that are supported with constant fp16 inputs |
145 | // by the delegate (in the TFLite graph, these nodes take in dequantized FP32 |
146 | // outputs). |
147 | // 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first* |
148 | // delegated partition. This is because TFLite's partitioning algorithm |
149 | // greedily puts all such nodes in the first partition. |
150 | class FP16GraphPartitionHelper : public GraphPartitionHelper { |
151 | public: |
152 | FP16GraphPartitionHelper(TfLiteContext* context, |
153 | IsNodeSupportedFn is_node_supported_fn) |
154 | : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} |
155 | |
156 | protected: |
157 | // Specialized function to handle fp16 nodes. |
158 | bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, |
159 | TfLiteRegistration* registration, int node_id, |
160 | std::string* unsupported_details) override; |
161 | |
162 | // This will remap input tensors by removing FP16 to FP32 dequantized tensors. |
163 | std::vector<int> GetNodesOfFirstNLargestPartitionsImpl( |
164 | int n, int min_nodes_per_partition) override; |
165 | |
166 | private: |
167 | // This remaps fp32 inputs of the given node to their corresponding fp16 |
168 | // version, if applicable. Can be summarized as: |
169 | // fp16 -> DEQUANTIZE -> fp32 -> OP -> output |
170 | // becomes |
171 | // fp16 -> OP -> output |
172 | void RemapFp16InputTensors(TfLiteNode* node, |
173 | std::vector<int>* orig_inputs) const; |
174 | |
175 | // Performs the above remapping for all nodes in the given list, without |
176 | // tracking the original inputs. |
177 | void RemapFp16InputTensors(const std::vector<int>& nodes) const; |
178 | |
179 | // ('dequantize' here refers to fp16 DEQUANTIZE) |
180 | // Mapping of dequantize nodes' output tensor-id to its node id. |
181 | // TODO(b/156707497): Use absl hash_maps here. |
182 | std::unordered_map<int, int> constant_dequant_nodes_; |
183 | // Mapping of DEQUANTIZE node's output (fp32) to its input (fp16). |
184 | std::unordered_map<int, int> constant_dequant_map_; |
185 | }; |
186 | |
187 | } // namespace delegates |
188 | } // namespace tflite |
189 | |
190 | #endif // TENSORFLOW_LITE_DELEGATES_UTILS_H_ |
191 | |