1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/aot_executor/aot_executor_factory.h
22 * \brief Aot executor factory creating aot executor.
23 */
24
25#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_
26#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_
27
28#include <tvm/runtime/c_runtime_api.h>
29#include <tvm/runtime/module.h>
30#include <tvm/runtime/ndarray.h>
31#include <tvm/runtime/packed_func.h>
32
33#include <algorithm>
34#include <functional>
35#include <numeric>
36#include <string>
37#include <unordered_map>
38#include <vector>
39
40#include "./aot_executor.h"
41
42namespace tvm {
43namespace runtime {
44
45class TVM_DLL AotExecutorFactory : public runtime::ModuleNode {
46 public:
47 /*!
48 * \brief Construct the AotExecutorFactory.
49 * \param params The params of aot.
50 * \param module_name The module name of aot.
51 */
52 AotExecutorFactory(const std::unordered_map<std::string, tvm::runtime::NDArray>& params,
53 const std::string& module_name);
54
55 /*!
56 * \brief Get member function to front-end
57 * \param name The name of the function.
58 * \param sptr_to_self The pointer to the module node.
59 * \return The corresponding member function.
60 */
61 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
62
63 /*!
64 * \return The type key of the executor.
65 */
66 const char* type_key() const final { return "AotExecutorFactory"; }
67
68 /*!
69 * \brief Save the module to binary stream.
70 * \param stream The binary stream to save to.
71 */
72 void SaveToBinary(dmlc::Stream* stream) override;
73
74 /*!
75 * \brief Create a specific executor module
76 * \param devs The device of the host and devices where the model will be
77 * executed.
78 * \return created executor module
79 */
80 Module ExecutorCreate(const std::vector<Device>& devs);
81
82 /*!
83 * \brief Set params.
84 * \param aot_executor The aot executor we want to set the params into.
85 * \param params The aot params value we want to set.
86 */
87 void SetParams(AotExecutor* aot_executor,
88 const std::unordered_map<std::string, tvm::runtime::NDArray>& params) const {
89 std::unordered_map<std::string, tvm::runtime::NDArray> value = params;
90 // upload big arrays first to avoid memory issue in rpc mode
91 std::vector<std::string> keys;
92 for (const auto& p : value) {
93 keys.emplace_back(p.first);
94 }
95 std::sort(std::begin(keys), std::end(keys),
96 [&](const std::string& lhs, const std::string& rhs) -> bool {
97 auto lhs_size = GetDataSize(*value[lhs].operator->());
98 auto rhs_size = GetDataSize(*value[rhs].operator->());
99 return lhs_size > rhs_size;
100 });
101 for (const auto& key : keys) {
102 int in_idx = aot_executor->GetInputIndex(key);
103 if (in_idx >= 0) {
104 aot_executor->SetInput(in_idx, const_cast<DLTensor*>(value[key].operator->()));
105 }
106 }
107 }
108
109 protected:
110 /*! \brief The params. */
111 std::unordered_map<std::string, tvm::runtime::NDArray> params_;
112 /*! \brief module name */
113 std::string module_name_;
114};
115
116} // namespace runtime
117} // namespace tvm
118
119#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_
120