1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "registry.h" |
20 | |
21 | #include <tvm/runtime/registry.h> |
22 | |
23 | namespace tvm { |
24 | namespace datatype { |
25 | |
26 | using runtime::TVMArgs; |
27 | using runtime::TVMRetValue; |
28 | |
29 | TVM_REGISTER_GLOBAL("runtime._datatype_register" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
30 | datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int())); |
31 | }); |
32 | |
33 | TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
34 | *ret = datatype::Registry::Global()->GetTypeCode(args[0]); |
35 | }); |
36 | |
37 | TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
38 | *ret = Registry::Global()->GetTypeName(args[0].operator int()); |
39 | }); |
40 | |
41 | TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered" ) |
42 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
43 | *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); |
44 | }); |
45 | |
46 | Registry* Registry::Global() { |
47 | static Registry inst; |
48 | return &inst; |
49 | } |
50 | |
51 | void Registry::Register(const std::string& type_name, uint8_t type_code) { |
52 | ICHECK(type_code >= DataType::kCustomBegin) |
53 | << "Please choose a type code >= DataType::kCustomBegin for custom types" ; |
54 | code_to_name_[type_code] = type_name; |
55 | name_to_code_[type_name] = type_code; |
56 | } |
57 | |
58 | uint8_t Registry::GetTypeCode(const std::string& type_name) { |
59 | ICHECK(name_to_code_.find(type_name) != name_to_code_.end()) |
60 | << "Type name " << type_name << " not registered" ; |
61 | return name_to_code_[type_name]; |
62 | } |
63 | |
64 | std::string Registry::GetTypeName(uint8_t type_code) { |
65 | ICHECK(code_to_name_.find(type_code) != code_to_name_.end()) |
66 | << "Type code " << static_cast<unsigned>(type_code) << " not registered" ; |
67 | return code_to_name_[type_code]; |
68 | } |
69 | |
70 | const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code, |
71 | uint8_t src_type_code) { |
72 | std::ostringstream ss; |
73 | ss << "tvm.datatype.lower." ; |
74 | ss << target << "." ; |
75 | ss << "Cast" |
76 | << "." ; |
77 | |
78 | if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { |
79 | ss << datatype::Registry::Global()->GetTypeName(type_code); |
80 | } else { |
81 | ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(type_code)); |
82 | } |
83 | |
84 | ss << "." ; |
85 | |
86 | if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { |
87 | ss << datatype::Registry::Global()->GetTypeName(src_type_code); |
88 | } else { |
89 | ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(src_type_code)); |
90 | } |
91 | return runtime::Registry::Get(ss.str()); |
92 | } |
93 | |
94 | const runtime::PackedFunc* GetMinFunc(uint8_t type_code) { |
95 | std::ostringstream ss; |
96 | ss << "tvm.datatype.min." ; |
97 | ss << datatype::Registry::Global()->GetTypeName(type_code); |
98 | return runtime::Registry::Get(ss.str()); |
99 | } |
100 | |
101 | const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) { |
102 | std::ostringstream ss; |
103 | ss << "tvm.datatype.lower." ; |
104 | ss << target; |
105 | ss << ".FloatImm." ; |
106 | ss << datatype::Registry::Global()->GetTypeName(type_code); |
107 | return runtime::Registry::Get(ss.str()); |
108 | } |
109 | |
110 | const runtime::PackedFunc* GetIntrinLowerFunc(const std::string& target, const std::string& name, |
111 | uint8_t type_code) { |
112 | std::ostringstream ss; |
113 | ss << "tvm.datatype.lower." ; |
114 | ss << target; |
115 | ss << ".Call.intrin." ; |
116 | ss << name; |
117 | ss << "." ; |
118 | ss << datatype::Registry::Global()->GetTypeName(type_code); |
119 | return runtime::Registry::Get(ss.str()); |
120 | } |
121 | |
122 | uint64_t ConvertConstScalar(uint8_t type_code, double value) { |
123 | std::ostringstream ss; |
124 | ss << "tvm.datatype.convertconstscalar.float." ; |
125 | ss << datatype::Registry::Global()->GetTypeName(type_code); |
126 | auto make_const_scalar_func = runtime::Registry::Get(ss.str()); |
127 | return (*make_const_scalar_func)(value).operator uint64_t(); |
128 | } |
129 | |
130 | } // namespace datatype |
131 | } // namespace tvm |
132 | |