1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_arm.cc
22 * \brief ARM specific code generator
23 */
24#ifdef TVM_LLVM_VERSION
25
26#include <llvm/IR/Intrinsics.h>
27#include <tvm/runtime/registry.h>
28#if TVM_LLVM_VERSION >= 100
29#include <llvm/IR/IntrinsicsARM.h>
30#endif
31#include <llvm/Target/TargetMachine.h>
32
33#include "codegen_cpu.h"
34
35namespace tvm {
36namespace codegen {
37
38// ARM specific code generator, this is used as an example on
39// how to override behavior llvm code generator for specific target
40class CodeGenARM final : public CodeGenCPU {
41 public:
42 CodeGenARM() = default;
43 virtual ~CodeGenARM() = default;
44
45 void InitTarget() final {
46 // set native vector bits.
47 native_vector_bits_ = 16 * 8;
48 CodeGenCPU::InitTarget();
49 }
50 llvm::Value* CreateIntrinsic(const CallNode* op) override;
51
52 private:
53 PrimExpr ARMPopcount(const CallNode* op);
54};
55
56llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
57 if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
58 llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
59 if (id == llvm::Intrinsic::ctpop) {
60 PrimExpr e = ARMPopcount(op);
61 return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
62 }
63 }
64 return CodeGenCPU::CreateIntrinsic(op);
65}
66
67PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
68 using namespace tir;
69 const PrimExpr& e = call->args[2];
70 llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop;
71 llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu;
72
73 // Fallback to default llvm lowering rule if input type not a full vector or half vector length
74 int total_size = call->dtype.bits() * call->dtype.lanes();
75 if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
76 (total_size != 128 && total_size != 64)) {
77 Array<PrimExpr> vcnt_args;
78 vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
79 vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
80 vcnt_args.push_back(e);
81 return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args);
82 }
83
84 // Popcount lowering rule:
85 // Reinterpret input vector as a vector of 8bit values and preform popcount
86 // Pairwise add between adjacent elements and double width with vpaddlu
87 // to return back to original input type
88
89 // Dvisions are always divisible (number of bits = 64 or 128)
90 DataType uint8_type = DataType(e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8);
91 DataType uint16_type =
92 DataType(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
93 DataType uint32_type =
94 DataType(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
95
96 // Interpret input as vector of 8bit values
97 PrimExpr input8 = reinterpret(uint8_type, e);
98 // Popcount 8bit->8bit
99 const CallNode* c0 = input8.as<CallNode>();
100 ICHECK(c0 != nullptr);
101 Array<PrimExpr> vcnt8_args;
102 vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
103 vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
104 vcnt8_args.push_back(input8);
105 PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args);
106
107 // Accumulation 8->16bit
108 Array<PrimExpr> vcnt16_args;
109 vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
110 vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
111 vcnt16_args.push_back(vcnt8);
112 PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args);
113 if (call->dtype.bits() == 16) {
114 return vcnt16;
115 }
116
117 // Accumulation 16->32bit
118 Array<PrimExpr> vcnt32_args;
119 vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
120 vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
121 vcnt32_args.push_back(vcnt16);
122 PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args);
123 if (call->dtype.bits() == 32) {
124 return vcnt32;
125 }
126
127 // Accumulation 32->64bit
128 Array<PrimExpr> vcnt64_args;
129 vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
130 vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
131 vcnt64_args.push_back(vcnt32);
132 return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args);
133}
134
135TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
136 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
137 *rv = static_cast<void*>(new CodeGenARM());
138 });
139
140} // namespace codegen
141} // namespace tvm
142
143#endif // TVM_LLVM_VERSION
144