1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/runtime/metadata.cc |
22 | * \brief Defines implementations of TVM metadata which can exist in the runtime. |
23 | */ |
24 | |
25 | #include <tvm/runtime/c_backend_api.h> |
26 | #include <tvm/runtime/c_runtime_api.h> |
27 | #include <tvm/runtime/metadata.h> |
28 | #include <tvm/runtime/registry.h> |
29 | |
30 | #include <string> |
31 | |
32 | namespace tvm { |
33 | namespace runtime { |
34 | namespace metadata { |
35 | |
36 | TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); |
37 | |
38 | ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::inputs() { |
39 | return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->inputs, data_->num_inputs); |
40 | } |
41 | ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::outputs() { |
42 | return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->outputs, data_->num_outputs); |
43 | } |
44 | ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::workspace_pools() { |
45 | return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->workspace_pools, |
46 | data_->num_workspace_pools); |
47 | } |
48 | ArrayAccessor<struct TVMConstantInfo, ConstantInfoMetadata> MetadataNode::constant_pools() { |
49 | return ArrayAccessor<struct TVMConstantInfo, ConstantInfoMetadata>(data_->constant_pools, |
50 | data_->num_constant_pools); |
51 | } |
52 | |
53 | TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); |
54 | |
55 | MetadataArray::MetadataArray(Array<ObjectRef> array, MetadataKind kind, const char* struct_name) |
56 | : MetadataBase{make_object<MetadataArrayNode>(array, kind, struct_name)} {} |
57 | |
58 | const char* MetadataArrayNode::get_c_struct_name() const { |
59 | ICHECK(false) << "MetadataArrayNode get_c_struct_name is unimplemented" ; |
60 | return nullptr; |
61 | } |
62 | TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); |
63 | |
64 | Metadata::Metadata(const struct ::TVMMetadata* data) |
65 | : MetadataBase{make_object<MetadataNode>(data)} {} |
66 | TVM_REGISTER_OBJECT_TYPE(MetadataNode); |
67 | |
68 | const char* MetadataNode::get_c_struct_name() const { return "TVMMetadata" ; } |
69 | |
70 | TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) |
71 | : MetadataBase{make_object<TensorInfoNode>(data)} {} |
72 | TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); |
73 | |
74 | const char* TensorInfoNode::get_c_struct_name() const { return "TVMTensorInfo" ; } |
75 | |
76 | ConstantInfoMetadata::ConstantInfoMetadata(const struct ::TVMConstantInfo* data) |
77 | : MetadataBase{make_object<ConstantInfoMetadataNode>(data)} {} |
78 | TVM_REGISTER_OBJECT_TYPE(ConstantInfoMetadataNode); |
79 | |
80 | const char* ConstantInfoMetadataNode::get_c_struct_name() const { return "TVMConstantInfo" ; } |
81 | |
82 | } // namespace metadata |
83 | |
84 | class MetadataModuleNode : public ::tvm::runtime::ModuleNode { |
85 | public: |
86 | explicit MetadataModuleNode(runtime::metadata::Metadata metadata) |
87 | : metadata_{::std::move(metadata)} {} |
88 | |
89 | const char* type_key() const final { return "metadata_module" ; } |
90 | |
91 | static Module LoadFromBinary() { |
92 | return Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata())); |
93 | } |
94 | |
95 | void SaveToBinary(dmlc::Stream* stream) final {} |
96 | |
97 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) { |
98 | if (name == "get_metadata" ) { |
99 | return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { |
100 | if (!metadata_.defined()) { |
101 | TVMFunctionHandle f_handle; |
102 | int32_t ret_code = TVMBackendGetFuncFromEnv(this, symbol::tvm_get_c_metadata, &f_handle); |
103 | ICHECK_EQ(ret_code, 0) << "Unable to locate " << symbol::tvm_get_c_metadata |
104 | << " PackedFunc" ; |
105 | |
106 | TVMValue ret_value; |
107 | int ret_type_code; |
108 | ret_code = TVMFuncCall(f_handle, nullptr, nullptr, 0, &ret_value, &ret_type_code); |
109 | ICHECK_EQ(ret_code, 0) << "Invoking " << symbol::tvm_get_c_metadata |
110 | << ": TVMFuncCall returned " << ret_code; |
111 | |
112 | ICHECK_EQ(ret_type_code, kTVMOpaqueHandle) |
113 | << "Expected kOpaqueHandle returned; got " << ret_type_code; |
114 | ICHECK(ret_value.v_handle != nullptr) |
115 | << symbol::tvm_get_c_metadata << " returned nullptr" ; |
116 | |
117 | metadata_ = runtime::metadata::Metadata( |
118 | static_cast<const struct ::TVMMetadata*>(ret_value.v_handle)); |
119 | } |
120 | |
121 | *rv = metadata_; |
122 | return; |
123 | }); |
124 | } |
125 | |
126 | return PackedFunc(); |
127 | } |
128 | |
129 | private: |
130 | runtime::metadata::Metadata metadata_; |
131 | }; |
132 | |
133 | Module MetadataModuleCreate(metadata::Metadata metadata) { |
134 | return Module(make_object<MetadataModuleNode>(metadata)); |
135 | } |
136 | |
137 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata_module" ) |
138 | .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = MetadataModuleNode::LoadFromBinary(); }); |
139 | |
140 | } // namespace runtime |
141 | } // namespace tvm |
142 | |