1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_vhls.cc |
22 | */ |
23 | #include "codegen_vhls.h" |
24 | |
25 | #include <string> |
26 | #include <vector> |
27 | |
28 | #include "../../runtime/opencl/sdaccel/sdaccel_module.h" |
29 | #include "../build_common.h" |
30 | |
31 | namespace tvm { |
32 | namespace codegen { |
33 | |
34 | void CodeGenVivadoHLS::Init(bool output_ssa) { |
35 | CodeGenC::Init(output_ssa); |
36 | |
37 | this->stream << "#include <ap_int.h>\n\n" ; |
38 | this->stream << "#include <algorithm>\n\n" ; |
39 | } |
40 | |
41 | void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { |
42 | if (t.is_uint()) { |
43 | switch (t.bits()) { |
44 | case 8: |
45 | os << "unsigned char" ; |
46 | break; |
47 | case 16: |
48 | os << "unsigned short" ; |
49 | break; |
50 | case 32: |
51 | os << "unsigned int" ; |
52 | break; |
53 | case 64: |
54 | os << "unsigned long long" ; |
55 | break; |
56 | default: |
57 | os << "ap_uint<" << t.bits() << ">" ; |
58 | break; |
59 | } |
60 | } else if (t.is_int()) { |
61 | switch (t.bits()) { |
62 | case 8: |
63 | os << "char" ; |
64 | break; |
65 | case 16: |
66 | os << "short" ; |
67 | break; |
68 | case 32: |
69 | os << "int" ; |
70 | break; |
71 | case 64: |
72 | os << "long long" ; |
73 | break; |
74 | default: |
75 | os << "ap_int<" << t.bits() << ">" ; |
76 | break; |
77 | } |
78 | } else { |
79 | CodeGenC::PrintType(t, os); |
80 | } |
81 | } |
82 | |
83 | void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void" ; } |
84 | |
85 | void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { |
86 | for (size_t i = 0; i < f->params.size(); ++i) { |
87 | Var v = f->params[i]; |
88 | std::string vid = GetVarID(v.get()); |
89 | if (v.dtype().is_handle()) { |
90 | this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n" ; |
91 | } |
92 | this->stream << "#pragma HLS INTERFACE s_axilite port=" << vid << " bundle=control\n" ; |
93 | } |
94 | this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n" ; |
95 | } |
96 | |
97 | template <typename T> |
98 | inline void PrintBinaryExpr(const T* op, const char* opstr, |
99 | std::ostream& os, // NOLINT(*) |
100 | CodeGenVivadoHLS* p) { |
101 | os << opstr << '('; |
102 | p->PrintExpr(op->a, os); |
103 | os << ", " ; |
104 | p->PrintExpr(op->b, os); |
105 | os << ')'; |
106 | } |
107 | |
108 | void CodeGenVivadoHLS::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) |
109 | const char* opstr = "std::min" ; |
110 | if (op->dtype.is_float()) { |
111 | switch (op->dtype.bits()) { |
112 | case 32: |
113 | opstr = "fminf" ; |
114 | break; |
115 | case 64: |
116 | opstr = "fmin" ; |
117 | break; |
118 | } |
119 | } |
120 | |
121 | PrintBinaryExpr(op, opstr, os, this); |
122 | } |
123 | |
124 | void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) |
125 | const char* opstr = "std::max" ; |
126 | if (op->dtype.is_float()) { |
127 | switch (op->dtype.bits()) { |
128 | case 32: |
129 | opstr = "fmaxf" ; |
130 | break; |
131 | case 64: |
132 | opstr = "fmax" ; |
133 | break; |
134 | } |
135 | } |
136 | |
137 | PrintBinaryExpr(op, opstr, os, this); |
138 | } |
139 | |
140 | runtime::Module BuildSDAccel(IRModule mod, Target target) { |
141 | using tvm::runtime::Registry; |
142 | bool output_ssa = false; |
143 | CodeGenVivadoHLS cg; |
144 | |
145 | // Generate source code for get_source(). |
146 | cg.Init(output_ssa); |
147 | |
148 | for (auto kv : mod->functions) { |
149 | ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc" ; |
150 | auto f = Downcast<PrimFunc>(kv.second); |
151 | auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); |
152 | ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) |
153 | << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch" ; |
154 | cg.AddFunction(f); |
155 | } |
156 | |
157 | std::string whole_code = cg.Finish(); |
158 | |
159 | // Generate source code for compilation. |
160 | Array<Array<runtime::String>> kernel_info; |
161 | |
162 | for (auto kv : mod->functions) { |
163 | ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc" ; |
164 | auto f = Downcast<PrimFunc>(kv.second); |
165 | CodeGenVivadoHLS cg; |
166 | cg.Init(output_ssa); |
167 | cg.AddFunction(f); |
168 | std::string code = cg.Finish(); |
169 | if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc" )) { |
170 | code = (*f)(code).operator std::string(); |
171 | } |
172 | |
173 | auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
174 | ICHECK(global_symbol.defined()) |
175 | << "CodeGenC: Expect PrimFunc to have the global_symbol attribute" ; |
176 | kernel_info.push_back({global_symbol.value(), code}); |
177 | } |
178 | |
179 | std::string xclbin; |
180 | if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile" )) { |
181 | String device = target->GetAttr<String>("device" , "" ).value(); |
182 | xclbin = (*f)(kernel_info, device).operator std::string(); |
183 | } else { |
184 | LOG(FATAL) << "Cannot compile Vivado HLS code." ; |
185 | } |
186 | return SDAccelModuleCreate(xclbin, "xclbin" , ExtractFuncInfo(mod), whole_code); |
187 | } |
188 | |
189 | TVM_REGISTER_GLOBAL("target.build.sdaccel" ).set_body_typed(BuildSDAccel); |
190 | |
191 | } // namespace codegen |
192 | } // namespace tvm |
193 | |