1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/support/scalars.cc
22 * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms.
23 */
24
25#include "./scalars.h"
26
27#include "tvm/relay/expr.h"
28#include "tvm/runtime/builtin_fp16.h"
29
30namespace tvm {
31namespace support {
32
33/*! \brief The standard scalar dtypes. */
34static const DataType kInt16 = DataType::Int(16);
35static const DataType kInt32 = DataType::Int(32);
36static const DataType kInt64 = DataType::Int(64);
37static const DataType kFloat16 = DataType::Float(16);
38static const DataType kFloat32 = DataType::Float(32);
39static const DataType kFloat64 = DataType::Float(64);
40static const DataType kBool = DataType::Bool();
41
42bool IsSimpleScalarDtype(DataType dtype) {
43 return dtype == kInt16 || dtype == kInt32 || dtype == kInt64 || dtype == kFloat16 ||
44 dtype == kFloat32 || dtype == kFloat64 || dtype == kBool;
45}
46
47bool IsSimpleScalar(const relay::ConstantNode* constant_node) {
48 return constant_node->is_scalar() && IsSimpleScalarDtype(DataType(constant_node->data->dtype));
49}
50
51runtime::NDArray IntImmToNDArray(const IntImm& int_imm) {
52 DLDevice dev = {DLDeviceType::kDLCPU, 0};
53 auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev);
54 if (int_imm.dtype() == kInt16) {
55 auto* array = reinterpret_cast<int16_t*>(data->data);
56 array[0] = static_cast<int16_t>(int_imm->value);
57 } else if (int_imm.dtype() == kInt32) {
58 auto* array = reinterpret_cast<int32_t*>(data->data);
59 array[0] = static_cast<int32_t>(int_imm->value);
60 } else if (int_imm.dtype() == kInt64) {
61 auto* array = reinterpret_cast<int64_t*>(data->data);
62 array[0] = int_imm->value;
63 } else {
64 LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(int_imm.dtype());
65 }
66 return data;
67}
68
69runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) {
70 DLDevice dev = {DLDeviceType::kDLCPU, 0};
71 auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev);
72 if (float_imm.dtype() == kFloat16) {
73 auto* array = reinterpret_cast<uint16_t*>(data->data);
74 array[0] = __gnu_f2h_ieee(static_cast<float>(float_imm->value));
75 } else if (float_imm.dtype() == kFloat32) {
76 auto* array = reinterpret_cast<float*>(data->data);
77 array[0] = static_cast<float>(float_imm->value);
78 } else if (float_imm.dtype() == kFloat64) {
79 auto* array = reinterpret_cast<double*>(data->data);
80 array[0] = float_imm->value;
81 } else {
82 LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(float_imm.dtype());
83 }
84 return data;
85}
86
87runtime::NDArray BoolToNDArray(bool value) {
88 DLDevice dev = {DLDeviceType::kDLCPU, 0};
89 auto data = runtime::NDArray::Empty({}, kBool, dev);
90 auto array = reinterpret_cast<bool*>(data->data);
91 array[0] = value;
92 return data;
93}
94
95std::string NDArrayScalarToString(const runtime::NDArray& data) {
96 std::ostringstream os;
97 DataType dtype(data->dtype);
98 ICHECK_EQ(data->device.device_type, kDLCPU) << "Scalars must reside on the CPU to be printed";
99 if (dtype == kInt16) {
100 auto value = static_cast<const int16_t*>(data->data)[0];
101 os << value << "i16";
102 } else if (dtype == kInt32) {
103 auto value = static_cast<const int32_t*>(data->data)[0];
104 os << value;
105 } else if (dtype == kInt64) {
106 auto value = static_cast<const int64_t*>(data->data)[0];
107 os << value << "i64";
108 } else if (dtype == kFloat16) {
109 auto value = __gnu_h2f_ieee(static_cast<const uint16_t*>(data->data)[0]);
110 os << value << "f16";
111 } else if (dtype == kFloat32) {
112 auto value = static_cast<const float*>(data->data)[0];
113 os << value << "f";
114 } else if (dtype == kFloat64) {
115 auto value = static_cast<const double*>(data->data)[0];
116 os << value << "f64";
117 } else if (dtype == kBool) {
118 auto value = static_cast<const uint8_t*>(data->data)[0];
119 os << (value ? "True" : "False");
120 } else {
121 LOG(FATAL) << "Unrecognized NDArray scalar dtype: " << DLDataType2String(dtype);
122 }
123 return os.str();
124}
125
126std::string IntImmToString(const IntImm& int_imm) {
127 std::ostringstream os;
128 if (int_imm->dtype == kInt16) {
129 os << int_imm->value << "i16";
130 } else if (int_imm->dtype == kInt32) {
131 os << int_imm->value;
132 } else if (int_imm->dtype == kInt64) {
133 os << int_imm->value << "i64";
134 } else if (int_imm->dtype == kBool) {
135 os << (int_imm->value ? "True" : "False");
136 } else {
137 LOG(FATAL) << "Unrecognised IntImm dtype: " << DLDataType2String(int_imm->dtype);
138 }
139 return os.str();
140}
141
142std::string FloatImmToString(const FloatImm& float_imm) {
143 std::ostringstream os;
144 if (float_imm->dtype == kFloat16) {
145 os << float_imm->value << "f16";
146 } else if (float_imm->dtype == kFloat32) {
147 os << float_imm->value << "f";
148 } else if (float_imm->dtype == kFloat64) {
149 os << float_imm->value << "f64";
150 } else {
151 LOG(FATAL) << "Unrecognised FloatImm dtype: " << DLDataType2String(float_imm->dtype);
152 }
153 return os.str();
154}
155
156IntImm ValueToIntImm(int64_t value, int width) {
157 if (width == 16) {
158 if (value < std::numeric_limits<int16_t>::min() ||
159 value > std::numeric_limits<int16_t>::max()) {
160 return {};
161 }
162 return IntImm(kInt16, value);
163 } else if (width == 32) {
164 if (value < std::numeric_limits<int32_t>::min() ||
165 value > std::numeric_limits<int32_t>::max()) {
166 return {};
167 }
168 return IntImm(kInt32, value);
169 } else if (width == 64) {
170 return IntImm(kInt64, value);
171 } else {
172 LOG(FATAL) << "Unrecognized int scalar width: " << width;
173 }
174}
175
176FloatImm ValueToFloatImm(double value, int width) {
177 if (width == 16) {
178 if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) {
179 return {};
180 }
181 return FloatImm(kFloat16, value);
182 } else if (width == 32) {
183 if (!std::isinf(value) &&
184 (value < -std::numeric_limits<float>::max() || value > std::numeric_limits<float>::max())) {
185 return {};
186 }
187 return FloatImm(kFloat32, value);
188 } else if (width == 64) {
189 return FloatImm(kFloat64, value);
190 } else {
191 LOG(FATAL) << "Unrecognized float scalar width: " << width;
192 }
193}
194
195} // namespace support
196} // namespace tvm
197