1 | #pragma once |
2 | |
3 | #include <c10/util/irange.h> |
4 | |
5 | #include <sstream> |
6 | #include <string> |
7 | #include <unordered_map> |
8 | #include <vector> |
9 | |
10 | namespace at { |
11 | namespace jit { |
12 | |
13 | // A template environment is a mapping from template variable names, e.g., |
14 | // identifier (corresponding to $identifier) to their expansions. |
15 | // |
16 | // This template environment supports storing strings, numbers and lists |
17 | // of strings, and can be chained together (so that lookup proceeds in |
18 | // in the top level environment, and then recurses into a parent |
19 | // environment if the key is not found.) |
20 | struct TemplateEnv { |
21 | TemplateEnv() = default; |
22 | TemplateEnv(TemplateEnv& parent) : parent(&parent) {} |
23 | |
24 | using string_list = std::vector<std::string>; |
25 | |
26 | // Add a string 'v' to the map at key 'k'. |
27 | void s(const std::string& k, const std::string& v) { |
28 | strings_[k] = v; |
29 | lists_.erase(k); |
30 | } |
31 | |
32 | // Add a number 'v' to the map at key 'k' |
33 | template <typename T> |
34 | void d(const std::string& k, const T& v) { |
35 | strings_[k] = c10::to_string(v); |
36 | lists_.erase(k); |
37 | } |
38 | |
39 | // Retrieve the string representation of the value stored at 'k' from the map. |
40 | // Raises an exception if the key is not found. |
41 | const std::string& s(const std::string& k) const { |
42 | if (strings_.count(k) == 0) { |
43 | if (parent) { |
44 | return parent->s(k); |
45 | } |
46 | notFound(k); |
47 | } |
48 | return strings_.at(k); |
49 | } |
50 | |
51 | // Store a list of strings 'v' in the map at 'k'. |
52 | void v(const std::string& k, const string_list& v) { |
53 | lists_[k] = v; |
54 | strings_.erase(k); |
55 | } |
56 | |
57 | // Retrieve a list of strings stored at 'k' from the map. |
58 | // Raises an exception if the key is not found. |
59 | const string_list& v(const std::string& k) const { |
60 | if (lists_.count(k) == 0) { |
61 | if (parent) { |
62 | return parent->v(k); |
63 | } |
64 | notFound(k); |
65 | } |
66 | return lists_.at(k); |
67 | } |
68 | |
69 | // Test if a string 'k' is a string (as opposed to a list.) |
70 | bool keyIsString(const std::string& k) const { |
71 | if (strings_.count(k) > 0) |
72 | return true; |
73 | if (lists_.count(k) > 0) |
74 | return false; |
75 | if (parent) |
76 | return parent->keyIsString(k); |
77 | notFound(k); |
78 | } |
79 | |
80 | private: |
81 | [[noreturn]] void notFound(const std::string& k) const { |
82 | std::stringstream ss; |
83 | ss << "key not found: " << k; |
84 | throw std::logic_error(ss.str()); |
85 | } |
86 | |
87 | std::unordered_map<std::string, std::string> strings_; |
88 | std::unordered_map<std::string, string_list> lists_; |
89 | TemplateEnv* parent{nullptr}; |
90 | }; |
91 | |
92 | /* |
93 | # Match $identifier or ${identifier} and replace with the value in env. |
94 | # If this identifier is at the beginning of whitespace on a line |
95 | # and its value is a list then it is treated as |
96 | # block substitution by indenting all lines of all elements. |
97 | # If the identifier is on a line starting with non-whitespace and a list |
98 | # then it is comma separated. ${,foo} will insert a comma before the list |
99 | # if this list is not empty and ${foo,} will insert one after. |
100 | */ |
101 | struct CodeTemplate { |
102 | /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {} |
103 | |
104 | std::string format(const TemplateEnv& env) const { |
105 | std::stringstream out; |
106 | size_t pos = 0; |
107 | size_t indent = 0; |
108 | bool all_whitespace = true; |
109 | while (pos < template_text.size()) { |
110 | char c = template_text[pos]; |
111 | if (c == '$') { |
112 | std::stringstream kss; |
113 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
114 | bool comma_before; |
115 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
116 | bool comma_after; |
117 | size_t new_pos = parseKey(pos, kss, comma_before, comma_after); |
118 | std::string k = kss.str(); |
119 | bool is_string = env.keyIsString(k); |
120 | if (all_whitespace) { |
121 | if (is_string) |
122 | emitStringWithIndents(out, indent, env.s(k)); |
123 | else |
124 | emitLinesIndented(out, indent, env.v(k)); |
125 | } else { |
126 | if (is_string) |
127 | out << env.s(k); |
128 | else |
129 | emitCommaSeparatedList(out, env.v(k), comma_before, comma_after); |
130 | } |
131 | all_whitespace = false; |
132 | pos = new_pos; |
133 | } else { |
134 | out << c; |
135 | if (!isspace(c)) |
136 | all_whitespace = false; |
137 | indent++; |
138 | if (c == '\n') { |
139 | indent = 0; |
140 | all_whitespace = true; |
141 | } |
142 | pos++; |
143 | } |
144 | } |
145 | return out.str(); |
146 | } |
147 | |
148 | private: |
149 | using string_list = std::vector<std::string>; |
150 | char charAt(size_t p) const { |
151 | if (p >= template_text.size()) |
152 | throw std::logic_error("EOS found in key" ); |
153 | return template_text[p]; |
154 | } |
155 | size_t parseKey( |
156 | size_t pos, |
157 | std::ostream& k, |
158 | bool& comma_before, |
159 | bool& comma_after) const { |
160 | comma_before = false; |
161 | comma_after = false; |
162 | pos++; |
163 | if (charAt(pos) == '{') { |
164 | pos++; |
165 | if (charAt(pos) == ',') { |
166 | comma_before = true; |
167 | pos++; |
168 | } |
169 | pos = parseIdent(pos, k); |
170 | if (charAt(pos) == ',') { |
171 | comma_after = true; |
172 | pos++; |
173 | } |
174 | if (charAt(pos) != '}') |
175 | throw std::logic_error("missing terminating '}'" ); |
176 | pos++; |
177 | return pos; |
178 | } else { |
179 | return parseIdent(pos, k); |
180 | } |
181 | } |
182 | size_t parseIdent(size_t pos, std::ostream& k) const { |
183 | while (pos < template_text.size() && |
184 | (isalnum(template_text[pos]) || template_text[pos] == '_')) { |
185 | k << template_text[pos]; |
186 | pos++; |
187 | } |
188 | return pos; |
189 | } |
190 | void emitCommaSeparatedList( |
191 | std::ostream& out, |
192 | const string_list& strings, |
193 | bool comma_before, |
194 | bool comma_after) const { |
195 | if (comma_before && !strings.empty()) |
196 | out << ", " ; |
197 | for (const auto i : c10::irange(strings.size())) { |
198 | if (i > 0) |
199 | out << ", " ; |
200 | out << strings[i]; |
201 | } |
202 | if (comma_after && !strings.empty()) |
203 | out << ", " ; |
204 | } |
205 | // These indentation functions follow the convention that they never emit |
206 | // leading or trailing newlines when the input string does not have leading |
207 | // or trailing newlines. It's the responsibility of the calling function |
208 | // to indent correctly in the context. |
209 | void emitIndent(std::ostream& out, size_t indent) const { |
210 | for (const auto i : c10::irange(indent)) { |
211 | (void)i; // Suppress unused variable warning |
212 | out << " " ; |
213 | } |
214 | } |
215 | void emitStringWithIndents( |
216 | std::ostream& out, |
217 | size_t indent, |
218 | const std::string& str) const { |
219 | for (auto c : str) { |
220 | out << c; |
221 | if (c == '\n') { |
222 | emitIndent(out, indent); |
223 | } |
224 | } |
225 | } |
226 | void emitLinesIndented( |
227 | std::stringstream& out, |
228 | size_t indent, |
229 | const string_list& strings) const { |
230 | for (const auto i : c10::irange(strings.size())) { |
231 | if (i > 0) |
232 | emitIndent(out, indent); |
233 | emitStringWithIndents(out, indent, strings[i]); |
234 | if (i + 1 != strings.size()) |
235 | out << "\n" ; |
236 | } |
237 | } |
238 | std::string template_text; |
239 | }; |
240 | |
241 | static inline std::string format(const std::string& fmt, TemplateEnv& env) { |
242 | return CodeTemplate(fmt).format(env); |
243 | } |
244 | |
245 | } // namespace jit |
246 | } // namespace at |
247 | |