1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_c.cc |
22 | */ |
23 | #include "codegen_c.h" |
24 | |
25 | #include <tvm/arith/analyzer.h> |
26 | |
27 | #include <cctype> |
28 | #include <iomanip> |
29 | |
30 | #include "../../arith/pattern_match.h" |
31 | #include "codegen_params.h" |
32 | |
33 | namespace tvm { |
34 | namespace codegen { |
35 | |
36 | using namespace tir; |
37 | |
38 | void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; } |
39 | |
40 | void CodeGenC::InitFuncState(const PrimFunc& f) { |
41 | alloc_storage_scope_.clear(); |
42 | handle_data_type_.clear(); |
43 | CodeGenSourceBase::ClearFuncState(); |
44 | } |
45 | |
46 | void CodeGenC::ReserveKeywordsAsUnique() { |
47 | // skip the first underscore, so SSA variable starts from _1 |
48 | name_supply_->ReserveName("_" ); |
49 | name_supply_->ReserveName("extern" ); |
50 | name_supply_->ReserveName("void" ); |
51 | name_supply_->ReserveName("int" ); |
52 | name_supply_->ReserveName("float" ); |
53 | name_supply_->ReserveName("double" ); |
54 | name_supply_->ReserveName("char" ); |
55 | name_supply_->ReserveName("unsigned" ); |
56 | name_supply_->ReserveName("short" ); |
57 | name_supply_->ReserveName("long" ); |
58 | name_supply_->ReserveName("if" ); |
59 | name_supply_->ReserveName("else" ); |
60 | name_supply_->ReserveName("switch" ); |
61 | name_supply_->ReserveName("case" ); |
62 | name_supply_->ReserveName("default" ); |
63 | name_supply_->ReserveName("for" ); |
64 | name_supply_->ReserveName("do" ); |
65 | name_supply_->ReserveName("while" ); |
66 | name_supply_->ReserveName("goto" ); |
67 | name_supply_->ReserveName("register" ); |
68 | name_supply_->ReserveName("continue" ); |
69 | name_supply_->ReserveName("break" ); |
70 | name_supply_->ReserveName("typedef" ); |
71 | name_supply_->ReserveName("struct" ); |
72 | name_supply_->ReserveName("enum" ); |
73 | name_supply_->ReserveName("union" ); |
74 | name_supply_->ReserveName("return" ); |
75 | } |
76 | |
77 | void CodeGenC::AddFunction(const PrimFunc& f) { |
78 | // clear previous generated state. |
79 | this->InitFuncState(f); |
80 | // reserve keywords |
81 | ReserveKeywordsAsUnique(); |
82 | |
83 | auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
84 | ICHECK(global_symbol.defined()) |
85 | << "CodeGenC: Expect PrimFunc to have the global_symbol attribute" ; |
86 | bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); |
87 | |
88 | this->PrintFuncPrefix(stream); |
89 | this->PrintExtraAttrs(f); |
90 | this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(" ; |
91 | |
92 | for (size_t i = 0; i < f->params.size(); ++i) { |
93 | tir::Var v = f->params[i]; |
94 | std::string vid = AllocVarID(v.get()); |
95 | if (i != 0) stream << ", " ; |
96 | if (v.dtype().is_handle()) { |
97 | auto it = alloc_storage_scope_.find(v.get()); |
98 | if (it != alloc_storage_scope_.end()) { |
99 | PrintStorageScope(it->second, stream); |
100 | } |
101 | |
102 | PrintType(GetType(v), stream); |
103 | // Register handle data type |
104 | // TODO(tvm-team): consider simply keep type info in the |
105 | // type annotation(via a normalizing rewriting). |
106 | if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { |
107 | if (auto* prim = ptr->element_type.as<PrimTypeNode>()) { |
108 | RegisterHandleType(v.get(), prim->dtype); |
109 | } |
110 | } |
111 | |
112 | if (no_alias) { |
113 | PrintRestrict(v, stream); |
114 | } |
115 | } else { |
116 | PrintType(GetType(v), stream); |
117 | } |
118 | stream << ' ' << vid; |
119 | } |
120 | stream << ") {\n" ; |
121 | this->PreFunctionBody(f); |
122 | int func_scope = this->BeginScope(); |
123 | this->PrintStmt(f->body); |
124 | this->PrintFinalReturn(); |
125 | this->EndScope(func_scope); |
126 | this->PrintIndent(); |
127 | this->stream << "}\n\n" ; |
128 | } |
129 | |
130 | void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void" ; } |
131 | |
132 | void CodeGenC::(const PrimFunc& f) {} |
133 | |
134 | void CodeGenC::PrintFinalReturn() {} |
135 | |
136 | std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } |
137 | |
138 | void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) |
139 | if (print_ssa_form_) { |
140 | std::ostringstream temp; |
141 | VisitExpr(n, temp); |
142 | os << SSAGetID(temp.str(), n.dtype()); |
143 | } else { |
144 | VisitExpr(n, os); |
145 | } |
146 | } |
147 | |
148 | static bool CheckOutermostBracketMatch(const std::string& s); |
149 | |
150 | void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) { |
151 | PrintType(t, stream); |
152 | stream << ' ' << target << " = " ; |
153 | if (CheckOutermostBracketMatch(src)) { |
154 | stream << src.substr(1, src.length() - 2); |
155 | } else { |
156 | stream << src; |
157 | } |
158 | stream << ";\n" ; |
159 | } |
160 | |
161 | // Print a reference expression to a buffer. |
162 | std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) { |
163 | const VarNode* buffer_var = buffer->data.get(); |
164 | std::ostringstream os; |
165 | std::string vid = GetVarID(buffer_var); |
166 | std::string scope; |
167 | if (alloc_storage_scope_.count(buffer_var)) { |
168 | scope = alloc_storage_scope_.at(buffer_var); |
169 | } |
170 | bool is_vol = IsVolatile(buffer_var); |
171 | |
172 | auto ptr_cast = [this, is_vol, scope](DataType pointed_to) { |
173 | std::ostringstream ptr_os; |
174 | ptr_os << "(" ; |
175 | if (is_vol) { |
176 | ptr_os << "volatile " ; |
177 | } |
178 | if (!scope.empty() && IsScopePartOfType()) { |
179 | PrintStorageScope(scope, ptr_os); |
180 | } |
181 | PrintType(pointed_to, ptr_os); |
182 | ptr_os << "*)" ; |
183 | return ptr_os.str(); |
184 | }; |
185 | |
186 | DataType buffer_element_dtype = buffer->dtype; |
187 | |
188 | std::string buffer_str = vid; |
189 | if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) { |
190 | std::stringstream temp; |
191 | temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")" ; |
192 | buffer_str = temp.str(); |
193 | } |
194 | |
195 | std::string index_str = PrintExpr(index); |
196 | if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { |
197 | // This is a special case, because CodegenCUDA::PrintType() |
198 | // returns "int" for bool and for 4-bit integers. In most cases, |
199 | // we divide by the number of lanes to determine the index. |
200 | // However, the backing type for scalar int4 and scalar bool is |
201 | // int32. Therefore, we need to divide by the ratio of their |
202 | // sizes in that case. |
203 | int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); |
204 | |
205 | os << "*(" |
206 | << "(" << ptr_cast(t) << vid << ")" |
207 | << " + " << index_str << " / " << div_factor << ")" ; |
208 | } else if (t == buffer_element_dtype) { |
209 | os << buffer_str << "[" << index_str << "]" ; |
210 | } else { |
211 | os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")" ; |
212 | } |
213 | |
214 | return os.str(); |
215 | } |
216 | |
217 | // Print a reference expression to a buffer. |
218 | std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, |
219 | int kind) { |
220 | if (kind < builtin::kArrKindBound_) { |
221 | std::ostringstream os; |
222 | os << "(((DLTensor*)" ; |
223 | this->PrintExpr(buffer, os); |
224 | os << ")" ; |
225 | if (kind == builtin::kArrAddr) { |
226 | os << " + " ; |
227 | this->PrintExpr(index, os); |
228 | os << ")" ; |
229 | return os.str(); |
230 | } |
231 | os << '['; |
232 | this->PrintExpr(index, os); |
233 | os << "]." ; |
234 | // other case: get fields. |
235 | switch (kind) { |
236 | case builtin::kArrData: |
237 | os << "data" ; |
238 | break; |
239 | case builtin::kArrShape: |
240 | os << "shape" ; |
241 | break; |
242 | case builtin::kArrStrides: |
243 | os << "strides" ; |
244 | break; |
245 | case builtin::kArrNDim: |
246 | os << "ndim" ; |
247 | break; |
248 | case builtin::kArrTypeCode: |
249 | os << "dtype.code" ; |
250 | break; |
251 | case builtin::kArrTypeBits: |
252 | os << "dtype.bits" ; |
253 | break; |
254 | case builtin::kArrByteOffset: |
255 | os << "byte_offset" ; |
256 | break; |
257 | case builtin::kArrTypeLanes: |
258 | os << "dtype.lanes" ; |
259 | break; |
260 | case builtin::kArrDeviceId: |
261 | os << "device.device_id" ; |
262 | break; |
263 | case builtin::kArrDeviceType: |
264 | os << "device.device_type" ; |
265 | break; |
266 | default: |
267 | LOG(FATAL) << "unknown field code" ; |
268 | } |
269 | os << ')'; |
270 | return os.str(); |
271 | } else { |
272 | ICHECK_LT(kind, builtin::kTVMValueKindBound_); |
273 | std::ostringstream os; |
274 | os << "(((TVMValue*)" ; |
275 | this->PrintExpr(buffer, os); |
276 | os << ")[" << index << "]." ; |
277 | if (t.is_handle()) { |
278 | os << "v_handle" ; |
279 | } else if (t.is_float()) { |
280 | os << "v_float64" ; |
281 | } else if (t.is_int()) { |
282 | os << "v_int64" ; |
283 | } else { |
284 | LOG(FATAL) << "Do not know how to handle type" << t; |
285 | } |
286 | os << ")" ; |
287 | return os.str(); |
288 | } |
289 | } |
290 | |
291 | bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const { |
292 | auto it = handle_data_type_.find(buf_var); |
293 | if (it == handle_data_type_.end()) return false; |
294 | return it->second == t; |
295 | } |
296 | |
297 | void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) { |
298 | auto it = handle_data_type_.find(buf_var); |
299 | if (it == handle_data_type_.end()) { |
300 | handle_data_type_[buf_var] = t; |
301 | } else { |
302 | ICHECK(it->second == t) << "conflicting buf var type" ; |
303 | } |
304 | } |
305 | |
306 | void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i, |
307 | std::ostream& os) { // NOLINT(*) |
308 | os << vec << ".s" << std::hex << i << std::dec; |
309 | } |
310 | |
311 | void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i, |
312 | const std::string& value) { |
313 | this->PrintIndent(); |
314 | stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec; |
315 | } |
316 | |
317 | std::string CodeGenC::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { |
318 | return GetBufferRef(t, buffer, base); |
319 | } |
320 | |
321 | void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, |
322 | const std::string& value) { |
323 | std::string ref = GetBufferRef(t, buffer, base); |
324 | this->PrintIndent(); |
325 | stream << ref << " = " << value << ";\n" ; |
326 | } |
327 | |
328 | std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) { |
329 | if (from == target) return value; |
330 | std::ostringstream os; |
331 | os << "((" ; |
332 | this->PrintType(target, os); |
333 | os << ")" << value << ")" ; |
334 | return os.str(); |
335 | } |
336 | |
337 | void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented" ; } |
338 | |
339 | void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) |
340 | } |
341 | |
342 | void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) |
343 | ICHECK_EQ(scope, "global" ); |
344 | } |
345 | |
346 | inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) |
347 | if (op->dtype == DataType::Int(32)) { |
348 | std::ostringstream temp; |
349 | temp << op->value; |
350 | p->MarkConst(temp.str()); |
351 | os << temp.str(); |
352 | } else { |
353 | os << "(" ; |
354 | p->PrintType(op->dtype, os); |
355 | os << ")" << op->value; |
356 | } |
357 | } |
358 | |
359 | inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, |
360 | CodeGenC* p) { // NOLINT(*) |
361 | if (dtype == DataType::UInt(32)) { |
362 | std::ostringstream temp; |
363 | temp << val << "U" ; |
364 | p->MarkConst(temp.str()); |
365 | os << temp.str(); |
366 | } else { |
367 | os << "(" ; |
368 | p->PrintType(dtype, os); |
369 | os << ")" << val; |
370 | } |
371 | } |
372 | |
373 | inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) |
374 | switch (op->dtype.bits()) { |
375 | case 64: |
376 | case 32: { |
377 | std::ostringstream temp; |
378 | temp << std::scientific << op->value; |
379 | if (op->dtype.bits() == 32) temp << 'f'; |
380 | p->MarkConst(temp.str()); |
381 | os << temp.str(); |
382 | break; |
383 | } |
384 | case 16: { |
385 | os << '('; |
386 | p->PrintType(op->dtype, os); |
387 | os << ')' << std::scientific << op->value << 'f'; |
388 | break; |
389 | } |
390 | default: |
391 | LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n" ; |
392 | } |
393 | } |
394 | |
395 | void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) |
396 | PrintConst(op, os, this); |
397 | } |
398 | |
399 | void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) |
400 | PrintConst(op, os, this); |
401 | } |
402 | void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) |
403 | os << "\"" << op->value << "\"" ; |
404 | } |
405 | |
406 | template <typename T> |
407 | inline void PrintBinaryExpr(const T* op, const char* opstr, |
408 | std::ostream& os, // NOLINT(*) |
409 | CodeGenC* p) { |
410 | if (op->dtype.lanes() == 1) { |
411 | if (isalpha(opstr[0])) { |
412 | os << opstr << '('; |
413 | p->PrintExpr(op->a, os); |
414 | os << ", " ; |
415 | p->PrintExpr(op->b, os); |
416 | os << ')'; |
417 | } else { |
418 | os << '('; |
419 | p->PrintExpr(op->a, os); |
420 | os << ' ' << opstr << ' '; |
421 | p->PrintExpr(op->b, os); |
422 | os << ')'; |
423 | } |
424 | } else { |
425 | p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); |
426 | } |
427 | } |
428 | |
429 | inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr, |
430 | std::ostream& os, // NOLINT(*) |
431 | CodeGenC* p) { |
432 | if (op->dtype.lanes() == 1) { |
433 | ICHECK_EQ(op->args.size(), 2U); |
434 | os << '('; |
435 | p->PrintExpr(op->args[0], os); |
436 | os << opstr; |
437 | p->PrintExpr(op->args[1], os); |
438 | os << ')'; |
439 | } else { |
440 | p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os); |
441 | } |
442 | } |
443 | void CodeGenC::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) |
444 | std::stringstream value; |
445 | this->PrintExpr(op->value, value); |
446 | os << CastFromTo(value.str(), op->value.dtype(), op->dtype); |
447 | } |
448 | void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*) |
449 | os << GetVarID(op); |
450 | } |
451 | void CodeGenC::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*) |
452 | PrintBinaryExpr(op, "+" , os, this); |
453 | } |
454 | void CodeGenC::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*) |
455 | PrintBinaryExpr(op, "-" , os, this); |
456 | } |
457 | void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*) |
458 | PrintBinaryExpr(op, "*" , os, this); |
459 | } |
460 | void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*) |
461 | PrintBinaryExpr(op, "/" , os, this); |
462 | } |
463 | void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) |
464 | if (op->dtype.is_int() || op->dtype.is_uint()) { |
465 | PrintBinaryExpr(op, "%" , os, this); |
466 | } else { |
467 | ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got " |
468 | << op->dtype; |
469 | if (op->dtype.bits() == 32) { |
470 | PrintBinaryExpr(op, "fmodf" , os, this); |
471 | } else if (op->dtype.bits() == 64) { |
472 | PrintBinaryExpr(op, "fmod" , os, this); |
473 | } else { |
474 | ICHECK(false) |
475 | << "Non single or double precision floating point in Mod, expected 32 or 64 bits but got " |
476 | << op->dtype.bits() << " bits." ; |
477 | } |
478 | } |
479 | } |
480 | void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) |
481 | PrintBinaryExpr(op, "min" , os, this); |
482 | } |
483 | void CodeGenC::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) |
484 | PrintBinaryExpr(op, "max" , os, this); |
485 | } |
486 | void CodeGenC::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*) |
487 | PrintBinaryExpr(op, "==" , os, this); |
488 | } |
489 | void CodeGenC::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*) |
490 | PrintBinaryExpr(op, "!=" , os, this); |
491 | } |
492 | void CodeGenC::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*) |
493 | PrintBinaryExpr(op, "<" , os, this); |
494 | } |
495 | void CodeGenC::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*) |
496 | PrintBinaryExpr(op, "<=" , os, this); |
497 | } |
498 | void CodeGenC::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*) |
499 | PrintBinaryExpr(op, ">" , os, this); |
500 | } |
501 | void CodeGenC::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*) |
502 | PrintBinaryExpr(op, ">=" , os, this); |
503 | } |
504 | void CodeGenC::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*) |
505 | PrintBinaryExpr(op, "&&" , os, this); |
506 | } |
507 | void CodeGenC::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*) |
508 | PrintBinaryExpr(op, "||" , os, this); |
509 | } |
510 | void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) |
511 | os << '!'; |
512 | PrintExpr(op->a, os); |
513 | } |
514 | |
515 | void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, |
516 | bool skip_first_arg, std::ostream& os) { // NOLINT(*) |
517 | os << global_symbol << "(" ; |
518 | for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) { |
519 | this->PrintExpr(args[i], os); |
520 | if (i < args.size() - 1) { |
521 | os << ", " ; |
522 | } |
523 | } |
524 | os << ")" ; |
525 | } |
526 | |
527 | void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) |
528 | if (auto* ptr_op = op->op.as<OpNode>()) { |
529 | auto call_op = GetRef<Op>(ptr_op); |
530 | |
531 | if (op->op.same_as(builtin::tvm_check_return())) { |
532 | const CallNode* call = op->args[2].as<CallNode>(); |
533 | os << "if (" ; |
534 | VisitExpr_(call, os); |
535 | os << " != " ; |
536 | PrintExpr(op->args[0], os); |
537 | os << " ) return " ; |
538 | PrintExpr(op->args[1], os); |
539 | } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { |
540 | ICHECK_GE(op->args.size(), 1U); |
541 | auto func = Downcast<StringImm>(op->args[0]); |
542 | this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os); |
543 | this->GenerateForwardFunctionDeclarations(func->value, op->args); |
544 | } else if (op_attr_global_symbol_.count(call_op)) { |
545 | // call extern if the op itself have a global symbol. |
546 | this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op], |
547 | op->args, false, os); |
548 | } else if (op->op.same_as(builtin::bitwise_and())) { |
549 | PrintBinaryIntrinsic(op, " & " , os, this); |
550 | } else if (op->op.same_as(builtin::large_uint_imm())) { |
551 | ICHECK_EQ(op->args.size(), 2U); |
552 | uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value); |
553 | uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value); |
554 | uint64_t val = (high << 32U) | low; |
555 | PrintUIntConst(op->dtype, val, os, this); |
556 | } else if (op->op.same_as(builtin::bitwise_xor())) { |
557 | PrintBinaryIntrinsic(op, " ^ " , os, this); |
558 | } else if (op->op.same_as(builtin::bitwise_or())) { |
559 | PrintBinaryIntrinsic(op, " | " , os, this); |
560 | } else if (op->op.same_as(builtin::bitwise_not())) { |
561 | ICHECK_EQ(op->args.size(), 1U); |
562 | os << "(~" ; |
563 | this->PrintExpr(op->args[0], os); |
564 | os << ')'; |
565 | } else if (op->op.same_as(builtin::shift_left())) { |
566 | PrintBinaryIntrinsic(op, " << " , os, this); |
567 | } else if (op->op.same_as(builtin::shift_right())) { |
568 | PrintBinaryIntrinsic(op, " >> " , os, this); |
569 | } else if (op->op.same_as(builtin::if_then_else())) { |
570 | os << "(" ; |
571 | PrintExpr(op->args[0], os); |
572 | os << " ? " ; |
573 | PrintExpr(op->args[1], os); |
574 | os << " : " ; |
575 | PrintExpr(op->args[2], os); |
576 | os << ")" ; |
577 | } else if (op->op.same_as(builtin::address_of())) { |
578 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
579 | ICHECK(op->args.size() == 1 && load); |
580 | ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations." ; |
581 | os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))" ; |
582 | } else if (op->op.same_as(builtin::tvm_struct_get())) { |
583 | ICHECK_EQ(op->args.size(), 3U); |
584 | os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value); |
585 | } else if (op->op.same_as(builtin::isnullptr())) { |
586 | ICHECK_EQ(op->args.size(), 1U); |
587 | os << "(" ; |
588 | this->PrintExpr(op->args[0], os); |
589 | os << " == NULL)" ; |
590 | } else if (op->op.same_as(builtin::reinterpret())) { |
591 | int ssa_scope = BeginScope(); |
592 | std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype); |
593 | os << "(*(" ; |
594 | this->PrintType(op->dtype, os); |
595 | os << " *)(&(" << rhs << ")))" ; |
596 | EndScope(ssa_scope); |
597 | } else if (op->op.same_as(builtin::isnan())) { |
598 | os << "(" ; |
599 | this->PrintExpr(op->args[0], os); |
600 | os << " != " ; |
601 | this->PrintExpr(op->args[0], os); |
602 | os << ")" ; |
603 | } else if (op->op.same_as(builtin::lookup_param())) { |
604 | ICHECK_EQ(op->args.size(), 1); |
605 | const StringImmNode* str = op->args[0].as<StringImmNode>(); |
606 | ICHECK(str != nullptr); |
607 | os << "__tvm_param__" << str->value; |
608 | } else { |
609 | LOG(FATAL) << "Unresolved call " << op->op; |
610 | } |
611 | } else { |
612 | ICHECK(op->op.as<GlobalVarNode>()); |
613 | LOG(FATAL) << "Do not yet support cross function call" ; |
614 | } |
615 | } |
616 | |
617 | void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, |
618 | std::ostream& os) { // NOLINT(*) |
619 | if (isalpha(op[0])) { |
620 | os << op << "(" ; |
621 | this->PrintExpr(lhs, os); |
622 | os << ", " ; |
623 | this->PrintExpr(rhs, os); |
624 | os << ")" ; |
625 | } else { |
626 | os << "(" ; |
627 | this->PrintExpr(lhs, os); |
628 | os << ' ' << op << ' '; |
629 | this->PrintExpr(rhs, os); |
630 | os << ")" ; |
631 | } |
632 | } |
633 | |
634 | void CodeGenC::VisitStmt_(const AllocateConstNode* op) { |
635 | std::string symbol_name = op->buffer_var->name_hint; |
636 | int64_t num_elements = 1; |
637 | const auto& data = op->data.value(); |
638 | |
639 | for (int64_t dim : data.Shape()) { |
640 | num_elements *= dim; |
641 | } |
642 | |
643 | decl_stream << "\n" |
644 | << "#ifdef __cplusplus\n" |
645 | << "extern \"C\" {\n" |
646 | << "#endif\n" |
647 | << "static const " ; |
648 | |
649 | PrintType(data.DataType(), decl_stream); |
650 | |
651 | // Allocate the global static variable |
652 | decl_stream << " __attribute__((section(\".rodata.tvm\"), " |
653 | << "aligned(" << constants_byte_alignment_->value << "))) " << symbol_name << "[" |
654 | << num_elements << "] = {\n" ; |
655 | NDArrayDataToC(data, 4, decl_stream); |
656 | |
657 | decl_stream << "};\n" |
658 | << "#ifdef __cplusplus\n" |
659 | << "} // extern \"C\"\n" |
660 | << "#endif\n" ; |
661 | var_idmap_[op->buffer_var.operator->()] = symbol_name; |
662 | this->PrintStmt(op->body); |
663 | } |
664 | |
665 | void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); } |
666 | |
667 | void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) |
668 | LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead." ; |
669 | } |
670 | |
671 | void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) |
672 | ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported." ; |
673 | |
674 | DataType value_dtype = op->dtype; |
675 | PrimExpr index = op->indices[0]; |
676 | Var buffer_var = op->buffer->data; |
677 | DataType element_dtype = op->buffer->dtype; |
678 | |
679 | int lanes = op->dtype.lanes(); |
680 | // delcare type. |
681 | if (value_dtype.lanes() == element_dtype.lanes()) { |
682 | std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); |
683 | HandleVolatileLoads(ref, op, os); |
684 | } else { |
685 | bool can_vector_load = false; |
686 | arith::PVar<PrimExpr> base; |
687 | if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { |
688 | const RampNode* ramp = index.as<RampNode>(); |
689 | ICHECK(ramp); |
690 | arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); |
691 | // The condition: {k * coeff + base} divisible by the alignment for any k |
692 | if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) { |
693 | can_vector_load = true; |
694 | } |
695 | } |
696 | |
697 | if (can_vector_load) { |
698 | std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); |
699 | HandleVolatileLoads(ref, op, os); |
700 | } else { |
701 | std::ostringstream svalue_expr; |
702 | std::string sindex = SSAGetID(PrintExpr(index), index.dtype()); |
703 | std::string vid = GetVarID(buffer_var.get()); |
704 | DataType elem_type = op->dtype.element_of(); |
705 | for (int i = 0; i < lanes; ++i) { |
706 | std::ostringstream value_temp; |
707 | if (!HandleTypeMatch(buffer_var.get(), elem_type)) { |
708 | value_temp << "((" ; |
709 | if (buffer_var.get()->dtype.is_handle()) { |
710 | auto it = alloc_storage_scope_.find(buffer_var.get()); |
711 | if (it != alloc_storage_scope_.end()) { |
712 | PrintStorageScope(it->second, value_temp); |
713 | } |
714 | } |
715 | PrintType(elem_type, value_temp); |
716 | value_temp << "*)" << vid << ')'; |
717 | } else { |
718 | value_temp << vid; |
719 | } |
720 | value_temp << '['; |
721 | PrintVecElemLoad(sindex, index.dtype(), i, value_temp); |
722 | value_temp << ']'; |
723 | PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr); |
724 | } |
725 | os << svalue_expr.str(); |
726 | } |
727 | } |
728 | } |
729 | |
730 | void CodeGenC::VisitStmt_(const StoreNode* op) { |
731 | LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead." ; |
732 | } |
733 | |
734 | void CodeGenC::VisitStmt_(const BufferStoreNode* op) { |
735 | ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported." ; |
736 | |
737 | DataType value_dtype = op->value.dtype(); |
738 | DataType element_dtype = op->buffer->dtype; |
739 | PrimExpr index_expr = op->indices[0]; |
740 | Var buffer_var = op->buffer->data; |
741 | |
742 | if (value_dtype.lanes() == element_dtype.lanes()) { |
743 | std::string value = this->PrintExpr(op->value); |
744 | std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); |
745 | this->PrintIndent(); |
746 | stream << ref << " = " << value << ";\n" ; |
747 | } else { |
748 | arith::PVar<PrimExpr> base; |
749 | |
750 | if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) { |
751 | std::string value = this->PrintExpr(op->value); |
752 | this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value); |
753 | } else { |
754 | // The assignment below introduces side-effect, and the resulting value cannot |
755 | // be reused across multiple expression, thus a new scope is needed |
756 | int vec_scope = BeginScope(); |
757 | |
758 | // store elements separately |
759 | std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype()); |
760 | std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype()); |
761 | std::string vid = GetVarID(buffer_var.get()); |
762 | for (int i = 0; i < value_dtype.lanes(); ++i) { |
763 | this->PrintIndent(); |
764 | DataType elem_type = value_dtype.element_of(); |
765 | if (!HandleTypeMatch(buffer_var.get(), elem_type)) { |
766 | stream << "((" ; |
767 | if (buffer_var.get()->dtype.is_handle()) { |
768 | auto it = alloc_storage_scope_.find(buffer_var.get()); |
769 | if (it != alloc_storage_scope_.end()) { |
770 | PrintStorageScope(it->second, stream); |
771 | } |
772 | } |
773 | PrintType(elem_type, stream); |
774 | stream << "*)" << vid << ')'; |
775 | } else { |
776 | stream << vid; |
777 | } |
778 | stream << '['; |
779 | PrintVecElemLoad(index, index_expr.dtype(), i, stream); |
780 | stream << "] = " ; |
781 | PrintVecElemLoad(value, op->value.dtype(), i, stream); |
782 | stream << ";\n" ; |
783 | } |
784 | EndScope(vec_scope); |
785 | } |
786 | } |
787 | } |
788 | |
789 | void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) |
790 | auto it = let_binding_.find(op->var); |
791 | if (it != let_binding_.end()) { |
792 | ICHECK(deep_equal_(it->second->value, op->value)) |
793 | << "Let cannot bind the same var to two different values" ; |
794 | } else { |
795 | let_binding_[op->var] = op; |
796 | } |
797 | std::string value = PrintExpr(op->value); |
798 | var_idmap_[op->var.get()] = value; |
799 | os << PrintExpr(op->body); |
800 | } |
801 | |
802 | void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) |
803 | // constraint of current logic |
804 | ICHECK_EQ(op->base.dtype(), DataType::Int(32)); |
805 | os << "((int" << op->lanes << ")(" ; |
806 | for (int i = 0; i < op->lanes; i++) { |
807 | os << "(" << PrintExpr(op->base) << ")" |
808 | << "+(" << PrintExpr(op->stride) << "*" << i << ")" ; |
809 | if (i != op->lanes - 1) os << ", " ; |
810 | } |
811 | os << "))" ; |
812 | } |
813 | |
814 | void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { |
815 | LOG(FATAL) << "Shuffle: not supported " ; |
816 | } |
817 | |
818 | void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) |
819 | LOG(FATAL) << "Broadcast: not supported " ; |
820 | } |
821 | |
822 | void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) |
823 | os << "(" ; |
824 | PrintExpr(op->condition, os); |
825 | os << " ? " ; |
826 | PrintExpr(op->true_value, os); |
827 | os << " : " ; |
828 | PrintExpr(op->false_value, os); |
829 | os << ")" ; |
830 | } |
831 | |
832 | void CodeGenC::VisitStmt_(const LetStmtNode* op) { |
833 | std::string value = PrintExpr(op->value); |
834 | if (print_ssa_form_) { |
835 | ICHECK(!var_idmap_.count(op->var.get())); |
836 | var_idmap_[op->var.get()] = value; |
837 | } else { |
838 | PrintIndent(); |
839 | if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) { |
840 | PrintType(handle_data_type_.at(op->var.get()), stream); |
841 | stream << "* " << AllocVarID(op->var.get()) << " = (" ; |
842 | PrintType(handle_data_type_.at(op->var.get()), stream); |
843 | stream << "*)" << value << ";\n" ; |
844 | } else { |
845 | PrintType(op->var.dtype(), this->stream); |
846 | this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n" ; |
847 | } |
848 | } |
849 | PrintStmt(op->body); |
850 | } |
851 | |
852 | void CodeGenC::VisitStmt_(const AllocateNode* op) { |
853 | ICHECK(!is_zero(op->condition)); |
854 | std::string vid = AllocVarID(op->buffer_var.get()); |
855 | |
856 | this->PrintIndent(); |
857 | size_t constant_size = op->ConstantAllocationSize(); |
858 | ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now" ; |
859 | |
860 | auto scope = GetPtrStorageScope(op->buffer_var); |
861 | alloc_storage_scope_[op->buffer_var.get()] = scope; |
862 | PrintStorageScope(scope, stream); |
863 | |
864 | PrintType(op->dtype, stream); |
865 | stream << ' ' << vid << '[' << constant_size << "];\n" ; |
866 | |
867 | RegisterHandleType(op->buffer_var.get(), op->dtype); |
868 | this->PrintStmt(op->body); |
869 | } |
870 | |
871 | void CodeGenC::VisitStmt_(const AttrStmtNode* op) { |
872 | if (op->attr_key == tir::attr::thread_extent) { |
873 | IterVar iv = Downcast<IterVar>(op->node); |
874 | if (iv->thread_tag.length() != 0) { |
875 | if (!var_idmap_.count(iv->var.get())) { |
876 | BindThreadIndex(iv); |
877 | } |
878 | } |
879 | } else if (op->attr_key == tir::attr::volatile_scope) { |
880 | const VarNode* v = op->node.as<VarNode>(); |
881 | ICHECK(v); |
882 | volatile_buf_.insert(v); |
883 | } else if (op->attr_key == tir::attr::pragma_import_c) { |
884 | const StringImmNode* value = op->value.as<StringImmNode>(); |
885 | ICHECK(value != nullptr); |
886 | decl_stream << value->value; |
887 | } |
888 | this->PrintStmt(op->body); |
889 | } |
890 | |
891 | void CodeGenC::VisitStmt_(const AssertStmtNode* op) { |
892 | std::string cond = PrintExpr(op->condition); |
893 | PrintIndent(); |
894 | if (const auto* str = op->message.as<StringImmNode>()) { |
895 | // GLOG style check |
896 | stream << "ICHECK(" << cond << ") << \"" << str->value << "\";\n" ; |
897 | } else { |
898 | stream << "assert(" << cond << ");\n" ; |
899 | } |
900 | this->PrintStmt(op->body); |
901 | } |
902 | |
903 | void CodeGenC::VisitStmt_(const ForNode* op) { |
904 | std::string extent = PrintExpr(op->extent); |
905 | PrintIndent(); |
906 | std::string vid = AllocVarID(op->loop_var.get()); |
907 | ICHECK(is_zero(op->min)); |
908 | stream << "for (" ; |
909 | PrintType(op->loop_var.dtype(), stream); |
910 | stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n" ; |
911 | int for_scope = BeginScope(); |
912 | PrintStmt(op->body); |
913 | this->EndScope(for_scope); |
914 | PrintIndent(); |
915 | stream << "}\n" ; |
916 | } |
917 | |
918 | void CodeGenC::VisitStmt_(const WhileNode* op) { |
919 | PrintIndent(); |
920 | stream << "while (" << PrintExpr(op->condition) << ") {\n" ; |
921 | int while_scope = BeginScope(); |
922 | PrintStmt(op->body); |
923 | this->EndScope(while_scope); |
924 | PrintIndent(); |
925 | stream << "}\n" ; |
926 | } |
927 | |
928 | void CodeGenC::VisitStmt_(const IfThenElseNode* op) { |
929 | std::string cond = PrintExpr(op->condition); |
930 | PrintIndent(); |
931 | if (cond[0] == '(' && cond[cond.length() - 1] == ')') { |
932 | stream << "if " << cond << " {\n" ; |
933 | } else { |
934 | stream << "if (" << cond << ") {\n" ; |
935 | } |
936 | int then_scope = BeginScope(); |
937 | PrintStmt(op->then_case); |
938 | this->EndScope(then_scope); |
939 | |
940 | if (op->else_case) { |
941 | PrintIndent(); |
942 | stream << "} else {\n" ; |
943 | int else_scope = BeginScope(); |
944 | PrintStmt(op->else_case.value()); |
945 | this->EndScope(else_scope); |
946 | } |
947 | PrintIndent(); |
948 | stream << "}\n" ; |
949 | } |
950 | |
951 | void CodeGenC::VisitStmt_(const SeqStmtNode* op) { |
952 | for (Stmt stmt : op->seq) { |
953 | PrintStmt(stmt); |
954 | } |
955 | } |
956 | |
957 | void CodeGenC::VisitStmt_(const EvaluateNode* op) { |
958 | if (is_const_int(op->value)) return; |
959 | const CallNode* call = op->value.as<CallNode>(); |
960 | if (call) { |
961 | if (call->op.same_as(builtin::tvm_storage_sync())) { |
962 | this->PrintStorageSync(call); |
963 | return; |
964 | } else if (call->op.same_as(builtin::tvm_struct_set())) { |
965 | ICHECK_EQ(call->args.size(), 4); |
966 | int kind = call->args[2].as<IntImmNode>()->value; |
967 | std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], kind); |
968 | std::string value = PrintExpr(call->args[3]); |
969 | std::string cast; |
970 | if (kind == builtin::kArrStrides) { |
971 | // cast void* to int64_t* |
972 | cast = call->args[3]->dtype.is_handle() ? "(int64_t*)" : "" ; |
973 | } else if (kind == builtin::kArrDeviceType) { |
974 | // cast int to enum |
975 | cast = "(DLDeviceType)" ; |
976 | } |
977 | this->PrintIndent(); |
978 | this->stream << ref << " = " << cast << value << ";\n" ; |
979 | return; |
980 | } |
981 | } |
982 | std::string vid = this->PrintExpr(op->value); |
983 | if (vid != "" ) { |
984 | this->PrintIndent(); |
985 | this->stream << vid << ";\n" ; |
986 | } |
987 | } |
988 | |
989 | void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) { |
990 | ICHECK_GT(t.lanes(), 1); |
991 | if (t.bits() == 8 && (t.is_int() || t.is_uint())) { |
992 | if (i != 0) { |
993 | os << "|" ; |
994 | } |
995 | os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))" ; |
996 | return; |
997 | } |
998 | |
999 | if (i == 0) { |
1000 | os << "((" ; |
1001 | PrintType(t, os); |
1002 | os << ")(" ; |
1003 | } |
1004 | os << value; |
1005 | if (i != t.lanes() - 1) { |
1006 | os << "," ; |
1007 | } else { |
1008 | os << "))" ; |
1009 | } |
1010 | return; |
1011 | } |
1012 | |
1013 | void CodeGenC::PrintRestrict(const Var& v, std::ostream& os) { |
1014 | if (restrict_keyword_.length() != 0) { |
1015 | os << ' ' << restrict_keyword_; |
1016 | } |
1017 | } |
1018 | |
1019 | static bool CheckOutermostBracketMatch(const std::string& s) { |
1020 | if (!s.empty() && s.front() == '(' && s.back() == ')') { |
1021 | size_t len = s.size(); |
1022 | int n_unmatched = 0; |
1023 | for (size_t i = 0; i < len; ++i) { |
1024 | if (s[i] == '(') { |
1025 | n_unmatched++; |
1026 | } else if (s[i] == ')') { |
1027 | n_unmatched--; |
1028 | } |
1029 | if (n_unmatched == 0) { |
1030 | return i == len - 1; |
1031 | } |
1032 | } |
1033 | } |
1034 | return false; |
1035 | } |
1036 | |
1037 | } // namespace codegen |
1038 | } // namespace tvm |
1039 | |