1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_c.cc
22 */
23#include "codegen_c.h"
24
25#include <tvm/arith/analyzer.h>
26
27#include <cctype>
28#include <iomanip>
29
30#include "../../arith/pattern_match.h"
31#include "codegen_params.h"
32
33namespace tvm {
34namespace codegen {
35
36using namespace tir;
37
38void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; }
39
40void CodeGenC::InitFuncState(const PrimFunc& f) {
41 alloc_storage_scope_.clear();
42 handle_data_type_.clear();
43 CodeGenSourceBase::ClearFuncState();
44}
45
46void CodeGenC::ReserveKeywordsAsUnique() {
47 // skip the first underscore, so SSA variable starts from _1
48 name_supply_->ReserveName("_");
49 name_supply_->ReserveName("extern");
50 name_supply_->ReserveName("void");
51 name_supply_->ReserveName("int");
52 name_supply_->ReserveName("float");
53 name_supply_->ReserveName("double");
54 name_supply_->ReserveName("char");
55 name_supply_->ReserveName("unsigned");
56 name_supply_->ReserveName("short");
57 name_supply_->ReserveName("long");
58 name_supply_->ReserveName("if");
59 name_supply_->ReserveName("else");
60 name_supply_->ReserveName("switch");
61 name_supply_->ReserveName("case");
62 name_supply_->ReserveName("default");
63 name_supply_->ReserveName("for");
64 name_supply_->ReserveName("do");
65 name_supply_->ReserveName("while");
66 name_supply_->ReserveName("goto");
67 name_supply_->ReserveName("register");
68 name_supply_->ReserveName("continue");
69 name_supply_->ReserveName("break");
70 name_supply_->ReserveName("typedef");
71 name_supply_->ReserveName("struct");
72 name_supply_->ReserveName("enum");
73 name_supply_->ReserveName("union");
74 name_supply_->ReserveName("return");
75}
76
77void CodeGenC::AddFunction(const PrimFunc& f) {
78 // clear previous generated state.
79 this->InitFuncState(f);
80 // reserve keywords
81 ReserveKeywordsAsUnique();
82
83 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
84 ICHECK(global_symbol.defined())
85 << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
86 bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
87
88 this->PrintFuncPrefix(stream);
89 this->PrintExtraAttrs(f);
90 this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
91
92 for (size_t i = 0; i < f->params.size(); ++i) {
93 tir::Var v = f->params[i];
94 std::string vid = AllocVarID(v.get());
95 if (i != 0) stream << ", ";
96 if (v.dtype().is_handle()) {
97 auto it = alloc_storage_scope_.find(v.get());
98 if (it != alloc_storage_scope_.end()) {
99 PrintStorageScope(it->second, stream);
100 }
101
102 PrintType(GetType(v), stream);
103 // Register handle data type
104 // TODO(tvm-team): consider simply keep type info in the
105 // type annotation(via a normalizing rewriting).
106 if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
107 if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
108 RegisterHandleType(v.get(), prim->dtype);
109 }
110 }
111
112 if (no_alias) {
113 PrintRestrict(v, stream);
114 }
115 } else {
116 PrintType(GetType(v), stream);
117 }
118 stream << ' ' << vid;
119 }
120 stream << ") {\n";
121 this->PreFunctionBody(f);
122 int func_scope = this->BeginScope();
123 this->PrintStmt(f->body);
124 this->PrintFinalReturn();
125 this->EndScope(func_scope);
126 this->PrintIndent();
127 this->stream << "}\n\n";
128}
129
130void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }
131
132void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
133
134void CodeGenC::PrintFinalReturn() {}
135
136std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
137
138void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
139 if (print_ssa_form_) {
140 std::ostringstream temp;
141 VisitExpr(n, temp);
142 os << SSAGetID(temp.str(), n.dtype());
143 } else {
144 VisitExpr(n, os);
145 }
146}
147
148static bool CheckOutermostBracketMatch(const std::string& s);
149
150void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) {
151 PrintType(t, stream);
152 stream << ' ' << target << " = ";
153 if (CheckOutermostBracketMatch(src)) {
154 stream << src.substr(1, src.length() - 2);
155 } else {
156 stream << src;
157 }
158 stream << ";\n";
159}
160
161// Print a reference expression to a buffer.
162std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
163 const VarNode* buffer_var = buffer->data.get();
164 std::ostringstream os;
165 std::string vid = GetVarID(buffer_var);
166 std::string scope;
167 if (alloc_storage_scope_.count(buffer_var)) {
168 scope = alloc_storage_scope_.at(buffer_var);
169 }
170 bool is_vol = IsVolatile(buffer_var);
171
172 auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
173 std::ostringstream ptr_os;
174 ptr_os << "(";
175 if (is_vol) {
176 ptr_os << "volatile ";
177 }
178 if (!scope.empty() && IsScopePartOfType()) {
179 PrintStorageScope(scope, ptr_os);
180 }
181 PrintType(pointed_to, ptr_os);
182 ptr_os << "*)";
183 return ptr_os.str();
184 };
185
186 DataType buffer_element_dtype = buffer->dtype;
187
188 std::string buffer_str = vid;
189 if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
190 std::stringstream temp;
191 temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
192 buffer_str = temp.str();
193 }
194
195 std::string index_str = PrintExpr(index);
196 if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
197 // This is a special case, because CodegenCUDA::PrintType()
198 // returns "int" for bool and for 4-bit integers. In most cases,
199 // we divide by the number of lanes to determine the index.
200 // However, the backing type for scalar int4 and scalar bool is
201 // int32. Therefore, we need to divide by the ratio of their
202 // sizes in that case.
203 int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();
204
205 os << "*("
206 << "(" << ptr_cast(t) << vid << ")"
207 << " + " << index_str << " / " << div_factor << ")";
208 } else if (t == buffer_element_dtype) {
209 os << buffer_str << "[" << index_str << "]";
210 } else {
211 os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
212 }
213
214 return os.str();
215}
216
217// Print a reference expression to a buffer.
218std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index,
219 int kind) {
220 if (kind < builtin::kArrKindBound_) {
221 std::ostringstream os;
222 os << "(((DLTensor*)";
223 this->PrintExpr(buffer, os);
224 os << ")";
225 if (kind == builtin::kArrAddr) {
226 os << " + ";
227 this->PrintExpr(index, os);
228 os << ")";
229 return os.str();
230 }
231 os << '[';
232 this->PrintExpr(index, os);
233 os << "].";
234 // other case: get fields.
235 switch (kind) {
236 case builtin::kArrData:
237 os << "data";
238 break;
239 case builtin::kArrShape:
240 os << "shape";
241 break;
242 case builtin::kArrStrides:
243 os << "strides";
244 break;
245 case builtin::kArrNDim:
246 os << "ndim";
247 break;
248 case builtin::kArrTypeCode:
249 os << "dtype.code";
250 break;
251 case builtin::kArrTypeBits:
252 os << "dtype.bits";
253 break;
254 case builtin::kArrByteOffset:
255 os << "byte_offset";
256 break;
257 case builtin::kArrTypeLanes:
258 os << "dtype.lanes";
259 break;
260 case builtin::kArrDeviceId:
261 os << "device.device_id";
262 break;
263 case builtin::kArrDeviceType:
264 os << "device.device_type";
265 break;
266 default:
267 LOG(FATAL) << "unknown field code";
268 }
269 os << ')';
270 return os.str();
271 } else {
272 ICHECK_LT(kind, builtin::kTVMValueKindBound_);
273 std::ostringstream os;
274 os << "(((TVMValue*)";
275 this->PrintExpr(buffer, os);
276 os << ")[" << index << "].";
277 if (t.is_handle()) {
278 os << "v_handle";
279 } else if (t.is_float()) {
280 os << "v_float64";
281 } else if (t.is_int()) {
282 os << "v_int64";
283 } else {
284 LOG(FATAL) << "Do not know how to handle type" << t;
285 }
286 os << ")";
287 return os.str();
288 }
289}
290
291bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const {
292 auto it = handle_data_type_.find(buf_var);
293 if (it == handle_data_type_.end()) return false;
294 return it->second == t;
295}
296
297void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) {
298 auto it = handle_data_type_.find(buf_var);
299 if (it == handle_data_type_.end()) {
300 handle_data_type_[buf_var] = t;
301 } else {
302 ICHECK(it->second == t) << "conflicting buf var type";
303 }
304}
305
306void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i,
307 std::ostream& os) { // NOLINT(*)
308 os << vec << ".s" << std::hex << i << std::dec;
309}
310
311void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i,
312 const std::string& value) {
313 this->PrintIndent();
314 stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec;
315}
316
317std::string CodeGenC::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) {
318 return GetBufferRef(t, buffer, base);
319}
320
321void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
322 const std::string& value) {
323 std::string ref = GetBufferRef(t, buffer, base);
324 this->PrintIndent();
325 stream << ref << " = " << value << ";\n";
326}
327
328std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) {
329 if (from == target) return value;
330 std::ostringstream os;
331 os << "((";
332 this->PrintType(target, os);
333 os << ")" << value << ")";
334 return os.str();
335}
336
337void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; }
338
339void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*)
340}
341
342void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
343 ICHECK_EQ(scope, "global");
344}
345
346inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
347 if (op->dtype == DataType::Int(32)) {
348 std::ostringstream temp;
349 temp << op->value;
350 p->MarkConst(temp.str());
351 os << temp.str();
352 } else {
353 os << "(";
354 p->PrintType(op->dtype, os);
355 os << ")" << op->value;
356 }
357}
358
359inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os,
360 CodeGenC* p) { // NOLINT(*)
361 if (dtype == DataType::UInt(32)) {
362 std::ostringstream temp;
363 temp << val << "U";
364 p->MarkConst(temp.str());
365 os << temp.str();
366 } else {
367 os << "(";
368 p->PrintType(dtype, os);
369 os << ")" << val;
370 }
371}
372
373inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
374 switch (op->dtype.bits()) {
375 case 64:
376 case 32: {
377 std::ostringstream temp;
378 temp << std::scientific << op->value;
379 if (op->dtype.bits() == 32) temp << 'f';
380 p->MarkConst(temp.str());
381 os << temp.str();
382 break;
383 }
384 case 16: {
385 os << '(';
386 p->PrintType(op->dtype, os);
387 os << ')' << std::scientific << op->value << 'f';
388 break;
389 }
390 default:
391 LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
392 }
393}
394
395void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
396 PrintConst(op, os, this);
397}
398
399void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
400 PrintConst(op, os, this);
401}
402void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
403 os << "\"" << op->value << "\"";
404}
405
406template <typename T>
407inline void PrintBinaryExpr(const T* op, const char* opstr,
408 std::ostream& os, // NOLINT(*)
409 CodeGenC* p) {
410 if (op->dtype.lanes() == 1) {
411 if (isalpha(opstr[0])) {
412 os << opstr << '(';
413 p->PrintExpr(op->a, os);
414 os << ", ";
415 p->PrintExpr(op->b, os);
416 os << ')';
417 } else {
418 os << '(';
419 p->PrintExpr(op->a, os);
420 os << ' ' << opstr << ' ';
421 p->PrintExpr(op->b, os);
422 os << ')';
423 }
424 } else {
425 p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
426 }
427}
428
429inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr,
430 std::ostream& os, // NOLINT(*)
431 CodeGenC* p) {
432 if (op->dtype.lanes() == 1) {
433 ICHECK_EQ(op->args.size(), 2U);
434 os << '(';
435 p->PrintExpr(op->args[0], os);
436 os << opstr;
437 p->PrintExpr(op->args[1], os);
438 os << ')';
439 } else {
440 p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os);
441 }
442}
443void CodeGenC::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*)
444 std::stringstream value;
445 this->PrintExpr(op->value, value);
446 os << CastFromTo(value.str(), op->value.dtype(), op->dtype);
447}
448void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*)
449 os << GetVarID(op);
450}
451void CodeGenC::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*)
452 PrintBinaryExpr(op, "+", os, this);
453}
454void CodeGenC::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*)
455 PrintBinaryExpr(op, "-", os, this);
456}
457void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
458 PrintBinaryExpr(op, "*", os, this);
459}
460void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
461 PrintBinaryExpr(op, "/", os, this);
462}
463void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
464 if (op->dtype.is_int() || op->dtype.is_uint()) {
465 PrintBinaryExpr(op, "%", os, this);
466 } else {
467 ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got "
468 << op->dtype;
469 if (op->dtype.bits() == 32) {
470 PrintBinaryExpr(op, "fmodf", os, this);
471 } else if (op->dtype.bits() == 64) {
472 PrintBinaryExpr(op, "fmod", os, this);
473 } else {
474 ICHECK(false)
475 << "Non single or double precision floating point in Mod, expected 32 or 64 bits but got "
476 << op->dtype.bits() << " bits.";
477 }
478 }
479}
480void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
481 PrintBinaryExpr(op, "min", os, this);
482}
483void CodeGenC::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
484 PrintBinaryExpr(op, "max", os, this);
485}
486void CodeGenC::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*)
487 PrintBinaryExpr(op, "==", os, this);
488}
489void CodeGenC::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*)
490 PrintBinaryExpr(op, "!=", os, this);
491}
492void CodeGenC::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*)
493 PrintBinaryExpr(op, "<", os, this);
494}
495void CodeGenC::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*)
496 PrintBinaryExpr(op, "<=", os, this);
497}
498void CodeGenC::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*)
499 PrintBinaryExpr(op, ">", os, this);
500}
501void CodeGenC::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*)
502 PrintBinaryExpr(op, ">=", os, this);
503}
504void CodeGenC::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*)
505 PrintBinaryExpr(op, "&&", os, this);
506}
507void CodeGenC::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*)
508 PrintBinaryExpr(op, "||", os, this);
509}
510void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*)
511 os << '!';
512 PrintExpr(op->a, os);
513}
514
515void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
516 bool skip_first_arg, std::ostream& os) { // NOLINT(*)
517 os << global_symbol << "(";
518 for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
519 this->PrintExpr(args[i], os);
520 if (i < args.size() - 1) {
521 os << ", ";
522 }
523 }
524 os << ")";
525}
526
527void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
528 if (auto* ptr_op = op->op.as<OpNode>()) {
529 auto call_op = GetRef<Op>(ptr_op);
530
531 if (op->op.same_as(builtin::tvm_check_return())) {
532 const CallNode* call = op->args[2].as<CallNode>();
533 os << "if (";
534 VisitExpr_(call, os);
535 os << " != ";
536 PrintExpr(op->args[0], os);
537 os << " ) return ";
538 PrintExpr(op->args[1], os);
539 } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
540 ICHECK_GE(op->args.size(), 1U);
541 auto func = Downcast<StringImm>(op->args[0]);
542 this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
543 this->GenerateForwardFunctionDeclarations(func->value, op->args);
544 } else if (op_attr_global_symbol_.count(call_op)) {
545 // call extern if the op itself have a global symbol.
546 this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
547 op->args, false, os);
548 } else if (op->op.same_as(builtin::bitwise_and())) {
549 PrintBinaryIntrinsic(op, " & ", os, this);
550 } else if (op->op.same_as(builtin::large_uint_imm())) {
551 ICHECK_EQ(op->args.size(), 2U);
552 uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
553 uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
554 uint64_t val = (high << 32U) | low;
555 PrintUIntConst(op->dtype, val, os, this);
556 } else if (op->op.same_as(builtin::bitwise_xor())) {
557 PrintBinaryIntrinsic(op, " ^ ", os, this);
558 } else if (op->op.same_as(builtin::bitwise_or())) {
559 PrintBinaryIntrinsic(op, " | ", os, this);
560 } else if (op->op.same_as(builtin::bitwise_not())) {
561 ICHECK_EQ(op->args.size(), 1U);
562 os << "(~";
563 this->PrintExpr(op->args[0], os);
564 os << ')';
565 } else if (op->op.same_as(builtin::shift_left())) {
566 PrintBinaryIntrinsic(op, " << ", os, this);
567 } else if (op->op.same_as(builtin::shift_right())) {
568 PrintBinaryIntrinsic(op, " >> ", os, this);
569 } else if (op->op.same_as(builtin::if_then_else())) {
570 os << "(";
571 PrintExpr(op->args[0], os);
572 os << " ? ";
573 PrintExpr(op->args[1], os);
574 os << " : ";
575 PrintExpr(op->args[2], os);
576 os << ")";
577 } else if (op->op.same_as(builtin::address_of())) {
578 const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
579 ICHECK(op->args.size() == 1 && load);
580 ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations.";
581 os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))";
582 } else if (op->op.same_as(builtin::tvm_struct_get())) {
583 ICHECK_EQ(op->args.size(), 3U);
584 os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value);
585 } else if (op->op.same_as(builtin::isnullptr())) {
586 ICHECK_EQ(op->args.size(), 1U);
587 os << "(";
588 this->PrintExpr(op->args[0], os);
589 os << " == NULL)";
590 } else if (op->op.same_as(builtin::reinterpret())) {
591 int ssa_scope = BeginScope();
592 std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
593 os << "(*(";
594 this->PrintType(op->dtype, os);
595 os << " *)(&(" << rhs << ")))";
596 EndScope(ssa_scope);
597 } else if (op->op.same_as(builtin::isnan())) {
598 os << "(";
599 this->PrintExpr(op->args[0], os);
600 os << " != ";
601 this->PrintExpr(op->args[0], os);
602 os << ")";
603 } else if (op->op.same_as(builtin::lookup_param())) {
604 ICHECK_EQ(op->args.size(), 1);
605 const StringImmNode* str = op->args[0].as<StringImmNode>();
606 ICHECK(str != nullptr);
607 os << "__tvm_param__" << str->value;
608 } else {
609 LOG(FATAL) << "Unresolved call " << op->op;
610 }
611 } else {
612 ICHECK(op->op.as<GlobalVarNode>());
613 LOG(FATAL) << "Do not yet support cross function call";
614 }
615}
616
617void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
618 std::ostream& os) { // NOLINT(*)
619 if (isalpha(op[0])) {
620 os << op << "(";
621 this->PrintExpr(lhs, os);
622 os << ", ";
623 this->PrintExpr(rhs, os);
624 os << ")";
625 } else {
626 os << "(";
627 this->PrintExpr(lhs, os);
628 os << ' ' << op << ' ';
629 this->PrintExpr(rhs, os);
630 os << ")";
631 }
632}
633
634void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
635 std::string symbol_name = op->buffer_var->name_hint;
636 int64_t num_elements = 1;
637 const auto& data = op->data.value();
638
639 for (int64_t dim : data.Shape()) {
640 num_elements *= dim;
641 }
642
643 decl_stream << "\n"
644 << "#ifdef __cplusplus\n"
645 << "extern \"C\" {\n"
646 << "#endif\n"
647 << "static const ";
648
649 PrintType(data.DataType(), decl_stream);
650
651 // Allocate the global static variable
652 decl_stream << " __attribute__((section(\".rodata.tvm\"), "
653 << "aligned(" << constants_byte_alignment_->value << "))) " << symbol_name << "["
654 << num_elements << "] = {\n";
655 NDArrayDataToC(data, 4, decl_stream);
656
657 decl_stream << "};\n"
658 << "#ifdef __cplusplus\n"
659 << "} // extern \"C\"\n"
660 << "#endif\n";
661 var_idmap_[op->buffer_var.operator->()] = symbol_name;
662 this->PrintStmt(op->body);
663}
664
665void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); }
666
667void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
668 LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead.";
669}
670
671void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*)
672 ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported.";
673
674 DataType value_dtype = op->dtype;
675 PrimExpr index = op->indices[0];
676 Var buffer_var = op->buffer->data;
677 DataType element_dtype = op->buffer->dtype;
678
679 int lanes = op->dtype.lanes();
680 // delcare type.
681 if (value_dtype.lanes() == element_dtype.lanes()) {
682 std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
683 HandleVolatileLoads(ref, op, os);
684 } else {
685 bool can_vector_load = false;
686 arith::PVar<PrimExpr> base;
687 if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
688 const RampNode* ramp = index.as<RampNode>();
689 ICHECK(ramp);
690 arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
691 // The condition: {k * coeff + base} divisible by the alignment for any k
692 if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) {
693 can_vector_load = true;
694 }
695 }
696
697 if (can_vector_load) {
698 std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
699 HandleVolatileLoads(ref, op, os);
700 } else {
701 std::ostringstream svalue_expr;
702 std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
703 std::string vid = GetVarID(buffer_var.get());
704 DataType elem_type = op->dtype.element_of();
705 for (int i = 0; i < lanes; ++i) {
706 std::ostringstream value_temp;
707 if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
708 value_temp << "((";
709 if (buffer_var.get()->dtype.is_handle()) {
710 auto it = alloc_storage_scope_.find(buffer_var.get());
711 if (it != alloc_storage_scope_.end()) {
712 PrintStorageScope(it->second, value_temp);
713 }
714 }
715 PrintType(elem_type, value_temp);
716 value_temp << "*)" << vid << ')';
717 } else {
718 value_temp << vid;
719 }
720 value_temp << '[';
721 PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
722 value_temp << ']';
723 PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
724 }
725 os << svalue_expr.str();
726 }
727 }
728}
729
730void CodeGenC::VisitStmt_(const StoreNode* op) {
731 LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead.";
732}
733
734void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
735 ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
736
737 DataType value_dtype = op->value.dtype();
738 DataType element_dtype = op->buffer->dtype;
739 PrimExpr index_expr = op->indices[0];
740 Var buffer_var = op->buffer->data;
741
742 if (value_dtype.lanes() == element_dtype.lanes()) {
743 std::string value = this->PrintExpr(op->value);
744 std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
745 this->PrintIndent();
746 stream << ref << " = " << value << ";\n";
747 } else {
748 arith::PVar<PrimExpr> base;
749
750 if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
751 std::string value = this->PrintExpr(op->value);
752 this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
753 } else {
754 // The assignment below introduces side-effect, and the resulting value cannot
755 // be reused across multiple expression, thus a new scope is needed
756 int vec_scope = BeginScope();
757
758 // store elements separately
759 std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype());
760 std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
761 std::string vid = GetVarID(buffer_var.get());
762 for (int i = 0; i < value_dtype.lanes(); ++i) {
763 this->PrintIndent();
764 DataType elem_type = value_dtype.element_of();
765 if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
766 stream << "((";
767 if (buffer_var.get()->dtype.is_handle()) {
768 auto it = alloc_storage_scope_.find(buffer_var.get());
769 if (it != alloc_storage_scope_.end()) {
770 PrintStorageScope(it->second, stream);
771 }
772 }
773 PrintType(elem_type, stream);
774 stream << "*)" << vid << ')';
775 } else {
776 stream << vid;
777 }
778 stream << '[';
779 PrintVecElemLoad(index, index_expr.dtype(), i, stream);
780 stream << "] = ";
781 PrintVecElemLoad(value, op->value.dtype(), i, stream);
782 stream << ";\n";
783 }
784 EndScope(vec_scope);
785 }
786 }
787}
788
789void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
790 auto it = let_binding_.find(op->var);
791 if (it != let_binding_.end()) {
792 ICHECK(deep_equal_(it->second->value, op->value))
793 << "Let cannot bind the same var to two different values";
794 } else {
795 let_binding_[op->var] = op;
796 }
797 std::string value = PrintExpr(op->value);
798 var_idmap_[op->var.get()] = value;
799 os << PrintExpr(op->body);
800}
801
802void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
803 // constraint of current logic
804 ICHECK_EQ(op->base.dtype(), DataType::Int(32));
805 os << "((int" << op->lanes << ")(";
806 for (int i = 0; i < op->lanes; i++) {
807 os << "(" << PrintExpr(op->base) << ")"
808 << "+(" << PrintExpr(op->stride) << "*" << i << ")";
809 if (i != op->lanes - 1) os << ", ";
810 }
811 os << "))";
812}
813
814void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
815 LOG(FATAL) << "Shuffle: not supported ";
816}
817
818void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
819 LOG(FATAL) << "Broadcast: not supported ";
820}
821
822void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
823 os << "(";
824 PrintExpr(op->condition, os);
825 os << " ? ";
826 PrintExpr(op->true_value, os);
827 os << " : ";
828 PrintExpr(op->false_value, os);
829 os << ")";
830}
831
832void CodeGenC::VisitStmt_(const LetStmtNode* op) {
833 std::string value = PrintExpr(op->value);
834 if (print_ssa_form_) {
835 ICHECK(!var_idmap_.count(op->var.get()));
836 var_idmap_[op->var.get()] = value;
837 } else {
838 PrintIndent();
839 if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) {
840 PrintType(handle_data_type_.at(op->var.get()), stream);
841 stream << "* " << AllocVarID(op->var.get()) << " = (";
842 PrintType(handle_data_type_.at(op->var.get()), stream);
843 stream << "*)" << value << ";\n";
844 } else {
845 PrintType(op->var.dtype(), this->stream);
846 this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n";
847 }
848 }
849 PrintStmt(op->body);
850}
851
852void CodeGenC::VisitStmt_(const AllocateNode* op) {
853 ICHECK(!is_zero(op->condition));
854 std::string vid = AllocVarID(op->buffer_var.get());
855
856 this->PrintIndent();
857 size_t constant_size = op->ConstantAllocationSize();
858 ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
859
860 auto scope = GetPtrStorageScope(op->buffer_var);
861 alloc_storage_scope_[op->buffer_var.get()] = scope;
862 PrintStorageScope(scope, stream);
863
864 PrintType(op->dtype, stream);
865 stream << ' ' << vid << '[' << constant_size << "];\n";
866
867 RegisterHandleType(op->buffer_var.get(), op->dtype);
868 this->PrintStmt(op->body);
869}
870
871void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
872 if (op->attr_key == tir::attr::thread_extent) {
873 IterVar iv = Downcast<IterVar>(op->node);
874 if (iv->thread_tag.length() != 0) {
875 if (!var_idmap_.count(iv->var.get())) {
876 BindThreadIndex(iv);
877 }
878 }
879 } else if (op->attr_key == tir::attr::volatile_scope) {
880 const VarNode* v = op->node.as<VarNode>();
881 ICHECK(v);
882 volatile_buf_.insert(v);
883 } else if (op->attr_key == tir::attr::pragma_import_c) {
884 const StringImmNode* value = op->value.as<StringImmNode>();
885 ICHECK(value != nullptr);
886 decl_stream << value->value;
887 }
888 this->PrintStmt(op->body);
889}
890
891void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
892 std::string cond = PrintExpr(op->condition);
893 PrintIndent();
894 if (const auto* str = op->message.as<StringImmNode>()) {
895 // GLOG style check
896 stream << "ICHECK(" << cond << ") << \"" << str->value << "\";\n";
897 } else {
898 stream << "assert(" << cond << ");\n";
899 }
900 this->PrintStmt(op->body);
901}
902
903void CodeGenC::VisitStmt_(const ForNode* op) {
904 std::string extent = PrintExpr(op->extent);
905 PrintIndent();
906 std::string vid = AllocVarID(op->loop_var.get());
907 ICHECK(is_zero(op->min));
908 stream << "for (";
909 PrintType(op->loop_var.dtype(), stream);
910 stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n";
911 int for_scope = BeginScope();
912 PrintStmt(op->body);
913 this->EndScope(for_scope);
914 PrintIndent();
915 stream << "}\n";
916}
917
918void CodeGenC::VisitStmt_(const WhileNode* op) {
919 PrintIndent();
920 stream << "while (" << PrintExpr(op->condition) << ") {\n";
921 int while_scope = BeginScope();
922 PrintStmt(op->body);
923 this->EndScope(while_scope);
924 PrintIndent();
925 stream << "}\n";
926}
927
928void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
929 std::string cond = PrintExpr(op->condition);
930 PrintIndent();
931 if (cond[0] == '(' && cond[cond.length() - 1] == ')') {
932 stream << "if " << cond << " {\n";
933 } else {
934 stream << "if (" << cond << ") {\n";
935 }
936 int then_scope = BeginScope();
937 PrintStmt(op->then_case);
938 this->EndScope(then_scope);
939
940 if (op->else_case) {
941 PrintIndent();
942 stream << "} else {\n";
943 int else_scope = BeginScope();
944 PrintStmt(op->else_case.value());
945 this->EndScope(else_scope);
946 }
947 PrintIndent();
948 stream << "}\n";
949}
950
951void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
952 for (Stmt stmt : op->seq) {
953 PrintStmt(stmt);
954 }
955}
956
957void CodeGenC::VisitStmt_(const EvaluateNode* op) {
958 if (is_const_int(op->value)) return;
959 const CallNode* call = op->value.as<CallNode>();
960 if (call) {
961 if (call->op.same_as(builtin::tvm_storage_sync())) {
962 this->PrintStorageSync(call);
963 return;
964 } else if (call->op.same_as(builtin::tvm_struct_set())) {
965 ICHECK_EQ(call->args.size(), 4);
966 int kind = call->args[2].as<IntImmNode>()->value;
967 std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], kind);
968 std::string value = PrintExpr(call->args[3]);
969 std::string cast;
970 if (kind == builtin::kArrStrides) {
971 // cast void* to int64_t*
972 cast = call->args[3]->dtype.is_handle() ? "(int64_t*)" : "";
973 } else if (kind == builtin::kArrDeviceType) {
974 // cast int to enum
975 cast = "(DLDeviceType)";
976 }
977 this->PrintIndent();
978 this->stream << ref << " = " << cast << value << ";\n";
979 return;
980 }
981 }
982 std::string vid = this->PrintExpr(op->value);
983 if (vid != "") {
984 this->PrintIndent();
985 this->stream << vid << ";\n";
986 }
987}
988
989void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) {
990 ICHECK_GT(t.lanes(), 1);
991 if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
992 if (i != 0) {
993 os << "|";
994 }
995 os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
996 return;
997 }
998
999 if (i == 0) {
1000 os << "((";
1001 PrintType(t, os);
1002 os << ")(";
1003 }
1004 os << value;
1005 if (i != t.lanes() - 1) {
1006 os << ",";
1007 } else {
1008 os << "))";
1009 }
1010 return;
1011}
1012
1013void CodeGenC::PrintRestrict(const Var& v, std::ostream& os) {
1014 if (restrict_keyword_.length() != 0) {
1015 os << ' ' << restrict_keyword_;
1016 }
1017}
1018
1019static bool CheckOutermostBracketMatch(const std::string& s) {
1020 if (!s.empty() && s.front() == '(' && s.back() == ')') {
1021 size_t len = s.size();
1022 int n_unmatched = 0;
1023 for (size_t i = 0; i < len; ++i) {
1024 if (s[i] == '(') {
1025 n_unmatched++;
1026 } else if (s[i] == ')') {
1027 n_unmatched--;
1028 }
1029 if (n_unmatched == 0) {
1030 return i == len - 1;
1031 }
1032 }
1033 }
1034 return false;
1035}
1036
1037} // namespace codegen
1038} // namespace tvm
1039