1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/container/adt.h
22 * \brief Runtime ADT container types.
23 */
24#ifndef TVM_RUNTIME_CONTAINER_ADT_H_
25#define TVM_RUNTIME_CONTAINER_ADT_H_
26
27#include <utility>
28#include <vector>
29
30#include "./base.h"
31
32namespace tvm {
33namespace runtime {
34
35/*! \brief An object representing a structure or enumeration. */
36class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
37 public:
38 /*! \brief The tag representing the constructor used. */
39 int32_t tag;
40 /*! \brief Number of fields in the ADT object. */
41 uint32_t size;
42 // The fields of the structure follows directly in memory.
43
44 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT;
45 static constexpr const char* _type_key = "runtime.ADT";
46 TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
47
48 private:
49 /*!
50 * \return The number of elements in the array.
51 */
52 size_t GetSize() const { return size; }
53
54 /*!
55 * \brief Initialize the elements in the array.
56 *
57 * \tparam Iterator Iterator type of the array.
58 * \param begin The begin iterator.
59 * \param end The end iterator.
60 */
61 template <typename Iterator>
62 void Init(Iterator begin, Iterator end) {
63 size_t num_elems = std::distance(begin, end);
64 this->size = 0;
65 auto it = begin;
66 for (size_t i = 0; i < num_elems; ++i) {
67 InplaceArrayBase::EmplaceInit(i, *it++);
68 // Only increment size after the initialization succeeds
69 this->size++;
70 }
71 }
72
73 friend class ADT;
74 friend InplaceArrayBase<ADTObj, ObjectRef>;
75};
76
77/*! \brief reference to algebraic data type objects. */
78class ADT : public ObjectRef {
79 public:
80 /*!
81 * \brief construct an ADT object reference.
82 * \param tag The tag of the ADT object.
83 * \param fields The fields of the ADT object.
84 * \return The constructed ADT object reference.
85 */
86 ADT(int32_t tag, std::vector<ObjectRef> fields) : ADT(tag, fields.begin(), fields.end()){};
87
88 /*!
89 * \brief construct an ADT object reference.
90 * \param tag The tag of the ADT object.
91 * \param begin The begin iterator to the start of the fields array.
92 * \param end The end iterator to the end of the fields array.
93 * \return The constructed ADT object reference.
94 */
95 template <typename Iterator>
96 ADT(int32_t tag, Iterator begin, Iterator end) {
97 size_t num_elems = std::distance(begin, end);
98 auto ptr = make_inplace_array_object<ADTObj, ObjectRef>(num_elems);
99 ptr->tag = tag;
100 ptr->Init(begin, end);
101 data_ = std::move(ptr);
102 }
103
104 /*!
105 * \brief construct an ADT object reference.
106 * \param tag The tag of the ADT object.
107 * \param init The initializer list of fields.
108 * \return The constructed ADT object reference.
109 */
110 ADT(int32_t tag, std::initializer_list<ObjectRef> init) : ADT(tag, init.begin(), init.end()){};
111
112 /*!
113 * \brief Access element at index.
114 *
115 * \param idx The array index
116 * \return const ObjectRef
117 */
118 const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); }
119
120 /*!
121 * \brief Return the ADT tag.
122 */
123 int32_t tag() const { return operator->()->tag; }
124
125 /*!
126 * \brief Return the number of fields.
127 */
128 size_t size() const { return operator->()->size; }
129
130 /*!
131 * \brief Construct a tuple object.
132 *
133 * \tparam Args Type params of tuple feilds.
134 * \param args Tuple fields.
135 * \return ADT The tuple object reference.
136 */
137 template <typename... Args>
138 static ADT Tuple(Args&&... args) {
139 return ADT(0, std::forward<Args>(args)...);
140 }
141
142 TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
143};
144} // namespace runtime
145} // namespace tvm
146#endif // TVM_RUNTIME_CONTAINER_ADT_H_
147