1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/runtime/metadata.cc
22 * \brief Defines implementations of TVM metadata which can exist in the runtime.
23 */
24
25#include <tvm/runtime/c_backend_api.h>
26#include <tvm/runtime/c_runtime_api.h>
27#include <tvm/runtime/metadata.h>
28#include <tvm/runtime/registry.h>
29
30#include <string>
31
32namespace tvm {
33namespace runtime {
34namespace metadata {
35
36TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode);
37
38ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::inputs() {
39 return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->inputs, data_->num_inputs);
40}
41ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::outputs() {
42 return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->outputs, data_->num_outputs);
43}
44ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::workspace_pools() {
45 return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->workspace_pools,
46 data_->num_workspace_pools);
47}
48ArrayAccessor<struct TVMConstantInfo, ConstantInfoMetadata> MetadataNode::constant_pools() {
49 return ArrayAccessor<struct TVMConstantInfo, ConstantInfoMetadata>(data_->constant_pools,
50 data_->num_constant_pools);
51}
52
53TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode);
54
55MetadataArray::MetadataArray(Array<ObjectRef> array, MetadataKind kind, const char* struct_name)
56 : MetadataBase{make_object<MetadataArrayNode>(array, kind, struct_name)} {}
57
58const char* MetadataArrayNode::get_c_struct_name() const {
59 ICHECK(false) << "MetadataArrayNode get_c_struct_name is unimplemented";
60 return nullptr;
61}
62TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode);
63
64Metadata::Metadata(const struct ::TVMMetadata* data)
65 : MetadataBase{make_object<MetadataNode>(data)} {}
66TVM_REGISTER_OBJECT_TYPE(MetadataNode);
67
68const char* MetadataNode::get_c_struct_name() const { return "TVMMetadata"; }
69
70TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data)
71 : MetadataBase{make_object<TensorInfoNode>(data)} {}
72TVM_REGISTER_OBJECT_TYPE(TensorInfoNode);
73
74const char* TensorInfoNode::get_c_struct_name() const { return "TVMTensorInfo"; }
75
76ConstantInfoMetadata::ConstantInfoMetadata(const struct ::TVMConstantInfo* data)
77 : MetadataBase{make_object<ConstantInfoMetadataNode>(data)} {}
78TVM_REGISTER_OBJECT_TYPE(ConstantInfoMetadataNode);
79
80const char* ConstantInfoMetadataNode::get_c_struct_name() const { return "TVMConstantInfo"; }
81
82} // namespace metadata
83
84class MetadataModuleNode : public ::tvm::runtime::ModuleNode {
85 public:
86 explicit MetadataModuleNode(runtime::metadata::Metadata metadata)
87 : metadata_{::std::move(metadata)} {}
88
89 const char* type_key() const final { return "metadata_module"; }
90
91 static Module LoadFromBinary() {
92 return Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata()));
93 }
94
95 void SaveToBinary(dmlc::Stream* stream) final {}
96
97 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
98 if (name == "get_metadata") {
99 return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
100 if (!metadata_.defined()) {
101 TVMFunctionHandle f_handle;
102 int32_t ret_code = TVMBackendGetFuncFromEnv(this, symbol::tvm_get_c_metadata, &f_handle);
103 ICHECK_EQ(ret_code, 0) << "Unable to locate " << symbol::tvm_get_c_metadata
104 << " PackedFunc";
105
106 TVMValue ret_value;
107 int ret_type_code;
108 ret_code = TVMFuncCall(f_handle, nullptr, nullptr, 0, &ret_value, &ret_type_code);
109 ICHECK_EQ(ret_code, 0) << "Invoking " << symbol::tvm_get_c_metadata
110 << ": TVMFuncCall returned " << ret_code;
111
112 ICHECK_EQ(ret_type_code, kTVMOpaqueHandle)
113 << "Expected kOpaqueHandle returned; got " << ret_type_code;
114 ICHECK(ret_value.v_handle != nullptr)
115 << symbol::tvm_get_c_metadata << " returned nullptr";
116
117 metadata_ = runtime::metadata::Metadata(
118 static_cast<const struct ::TVMMetadata*>(ret_value.v_handle));
119 }
120
121 *rv = metadata_;
122 return;
123 });
124 }
125
126 return PackedFunc();
127 }
128
129 private:
130 runtime::metadata::Metadata metadata_;
131};
132
133Module MetadataModuleCreate(metadata::Metadata metadata) {
134 return Module(make_object<MetadataModuleNode>(metadata));
135}
136
137TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata_module")
138 .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = MetadataModuleNode::LoadFromBinary(); });
139
140} // namespace runtime
141} // namespace tvm
142