1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file install_debug_spans.cc
22 * \brief Prints TIR code in memory and replaces all spans in the module with
23 the location to which the ops would be printed
24 */
25
26#include "./install_debug_spans.h"
27
28#include <tvm/tir/transform.h>
29
30#include <string>
31#include <utility>
32
33#include "../../relay/printer/tir_text_printer_debug.h"
34
35namespace tvm {
36namespace tir {
37
38Stmt DebugInfoInstaller::InstallInfo(const std::string& name, const Stmt& stmt) {
39 DebugInfoInstaller installer(stmt, name + ".tir");
40 return installer.VisitStmt(stmt);
41}
42
43DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) {
44 // Determine the line that each stmt/expr will be printed on
45 tvm::relay::TIRTextPrinterDebug printer(false);
46
47 // Fill in the stmts and exprs' line info
48 auto result = printer.Print(stmt).str();
49
50 // Create map of the stmt/expr -> its line number in the output to later
51 // create new spans for each stmt/expr
52 const auto& stmts = printer.GetStmtsByLine();
53 VLOG(0) << "Debug printer found " << stmts.size() << " stmts after printing";
54 for (const auto& line : stmts) {
55 stmt_lines_[std::get<0>(line)] = std::get<1>(line);
56 }
57
58 const auto& exprs = printer.GetExprsByLine();
59 VLOG(0) << "Debug printer found " << exprs.size() << " exprs after printing";
60 for (const auto& line : exprs) {
61 expr_lines_[std::get<0>(line)] = std::get<1>(line);
62 }
63
64 // Output the printed TIR to the specified file
65 VLOG(0) << "Outputting TIR to " << filename;
66 filename_ = std::move(filename);
67 std::ofstream out(filename_);
68 out << result;
69 out.close();
70}
71
72PrimExpr DebugInfoInstaller::VisitExpr(const PrimExpr& expr) {
73 PrimExpr result = expr;
74 result = StmtExprMutator::VisitExpr(result);
75 return result;
76}
77
78Stmt DebugInfoInstaller::VisitStmt(const Stmt& stmt) {
79 Stmt result = stmt;
80 result = StmtExprMutator::VisitStmt(result);
81 return result;
82}
83
84Span DebugInfoInstaller::MaybeSpan(const StmtNode* op) {
85 auto entry = stmt_lines_.find(op);
86 if (entry == stmt_lines_.end()) {
87 return Span();
88 } else {
89 size_t column = 0;
90 size_t line = entry->second;
91 return Span(SourceName::Get(filename_), line, line, column, column);
92 }
93}
94
95Span DebugInfoInstaller::MaybeSpan(const PrimExprNode* op) {
96 auto entry = expr_lines_.find(op);
97 if (entry == expr_lines_.end()) {
98 return Span();
99 } else {
100 size_t column = 0;
101 size_t line = entry->second;
102 return Span(SourceName::Get(filename_), line, line, column, column);
103 }
104}
105
106#define X(TypeName) \
107 PrimExpr DebugInfoInstaller::VisitExpr_(const TypeName##Node* op) { \
108 auto new_expr = StmtExprMutator::VisitExpr_(op); \
109 auto new_type = Downcast<TypeName>(new_expr); \
110 auto new_node = new_type.CopyOnWrite(); \
111 new_node->span = MaybeSpan(op); \
112 return new_type; \
113 }
114TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS
115#undef X
116
117#define X(TypeName) \
118 Stmt DebugInfoInstaller::VisitStmt_(const TypeName##Node* op) { \
119 Stmt new_stmt = StmtExprMutator::VisitStmt_(op); \
120 auto new_type = Downcast<TypeName>(new_stmt); \
121 auto new_node = new_type.CopyOnWrite(); \
122 new_node->span = MaybeSpan(op); \
123 return new_type; \
124 }
125TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS
126#undef X
127
128namespace transform {
129
130Pass InstallDebugSpans() {
131 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
132 ICHECK(m->functions.size() == 1)
133 << "Debug info can only be added to IRModules with a single function";
134 // There is known to be only 1 function in the module at this point
135 auto entry = m->functions.begin();
136 auto name = std::get<0>(*entry)->name_hint;
137 auto* n = f.CopyOnWrite();
138
139 n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body));
140
141 return f;
142 };
143 return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {});
144}
145
146TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans);
147
148} // namespace transform
149} // namespace tir
150} // namespace tvm
151