1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/runtime/container.cc
22 * \brief Implementations of common containers.
23 */
24#include <tvm/runtime/container/adt.h>
25#include <tvm/runtime/container/array.h>
26#include <tvm/runtime/container/closure.h>
27#include <tvm/runtime/container/map.h>
28#include <tvm/runtime/container/shape_tuple.h>
29#include <tvm/runtime/container/string.h>
30#include <tvm/runtime/memory.h>
31#include <tvm/runtime/object.h>
32#include <tvm/runtime/registry.h>
33
34namespace tvm {
35namespace runtime {
36
37// Array
38TVM_REGISTER_OBJECT_TYPE(ArrayNode);
39
40TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) {
41 std::vector<ObjectRef> data;
42 for (int i = 0; i < args.size(); ++i) {
43 if (args[i].type_code() != kTVMNullptr) {
44 data.push_back(args[i].operator ObjectRef());
45 } else {
46 data.push_back(ObjectRef(nullptr));
47 }
48 }
49 *ret = Array<ObjectRef>(data);
50});
51
52TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
53 int64_t i = args[1];
54 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
55 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
56 ICHECK(ptr->IsInstance<ArrayNode>());
57 auto* n = static_cast<const ArrayNode*>(ptr);
58 ICHECK_LT(static_cast<size_t>(i), n->size()) << "out of bound of array";
59 *ret = n->at(i);
60});
61
62TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) {
63 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
64 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
65 ICHECK(ptr->IsInstance<ArrayNode>());
66 *ret = static_cast<int64_t>(static_cast<const ArrayNode*>(ptr)->size());
67});
68
69// ADT
70
71TVM_REGISTER_OBJECT_TYPE(ADTObj);
72
73TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) {
74 ObjectRef obj = args[0];
75 const auto& adt = Downcast<ADT>(obj);
76 *rv = static_cast<int64_t>(adt.tag());
77});
78
79TVM_REGISTER_GLOBAL("runtime.GetADTSize").set_body([](TVMArgs args, TVMRetValue* rv) {
80 ObjectRef obj = args[0];
81 const auto& adt = Downcast<ADT>(obj);
82 *rv = static_cast<int64_t>(adt.size());
83});
84
85TVM_REGISTER_GLOBAL("runtime.GetADTFields").set_body([](TVMArgs args, TVMRetValue* rv) {
86 ObjectRef obj = args[0];
87 int idx = args[1];
88 const auto& adt = Downcast<ADT>(obj);
89 ICHECK_LT(idx, adt.size());
90 *rv = adt[idx];
91});
92
93TVM_REGISTER_GLOBAL("runtime.Tuple").set_body([](TVMArgs args, TVMRetValue* rv) {
94 std::vector<ObjectRef> fields;
95 for (auto i = 0; i < args.size(); ++i) {
96 fields.push_back(args[i]);
97 }
98 *rv = ADT::Tuple(fields);
99});
100
101TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) {
102 int itag = args[0];
103 size_t tag = static_cast<size_t>(itag);
104 std::vector<ObjectRef> fields;
105 for (int i = 1; i < args.size(); i++) {
106 fields.push_back(args[i]);
107 }
108 *rv = ADT(tag, fields);
109});
110
111// String
112TVM_REGISTER_OBJECT_TYPE(StringObj);
113
114TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) {
115 return String(std::move(str));
116});
117
118TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) {
119 return std::string(str);
120});
121
122// Map
123TVM_REGISTER_OBJECT_TYPE(MapNode);
124
125TVM_REGISTER_GLOBAL("runtime.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
126 ICHECK_EQ(args.size() % 2, 0);
127 std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> data;
128 for (int i = 0; i < args.num_args; i += 2) {
129 ObjectRef k =
130 String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
131 ObjectRef v = args[i + 1];
132 data.emplace(std::move(k), std::move(v));
133 }
134 *ret = Map<ObjectRef, ObjectRef>(std::move(data));
135});
136
137TVM_REGISTER_GLOBAL("runtime.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) {
138 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
139 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
140 ICHECK(ptr->IsInstance<MapNode>());
141 auto* n = static_cast<const MapNode*>(ptr);
142 *ret = static_cast<int64_t>(n->size());
143});
144
145TVM_REGISTER_GLOBAL("runtime.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
146 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
147 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
148 ICHECK(ptr->IsInstance<MapNode>());
149
150 auto* n = static_cast<const MapNode*>(ptr);
151 auto it = n->find(String::CanConvertFrom(args[1]) ? args[1].operator String()
152 : args[1].operator ObjectRef());
153 ICHECK(it != n->end()) << "cannot find the corresponding key in the Map";
154 *ret = (*it).second;
155});
156
157TVM_REGISTER_GLOBAL("runtime.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) {
158 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
159 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
160 ICHECK(ptr->IsInstance<MapNode>());
161 const MapNode* n = static_cast<const MapNode*>(ptr);
162 int64_t cnt = n->count(String::CanConvertFrom(args[1]) ? args[1].operator String()
163 : args[1].operator ObjectRef());
164 *ret = cnt;
165});
166
167TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) {
168 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
169 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
170 auto* n = static_cast<const MapNode*>(ptr);
171 Array<ObjectRef> rkvs;
172 for (const auto& kv : *n) {
173 if (kv.first->IsInstance<StringObj>()) {
174 rkvs.push_back(Downcast<String>(kv.first));
175 } else {
176 rkvs.push_back(kv.first);
177 }
178 rkvs.push_back(kv.second);
179 }
180 *ret = std::move(rkvs);
181});
182
183// Closure
184TVM_REGISTER_OBJECT_TYPE(ClosureObj);
185
186// ShapeTuple
187TVM_REGISTER_OBJECT_TYPE(ShapeTupleObj);
188
189TVM_REGISTER_GLOBAL("runtime.ShapeTuple").set_body([](TVMArgs args, TVMRetValue* rv) {
190 std::vector<ShapeTuple::index_type> shape;
191 for (int i = 0; i < args.size(); i++) {
192 shape.push_back(args[i]);
193 }
194 *rv = ShapeTuple(shape);
195});
196
197TVM_REGISTER_GLOBAL("runtime.GetShapeTupleSize").set_body_typed([](ShapeTuple shape) {
198 return static_cast<int64_t>(shape.size());
199});
200
201TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple shape, int idx) {
202 ICHECK_LT(idx, shape.size());
203 return shape[idx];
204});
205} // namespace runtime
206} // namespace tvm
207