1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/graph_executor/graph_executor_factory.h
22 * \brief Graph executor factory creating graph executor.
23 */
24
25#ifndef TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_
26#define TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_
27
28#include <tvm/runtime/c_runtime_api.h>
29#include <tvm/runtime/module.h>
30#include <tvm/runtime/ndarray.h>
31#include <tvm/runtime/packed_func.h>
32
33#include <algorithm>
34#include <functional>
35#include <numeric>
36#include <string>
37#include <unordered_map>
38#include <vector>
39
40#include "./graph_executor.h"
41
42namespace tvm {
43namespace runtime {
44
45class TVM_DLL GraphExecutorFactory : public runtime::ModuleNode {
46 public:
47 /*!
48 * \brief Construct the GraphExecutorFactory.
49 * \param graph_json The execution graph.
50 * \param params The params of graph.
51 * \param module_name The module name of graph.
52 */
53 GraphExecutorFactory(const std::string& graph_json,
54 const std::unordered_map<std::string, tvm::runtime::NDArray>& params,
55 const std::string& module_name = "default");
56
57 /*!
58 * \brief Get member function to front-end
59 * \param name The name of the function.
60 * \param sptr_to_self The pointer to the module node.
61 * \return The corresponding member function.
62 */
63 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
64
65 /*!
66 * \return The type key of the executor.
67 */
68 const char* type_key() const final { return "GraphExecutorFactory"; }
69
70 /*!
71 * \brief Save the module to binary stream.
72 * \param stream The binary stream to save to.
73 */
74 void SaveToBinary(dmlc::Stream* stream) override;
75
76 /*!
77 * \brief Create a specific executor module
78 * \param devs The device of the host and devices where graph nodes will be
79 * executed on.
80 * \return created executor module
81 */
82 Module ExecutorCreate(const std::vector<Device>& devs);
83
84 /*!
85 * \brief Create a specific debug executor module
86 * \param devs The device of the host and devices where graph nodes will be
87 * executed on.
88 * \return created debug executor module
89 */
90 Module DebugExecutorCreate(const std::vector<Device>& devs);
91
92 /*!
93 * \brief Create a specific cuda graph executor module
94 * \param devs The device of the host and devices where graph nodes will be
95 * executed on.
96 * \return created cuda graph executor module
97 */
98 Module CudaGraphExecutorCreate(const std::vector<Device>& devs);
99
100 /*!
101 * \brief Set params.
102 * \param graph_executor The graph executor we want to set the params into.
103 * \param params The graph params value we want to set.
104 */
105 void SetParams(GraphExecutor* graph_executor,
106 const std::unordered_map<std::string, tvm::runtime::NDArray>& params) const {
107 std::unordered_map<std::string, tvm::runtime::NDArray> value = params;
108 // upload big arrays first to avoid memory issue in rpc mode
109 std::vector<std::string> keys;
110 for (const auto& p : value) {
111 keys.emplace_back(p.first);
112 }
113 std::sort(std::begin(keys), std::end(keys),
114 [&](const std::string& lhs, const std::string& rhs) -> bool {
115 auto lhs_size = GetDataSize(*value[lhs].operator->());
116 auto rhs_size = GetDataSize(*value[rhs].operator->());
117 return lhs_size > rhs_size;
118 });
119 for (const auto& key : keys) {
120 int in_idx = graph_executor->GetInputIndex(key);
121 if (in_idx >= 0) {
122 graph_executor->SetInput(in_idx, const_cast<DLTensor*>(value[key].operator->()));
123 }
124 }
125 }
126
127 protected:
128 /*! \brief The execution graph. */
129 std::string graph_json_;
130 /*! \brief The params. */
131 std::unordered_map<std::string, tvm::runtime::NDArray> params_;
132 /*! \brief module name */
133 std::string module_name_;
134};
135
136} // namespace runtime
137} // namespace tvm
138
139#endif // TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_
140