1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "registry.h"
20
21#include <tvm/runtime/registry.h>
22
23namespace tvm {
24namespace datatype {
25
26using runtime::TVMArgs;
27using runtime::TVMRetValue;
28
29TVM_REGISTER_GLOBAL("runtime._datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) {
30 datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
31});
32
33TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) {
34 *ret = datatype::Registry::Global()->GetTypeCode(args[0]);
35});
36
37TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) {
38 *ret = Registry::Global()->GetTypeName(args[0].operator int());
39});
40
41TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered")
42 .set_body([](TVMArgs args, TVMRetValue* ret) {
43 *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
44 });
45
46Registry* Registry::Global() {
47 static Registry inst;
48 return &inst;
49}
50
51void Registry::Register(const std::string& type_name, uint8_t type_code) {
52 ICHECK(type_code >= DataType::kCustomBegin)
53 << "Please choose a type code >= DataType::kCustomBegin for custom types";
54 code_to_name_[type_code] = type_name;
55 name_to_code_[type_name] = type_code;
56}
57
58uint8_t Registry::GetTypeCode(const std::string& type_name) {
59 ICHECK(name_to_code_.find(type_name) != name_to_code_.end())
60 << "Type name " << type_name << " not registered";
61 return name_to_code_[type_name];
62}
63
64std::string Registry::GetTypeName(uint8_t type_code) {
65 ICHECK(code_to_name_.find(type_code) != code_to_name_.end())
66 << "Type code " << static_cast<unsigned>(type_code) << " not registered";
67 return code_to_name_[type_code];
68}
69
70const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code,
71 uint8_t src_type_code) {
72 std::ostringstream ss;
73 ss << "tvm.datatype.lower.";
74 ss << target << ".";
75 ss << "Cast"
76 << ".";
77
78 if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
79 ss << datatype::Registry::Global()->GetTypeName(type_code);
80 } else {
81 ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(type_code));
82 }
83
84 ss << ".";
85
86 if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) {
87 ss << datatype::Registry::Global()->GetTypeName(src_type_code);
88 } else {
89 ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(src_type_code));
90 }
91 return runtime::Registry::Get(ss.str());
92}
93
94const runtime::PackedFunc* GetMinFunc(uint8_t type_code) {
95 std::ostringstream ss;
96 ss << "tvm.datatype.min.";
97 ss << datatype::Registry::Global()->GetTypeName(type_code);
98 return runtime::Registry::Get(ss.str());
99}
100
101const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) {
102 std::ostringstream ss;
103 ss << "tvm.datatype.lower.";
104 ss << target;
105 ss << ".FloatImm.";
106 ss << datatype::Registry::Global()->GetTypeName(type_code);
107 return runtime::Registry::Get(ss.str());
108}
109
110const runtime::PackedFunc* GetIntrinLowerFunc(const std::string& target, const std::string& name,
111 uint8_t type_code) {
112 std::ostringstream ss;
113 ss << "tvm.datatype.lower.";
114 ss << target;
115 ss << ".Call.intrin.";
116 ss << name;
117 ss << ".";
118 ss << datatype::Registry::Global()->GetTypeName(type_code);
119 return runtime::Registry::Get(ss.str());
120}
121
122uint64_t ConvertConstScalar(uint8_t type_code, double value) {
123 std::ostringstream ss;
124 ss << "tvm.datatype.convertconstscalar.float.";
125 ss << datatype::Registry::Global()->GetTypeName(type_code);
126 auto make_const_scalar_func = runtime::Registry::Get(ss.str());
127 return (*make_const_scalar_func)(value).operator uint64_t();
128}
129
130} // namespace datatype
131} // namespace tvm
132