1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_params.cc
22 */
23#ifdef TVM_LLVM_VERSION
24
25#include "codegen_params.h"
26
27#include <llvm/ADT/ArrayRef.h>
28#include <llvm/IR/Constants.h>
29#include <llvm/IR/DerivedTypes.h>
30#include <llvm/IR/LLVMContext.h>
31#include <llvm/Support/Casting.h>
32
33#include <algorithm>
34#include <type_traits>
35#include <vector>
36
37namespace tvm {
38namespace codegen {
39
40template <typename T, typename E = void>
41struct LLVMConstantGetter {
42 static llvm::Constant* getElement(llvm::Type* ty, T t);
43};
44
45template <typename T>
46struct LLVMConstantGetter<
47 T, std::enable_if_t<(std::is_integral<T>::value && std::is_signed<T>::value)>> {
48 static llvm::Constant* getElement(llvm::Type* ty, T t) {
49 return llvm::ConstantInt::getSigned(ty, t);
50 }
51};
52
53template <typename T>
54struct LLVMConstantGetter<
55 T, std::enable_if_t<(std::is_integral<T>::value && !std::is_signed<T>::value)>> {
56 static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantInt::get(ty, t); }
57};
58
59template <typename T>
60struct LLVMConstantGetter<T, std::enable_if_t<std::is_floating_point<T>::value>> {
61 static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantFP::get(ty, t); }
62};
63
64template <typename T, typename = std::enable_if<std::is_pod<T>::value>>
65void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_elements,
66 std::vector<llvm::Constant*>* elements) {
67 elements->resize(num_elements, nullptr);
68 std::transform(static_cast<T*>(tensor_data), static_cast<T*>(tensor_data) + num_elements,
69 elements->begin(),
70 [&](T t) { return LLVMConstantGetter<T>::getElement(element_type, t); });
71}
72
73llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr) {
74 llvm::Type* element_type = nullptr;
75
76 auto arr_type = arr.DataType();
77 CHECK(arr.IsContiguous()) << "CodegenParams: only support contiguous arrays";
78 CHECK_EQ(arr->device.device_type, kDLCPU) << "CodegenParams: only support contiguous arrays";
79 CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw "
80 << arr_type.lanes();
81
82 auto shape = arr.Shape();
83 int num_elements = 1;
84 for (auto shape_elem : shape) {
85 num_elements *= shape_elem;
86 }
87
88 std::vector<llvm::Constant*> elements;
89
90 switch (arr_type.code()) {
91 case runtime::DataType::kInt:
92 CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 ||
93 arr_type.bits() == 64)
94 << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
95 << arr_type.bits() << "-bit array";
96 element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
97
98 switch (arr_type.bits()) {
99 case 8:
100 BuildLLVMVector<int8_t>(element_type, arr->data, num_elements, &elements);
101 break;
102 case 16:
103 BuildLLVMVector<int16_t>(element_type, arr->data, num_elements, &elements);
104 break;
105 case 32:
106 BuildLLVMVector<int32_t>(element_type, arr->data, num_elements, &elements);
107 break;
108 case 64:
109 BuildLLVMVector<int64_t>(element_type, arr->data, num_elements, &elements);
110 break;
111 default:
112 ICHECK(false) << "should not get here";
113 break;
114 }
115 break;
116
117 case runtime::DataType::TypeCode::kUInt:
118 CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 ||
119 arr_type.bits() == 64)
120 << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
121 << arr_type.bits() << "-bit array";
122 element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
123
124 switch (arr_type.bits()) {
125 case 8:
126 BuildLLVMVector<uint8_t>(element_type, arr->data, num_elements, &elements);
127 break;
128 case 16:
129 BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements);
130 break;
131 case 32:
132 BuildLLVMVector<uint32_t>(element_type, arr->data, num_elements, &elements);
133 break;
134 case 64:
135 BuildLLVMVector<uint64_t>(element_type, arr->data, num_elements, &elements);
136 break;
137 default:
138 ICHECK(false) << "should not get here";
139 break;
140 }
141 break;
142
143 case runtime::DataType::TypeCode::kFloat:
144 switch (arr_type.bits()) {
145 case 16:
146 // NOTE: float16 is treated as uint16_t.
147 element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
148 BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements);
149 break;
150 case 32:
151 element_type = llvm::Type::getFloatTy(*ctx);
152 BuildLLVMVector<float>(element_type, arr->data, num_elements, &elements);
153 break;
154 case 64:
155 element_type = llvm::Type::getDoubleTy(*ctx);
156 BuildLLVMVector<double>(element_type, arr->data, num_elements, &elements);
157 break;
158 default:
159 CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw "
160 << arr_type.bits() << "-bit array";
161 break;
162 }
163 break;
164
165 case runtime::DataType::TypeCode::kBFloat:
166 CHECK(arr_type.bits() == 16)
167 << "CodegenParams: only support 16-bit bfloat; saw " << arr_type.bits() << "-bit array";
168 element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
169 BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements);
170
171 default:
172 CHECK(false) << "Data type not supported";
173 }
174
175 return llvm::cast<llvm::ConstantArray>(llvm::ConstantArray::get(
176 llvm::ArrayType::get(element_type, num_elements), llvm::ArrayRef<llvm::Constant*>(elements)));
177}
178
179} // namespace codegen
180} // namespace tvm
181
182#endif // TVM_LLVM_VERSION
183