1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_params.cc |
22 | */ |
23 | #ifdef TVM_LLVM_VERSION |
24 | |
25 | #include "codegen_params.h" |
26 | |
27 | #include <llvm/ADT/ArrayRef.h> |
28 | #include <llvm/IR/Constants.h> |
29 | #include <llvm/IR/DerivedTypes.h> |
30 | #include <llvm/IR/LLVMContext.h> |
31 | #include <llvm/Support/Casting.h> |
32 | |
33 | #include <algorithm> |
34 | #include <type_traits> |
35 | #include <vector> |
36 | |
37 | namespace tvm { |
38 | namespace codegen { |
39 | |
40 | template <typename T, typename E = void> |
41 | struct LLVMConstantGetter { |
42 | static llvm::Constant* getElement(llvm::Type* ty, T t); |
43 | }; |
44 | |
45 | template <typename T> |
46 | struct LLVMConstantGetter< |
47 | T, std::enable_if_t<(std::is_integral<T>::value && std::is_signed<T>::value)>> { |
48 | static llvm::Constant* getElement(llvm::Type* ty, T t) { |
49 | return llvm::ConstantInt::getSigned(ty, t); |
50 | } |
51 | }; |
52 | |
53 | template <typename T> |
54 | struct LLVMConstantGetter< |
55 | T, std::enable_if_t<(std::is_integral<T>::value && !std::is_signed<T>::value)>> { |
56 | static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantInt::get(ty, t); } |
57 | }; |
58 | |
59 | template <typename T> |
60 | struct LLVMConstantGetter<T, std::enable_if_t<std::is_floating_point<T>::value>> { |
61 | static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantFP::get(ty, t); } |
62 | }; |
63 | |
64 | template <typename T, typename = std::enable_if<std::is_pod<T>::value>> |
65 | void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_elements, |
66 | std::vector<llvm::Constant*>* elements) { |
67 | elements->resize(num_elements, nullptr); |
68 | std::transform(static_cast<T*>(tensor_data), static_cast<T*>(tensor_data) + num_elements, |
69 | elements->begin(), |
70 | [&](T t) { return LLVMConstantGetter<T>::getElement(element_type, t); }); |
71 | } |
72 | |
73 | llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr) { |
74 | llvm::Type* element_type = nullptr; |
75 | |
76 | auto arr_type = arr.DataType(); |
77 | CHECK(arr.IsContiguous()) << "CodegenParams: only support contiguous arrays" ; |
78 | CHECK_EQ(arr->device.device_type, kDLCPU) << "CodegenParams: only support contiguous arrays" ; |
79 | CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw " |
80 | << arr_type.lanes(); |
81 | |
82 | auto shape = arr.Shape(); |
83 | int num_elements = 1; |
84 | for (auto shape_elem : shape) { |
85 | num_elements *= shape_elem; |
86 | } |
87 | |
88 | std::vector<llvm::Constant*> elements; |
89 | |
90 | switch (arr_type.code()) { |
91 | case runtime::DataType::kInt: |
92 | CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || |
93 | arr_type.bits() == 64) |
94 | << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " |
95 | << arr_type.bits() << "-bit array" ; |
96 | element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); |
97 | |
98 | switch (arr_type.bits()) { |
99 | case 8: |
100 | BuildLLVMVector<int8_t>(element_type, arr->data, num_elements, &elements); |
101 | break; |
102 | case 16: |
103 | BuildLLVMVector<int16_t>(element_type, arr->data, num_elements, &elements); |
104 | break; |
105 | case 32: |
106 | BuildLLVMVector<int32_t>(element_type, arr->data, num_elements, &elements); |
107 | break; |
108 | case 64: |
109 | BuildLLVMVector<int64_t>(element_type, arr->data, num_elements, &elements); |
110 | break; |
111 | default: |
112 | ICHECK(false) << "should not get here" ; |
113 | break; |
114 | } |
115 | break; |
116 | |
117 | case runtime::DataType::TypeCode::kUInt: |
118 | CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || |
119 | arr_type.bits() == 64) |
120 | << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " |
121 | << arr_type.bits() << "-bit array" ; |
122 | element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); |
123 | |
124 | switch (arr_type.bits()) { |
125 | case 8: |
126 | BuildLLVMVector<uint8_t>(element_type, arr->data, num_elements, &elements); |
127 | break; |
128 | case 16: |
129 | BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements); |
130 | break; |
131 | case 32: |
132 | BuildLLVMVector<uint32_t>(element_type, arr->data, num_elements, &elements); |
133 | break; |
134 | case 64: |
135 | BuildLLVMVector<uint64_t>(element_type, arr->data, num_elements, &elements); |
136 | break; |
137 | default: |
138 | ICHECK(false) << "should not get here" ; |
139 | break; |
140 | } |
141 | break; |
142 | |
143 | case runtime::DataType::TypeCode::kFloat: |
144 | switch (arr_type.bits()) { |
145 | case 16: |
146 | // NOTE: float16 is treated as uint16_t. |
147 | element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); |
148 | BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements); |
149 | break; |
150 | case 32: |
151 | element_type = llvm::Type::getFloatTy(*ctx); |
152 | BuildLLVMVector<float>(element_type, arr->data, num_elements, &elements); |
153 | break; |
154 | case 64: |
155 | element_type = llvm::Type::getDoubleTy(*ctx); |
156 | BuildLLVMVector<double>(element_type, arr->data, num_elements, &elements); |
157 | break; |
158 | default: |
159 | CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " |
160 | << arr_type.bits() << "-bit array" ; |
161 | break; |
162 | } |
163 | break; |
164 | |
165 | case runtime::DataType::TypeCode::kBFloat: |
166 | CHECK(arr_type.bits() == 16) |
167 | << "CodegenParams: only support 16-bit bfloat; saw " << arr_type.bits() << "-bit array" ; |
168 | element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); |
169 | BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements); |
170 | |
171 | default: |
172 | CHECK(false) << "Data type not supported" ; |
173 | } |
174 | |
175 | return llvm::cast<llvm::ConstantArray>(llvm::ConstantArray::get( |
176 | llvm::ArrayType::get(element_type, num_elements), llvm::ArrayRef<llvm::Constant*>(elements))); |
177 | } |
178 | |
179 | } // namespace codegen |
180 | } // namespace tvm |
181 | |
182 | #endif // TVM_LLVM_VERSION |
183 | |