1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_vhls.cc
22 */
23#include "codegen_vhls.h"
24
25#include <string>
26#include <vector>
27
28#include "../../runtime/opencl/sdaccel/sdaccel_module.h"
29#include "../build_common.h"
30
31namespace tvm {
32namespace codegen {
33
34void CodeGenVivadoHLS::Init(bool output_ssa) {
35 CodeGenC::Init(output_ssa);
36
37 this->stream << "#include <ap_int.h>\n\n";
38 this->stream << "#include <algorithm>\n\n";
39}
40
41void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
42 if (t.is_uint()) {
43 switch (t.bits()) {
44 case 8:
45 os << "unsigned char";
46 break;
47 case 16:
48 os << "unsigned short";
49 break;
50 case 32:
51 os << "unsigned int";
52 break;
53 case 64:
54 os << "unsigned long long";
55 break;
56 default:
57 os << "ap_uint<" << t.bits() << ">";
58 break;
59 }
60 } else if (t.is_int()) {
61 switch (t.bits()) {
62 case 8:
63 os << "char";
64 break;
65 case 16:
66 os << "short";
67 break;
68 case 32:
69 os << "int";
70 break;
71 case 64:
72 os << "long long";
73 break;
74 default:
75 os << "ap_int<" << t.bits() << ">";
76 break;
77 }
78 } else {
79 CodeGenC::PrintType(t, os);
80 }
81}
82
83void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void"; }
84
85void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
86 for (size_t i = 0; i < f->params.size(); ++i) {
87 Var v = f->params[i];
88 std::string vid = GetVarID(v.get());
89 if (v.dtype().is_handle()) {
90 this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n";
91 }
92 this->stream << "#pragma HLS INTERFACE s_axilite port=" << vid << " bundle=control\n";
93 }
94 this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n";
95}
96
97template <typename T>
98inline void PrintBinaryExpr(const T* op, const char* opstr,
99 std::ostream& os, // NOLINT(*)
100 CodeGenVivadoHLS* p) {
101 os << opstr << '(';
102 p->PrintExpr(op->a, os);
103 os << ", ";
104 p->PrintExpr(op->b, os);
105 os << ')';
106}
107
108void CodeGenVivadoHLS::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
109 const char* opstr = "std::min";
110 if (op->dtype.is_float()) {
111 switch (op->dtype.bits()) {
112 case 32:
113 opstr = "fminf";
114 break;
115 case 64:
116 opstr = "fmin";
117 break;
118 }
119 }
120
121 PrintBinaryExpr(op, opstr, os, this);
122}
123
124void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
125 const char* opstr = "std::max";
126 if (op->dtype.is_float()) {
127 switch (op->dtype.bits()) {
128 case 32:
129 opstr = "fmaxf";
130 break;
131 case 64:
132 opstr = "fmax";
133 break;
134 }
135 }
136
137 PrintBinaryExpr(op, opstr, os, this);
138}
139
140runtime::Module BuildSDAccel(IRModule mod, Target target) {
141 using tvm::runtime::Registry;
142 bool output_ssa = false;
143 CodeGenVivadoHLS cg;
144
145 // Generate source code for get_source().
146 cg.Init(output_ssa);
147
148 for (auto kv : mod->functions) {
149 ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
150 auto f = Downcast<PrimFunc>(kv.second);
151 auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
152 ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
153 << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
154 cg.AddFunction(f);
155 }
156
157 std::string whole_code = cg.Finish();
158
159 // Generate source code for compilation.
160 Array<Array<runtime::String>> kernel_info;
161
162 for (auto kv : mod->functions) {
163 ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
164 auto f = Downcast<PrimFunc>(kv.second);
165 CodeGenVivadoHLS cg;
166 cg.Init(output_ssa);
167 cg.AddFunction(f);
168 std::string code = cg.Finish();
169 if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
170 code = (*f)(code).operator std::string();
171 }
172
173 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
174 ICHECK(global_symbol.defined())
175 << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
176 kernel_info.push_back({global_symbol.value(), code});
177 }
178
179 std::string xclbin;
180 if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
181 String device = target->GetAttr<String>("device", "").value();
182 xclbin = (*f)(kernel_info, device).operator std::string();
183 } else {
184 LOG(FATAL) << "Cannot compile Vivado HLS code.";
185 }
186 return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code);
187}
188
189TVM_REGISTER_GLOBAL("target.build.sdaccel").set_body_typed(BuildSDAccel);
190
191} // namespace codegen
192} // namespace tvm
193