1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/target/metadata_utils.cc |
22 | * \brief Defines utility functions and classes for emitting metadata. |
23 | */ |
24 | #include "metadata_utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace codegen { |
28 | namespace metadata { |
29 | |
30 | std::string AddressFromParts(const std::vector<std::string>& parts) { |
31 | std::stringstream ss; |
32 | for (unsigned int i = 0; i < parts.size(); ++i) { |
33 | if (i > 0) { |
34 | ss << "_" ; |
35 | } |
36 | ss << parts[i]; |
37 | } |
38 | return ss.str(); |
39 | } |
40 | |
41 | DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector<DiscoveredArray>* queue) : queue_{queue} {} |
42 | |
43 | void DiscoverArraysVisitor::Visit(const char* key, double* value) {} |
44 | void DiscoverArraysVisitor::Visit(const char* key, int64_t* value) {} |
45 | void DiscoverArraysVisitor::Visit(const char* key, uint64_t* value) {} |
46 | void DiscoverArraysVisitor::Visit(const char* key, int* value) {} |
47 | void DiscoverArraysVisitor::Visit(const char* key, bool* value) {} |
48 | void DiscoverArraysVisitor::Visit(const char* key, std::string* value) {} |
49 | void DiscoverArraysVisitor::Visit(const char* key, DataType* value) {} |
50 | void DiscoverArraysVisitor::Visit(const char* key, runtime::NDArray* value) {} |
51 | void DiscoverArraysVisitor::Visit(const char* key, void** value) {} |
52 | |
53 | void DiscoverArraysVisitor::Visit(const char* key, ObjectRef* value) { |
54 | address_parts_.push_back(key); |
55 | if (value->as<runtime::metadata::MetadataBaseNode>() != nullptr) { |
56 | auto metadata = Downcast<runtime::metadata::MetadataBase>(*value); |
57 | const runtime::metadata::MetadataArrayNode* arr = |
58 | value->as<runtime::metadata::MetadataArrayNode>(); |
59 | if (arr != nullptr) { |
60 | for (unsigned int i = 0; i < arr->array.size(); i++) { |
61 | ObjectRef o = arr->array[i]; |
62 | if (o.as<runtime::metadata::MetadataBaseNode>() != nullptr) { |
63 | std::stringstream ss; |
64 | ss << i; |
65 | address_parts_.push_back(ss.str()); |
66 | runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(o); |
67 | ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); |
68 | address_parts_.pop_back(); |
69 | } |
70 | } |
71 | |
72 | queue_->push_back(std::make_tuple(AddressFromParts(address_parts_), |
73 | Downcast<runtime::metadata::MetadataArray>(metadata))); |
74 | } else { |
75 | ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); |
76 | } |
77 | } |
78 | address_parts_.pop_back(); |
79 | } |
80 | |
81 | void DiscoverComplexTypesVisitor::Visit(const char* key, double* value) {} |
82 | void DiscoverComplexTypesVisitor::Visit(const char* key, int64_t* value) {} |
83 | void DiscoverComplexTypesVisitor::Visit(const char* key, uint64_t* value) {} |
84 | void DiscoverComplexTypesVisitor::Visit(const char* key, int* value) {} |
85 | void DiscoverComplexTypesVisitor::Visit(const char* key, bool* value) {} |
86 | void DiscoverComplexTypesVisitor::Visit(const char* key, std::string* value) {} |
87 | void DiscoverComplexTypesVisitor::Visit(const char* key, DataType* value) {} |
88 | void DiscoverComplexTypesVisitor::Visit(const char* key, runtime::NDArray* value) {} |
89 | void DiscoverComplexTypesVisitor::Visit(const char* key, void** value) {} |
90 | |
91 | bool DiscoverComplexTypesVisitor::DiscoverType(std::string type_key) { |
92 | VLOG(2) << "DiscoverType " << type_key; |
93 | auto position_it = type_key_to_position_.find(type_key); |
94 | if (position_it != type_key_to_position_.end()) { |
95 | return false; |
96 | } |
97 | |
98 | queue_->emplace_back(tvm::runtime::metadata::MetadataBase()); |
99 | type_key_to_position_[type_key] = queue_->size() - 1; |
100 | return true; |
101 | } |
102 | |
103 | void DiscoverComplexTypesVisitor::DiscoverInstance(runtime::metadata::MetadataBase md) { |
104 | auto position_it = type_key_to_position_.find(md->GetTypeKey()); |
105 | ICHECK(position_it != type_key_to_position_.end()) |
106 | << "DiscoverInstance requires that DiscoverType has already been called: type_key=" |
107 | << md->GetTypeKey(); |
108 | |
109 | int queue_position = (*position_it).second; |
110 | if (!(*queue_)[queue_position].defined() && md.defined()) { |
111 | VLOG(2) << "DiscoverInstance " << md->GetTypeKey() << ":" << md; |
112 | (*queue_)[queue_position] = md; |
113 | } |
114 | } |
115 | |
116 | void DiscoverComplexTypesVisitor::Visit(const char* key, ObjectRef* value) { |
117 | ICHECK_NOTNULL(value->as<runtime::metadata::MetadataBaseNode>()); |
118 | |
119 | auto metadata = Downcast<runtime::metadata::MetadataBase>(*value); |
120 | const runtime::metadata::MetadataArrayNode* arr = |
121 | value->as<runtime::metadata::MetadataArrayNode>(); |
122 | |
123 | if (arr == nullptr) { |
124 | VLOG(2) << "No array, object-traversing " << metadata->GetTypeKey(); |
125 | ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); |
126 | DiscoverType(metadata->GetTypeKey()); |
127 | DiscoverInstance(metadata); |
128 | return; |
129 | } |
130 | |
131 | if (arr->kind != tvm::runtime::metadata::MetadataKind::kMetadata) { |
132 | return; |
133 | } |
134 | |
135 | bool needs_instance = DiscoverType(arr->type_key); |
136 | for (unsigned int i = 0; i < arr->array.size(); i++) { |
137 | tvm::runtime::metadata::MetadataBase o = |
138 | Downcast<tvm::runtime::metadata::MetadataBase>(arr->array[i]); |
139 | if (needs_instance) { |
140 | DiscoverInstance(o); |
141 | needs_instance = false; |
142 | } |
143 | ReflectionVTable::Global()->VisitAttrs(o.operator->(), this); |
144 | } |
145 | } |
146 | |
147 | void DiscoverComplexTypesVisitor::Discover(runtime::metadata::MetadataBase metadata) { |
148 | ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); |
149 | DiscoverType(metadata->GetTypeKey()); |
150 | DiscoverInstance(metadata); |
151 | } |
152 | |
153 | } // namespace metadata |
154 | } // namespace codegen |
155 | } // namespace tvm |
156 | |