1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_arm.cc |
22 | * \brief ARM specific code generator |
23 | */ |
24 | #ifdef TVM_LLVM_VERSION |
25 | |
26 | #include <llvm/IR/Intrinsics.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #if TVM_LLVM_VERSION >= 100 |
29 | #include <llvm/IR/IntrinsicsARM.h> |
30 | #endif |
31 | #include <llvm/Target/TargetMachine.h> |
32 | |
33 | #include "codegen_cpu.h" |
34 | |
35 | namespace tvm { |
36 | namespace codegen { |
37 | |
38 | // ARM specific code generator, this is used as an example on |
39 | // how to override behavior llvm code generator for specific target |
40 | class CodeGenARM final : public CodeGenCPU { |
41 | public: |
42 | CodeGenARM() = default; |
43 | virtual ~CodeGenARM() = default; |
44 | |
45 | void InitTarget() final { |
46 | // set native vector bits. |
47 | native_vector_bits_ = 16 * 8; |
48 | CodeGenCPU::InitTarget(); |
49 | } |
50 | llvm::Value* CreateIntrinsic(const CallNode* op) override; |
51 | |
52 | private: |
53 | PrimExpr ARMPopcount(const CallNode* op); |
54 | }; |
55 | |
56 | llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { |
57 | if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { |
58 | llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value); |
59 | if (id == llvm::Intrinsic::ctpop) { |
60 | PrimExpr e = ARMPopcount(op); |
61 | return CodeGenCPU::CreateIntrinsic(e.as<CallNode>()); |
62 | } |
63 | } |
64 | return CodeGenCPU::CreateIntrinsic(op); |
65 | } |
66 | |
67 | PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { |
68 | using namespace tir; |
69 | const PrimExpr& e = call->args[2]; |
70 | llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop; |
71 | llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu; |
72 | |
73 | // Fallback to default llvm lowering rule if input type not a full vector or half vector length |
74 | int total_size = call->dtype.bits() * call->dtype.lanes(); |
75 | if (!call->dtype.is_vector() || call->dtype.bits() == 8 || |
76 | (total_size != 128 && total_size != 64)) { |
77 | Array<PrimExpr> vcnt_args; |
78 | vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); |
79 | vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); |
80 | vcnt_args.push_back(e); |
81 | return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); |
82 | } |
83 | |
84 | // Popcount lowering rule: |
85 | // Reinterpret input vector as a vector of 8bit values and preform popcount |
86 | // Pairwise add between adjacent elements and double width with vpaddlu |
87 | // to return back to original input type |
88 | |
89 | // Dvisions are always divisible (number of bits = 64 or 128) |
90 | DataType uint8_type = DataType(e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); |
91 | DataType uint16_type = |
92 | DataType(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); |
93 | DataType uint32_type = |
94 | DataType(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); |
95 | |
96 | // Interpret input as vector of 8bit values |
97 | PrimExpr input8 = reinterpret(uint8_type, e); |
98 | // Popcount 8bit->8bit |
99 | const CallNode* c0 = input8.as<CallNode>(); |
100 | ICHECK(c0 != nullptr); |
101 | Array<PrimExpr> vcnt8_args; |
102 | vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); |
103 | vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); |
104 | vcnt8_args.push_back(input8); |
105 | PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); |
106 | |
107 | // Accumulation 8->16bit |
108 | Array<PrimExpr> vcnt16_args; |
109 | vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); |
110 | vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); |
111 | vcnt16_args.push_back(vcnt8); |
112 | PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); |
113 | if (call->dtype.bits() == 16) { |
114 | return vcnt16; |
115 | } |
116 | |
117 | // Accumulation 16->32bit |
118 | Array<PrimExpr> vcnt32_args; |
119 | vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); |
120 | vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); |
121 | vcnt32_args.push_back(vcnt16); |
122 | PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); |
123 | if (call->dtype.bits() == 32) { |
124 | return vcnt32; |
125 | } |
126 | |
127 | // Accumulation 32->64bit |
128 | Array<PrimExpr> vcnt64_args; |
129 | vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); |
130 | vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); |
131 | vcnt64_args.push_back(vcnt32); |
132 | return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); |
133 | } |
134 | |
135 | TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm" ) |
136 | .set_body([](const TVMArgs& targs, TVMRetValue* rv) { |
137 | *rv = static_cast<void*>(new CodeGenARM()); |
138 | }); |
139 | |
140 | } // namespace codegen |
141 | } // namespace tvm |
142 | |
143 | #endif // TVM_LLVM_VERSION |
144 | |