1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_source_base.cc |
22 | */ |
23 | #include "codegen_source_base.h" |
24 | |
25 | #include <algorithm> |
26 | |
27 | namespace tvm { |
28 | namespace codegen { |
29 | |
30 | void CodeGenSourceBase::ClearFuncState() { |
31 | name_supply_ = NameSupply("" ); |
32 | ssa_assign_map_.clear(); |
33 | var_idmap_.clear(); |
34 | scope_mark_.clear(); |
35 | } |
36 | |
37 | std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { |
38 | if (name_supply_->ContainsName(src)) return src; |
39 | auto it = ssa_assign_map_.find(src); |
40 | if (it != ssa_assign_map_.end()) { |
41 | if (scope_mark_.at(it->second.scope_id)) { |
42 | return it->second.vid; |
43 | } |
44 | } |
45 | SSAEntry e; |
46 | e.vid = name_supply_->FreshName("_" ); |
47 | e.scope_id = static_cast<int>(scope_mark_.size() - 1); |
48 | ssa_assign_map_[src] = e; |
49 | this->PrintIndent(); |
50 | PrintSSAAssign(e.vid, src, t); |
51 | return e.vid; |
52 | } |
53 | |
54 | std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { |
55 | ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; |
56 | std::string key = v->name_hint; |
57 | std::string vid = name_supply_->FreshName(key); |
58 | std::replace(vid.begin(), vid.end(), ':', '_'); |
59 | std::replace(vid.begin(), vid.end(), '-', '_'); |
60 | std::replace(vid.begin(), vid.end(), '.', '_'); |
61 | var_idmap_[v] = vid; |
62 | return vid; |
63 | } |
64 | |
65 | std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const { |
66 | auto it = var_idmap_.find(v); |
67 | ICHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; |
68 | return it->second; |
69 | } |
70 | |
71 | void CodeGenSourceBase::PrintIndent() { |
72 | for (int i = 0; i < indent_; ++i) { |
73 | this->stream << ' '; |
74 | } |
75 | } |
76 | |
77 | void CodeGenSourceBase::MarkConst(std::string vid) { |
78 | auto it = ssa_assign_map_.find(vid); |
79 | if (it == ssa_assign_map_.end()) { |
80 | SSAEntry e; |
81 | e.vid = vid; |
82 | e.scope_id = 0; |
83 | ssa_assign_map_[vid] = e; |
84 | } else { |
85 | ICHECK_EQ(it->second.vid, vid); |
86 | } |
87 | } |
88 | |
89 | int CodeGenSourceBase::BeginScope() { |
90 | int sid = static_cast<int>(scope_mark_.size()); |
91 | scope_mark_.push_back(true); |
92 | indent_ += 2; |
93 | return sid; |
94 | } |
95 | |
96 | void CodeGenSourceBase::EndScope(int scope_id) { |
97 | scope_mark_[scope_id] = false; |
98 | indent_ -= 2; |
99 | } |
100 | |
101 | void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT(*) |
102 | ICHECK_EQ(type.lanes(), 1) << "do not yet support vector types" ; |
103 | if (type.is_handle()) { |
104 | os << "void*" ; |
105 | return; |
106 | } |
107 | if (type.is_void()) { |
108 | os << "void" ; |
109 | return; |
110 | } |
111 | if (type.is_float()) { |
112 | if (type.bits() == 32) { |
113 | os << "float" ; |
114 | return; |
115 | } |
116 | if (type.bits() == 64) { |
117 | os << "double" ; |
118 | return; |
119 | } |
120 | } else if (type.is_uint()) { |
121 | switch (type.bits()) { |
122 | case 8: |
123 | case 16: |
124 | case 32: |
125 | case 64: { |
126 | os << "uint" << type.bits() << "_t" ; |
127 | return; |
128 | } |
129 | case 1: |
130 | os << "int" ; |
131 | return; |
132 | } |
133 | } else if (type.is_int()) { |
134 | switch (type.bits()) { |
135 | case 8: |
136 | case 16: |
137 | case 32: |
138 | case 64: { |
139 | os << "int" << type.bits() << "_t" ; |
140 | return; |
141 | } |
142 | } |
143 | } |
144 | LOG(FATAL) << "Cannot convert type " << type << " to C type" ; |
145 | } |
146 | |
147 | void CodeGenSourceBase::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) |
148 | if (auto* ptr = type.as<PrimTypeNode>()) { |
149 | return PrintType(ptr->dtype, os); |
150 | } else if (auto* ptr = type.as<PointerTypeNode>()) { |
151 | PrintType(ptr->element_type, os); |
152 | os << '*'; |
153 | } else if (IsVoidType(type)) { |
154 | os << "void" ; |
155 | } else { |
156 | LOG(FATAL) << "Type " << type << " does not have a corresponding C Type" ; |
157 | } |
158 | } |
159 | |
160 | } // namespace codegen |
161 | } // namespace tvm |
162 | |