1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_LITE_GRAPH_INFO_H_ |
16 | #define TENSORFLOW_LITE_GRAPH_INFO_H_ |
17 | |
18 | #include <stddef.h> |
19 | |
20 | #include <cstdint> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/lite/c/common.h" |
25 | |
26 | namespace tflite { |
27 | |
28 | // Basic information about an inference graph, where execution nodes |
29 | // are connected via tensors. |
30 | class GraphInfo { |
31 | public: |
32 | virtual ~GraphInfo() {} |
33 | |
34 | // Total number of tensors in the graph. This should be cached when possible. |
35 | virtual size_t num_tensors() const = 0; |
36 | |
37 | // Returns a tensor given its index which is expected to be between 0 and |
38 | // num_tensors(). Use tensors() below for iteration as it is much faster. |
39 | virtual TfLiteTensor* tensor(size_t index) = 0; |
40 | |
41 | // Returns all tensors in the graph |
42 | virtual TfLiteTensor* tensors() = 0; |
43 | |
44 | // Number of nodes in the current execution plan. |
45 | virtual size_t num_execution_nodes() const = 0; |
46 | |
47 | // Total number of known nodes, which may include nodes that are no longer in |
48 | // the execution plan. This happens in case of applying multiple delegates. |
49 | // Should be >= num_execution_nodes() |
50 | virtual size_t num_total_nodes() const = 0; |
51 | |
52 | // Returns a node given its index in the execution plan, which is expected to |
53 | // be between 0 and num_execution_nodes(). |
54 | virtual const TfLiteNode& node(size_t index) const = 0; |
55 | |
56 | // Returns an implementation-specific node index which may be different from |
57 | // execution-plan index. |
58 | // Expected to be between 0 and num_total_nodes(). |
59 | virtual size_t node_index(size_t index) const = 0; |
60 | |
61 | // Returns the indices of the input tensors. |
62 | virtual const std::vector<int>& inputs() const = 0; |
63 | |
64 | // Returns the indices of the output tensors. |
65 | virtual const std::vector<int>& outputs() const = 0; |
66 | |
67 | // Returns the indices of the variable tensors. |
68 | virtual const std::vector<int>& variables() const = 0; |
69 | }; |
70 | |
71 | // Represents a subset of nodes in a TensorFlow Lite graph. |
72 | struct NodeSubset { |
73 | enum Type { |
74 | kTfUnexplored = 0, // temporarily used during creation |
75 | kTfPartition, |
76 | kTfNonPartition |
77 | }; |
78 | Type type = kTfUnexplored; |
79 | // Nodes within the node sub set |
80 | std::vector<int> nodes; |
81 | // Tensors that stride output from another node sub set that this depends on, |
82 | // or global inputs to the TensorFlow Lite full graph. |
83 | std::vector<int> input_tensors; |
84 | // Outputs that are consumed by other node sub sets or are global output |
85 | // tensors. All output tensors of the nodes in the node sub set that do not |
86 | // appear in this list are intermediate results that can be potentially |
87 | // elided. |
88 | std::vector<int> output_tensors; |
89 | }; |
90 | |
91 | // Node edge.second depends on node edge.first. |
92 | using ControlEdge = std::pair<int32_t, int32_t>; |
93 | using ControlEdges = std::vector<ControlEdge>; |
94 | |
95 | // Partitions a list of node indices `nodes_to_partition` into node subsets. |
96 | // Each node subset is in dependency order (i.e. all members of the node subsets |
97 | // can be executed in the order they occur). Maintains the relative ordering of |
98 | // nodes that have their `might_have_side_effects` attribute set. `node_subsets` |
99 | // is assumed to be empty. |
100 | TfLiteStatus PartitionGraphIntoIndependentNodeSubsets( |
101 | const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, |
102 | std::vector<NodeSubset>* node_subsets); |
103 | |
104 | // Partitions a list of node indices `nodes_to_partition` into node subsets. |
105 | // Each node subset is in dependency order (i.e. all members of the node subset |
106 | // can be executed in the order they occur). `control_edges` specified a control |
107 | // dependency DAG on the nodes contained in `info`. The resulting partitioning |
108 | // will respect these control dependencies. This way, restrictions (in addition |
109 | // to the nodes' data dependencies) can be imposed on the ultimate execution |
110 | // order of the graph. |
111 | // |
112 | // (Example: with `control_edges.empty()` and `nodes_to_partition == {2, 3}`, |
113 | // the graph |
114 | // /------------\ |
115 | // | v |
116 | // 0 --> 1 --> 2* --> 3* 4 --> 5 |
117 | // | ^ |
118 | // \-------------------/ |
119 | // |
120 | // will be partitioned as {{0, 1, 4}, {2, 3}, {5}}, since data dependencies |
121 | // (notated '-->') allow for execution of 4 immediately after 1. |
122 | // |
123 | // With an additional control dependency `control_edges == {{3, 4}}` (notated |
124 | // '==>'), execution of node 4 requires prior execution of node 3: |
125 | // |
126 | // /------------\ |
127 | // | v |
128 | // 0 --> 1 --> 2* --> 3* ==> 4 --> 5 |
129 | // | ^ |
130 | // \-------------------/ |
131 | // |
132 | // and the partitioning will be {{0, 1}, {2, 3}, {4, 5}}.) |
133 | // |
134 | // `node_subsets` is assumed to be empty. |
135 | TfLiteStatus PartitionGraphIntoIndependentNodeSubsets( |
136 | const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, |
137 | const ControlEdges& control_edges, std::vector<NodeSubset>* node_subsets); |
138 | |
139 | } // namespace tflite |
140 | |
141 | #endif // TENSORFLOW_LITE_GRAPH_INFO_H_ |
142 | |