1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/container/shape_tuple.h
22 * \brief Runtime ShapeTuple container types.
23 */
24#ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
25#define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
26
27#include <utility>
28#include <vector>
29
30#include "./base.h"
31
32namespace tvm {
33namespace runtime {
34
35/*! \brief An object representing a shape tuple. */
36class ShapeTupleObj : public Object {
37 public:
38 /*! \brief The type of shape index element. */
39 using index_type = int64_t;
40 /*! \brief The pointer to shape tuple data. */
41 index_type* data;
42 /*! \brief The size of the shape tuple object. */
43 uint64_t size;
44
45 static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple;
46 static constexpr const char* _type_key = "runtime.ShapeTuple";
47 TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTupleObj, Object);
48
49 private:
50 /*! \brief ShapeTuple object which is moved from std::vector container. */
51 class FromStd;
52
53 friend class ShapeTuple;
54};
55
56/*! \brief An object representing shape tuple moved from std::vector. */
57class ShapeTupleObj::FromStd : public ShapeTupleObj {
58 public:
59 /*! \brief The type of shape index element. */
60 using index_type = ShapeTupleObj::index_type;
61 /*!
62 * \brief Construct a new FromStd object
63 *
64 * \param other The moved/copied std::vector object
65 *
66 * \note If user passes const reference, it will trigger copy. If it's rvalue,
67 * it will be moved into other.
68 */
69 explicit FromStd(std::vector<index_type> other) : data_container{other} {}
70
71 private:
72 /*! \brief Container that holds the memory. */
73 std::vector<index_type> data_container;
74
75 friend class ShapeTuple;
76};
77
78/*!
79 * \brief Reference to shape tuple objects.
80 */
81class ShapeTuple : public ObjectRef {
82 public:
83 /*! \brief The type of shape index element. */
84 using index_type = ShapeTupleObj::index_type;
85
86 /*!
87 * \brief Construct an empty shape tuple.
88 */
89 ShapeTuple() : ShapeTuple(std::vector<index_type>()) {}
90
91 /*!
92 * \brief Constructor from iterator
93 * \param begin begin of iterator
94 * \param end end of iterator
95 * \tparam IterType The type of iterator
96 */
97 template <typename IterType>
98 ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector<index_type>(begin, end)) {}
99
100 /*!
101 * \brief constructor from initializer list
102 * \param shape The initializer list
103 */
104 ShapeTuple(std::initializer_list<index_type> shape) : ShapeTuple(shape.begin(), shape.end()) {}
105
106 /*!
107 * \brief Construct a new ShapeTuple object
108 *
109 * \param shape The moved/copied std::vector object
110 *
111 * \note If user passes const reference, it will trigger copy. If it's rvalue,
112 * it will be moved into other.
113 */
114 ShapeTuple(std::vector<index_type> shape); // NOLINT(*)
115
116 /*!
117 * \brief Return the data pointer
118 *
119 * \return const index_type* data pointer
120 */
121 const index_type* data() const { return get()->data; }
122
123 /*!
124 * \brief Return the size of the shape tuple
125 *
126 * \return size_t shape tuple size
127 */
128 size_t size() const { return get()->size; }
129
130 /*!
131 * \brief Immutably read i-th element from the shape tuple.
132 * \param idx The index
133 * \return the i-th element.
134 */
135 index_type operator[](size_t idx) const {
136 ICHECK(idx < this->size()) << "IndexError: indexing " << idx << " on an array of size "
137 << this->size();
138 return this->data()[idx];
139 }
140
141 /*!
142 * \brief Immutably read i-th element from the shape tuple.
143 * \param idx The index
144 * \return the i-th element.
145 */
146 index_type at(size_t idx) const { return this->operator[](idx); }
147
148 /*! \return Whether shape tuple is empty */
149 bool empty() const { return size() == 0; }
150
151 /*! \return The first element of the shape tuple */
152 index_type front() const { return this->at(0); }
153
154 /*! \return The last element of the shape tuple */
155 index_type back() const { return this->at(this->size() - 1); }
156
157 /*! \return begin iterator */
158 const index_type* begin() const { return get()->data; }
159
160 /*! \return end iterator */
161 const index_type* end() const { return (get()->data + size()); }
162
163 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj);
164};
165
166inline ShapeTuple::ShapeTuple(std::vector<index_type> shape) {
167 auto ptr = make_object<ShapeTupleObj::FromStd>(std::move(shape));
168 ptr->size = ptr->data_container.size();
169 ptr->data = ptr->data_container.data();
170 data_ = std::move(ptr);
171}
172
173} // namespace runtime
174
175// expose the functions to the root namespace.
176using runtime::ShapeTuple;
177using runtime::ShapeTupleObj;
178} // namespace tvm
179
180#endif // TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
181