1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/graph_executor/graph_executor_factory.h |
22 | * \brief Graph executor factory creating graph executor. |
23 | */ |
24 | |
25 | #ifndef TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_ |
26 | #define TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_ |
27 | |
28 | #include <tvm/runtime/c_runtime_api.h> |
29 | #include <tvm/runtime/module.h> |
30 | #include <tvm/runtime/ndarray.h> |
31 | #include <tvm/runtime/packed_func.h> |
32 | |
33 | #include <algorithm> |
34 | #include <functional> |
35 | #include <numeric> |
36 | #include <string> |
37 | #include <unordered_map> |
38 | #include <vector> |
39 | |
40 | #include "./graph_executor.h" |
41 | |
42 | namespace tvm { |
43 | namespace runtime { |
44 | |
45 | class TVM_DLL GraphExecutorFactory : public runtime::ModuleNode { |
46 | public: |
47 | /*! |
48 | * \brief Construct the GraphExecutorFactory. |
49 | * \param graph_json The execution graph. |
50 | * \param params The params of graph. |
51 | * \param module_name The module name of graph. |
52 | */ |
53 | GraphExecutorFactory(const std::string& graph_json, |
54 | const std::unordered_map<std::string, tvm::runtime::NDArray>& params, |
55 | const std::string& module_name = "default" ); |
56 | |
57 | /*! |
58 | * \brief Get member function to front-end |
59 | * \param name The name of the function. |
60 | * \param sptr_to_self The pointer to the module node. |
61 | * \return The corresponding member function. |
62 | */ |
63 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final; |
64 | |
65 | /*! |
66 | * \return The type key of the executor. |
67 | */ |
68 | const char* type_key() const final { return "GraphExecutorFactory" ; } |
69 | |
70 | /*! |
71 | * \brief Save the module to binary stream. |
72 | * \param stream The binary stream to save to. |
73 | */ |
74 | void SaveToBinary(dmlc::Stream* stream) override; |
75 | |
76 | /*! |
77 | * \brief Create a specific executor module |
78 | * \param devs The device of the host and devices where graph nodes will be |
79 | * executed on. |
80 | * \return created executor module |
81 | */ |
82 | Module ExecutorCreate(const std::vector<Device>& devs); |
83 | |
84 | /*! |
85 | * \brief Create a specific debug executor module |
86 | * \param devs The device of the host and devices where graph nodes will be |
87 | * executed on. |
88 | * \return created debug executor module |
89 | */ |
90 | Module DebugExecutorCreate(const std::vector<Device>& devs); |
91 | |
92 | /*! |
93 | * \brief Create a specific cuda graph executor module |
94 | * \param devs The device of the host and devices where graph nodes will be |
95 | * executed on. |
96 | * \return created cuda graph executor module |
97 | */ |
98 | Module CudaGraphExecutorCreate(const std::vector<Device>& devs); |
99 | |
100 | /*! |
101 | * \brief Set params. |
102 | * \param graph_executor The graph executor we want to set the params into. |
103 | * \param params The graph params value we want to set. |
104 | */ |
105 | void SetParams(GraphExecutor* graph_executor, |
106 | const std::unordered_map<std::string, tvm::runtime::NDArray>& params) const { |
107 | std::unordered_map<std::string, tvm::runtime::NDArray> value = params; |
108 | // upload big arrays first to avoid memory issue in rpc mode |
109 | std::vector<std::string> keys; |
110 | for (const auto& p : value) { |
111 | keys.emplace_back(p.first); |
112 | } |
113 | std::sort(std::begin(keys), std::end(keys), |
114 | [&](const std::string& lhs, const std::string& rhs) -> bool { |
115 | auto lhs_size = GetDataSize(*value[lhs].operator->()); |
116 | auto rhs_size = GetDataSize(*value[rhs].operator->()); |
117 | return lhs_size > rhs_size; |
118 | }); |
119 | for (const auto& key : keys) { |
120 | int in_idx = graph_executor->GetInputIndex(key); |
121 | if (in_idx >= 0) { |
122 | graph_executor->SetInput(in_idx, const_cast<DLTensor*>(value[key].operator->())); |
123 | } |
124 | } |
125 | } |
126 | |
127 | protected: |
128 | /*! \brief The execution graph. */ |
129 | std::string graph_json_; |
130 | /*! \brief The params. */ |
131 | std::unordered_map<std::string, tvm::runtime::NDArray> params_; |
132 | /*! \brief module name */ |
133 | std::string module_name_; |
134 | }; |
135 | |
136 | } // namespace runtime |
137 | } // namespace tvm |
138 | |
139 | #endif // TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_ |
140 | |