1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/container/adt.h |
22 | * \brief Runtime ADT container types. |
23 | */ |
24 | #ifndef TVM_RUNTIME_CONTAINER_ADT_H_ |
25 | #define TVM_RUNTIME_CONTAINER_ADT_H_ |
26 | |
27 | #include <utility> |
28 | #include <vector> |
29 | |
30 | #include "./base.h" |
31 | |
32 | namespace tvm { |
33 | namespace runtime { |
34 | |
35 | /*! \brief An object representing a structure or enumeration. */ |
36 | class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> { |
37 | public: |
38 | /*! \brief The tag representing the constructor used. */ |
39 | int32_t tag; |
40 | /*! \brief Number of fields in the ADT object. */ |
41 | uint32_t size; |
42 | // The fields of the structure follows directly in memory. |
43 | |
44 | static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT; |
45 | static constexpr const char* _type_key = "runtime.ADT" ; |
46 | TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); |
47 | |
48 | private: |
49 | /*! |
50 | * \return The number of elements in the array. |
51 | */ |
52 | size_t GetSize() const { return size; } |
53 | |
54 | /*! |
55 | * \brief Initialize the elements in the array. |
56 | * |
57 | * \tparam Iterator Iterator type of the array. |
58 | * \param begin The begin iterator. |
59 | * \param end The end iterator. |
60 | */ |
61 | template <typename Iterator> |
62 | void Init(Iterator begin, Iterator end) { |
63 | size_t num_elems = std::distance(begin, end); |
64 | this->size = 0; |
65 | auto it = begin; |
66 | for (size_t i = 0; i < num_elems; ++i) { |
67 | InplaceArrayBase::EmplaceInit(i, *it++); |
68 | // Only increment size after the initialization succeeds |
69 | this->size++; |
70 | } |
71 | } |
72 | |
73 | friend class ADT; |
74 | friend InplaceArrayBase<ADTObj, ObjectRef>; |
75 | }; |
76 | |
77 | /*! \brief reference to algebraic data type objects. */ |
78 | class ADT : public ObjectRef { |
79 | public: |
80 | /*! |
81 | * \brief construct an ADT object reference. |
82 | * \param tag The tag of the ADT object. |
83 | * \param fields The fields of the ADT object. |
84 | * \return The constructed ADT object reference. |
85 | */ |
86 | ADT(int32_t tag, std::vector<ObjectRef> fields) : ADT(tag, fields.begin(), fields.end()){}; |
87 | |
88 | /*! |
89 | * \brief construct an ADT object reference. |
90 | * \param tag The tag of the ADT object. |
91 | * \param begin The begin iterator to the start of the fields array. |
92 | * \param end The end iterator to the end of the fields array. |
93 | * \return The constructed ADT object reference. |
94 | */ |
95 | template <typename Iterator> |
96 | ADT(int32_t tag, Iterator begin, Iterator end) { |
97 | size_t num_elems = std::distance(begin, end); |
98 | auto ptr = make_inplace_array_object<ADTObj, ObjectRef>(num_elems); |
99 | ptr->tag = tag; |
100 | ptr->Init(begin, end); |
101 | data_ = std::move(ptr); |
102 | } |
103 | |
104 | /*! |
105 | * \brief construct an ADT object reference. |
106 | * \param tag The tag of the ADT object. |
107 | * \param init The initializer list of fields. |
108 | * \return The constructed ADT object reference. |
109 | */ |
110 | ADT(int32_t tag, std::initializer_list<ObjectRef> init) : ADT(tag, init.begin(), init.end()){}; |
111 | |
112 | /*! |
113 | * \brief Access element at index. |
114 | * |
115 | * \param idx The array index |
116 | * \return const ObjectRef |
117 | */ |
118 | const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } |
119 | |
120 | /*! |
121 | * \brief Return the ADT tag. |
122 | */ |
123 | int32_t tag() const { return operator->()->tag; } |
124 | |
125 | /*! |
126 | * \brief Return the number of fields. |
127 | */ |
128 | size_t size() const { return operator->()->size; } |
129 | |
130 | /*! |
131 | * \brief Construct a tuple object. |
132 | * |
133 | * \tparam Args Type params of tuple feilds. |
134 | * \param args Tuple fields. |
135 | * \return ADT The tuple object reference. |
136 | */ |
137 | template <typename... Args> |
138 | static ADT Tuple(Args&&... args) { |
139 | return ADT(0, std::forward<Args>(args)...); |
140 | } |
141 | |
142 | TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); |
143 | }; |
144 | } // namespace runtime |
145 | } // namespace tvm |
146 | #endif // TVM_RUNTIME_CONTAINER_ADT_H_ |
147 | |