1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_hybrid.cc
22 */
23#include "codegen_hybrid.h"
24
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/builtin.h>
27
28#include <cctype>
29#include <iomanip>
30
31namespace tvm {
32namespace contrib {
33
34using runtime::TVMArgs;
35using runtime::TVMRetValue;
36
37using namespace tir;
38
39std::string dot_to_underscore(std::string s) {
40 for (auto& ch : s)
41 if (ch == '.') ch = '_';
42 return s;
43}
44
45std::string CodeGenHybrid::Finish() { return stream.str(); }
46
47void CodeGenHybrid::PrintType(DataType t, std::ostream& os) {
48 if (t.is_float()) {
49 os << "float";
50 ICHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
51 } else if (t.is_int()) {
52 os << "int";
53 ICHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
54 } else if (t.is_bfloat16()) {
55 os << "bfloat";
56 ICHECK(t.bits() == 16);
57 } else {
58 ICHECK(t.is_uint()) << "Unsupported type " << t;
59 os << "uint";
60 ICHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
61 }
62 os << t.bits();
63}
64
65void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
66 os << op->value;
67}
68
69void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
70 PrintType(op->dtype, os);
71 os << "(" << std::setprecision(20) << op->value << ")";
72}
73void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
74 os << "'" << op->value << "'";
75}
76
77template <typename T>
78inline void PrintBinaryExpr(const T* op, const char* opstr,
79 std::ostream& os, // NOLINT(*)
80 CodeGenHybrid* p) {
81 ICHECK(op->dtype.lanes() == 1) << "vec bin op not implemented";
82 if (isalpha(opstr[0])) {
83 os << opstr << '(';
84 p->PrintExpr(op->a, os);
85 os << ", ";
86 p->PrintExpr(op->b, os);
87 os << ')';
88 } else {
89 os << '(';
90 p->PrintExpr(op->a, os);
91 if (!strcmp(opstr, "&&")) opstr = "and";
92 if (!strcmp(opstr, "||")) opstr = "or";
93 os << ' ' << opstr << ' ';
94 p->PrintExpr(op->b, os);
95 os << ')';
96 }
97}
98
99inline void PrintBinaryIntrinsitc(const CallNode* op, const char* opstr,
100 std::ostream& os, // NOLINT(*)
101 CodeGenHybrid* p) {
102 ICHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented";
103 ICHECK_EQ(op->args.size(), 2U);
104 os << '(';
105 p->PrintExpr(op->args[0], os);
106 os << opstr;
107 p->PrintExpr(op->args[1], os);
108 os << ')';
109}
110
111void CodeGenHybrid::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*)
112 if (op->dtype == op->value.dtype()) {
113 PrintExpr(op->value, stream);
114 } else {
115 PrintType(op->dtype, os);
116 os << "(";
117 PrintExpr(op->value, os);
118 os << ")";
119 }
120}
121
122void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*)
123 os << GetVarID(op);
124}
125void CodeGenHybrid::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*)
126 PrintBinaryExpr(op, "+", os, this);
127}
128void CodeGenHybrid::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*)
129 PrintBinaryExpr(op, "-", os, this);
130}
131void CodeGenHybrid::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
132 PrintBinaryExpr(op, "*", os, this);
133}
134
135void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
136 if (op->dtype.is_int())
137 PrintBinaryExpr(op, "//", os, this);
138 else
139 PrintBinaryExpr(op, "/", os, this);
140}
141
142void CodeGenHybrid::VisitExpr_(const FloorDivNode* op, std::ostream& os) { // NOLINT(*)
143 if (op->dtype.is_int())
144 PrintBinaryExpr(op, "//", os, this);
145 else
146 PrintBinaryExpr(op, "/", os, this);
147}
148
149void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
150 PrintBinaryExpr(op, "%", os, this);
151}
152
153void CodeGenHybrid::VisitExpr_(const FloorModNode* op, std::ostream& os) { // NOLINT(*)
154 PrintBinaryExpr(op, "%", os, this);
155}
156void CodeGenHybrid::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
157 PrintBinaryExpr(op, "min", os, this);
158}
159void CodeGenHybrid::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
160 PrintBinaryExpr(op, "max", os, this);
161}
162void CodeGenHybrid::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*)
163 PrintBinaryExpr(op, "==", os, this);
164}
165void CodeGenHybrid::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*)
166 PrintBinaryExpr(op, "!=", os, this);
167}
168void CodeGenHybrid::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*)
169 PrintBinaryExpr(op, "<", os, this);
170}
171void CodeGenHybrid::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*)
172 PrintBinaryExpr(op, "<=", os, this);
173}
174void CodeGenHybrid::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*)
175 PrintBinaryExpr(op, ">", os, this);
176}
177void CodeGenHybrid::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*)
178 PrintBinaryExpr(op, ">=", os, this);
179}
180void CodeGenHybrid::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*)
181 PrintBinaryExpr(op, "&&", os, this);
182}
183void CodeGenHybrid::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*)
184 PrintBinaryExpr(op, "||", os, this);
185}
186void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*)
187 os << "not ";
188 PrintExpr(op->a, os);
189}
190
191void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { // NOLINT(*)
192 auto tensor = Downcast<Tensor>(op->producer);
193
194 os << GetTensorID(tensor);
195 os << "[";
196 for (size_t i = 0; i < op->indices.size(); ++i) {
197 if (i) os << ", ";
198 std::stringstream idx;
199 PrintExpr(op->indices[i], idx);
200 os << idx.str();
201 }
202 os << "]";
203}
204void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
205 if (op->op.same_as(builtin::bitwise_and())) {
206 PrintBinaryIntrinsitc(op, "&", os, this);
207 } else if (op->op.same_as(builtin::bitwise_xor())) {
208 PrintBinaryIntrinsitc(op, "^", os, this);
209 } else if (op->op.same_as(builtin::bitwise_or())) {
210 PrintBinaryIntrinsitc(op, "|", os, this);
211 } else if (op->op.same_as(builtin::shift_left())) {
212 PrintBinaryIntrinsitc(op, "<<", os, this);
213 } else if (op->op.same_as(builtin::shift_right())) {
214 PrintBinaryIntrinsitc(op, ">>", os, this);
215 } else if (op->op.same_as(builtin::bitwise_not())) {
216 ICHECK_EQ(op->args.size(), 1U);
217 os << "(~";
218 PrintExpr(op->args[0], os);
219 os << ')';
220 } else if (op->op.same_as(builtin::if_then_else())) {
221 PrintExpr(op->args[1], os);
222 os << " if ";
223 PrintExpr(op->args[0], os);
224 os << " else ";
225 PrintExpr(op->args[2], os);
226 } else if (op->op.same_as(builtin::call_pure_extern()) ||
227 op->op.same_as(builtin::call_extern())) {
228 StringImm fname = Downcast<StringImm>(op->args[0]);
229 os << fname << "(";
230 for (size_t i = 1; i < op->args.size(); i++) {
231 PrintExpr(op->args[i], os);
232 if (i < op->args.size() - 1) {
233 os << ", ";
234 }
235 }
236 os << ")";
237 } else {
238 auto* ptr_op = op->op.as<OpNode>();
239 ICHECK(ptr_op != nullptr);
240 std::string name = ptr_op->name;
241 ICHECK_EQ(name.compare(0, 4, "tir."), 0);
242 os << name.substr(4) << "(";
243 for (size_t i = 0; i < op->args.size(); i++) {
244 PrintExpr(op->args[i], os);
245 if (i < op->args.size() - 1) {
246 os << ", ";
247 }
248 }
249 os << ")";
250 }
251}
252
253void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
254 LOG(FATAL) << "Phase 0 has no Load(s)!";
255}
256
257void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; }
258
259void CodeGenHybrid::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*)
260 LOG(FATAL) << "Phase 0 has no BufferLoad(s)!";
261}
262
263void CodeGenHybrid::VisitStmt_(const BufferStoreNode* op) {
264 LOG(FATAL) << "Phase 0 has no BufferStore(s)!";
265}
266
267void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
268 LOG(FATAL) << "Phase 0 has no Let(s)!";
269}
270
271void CodeGenHybrid::VisitStmt_(const AllocateNode* op) {
272 LOG(FATAL) << "Phase 0 has no Allocate(s)!";
273}
274
275void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
276 LOG(FATAL) << "Ramp to be supported yet";
277}
278
279void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
280 LOG(FATAL) << "Broadcast: not supported ";
281}
282
283void CodeGenHybrid::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
284 PrintExpr(op->true_value, os);
285 os << " if ";
286 PrintExpr(op->condition, os);
287 os << " else ";
288 PrintExpr(op->false_value, os);
289 os << "\n";
290}
291
292void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) {
293 std::string value = PrintExpr(op->value);
294 stream << GetVarID(op->var.get()) << " = " << value << ";\n";
295 PrintStmt(op->body);
296}
297
298void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
299 if (op->attr_key == tir::attr::thread_extent) {
300 auto iter_var = op->node.as<IterVarNode>();
301 ICHECK(iter_var);
302 binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint);
303 PrintIndent();
304 stream << "for " << binds_[iter_var->var.get()] << " in bind('" << iter_var->var->name_hint
305 << "', ";
306 PrintExpr(op->value, stream);
307 stream << "):\n";
308 indent_ += tab_;
309 PrintStmt(op->body);
310 indent_ -= tab_;
311 } else {
312 // For now we ignore the unsupported AttrStmt
313 PrintStmt(op->body);
314 }
315}
316
317void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) {
318 auto tensor = Downcast<Tensor>(op->producer);
319 if (!op->storage_scope.empty()) {
320 PrintIndent();
321 stream << GetTensorID(tensor) << " = allocate((";
322 for (size_t i = 0; i < op->bounds.size(); ++i) {
323 if (i) stream << ", ";
324 stream << PrintExpr(op->bounds[i]->extent);
325 }
326 if (op->bounds.size() == 1) stream << ", ";
327 stream << "), '";
328 PrintType(tensor->dtype, stream);
329 stream << "', '";
330 stream << op->storage_scope << "')\n";
331 }
332 PrintStmt(op->body);
333}
334
335void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) {
336 PrintIndent();
337 stream << "assert ";
338 PrintExpr(op->condition, stream);
339 stream << ", ";
340 PrintExpr(op->message, stream);
341 stream << "\n";
342 PrintStmt(op->body);
343}
344
345void CodeGenHybrid::VisitStmt_(const ProducerStoreNode* op) {
346 auto tensor = Downcast<Tensor>(op->producer);
347 PrintIndent();
348 stream << GetTensorID(tensor);
349 stream << "[";
350 for (size_t i = 0; i < op->indices.size(); ++i) {
351 if (i) stream << ", ";
352 PrintExpr(op->indices[i], stream);
353 }
354 stream << "] = ";
355 PrintExpr(op->value, stream);
356 stream << "\n";
357}
358
359void CodeGenHybrid::VisitStmt_(const ForNode* op) {
360 std::string extent = PrintExpr(op->extent);
361 PrintIndent();
362 std::string vid = GetVarID(op->loop_var.get());
363 stream << "for " << vid << " in "
364 << "range(" << extent << "):\n";
365 indent_ += tab_;
366 PrintStmt(op->body);
367 indent_ -= tab_;
368}
369
370bool is_noop(const Stmt& stmt) {
371 if (!stmt.defined()) return true;
372 if (auto eval = stmt.as<EvaluateNode>()) return is_const_int(eval->value);
373 return false;
374}
375
376void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) {
377 std::string cond = PrintExpr(op->condition);
378 PrintIndent();
379 stream << "if " << cond << ":\n";
380 indent_ += tab_;
381 PrintStmt(op->then_case);
382 indent_ -= tab_;
383
384 if (op->else_case && !is_noop(op->else_case.value())) {
385 PrintIndent();
386 stream << "else:\n";
387 indent_ += tab_;
388 PrintStmt(op->else_case.value());
389 indent_ -= tab_;
390 }
391}
392
393void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
394 for (Stmt stmt : op->seq) {
395 PrintStmt(stmt);
396 }
397}
398
399void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
400 if (is_const_int(op->value)) return;
401 std::string str = PrintExpr(op->value);
402 if (!str.empty()) stream << str << "\n";
403}
404
405void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); }
406
407std::string CodeGenHybrid::GetVarID(const VarNode* v) {
408 if (binds_.count(v)) return binds_[v];
409 auto key = std::make_pair(static_cast<const Object*>(v), 0);
410 if (id_map_.count(key)) {
411 return id_map_[key];
412 }
413 return id_map_[key] = ids_allocated->FreshName(v->name_hint);
414}
415
416std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) {
417 auto key = std::make_pair(tensor->op.get(), tensor->value_index);
418 if (id_map_.count(key)) {
419 return id_map_[key];
420 }
421 std::string name_hint = tensor->op->name;
422 if (tensor->op->num_outputs() > 1) {
423 name_hint += "_v" + std::to_string(tensor->value_index);
424 }
425 return id_map_[key] = ids_allocated->FreshName(name_hint);
426}
427
428void CodeGenHybrid::ReserveKeywords() {
429 ids_allocated->ReserveName("def");
430 ids_allocated->ReserveName("for");
431 ids_allocated->ReserveName("in");
432 ids_allocated->ReserveName("range");
433 ids_allocated->ReserveName("True");
434 ids_allocated->ReserveName("False");
435 ids_allocated->ReserveName("unroll");
436 ids_allocated->ReserveName("const_range");
437 ids_allocated->ReserveName("parallel");
438 ids_allocated->ReserveName("vectorize");
439 ids_allocated->ReserveName("bind");
440 ids_allocated->ReserveName("threadIdx.x");
441 ids_allocated->ReserveName("threadIdx.y");
442 ids_allocated->ReserveName("threadIdx.z");
443 ids_allocated->ReserveName("blockIdx.x");
444 ids_allocated->ReserveName("blockIdx.y");
445 ids_allocated->ReserveName("blockIdx.z");
446 ids_allocated->ReserveName("vthread");
447 ids_allocated->ReserveName("allocate");
448 ids_allocated->ReserveName("output_tensor");
449 ids_allocated->ReserveName("sqrt");
450 ids_allocated->ReserveName("log");
451 ids_allocated->ReserveName("tanh");
452 ids_allocated->ReserveName("power");
453 ids_allocated->ReserveName("exp");
454 ids_allocated->ReserveName("sigmoid");
455 ids_allocated->ReserveName("popcount");
456 ids_allocated->ReserveName("likely");
457 ids_allocated->ReserveName("int8");
458 ids_allocated->ReserveName("int16");
459 ids_allocated->ReserveName("int32");
460 ids_allocated->ReserveName("int64");
461 ids_allocated->ReserveName("uint8");
462 ids_allocated->ReserveName("uint16");
463 ids_allocated->ReserveName("uint32");
464 ids_allocated->ReserveName("uint64");
465 ids_allocated->ReserveName("float16");
466 ids_allocated->ReserveName("float32");
467 ids_allocated->ReserveName("float64");
468 ids_allocated->ReserveName("ceil_div");
469 ids_allocated->ReserveName("max_num_threads");
470}
471
472void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs,
473 const Array<Tensor>& outputs, const std::string& name) {
474 ReserveKeywords();
475 ids_allocated->ReserveName(name);
476
477 stream << "def " << name << "(";
478 for (size_t i = 0; i < inputs.size(); ++i) {
479 if (i) stream << ", ";
480 if (auto tensor = inputs[i].as<TensorNode>()) {
481 stream << GetTensorID(GetRef<Tensor>(tensor));
482 } else {
483 auto var = inputs[i].as<VarNode>();
484 ICHECK(var) << "Input should either be a tensor or a variable!";
485 stream << GetVarID(var);
486 }
487 }
488 stream << "):\n";
489 indent_ += tab_;
490 for (size_t i = 0; i < outputs.size(); ++i) {
491 PrintIndent();
492 stream << GetTensorID(outputs[i]) << " = output_tensor((";
493 for (size_t j = 0; j < outputs[i]->shape.size(); ++j) {
494 if (j) stream << ", ";
495 PrintExpr(outputs[i]->shape[j], stream);
496 }
497 if (outputs[i]->shape.size() == 1) stream << ", ";
498 stream << "), '" << outputs[i]->dtype << "')\n";
499 }
500 PrintStmt(stmt);
501 PrintIndent();
502 stream << "return ";
503 for (size_t i = 0; i < outputs.size(); ++i) {
504 if (i) stream << ", ";
505 stream << GetTensorID(outputs[i]);
506 }
507 stream << "\n";
508}
509
510TVM_REGISTER_GLOBAL("hybrid._Dump").set_body([](TVMArgs args, TVMRetValue* rv) {
511 CodeGenHybrid codegen;
512 if (args.size() == 4)
513 codegen.DumpStmt(args[0], args[1], args[2], args[3]);
514 else
515 codegen.DumpStmt(args[0], args[1], args[2]);
516 *rv = codegen.Finish();
517});
518} // namespace contrib
519} // namespace tvm
520