1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/support/scalars.cc |
22 | * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms. |
23 | */ |
24 | |
25 | #include "./scalars.h" |
26 | |
27 | #include "tvm/relay/expr.h" |
28 | #include "tvm/runtime/builtin_fp16.h" |
29 | |
30 | namespace tvm { |
31 | namespace support { |
32 | |
33 | /*! \brief The standard scalar dtypes. */ |
34 | static const DataType kInt16 = DataType::Int(16); |
35 | static const DataType kInt32 = DataType::Int(32); |
36 | static const DataType kInt64 = DataType::Int(64); |
37 | static const DataType kFloat16 = DataType::Float(16); |
38 | static const DataType kFloat32 = DataType::Float(32); |
39 | static const DataType kFloat64 = DataType::Float(64); |
40 | static const DataType kBool = DataType::Bool(); |
41 | |
42 | bool IsSimpleScalarDtype(DataType dtype) { |
43 | return dtype == kInt16 || dtype == kInt32 || dtype == kInt64 || dtype == kFloat16 || |
44 | dtype == kFloat32 || dtype == kFloat64 || dtype == kBool; |
45 | } |
46 | |
47 | bool IsSimpleScalar(const relay::ConstantNode* constant_node) { |
48 | return constant_node->is_scalar() && IsSimpleScalarDtype(DataType(constant_node->data->dtype)); |
49 | } |
50 | |
51 | runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { |
52 | DLDevice dev = {DLDeviceType::kDLCPU, 0}; |
53 | auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev); |
54 | if (int_imm.dtype() == kInt16) { |
55 | auto* array = reinterpret_cast<int16_t*>(data->data); |
56 | array[0] = static_cast<int16_t>(int_imm->value); |
57 | } else if (int_imm.dtype() == kInt32) { |
58 | auto* array = reinterpret_cast<int32_t*>(data->data); |
59 | array[0] = static_cast<int32_t>(int_imm->value); |
60 | } else if (int_imm.dtype() == kInt64) { |
61 | auto* array = reinterpret_cast<int64_t*>(data->data); |
62 | array[0] = int_imm->value; |
63 | } else { |
64 | LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(int_imm.dtype()); |
65 | } |
66 | return data; |
67 | } |
68 | |
69 | runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) { |
70 | DLDevice dev = {DLDeviceType::kDLCPU, 0}; |
71 | auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev); |
72 | if (float_imm.dtype() == kFloat16) { |
73 | auto* array = reinterpret_cast<uint16_t*>(data->data); |
74 | array[0] = __gnu_f2h_ieee(static_cast<float>(float_imm->value)); |
75 | } else if (float_imm.dtype() == kFloat32) { |
76 | auto* array = reinterpret_cast<float*>(data->data); |
77 | array[0] = static_cast<float>(float_imm->value); |
78 | } else if (float_imm.dtype() == kFloat64) { |
79 | auto* array = reinterpret_cast<double*>(data->data); |
80 | array[0] = float_imm->value; |
81 | } else { |
82 | LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(float_imm.dtype()); |
83 | } |
84 | return data; |
85 | } |
86 | |
87 | runtime::NDArray BoolToNDArray(bool value) { |
88 | DLDevice dev = {DLDeviceType::kDLCPU, 0}; |
89 | auto data = runtime::NDArray::Empty({}, kBool, dev); |
90 | auto array = reinterpret_cast<bool*>(data->data); |
91 | array[0] = value; |
92 | return data; |
93 | } |
94 | |
95 | std::string NDArrayScalarToString(const runtime::NDArray& data) { |
96 | std::ostringstream os; |
97 | DataType dtype(data->dtype); |
98 | ICHECK_EQ(data->device.device_type, kDLCPU) << "Scalars must reside on the CPU to be printed" ; |
99 | if (dtype == kInt16) { |
100 | auto value = static_cast<const int16_t*>(data->data)[0]; |
101 | os << value << "i16" ; |
102 | } else if (dtype == kInt32) { |
103 | auto value = static_cast<const int32_t*>(data->data)[0]; |
104 | os << value; |
105 | } else if (dtype == kInt64) { |
106 | auto value = static_cast<const int64_t*>(data->data)[0]; |
107 | os << value << "i64" ; |
108 | } else if (dtype == kFloat16) { |
109 | auto value = __gnu_h2f_ieee(static_cast<const uint16_t*>(data->data)[0]); |
110 | os << value << "f16" ; |
111 | } else if (dtype == kFloat32) { |
112 | auto value = static_cast<const float*>(data->data)[0]; |
113 | os << value << "f" ; |
114 | } else if (dtype == kFloat64) { |
115 | auto value = static_cast<const double*>(data->data)[0]; |
116 | os << value << "f64" ; |
117 | } else if (dtype == kBool) { |
118 | auto value = static_cast<const uint8_t*>(data->data)[0]; |
119 | os << (value ? "True" : "False" ); |
120 | } else { |
121 | LOG(FATAL) << "Unrecognized NDArray scalar dtype: " << DLDataType2String(dtype); |
122 | } |
123 | return os.str(); |
124 | } |
125 | |
126 | std::string IntImmToString(const IntImm& int_imm) { |
127 | std::ostringstream os; |
128 | if (int_imm->dtype == kInt16) { |
129 | os << int_imm->value << "i16" ; |
130 | } else if (int_imm->dtype == kInt32) { |
131 | os << int_imm->value; |
132 | } else if (int_imm->dtype == kInt64) { |
133 | os << int_imm->value << "i64" ; |
134 | } else if (int_imm->dtype == kBool) { |
135 | os << (int_imm->value ? "True" : "False" ); |
136 | } else { |
137 | LOG(FATAL) << "Unrecognised IntImm dtype: " << DLDataType2String(int_imm->dtype); |
138 | } |
139 | return os.str(); |
140 | } |
141 | |
142 | std::string FloatImmToString(const FloatImm& float_imm) { |
143 | std::ostringstream os; |
144 | if (float_imm->dtype == kFloat16) { |
145 | os << float_imm->value << "f16" ; |
146 | } else if (float_imm->dtype == kFloat32) { |
147 | os << float_imm->value << "f" ; |
148 | } else if (float_imm->dtype == kFloat64) { |
149 | os << float_imm->value << "f64" ; |
150 | } else { |
151 | LOG(FATAL) << "Unrecognised FloatImm dtype: " << DLDataType2String(float_imm->dtype); |
152 | } |
153 | return os.str(); |
154 | } |
155 | |
156 | IntImm ValueToIntImm(int64_t value, int width) { |
157 | if (width == 16) { |
158 | if (value < std::numeric_limits<int16_t>::min() || |
159 | value > std::numeric_limits<int16_t>::max()) { |
160 | return {}; |
161 | } |
162 | return IntImm(kInt16, value); |
163 | } else if (width == 32) { |
164 | if (value < std::numeric_limits<int32_t>::min() || |
165 | value > std::numeric_limits<int32_t>::max()) { |
166 | return {}; |
167 | } |
168 | return IntImm(kInt32, value); |
169 | } else if (width == 64) { |
170 | return IntImm(kInt64, value); |
171 | } else { |
172 | LOG(FATAL) << "Unrecognized int scalar width: " << width; |
173 | } |
174 | } |
175 | |
176 | FloatImm ValueToFloatImm(double value, int width) { |
177 | if (width == 16) { |
178 | if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) { |
179 | return {}; |
180 | } |
181 | return FloatImm(kFloat16, value); |
182 | } else if (width == 32) { |
183 | if (!std::isinf(value) && |
184 | (value < -std::numeric_limits<float>::max() || value > std::numeric_limits<float>::max())) { |
185 | return {}; |
186 | } |
187 | return FloatImm(kFloat32, value); |
188 | } else if (width == 64) { |
189 | return FloatImm(kFloat64, value); |
190 | } else { |
191 | LOG(FATAL) << "Unrecognized float scalar width: " << width; |
192 | } |
193 | } |
194 | |
195 | } // namespace support |
196 | } // namespace tvm |
197 | |