1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_source_base.cc
22 */
23#include "codegen_source_base.h"
24
25#include <algorithm>
26
27namespace tvm {
28namespace codegen {
29
30void CodeGenSourceBase::ClearFuncState() {
31 name_supply_ = NameSupply("");
32 ssa_assign_map_.clear();
33 var_idmap_.clear();
34 scope_mark_.clear();
35}
36
37std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
38 if (name_supply_->ContainsName(src)) return src;
39 auto it = ssa_assign_map_.find(src);
40 if (it != ssa_assign_map_.end()) {
41 if (scope_mark_.at(it->second.scope_id)) {
42 return it->second.vid;
43 }
44 }
45 SSAEntry e;
46 e.vid = name_supply_->FreshName("_");
47 e.scope_id = static_cast<int>(scope_mark_.size() - 1);
48 ssa_assign_map_[src] = e;
49 this->PrintIndent();
50 PrintSSAAssign(e.vid, src, t);
51 return e.vid;
52}
53
54std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) {
55 ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint;
56 std::string key = v->name_hint;
57 std::string vid = name_supply_->FreshName(key);
58 std::replace(vid.begin(), vid.end(), ':', '_');
59 std::replace(vid.begin(), vid.end(), '-', '_');
60 std::replace(vid.begin(), vid.end(), '.', '_');
61 var_idmap_[v] = vid;
62 return vid;
63}
64
65std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const {
66 auto it = var_idmap_.find(v);
67 ICHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint;
68 return it->second;
69}
70
71void CodeGenSourceBase::PrintIndent() {
72 for (int i = 0; i < indent_; ++i) {
73 this->stream << ' ';
74 }
75}
76
77void CodeGenSourceBase::MarkConst(std::string vid) {
78 auto it = ssa_assign_map_.find(vid);
79 if (it == ssa_assign_map_.end()) {
80 SSAEntry e;
81 e.vid = vid;
82 e.scope_id = 0;
83 ssa_assign_map_[vid] = e;
84 } else {
85 ICHECK_EQ(it->second.vid, vid);
86 }
87}
88
89int CodeGenSourceBase::BeginScope() {
90 int sid = static_cast<int>(scope_mark_.size());
91 scope_mark_.push_back(true);
92 indent_ += 2;
93 return sid;
94}
95
96void CodeGenSourceBase::EndScope(int scope_id) {
97 scope_mark_[scope_id] = false;
98 indent_ -= 2;
99}
100
101void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT(*)
102 ICHECK_EQ(type.lanes(), 1) << "do not yet support vector types";
103 if (type.is_handle()) {
104 os << "void*";
105 return;
106 }
107 if (type.is_void()) {
108 os << "void";
109 return;
110 }
111 if (type.is_float()) {
112 if (type.bits() == 32) {
113 os << "float";
114 return;
115 }
116 if (type.bits() == 64) {
117 os << "double";
118 return;
119 }
120 } else if (type.is_uint()) {
121 switch (type.bits()) {
122 case 8:
123 case 16:
124 case 32:
125 case 64: {
126 os << "uint" << type.bits() << "_t";
127 return;
128 }
129 case 1:
130 os << "int";
131 return;
132 }
133 } else if (type.is_int()) {
134 switch (type.bits()) {
135 case 8:
136 case 16:
137 case 32:
138 case 64: {
139 os << "int" << type.bits() << "_t";
140 return;
141 }
142 }
143 }
144 LOG(FATAL) << "Cannot convert type " << type << " to C type";
145}
146
147void CodeGenSourceBase::PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
148 if (auto* ptr = type.as<PrimTypeNode>()) {
149 return PrintType(ptr->dtype, os);
150 } else if (auto* ptr = type.as<PointerTypeNode>()) {
151 PrintType(ptr->element_type, os);
152 os << '*';
153 } else if (IsVoidType(type)) {
154 os << "void";
155 } else {
156 LOG(FATAL) << "Type " << type << " does not have a corresponding C Type";
157 }
158}
159
160} // namespace codegen
161} // namespace tvm
162