1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_RUNTIME_GRAPH_EXECUTOR_DEBUG_GRAPH_EXECUTOR_DEBUG_H_ |
21 | #define TVM_RUNTIME_GRAPH_EXECUTOR_DEBUG_GRAPH_EXECUTOR_DEBUG_H_ |
22 | |
23 | #include <tvm/runtime/profiling.h> |
24 | |
25 | #include <string> |
26 | #include <vector> |
27 | |
28 | #include "../graph_executor.h" |
29 | |
30 | namespace tvm { |
31 | namespace runtime { |
32 | |
33 | /*! |
34 | * \brief Graph executor with debug . |
35 | * |
36 | * This is the extension of GraphExecutor class used for debugging |
37 | * TVM runtime PackedFunc API. |
38 | */ |
39 | class GraphExecutorDebug : public GraphExecutor { |
40 | public: |
41 | /*! |
42 | * \brief Run each operation in the graph and get the time per op for all ops. |
43 | * \param number The number of times to run this function for taking average. |
44 | * \param repeat The number of times to repeat the measurement. |
45 | * In total, the function will be invoked (1 + number x repeat) times, |
46 | * where the first one is warmed up and will be discarded in case |
47 | * there is lazy initialization. |
48 | * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. |
49 | * By default, one `repeat` contains `number` runs. If this parameter is set, |
50 | * the parameters `number` will be dynamically adjusted to meet the |
51 | * minimum duration requirement of one `repeat`. |
52 | * \param limit_zero_time_iterations The maximum number of repeats when |
53 | * measured time is equal to 0. It helps to avoid hanging during |
54 | * measurements. |
55 | * \param cooldown_interval_ms The cooldown interval in milliseconds between the number of repeats |
56 | * defined by `repeats_to_cooldown`. |
57 | * \param repeats_to_cooldown The number of repeats before the |
58 | * cooldown is activated. |
59 | * \return Returns a string with an encoded byte array. Where the first 8 bytes are int64_t |
60 | * representing the number of layers. Next the encoded real numbers are float32_t in the number of |
61 | * repeat multiplied by the number of layers. |
62 | */ |
63 | std::string RunIndividual(int number, int repeat, int min_repeat_ms, |
64 | int limit_zero_time_iterations, int cooldown_interval_ms, |
65 | int repeats_to_cooldown); |
66 | |
67 | std::string RunIndividualNode(int node_index, int number, int repeat, int min_repeat_ms, |
68 | int limit_zero_time_iterations, int cooldown_interval_ms, |
69 | int repeats_to_cooldown); |
70 | |
71 | std::vector<double> RunOpRPC(int index, int number, int repeat, int min_repeat_ms, |
72 | int limit_zero_time_iterations, int cooldown_interval_ms, |
73 | int repeats_to_cooldown); |
74 | |
75 | Timer RunOpHost(int index); |
76 | |
77 | /*! |
78 | * \brief GetFunction Get the function based on input. |
79 | * \param name The function which needs to be invoked. |
80 | * \param sptr_to_self Packed function pointer. |
81 | */ |
82 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self); |
83 | |
84 | /*! |
85 | * \brief Get the node index given the name of node. |
86 | * \param name The name of the node. |
87 | * \return The index of node. |
88 | */ |
89 | int GetNodeIndex(const std::string& name) const; |
90 | |
91 | /*! |
92 | * \brief Execute index-th node in the network. |
93 | * |
94 | * This method will do a partial run of the graph |
95 | * up to index-th node. |
96 | * |
97 | * \param node: The index of the node. |
98 | */ |
99 | void ExecuteNode(int node); |
100 | |
101 | /*! |
102 | * \brief Returns index-th output of node. |
103 | * |
104 | * This method will return index-th out_ind output |
105 | * of index-th node in the network. |
106 | * |
107 | * \param node: The index of the node. |
108 | * \param out_ind: The index of the output. |
109 | * \return Output array. |
110 | */ |
111 | NDArray GetNodeOutput(int node, int out_ind); |
112 | |
113 | /*! |
114 | * \brief Copy index-th node to data_out. |
115 | * |
116 | * This method will do a partial run of the graph |
117 | * from begining upto the index-th node and return output of index-th node. |
118 | * This is costly operation and suggest to use only for debug porpose. |
119 | * |
120 | * \param index: The index of the node. |
121 | * \param data_out the node data. |
122 | */ |
123 | void DebugGetNodeOutput(int index, DLTensor* data_out); |
124 | |
125 | /*! |
126 | * \brief Profile execution time of the module. |
127 | * |
128 | * We run the entire module while recording overall and per-op timing |
129 | * information. The module may be run multiple times to ensure everything is |
130 | * warmed up. This function is a more correct reflection of actual runtime of |
131 | * the module compared to GraphRuntimeDebug::RunIndividual as it runs the |
132 | * entire graph in order. |
133 | * |
134 | * \param collectors Optional user defined `MetricCollector`s to use with this profiling run. |
135 | * |
136 | * \returns A table of per-op runtimes and total times. |
137 | */ |
138 | profiling::Report Profile(Array<profiling::MetricCollector> collectors); |
139 | |
140 | private: |
141 | int last_executed_node_ = -1; |
142 | }; |
143 | |
144 | } // namespace runtime |
145 | } // namespace tvm |
146 | |
147 | #endif // TVM_RUNTIME_GRAPH_EXECUTOR_DEBUG_GRAPH_EXECUTOR_DEBUG_H_ |
148 | |