1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/metadata_utils.cc
22 * \brief Defines utility functions and classes for emitting metadata.
23 */
24#include "metadata_utils.h"
25
26namespace tvm {
27namespace codegen {
28namespace metadata {
29
30std::string AddressFromParts(const std::vector<std::string>& parts) {
31 std::stringstream ss;
32 for (unsigned int i = 0; i < parts.size(); ++i) {
33 if (i > 0) {
34 ss << "_";
35 }
36 ss << parts[i];
37 }
38 return ss.str();
39}
40
41DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector<DiscoveredArray>* queue) : queue_{queue} {}
42
43void DiscoverArraysVisitor::Visit(const char* key, double* value) {}
44void DiscoverArraysVisitor::Visit(const char* key, int64_t* value) {}
45void DiscoverArraysVisitor::Visit(const char* key, uint64_t* value) {}
46void DiscoverArraysVisitor::Visit(const char* key, int* value) {}
47void DiscoverArraysVisitor::Visit(const char* key, bool* value) {}
48void DiscoverArraysVisitor::Visit(const char* key, std::string* value) {}
49void DiscoverArraysVisitor::Visit(const char* key, DataType* value) {}
50void DiscoverArraysVisitor::Visit(const char* key, runtime::NDArray* value) {}
51void DiscoverArraysVisitor::Visit(const char* key, void** value) {}
52
53void DiscoverArraysVisitor::Visit(const char* key, ObjectRef* value) {
54 address_parts_.push_back(key);
55 if (value->as<runtime::metadata::MetadataBaseNode>() != nullptr) {
56 auto metadata = Downcast<runtime::metadata::MetadataBase>(*value);
57 const runtime::metadata::MetadataArrayNode* arr =
58 value->as<runtime::metadata::MetadataArrayNode>();
59 if (arr != nullptr) {
60 for (unsigned int i = 0; i < arr->array.size(); i++) {
61 ObjectRef o = arr->array[i];
62 if (o.as<runtime::metadata::MetadataBaseNode>() != nullptr) {
63 std::stringstream ss;
64 ss << i;
65 address_parts_.push_back(ss.str());
66 runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(o);
67 ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
68 address_parts_.pop_back();
69 }
70 }
71
72 queue_->push_back(std::make_tuple(AddressFromParts(address_parts_),
73 Downcast<runtime::metadata::MetadataArray>(metadata)));
74 } else {
75 ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
76 }
77 }
78 address_parts_.pop_back();
79}
80
81void DiscoverComplexTypesVisitor::Visit(const char* key, double* value) {}
82void DiscoverComplexTypesVisitor::Visit(const char* key, int64_t* value) {}
83void DiscoverComplexTypesVisitor::Visit(const char* key, uint64_t* value) {}
84void DiscoverComplexTypesVisitor::Visit(const char* key, int* value) {}
85void DiscoverComplexTypesVisitor::Visit(const char* key, bool* value) {}
86void DiscoverComplexTypesVisitor::Visit(const char* key, std::string* value) {}
87void DiscoverComplexTypesVisitor::Visit(const char* key, DataType* value) {}
88void DiscoverComplexTypesVisitor::Visit(const char* key, runtime::NDArray* value) {}
89void DiscoverComplexTypesVisitor::Visit(const char* key, void** value) {}
90
91bool DiscoverComplexTypesVisitor::DiscoverType(std::string type_key) {
92 VLOG(2) << "DiscoverType " << type_key;
93 auto position_it = type_key_to_position_.find(type_key);
94 if (position_it != type_key_to_position_.end()) {
95 return false;
96 }
97
98 queue_->emplace_back(tvm::runtime::metadata::MetadataBase());
99 type_key_to_position_[type_key] = queue_->size() - 1;
100 return true;
101}
102
103void DiscoverComplexTypesVisitor::DiscoverInstance(runtime::metadata::MetadataBase md) {
104 auto position_it = type_key_to_position_.find(md->GetTypeKey());
105 ICHECK(position_it != type_key_to_position_.end())
106 << "DiscoverInstance requires that DiscoverType has already been called: type_key="
107 << md->GetTypeKey();
108
109 int queue_position = (*position_it).second;
110 if (!(*queue_)[queue_position].defined() && md.defined()) {
111 VLOG(2) << "DiscoverInstance " << md->GetTypeKey() << ":" << md;
112 (*queue_)[queue_position] = md;
113 }
114}
115
116void DiscoverComplexTypesVisitor::Visit(const char* key, ObjectRef* value) {
117 ICHECK_NOTNULL(value->as<runtime::metadata::MetadataBaseNode>());
118
119 auto metadata = Downcast<runtime::metadata::MetadataBase>(*value);
120 const runtime::metadata::MetadataArrayNode* arr =
121 value->as<runtime::metadata::MetadataArrayNode>();
122
123 if (arr == nullptr) {
124 VLOG(2) << "No array, object-traversing " << metadata->GetTypeKey();
125 ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
126 DiscoverType(metadata->GetTypeKey());
127 DiscoverInstance(metadata);
128 return;
129 }
130
131 if (arr->kind != tvm::runtime::metadata::MetadataKind::kMetadata) {
132 return;
133 }
134
135 bool needs_instance = DiscoverType(arr->type_key);
136 for (unsigned int i = 0; i < arr->array.size(); i++) {
137 tvm::runtime::metadata::MetadataBase o =
138 Downcast<tvm::runtime::metadata::MetadataBase>(arr->array[i]);
139 if (needs_instance) {
140 DiscoverInstance(o);
141 needs_instance = false;
142 }
143 ReflectionVTable::Global()->VisitAttrs(o.operator->(), this);
144 }
145}
146
147void DiscoverComplexTypesVisitor::Discover(runtime::metadata::MetadataBase metadata) {
148 ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
149 DiscoverType(metadata->GetTypeKey());
150 DiscoverInstance(metadata);
151}
152
153} // namespace metadata
154} // namespace codegen
155} // namespace tvm
156