1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/metadata_base.h
22 * \brief Defines types which can be used in Metadata.
23 */
24#ifndef TVM_RUNTIME_METADATA_BASE_H_
25#define TVM_RUNTIME_METADATA_BASE_H_
26
27#include <tvm/runtime/container/array.h>
28#include <tvm/runtime/container/string.h>
29#include <tvm/runtime/data_type.h>
30#include <tvm/runtime/ndarray.h>
31#include <tvm/runtime/object.h>
32
33#include <memory>
34#include <string>
35#include <utility>
36#include <vector>
37
38namespace tvm {
39namespace runtime {
40namespace metadata {
41
42/*!
43 * \brief Common base class for all Metadata.
44 *
45 * This class is used in the visitor classes as a internal check to ensure that verify that all
46 * parts of the Metadata struct used in codegen are Metadata objects.
47 */
48class MetadataBaseNode : public ::tvm::runtime::Object {
49 public:
50 virtual const char* get_c_struct_name() const = 0;
51
52 static constexpr const char* _type_key = "metadata.MetadataBaseNode";
53 TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
54};
55
56/*! \brief Reference class for the common MetadataBaseNode class. */
57class MetadataBase : public ::tvm::runtime::ObjectRef {
58 public:
59 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode);
60};
61
62template <typename C, class Ref>
63class ArrayAccessor;
64
65/*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */
66template <typename C, class Ref>
67class ArrayIterator {
68 public:
69 ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent)
70 : index_{index}, parent_{parent} {}
71
72 inline Ref operator*() { return (*parent_)[index_]; }
73
74 inline ArrayIterator<C, Ref>& operator++() {
75 if (index_ < parent_->size()) {
76 index_++;
77 }
78
79 return *this;
80 }
81
82 inline bool operator==(const ArrayIterator<C, Ref>& other) const {
83 return parent_ == other.parent_ && index_ == other.index_;
84 }
85
86 inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); }
87
88 private:
89 size_t index_;
90 const ArrayAccessor<C, Ref>* parent_;
91};
92
93/*! \brief A span-like class which permits access to Array fields with complex elements.
94 * These array fields should be accessed from C++ using the Metadata wrapper classes. This class
95 * lazily instantiates those wrappers as they are accessed.
96 */
97template <typename C, class Ref>
98class ArrayAccessor {
99 public:
100 using value_type = Ref;
101 using iterator = ArrayIterator<C, Ref>;
102 using const_iterator = iterator;
103
104 template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type>
105 ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {}
106
107 inline size_t size() const { return num_data_; }
108
109 inline Ref operator[](size_t index) const {
110 if (index >= num_data_) {
111 throw std::runtime_error("Index out of range");
112 }
113
114 return Ref(&data_[index]);
115 }
116
117 inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; }
118
119 inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; }
120
121 private:
122 const C* data_;
123 size_t num_data_;
124};
125
126/*! \brief A specialization of ArrayAccessor for String.
127 * This class is needed because the String constructor signature is different from the typical
128 * Metadata subclass.
129 */
130template <>
131class ArrayAccessor<const char*, ::tvm::runtime::String> {
132 public:
133 using value_type = ::tvm::runtime::String;
134 using iterator = ArrayIterator<const char*, ::tvm::runtime::String>;
135 using const_iterator = iterator;
136
137 ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {}
138
139 inline size_t size() const { return num_data_; }
140
141 inline ::tvm::runtime::String operator[](size_t index) const {
142 if (index >= num_data_) {
143 throw std::runtime_error("Index out of range");
144 }
145 return ::tvm::runtime::String(data_[index]);
146 }
147
148 inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const {
149 return ArrayIterator<const char*, ::tvm::runtime::String>{0, this};
150 }
151
152 inline ArrayIterator<const char*, ::tvm::runtime::String> end() const {
153 return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this};
154 }
155
156 private:
157 const char** data_;
158 size_t num_data_;
159};
160
161/*! \brief Enumerates the primitive types which can be part of a Metadata instance.
162 *
163 * These are separate from TIR DataType because TIR does not model structs.
164 */
165enum MetadataKind : uint8_t {
166 kUint64 = 0,
167 kInt64 = 1,
168 kBool = 2,
169 kString = 3,
170 kHandle = 4,
171 kMetadata = 5,
172};
173
174/*! \brief Container for arrays in the metadata.
175 *
176 * Type information is needed when emitting arrays. This container augments the data field with
177 * the necessary typing information.
178 */
179class MetadataArrayNode : public MetadataBaseNode {
180 public:
181 MetadataArrayNode(Array<ObjectRef> array, MetadataKind kind, const char* type_key)
182 : array(::std::move(array)), kind{kind}, type_key{type_key} {}
183
184 const char* get_c_struct_name() const final;
185
186 std::string get_element_c_struct_name() const {
187 CHECK(kind == MetadataKind::kMetadata)
188 << "cannot get struct name for MetadataArray with kind=" << kind;
189 constexpr int prefix_size = sizeof("metadata.") - 1;
190 constexpr int suffix_size = sizeof("Node") - 1;
191 std::string type_key_str(type_key);
192 return std::string("TVM") +
193 type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size);
194 }
195
196 Array<ObjectRef> array;
197
198 /*! \brief Describes the storage class of the emitted struct member. */
199 MetadataKind kind;
200
201 /*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */
202 const char* type_key;
203
204 static constexpr const char* _type_key = "metadata.MetadataArrayNode";
205 TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
206};
207
208/*! \brief Reference class for MetadataArray. */
209class MetadataArray : public MetadataBase {
210 public:
211 MetadataArray(Array<ObjectRef> array, MetadataKind kind, const char* struct_name);
212
213 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
214};
215
216} // namespace metadata
217} // namespace runtime
218} // namespace tvm
219
220#endif // TVM_RUNTIME_METADATA_BASE_H_
221