1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/metadata_base.h |
22 | * \brief Defines types which can be used in Metadata. |
23 | */ |
24 | #ifndef TVM_RUNTIME_METADATA_BASE_H_ |
25 | #define TVM_RUNTIME_METADATA_BASE_H_ |
26 | |
27 | #include <tvm/runtime/container/array.h> |
28 | #include <tvm/runtime/container/string.h> |
29 | #include <tvm/runtime/data_type.h> |
30 | #include <tvm/runtime/ndarray.h> |
31 | #include <tvm/runtime/object.h> |
32 | |
33 | #include <memory> |
34 | #include <string> |
35 | #include <utility> |
36 | #include <vector> |
37 | |
38 | namespace tvm { |
39 | namespace runtime { |
40 | namespace metadata { |
41 | |
42 | /*! |
43 | * \brief Common base class for all Metadata. |
44 | * |
45 | * This class is used in the visitor classes as a internal check to ensure that verify that all |
46 | * parts of the Metadata struct used in codegen are Metadata objects. |
47 | */ |
48 | class MetadataBaseNode : public ::tvm::runtime::Object { |
49 | public: |
50 | virtual const char* get_c_struct_name() const = 0; |
51 | |
52 | static constexpr const char* _type_key = "metadata.MetadataBaseNode" ; |
53 | TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); |
54 | }; |
55 | |
56 | /*! \brief Reference class for the common MetadataBaseNode class. */ |
57 | class MetadataBase : public ::tvm::runtime::ObjectRef { |
58 | public: |
59 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode); |
60 | }; |
61 | |
62 | template <typename C, class Ref> |
63 | class ArrayAccessor; |
64 | |
65 | /*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */ |
66 | template <typename C, class Ref> |
67 | class ArrayIterator { |
68 | public: |
69 | ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent) |
70 | : index_{index}, parent_{parent} {} |
71 | |
72 | inline Ref operator*() { return (*parent_)[index_]; } |
73 | |
74 | inline ArrayIterator<C, Ref>& operator++() { |
75 | if (index_ < parent_->size()) { |
76 | index_++; |
77 | } |
78 | |
79 | return *this; |
80 | } |
81 | |
82 | inline bool operator==(const ArrayIterator<C, Ref>& other) const { |
83 | return parent_ == other.parent_ && index_ == other.index_; |
84 | } |
85 | |
86 | inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); } |
87 | |
88 | private: |
89 | size_t index_; |
90 | const ArrayAccessor<C, Ref>* parent_; |
91 | }; |
92 | |
93 | /*! \brief A span-like class which permits access to Array fields with complex elements. |
94 | * These array fields should be accessed from C++ using the Metadata wrapper classes. This class |
95 | * lazily instantiates those wrappers as they are accessed. |
96 | */ |
97 | template <typename C, class Ref> |
98 | class ArrayAccessor { |
99 | public: |
100 | using value_type = Ref; |
101 | using iterator = ArrayIterator<C, Ref>; |
102 | using const_iterator = iterator; |
103 | |
104 | template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type> |
105 | ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {} |
106 | |
107 | inline size_t size() const { return num_data_; } |
108 | |
109 | inline Ref operator[](size_t index) const { |
110 | if (index >= num_data_) { |
111 | throw std::runtime_error("Index out of range" ); |
112 | } |
113 | |
114 | return Ref(&data_[index]); |
115 | } |
116 | |
117 | inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; } |
118 | |
119 | inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; } |
120 | |
121 | private: |
122 | const C* data_; |
123 | size_t num_data_; |
124 | }; |
125 | |
126 | /*! \brief A specialization of ArrayAccessor for String. |
127 | * This class is needed because the String constructor signature is different from the typical |
128 | * Metadata subclass. |
129 | */ |
130 | template <> |
131 | class ArrayAccessor<const char*, ::tvm::runtime::String> { |
132 | public: |
133 | using value_type = ::tvm::runtime::String; |
134 | using iterator = ArrayIterator<const char*, ::tvm::runtime::String>; |
135 | using const_iterator = iterator; |
136 | |
137 | ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {} |
138 | |
139 | inline size_t size() const { return num_data_; } |
140 | |
141 | inline ::tvm::runtime::String operator[](size_t index) const { |
142 | if (index >= num_data_) { |
143 | throw std::runtime_error("Index out of range" ); |
144 | } |
145 | return ::tvm::runtime::String(data_[index]); |
146 | } |
147 | |
148 | inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const { |
149 | return ArrayIterator<const char*, ::tvm::runtime::String>{0, this}; |
150 | } |
151 | |
152 | inline ArrayIterator<const char*, ::tvm::runtime::String> end() const { |
153 | return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this}; |
154 | } |
155 | |
156 | private: |
157 | const char** data_; |
158 | size_t num_data_; |
159 | }; |
160 | |
161 | /*! \brief Enumerates the primitive types which can be part of a Metadata instance. |
162 | * |
163 | * These are separate from TIR DataType because TIR does not model structs. |
164 | */ |
165 | enum MetadataKind : uint8_t { |
166 | kUint64 = 0, |
167 | kInt64 = 1, |
168 | kBool = 2, |
169 | kString = 3, |
170 | kHandle = 4, |
171 | kMetadata = 5, |
172 | }; |
173 | |
174 | /*! \brief Container for arrays in the metadata. |
175 | * |
176 | * Type information is needed when emitting arrays. This container augments the data field with |
177 | * the necessary typing information. |
178 | */ |
179 | class MetadataArrayNode : public MetadataBaseNode { |
180 | public: |
181 | MetadataArrayNode(Array<ObjectRef> array, MetadataKind kind, const char* type_key) |
182 | : array(::std::move(array)), kind{kind}, type_key{type_key} {} |
183 | |
184 | const char* get_c_struct_name() const final; |
185 | |
186 | std::string get_element_c_struct_name() const { |
187 | CHECK(kind == MetadataKind::kMetadata) |
188 | << "cannot get struct name for MetadataArray with kind=" << kind; |
189 | constexpr int prefix_size = sizeof("metadata." ) - 1; |
190 | constexpr int suffix_size = sizeof("Node" ) - 1; |
191 | std::string type_key_str(type_key); |
192 | return std::string("TVM" ) + |
193 | type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size); |
194 | } |
195 | |
196 | Array<ObjectRef> array; |
197 | |
198 | /*! \brief Describes the storage class of the emitted struct member. */ |
199 | MetadataKind kind; |
200 | |
201 | /*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */ |
202 | const char* type_key; |
203 | |
204 | static constexpr const char* _type_key = "metadata.MetadataArrayNode" ; |
205 | TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); |
206 | }; |
207 | |
208 | /*! \brief Reference class for MetadataArray. */ |
209 | class MetadataArray : public MetadataBase { |
210 | public: |
211 | MetadataArray(Array<ObjectRef> array, MetadataKind kind, const char* struct_name); |
212 | |
213 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); |
214 | }; |
215 | |
216 | } // namespace metadata |
217 | } // namespace runtime |
218 | } // namespace tvm |
219 | |
220 | #endif // TVM_RUNTIME_METADATA_BASE_H_ |
221 | |