1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_x86_64.cc
22 * \brief X86-64 specific code generator
23 */
24#ifdef TVM_LLVM_VERSION
25
26#include <llvm/IR/DerivedTypes.h>
27#include <llvm/IR/Function.h>
28#include <llvm/IR/Intrinsics.h>
29#if TVM_LLVM_VERSION >= 100
30#include <llvm/IR/IntrinsicsX86.h>
31#endif
32#include <llvm/MC/MCSubtargetInfo.h>
33#include <llvm/Support/Casting.h>
34#include <llvm/Target/TargetMachine.h>
35#include <tvm/runtime/registry.h>
36
37#include <string>
38#include <vector>
39
40#include "codegen_cpu.h"
41#include "llvm_instance.h"
42
43namespace tvm {
44namespace codegen {
45
46namespace {
47bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) {
48 // MCSubTargetInfo::checkFeatures was added in LLVM 6.0
49#if TVM_LLVM_VERSION >= 60
50 const auto* MCInfo = tm.getMCSubtargetInfo();
51 return MCInfo->checkFeatures(std::string("+") + feature);
52#else
53 return false;
54 // TODO(tulloch) - enable this block, need to figure out how to reimplement
55 // this given visibility constraints, similar to
56 // https://github.com/rust-lang/rust/pull/31709
57
58 // Copied from
59 // https://github.com/llvm-mirror/llvm/blob/5136df4/lib/MC/MCSubtargetInfo.cpp#L78-L88.
60
61 // auto checkFeatures = [&](const std::string FS) {
62 // llvm::SubtargetFeatures T(FS);
63 // llvm::FeatureBitset Set, All;
64 // for (std::string F : T.getFeatures()) {
65 // llvm::SubtargetFeatures::ApplyFeatureFlag(Set, F, MCInfo->ProcFeatures);
66 // if (F[0] == '-') {
67 // F[0] = '+';
68 // }
69 // llvm::SubtargetFeatures::ApplyFeatureFlag(All, F, MCInfo->ProcFeatures);
70 // }
71 // return (MCInfo->getFeatureBits() & All) == Set;
72 // };
73 // return checkFeatures(MCInfo, std::string("+") + feature);
74#endif
75}
76} // namespace
77
78class CodeGenX86_64 final : public CodeGenCPU {
79 public:
80 llvm::Value* VisitExpr_(const CastNode* op) override;
81
82 private:
83 llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
84 const std::vector<llvm::Value*>& args);
85};
86
87llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
88 // LLVM does not automatically generate the correct instruction sequences for
89 // half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
90 // vcvtph2ps), so we explicitly generate them ourselves.
91 const auto from = op->value.dtype();
92 const auto to = op->dtype;
93 if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
94 ICHECK_EQ(from.lanes(), to.lanes());
95 llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
96
97 const auto has_avx512 = TargetHasFeature(*tm, "avx512f");
98
99 if (from.lanes() >= 16 && has_avx512) {
100 return CallVectorIntrin(
101 llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
102 DTypeToLLVMType(DataType::Float(32, from.lanes())),
103 {
104 MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(),
105 {op->value})),
106 MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())),
107 /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
108 /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
109 });
110 }
111
112#if TVM_LLVM_VERSION <= 100
113 // The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11.
114 const auto has_f16c = TargetHasFeature(*tm, "f16c");
115
116 if (from.lanes() >= 8 && has_f16c) {
117 return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8,
118 DTypeToLLVMType(DataType::Float(32, from.lanes())),
119 {MakeValue(tir::Call(DataType::Int(16, from.lanes()),
120 tir::builtin::reinterpret(), {op->value}))});
121 }
122#endif
123 }
124
125 return CodeGenCPU::VisitExpr_(op);
126}
127
128llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes,
129 llvm::Type* result_ty,
130 const std::vector<llvm::Value*>& args) {
131 llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
132#if TVM_LLVM_VERSION >= 120
133 size_t num_elems = llvm::cast<llvm::FixedVectorType>(result_ty)->getNumElements();
134#else
135 size_t num_elems = llvm::cast<llvm::VectorType>(result_ty)->getNumElements();
136#endif
137 if (intrin_lanes == num_elems) {
138 return builder_->CreateCall(f, args);
139 }
140
141 // Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary),
142 // compute each result, and then concatenate the vectors (slicing the result if necessary).
143 ICHECK_LT(intrin_lanes, num_elems);
144 std::vector<llvm::Value*> split_results;
145 for (size_t i = 0; i < num_elems; i += intrin_lanes) {
146 std::vector<llvm::Value*> split_args;
147 for (const auto& v : args) {
148 if (v->getType()->isVectorTy()) {
149 ICHECK_EQ(GetVectorNumElements(v), num_elems);
150 split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
151 } else {
152 split_args.push_back(v);
153 }
154 }
155#if TVM_LLVM_VERSION >= 110
156 llvm::Type* type = llvm::FixedVectorType::get(result_ty->getScalarType(), intrin_lanes);
157#else
158 llvm::Type* type = llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes);
159#endif
160 split_results.push_back(CallVectorIntrin(id, intrin_lanes, type, split_args));
161 }
162 return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems);
163}
164
165TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
166 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
167 *rv = static_cast<void*>(new CodeGenX86_64());
168 });
169
170} // namespace codegen
171} // namespace tvm
172
173#endif // TVM_LLVM_VERSION
174