1#pragma once
2
3#include <c10/util/irange.h>
4
5#include <sstream>
6#include <string>
7#include <unordered_map>
8#include <vector>
9
10namespace at {
11namespace jit {
12
13// A template environment is a mapping from template variable names, e.g.,
14// identifier (corresponding to $identifier) to their expansions.
15//
16// This template environment supports storing strings, numbers and lists
17// of strings, and can be chained together (so that lookup proceeds in
18// in the top level environment, and then recurses into a parent
19// environment if the key is not found.)
20struct TemplateEnv {
21 TemplateEnv() = default;
22 TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
23
24 using string_list = std::vector<std::string>;
25
26 // Add a string 'v' to the map at key 'k'.
27 void s(const std::string& k, const std::string& v) {
28 strings_[k] = v;
29 lists_.erase(k);
30 }
31
32 // Add a number 'v' to the map at key 'k'
33 template <typename T>
34 void d(const std::string& k, const T& v) {
35 strings_[k] = c10::to_string(v);
36 lists_.erase(k);
37 }
38
39 // Retrieve the string representation of the value stored at 'k' from the map.
40 // Raises an exception if the key is not found.
41 const std::string& s(const std::string& k) const {
42 if (strings_.count(k) == 0) {
43 if (parent) {
44 return parent->s(k);
45 }
46 notFound(k);
47 }
48 return strings_.at(k);
49 }
50
51 // Store a list of strings 'v' in the map at 'k'.
52 void v(const std::string& k, const string_list& v) {
53 lists_[k] = v;
54 strings_.erase(k);
55 }
56
57 // Retrieve a list of strings stored at 'k' from the map.
58 // Raises an exception if the key is not found.
59 const string_list& v(const std::string& k) const {
60 if (lists_.count(k) == 0) {
61 if (parent) {
62 return parent->v(k);
63 }
64 notFound(k);
65 }
66 return lists_.at(k);
67 }
68
69 // Test if a string 'k' is a string (as opposed to a list.)
70 bool keyIsString(const std::string& k) const {
71 if (strings_.count(k) > 0)
72 return true;
73 if (lists_.count(k) > 0)
74 return false;
75 if (parent)
76 return parent->keyIsString(k);
77 notFound(k);
78 }
79
80 private:
81 [[noreturn]] void notFound(const std::string& k) const {
82 std::stringstream ss;
83 ss << "key not found: " << k;
84 throw std::logic_error(ss.str());
85 }
86
87 std::unordered_map<std::string, std::string> strings_;
88 std::unordered_map<std::string, string_list> lists_;
89 TemplateEnv* parent{nullptr};
90};
91
92/*
93# Match $identifier or ${identifier} and replace with the value in env.
94# If this identifier is at the beginning of whitespace on a line
95# and its value is a list then it is treated as
96# block substitution by indenting all lines of all elements.
97# If the identifier is on a line starting with non-whitespace and a list
98# then it is comma separated. ${,foo} will insert a comma before the list
99# if this list is not empty and ${foo,} will insert one after.
100*/
101struct CodeTemplate {
102 /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
103
104 std::string format(const TemplateEnv& env) const {
105 std::stringstream out;
106 size_t pos = 0;
107 size_t indent = 0;
108 bool all_whitespace = true;
109 while (pos < template_text.size()) {
110 char c = template_text[pos];
111 if (c == '$') {
112 std::stringstream kss;
113 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
114 bool comma_before;
115 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
116 bool comma_after;
117 size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
118 std::string k = kss.str();
119 bool is_string = env.keyIsString(k);
120 if (all_whitespace) {
121 if (is_string)
122 emitStringWithIndents(out, indent, env.s(k));
123 else
124 emitLinesIndented(out, indent, env.v(k));
125 } else {
126 if (is_string)
127 out << env.s(k);
128 else
129 emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
130 }
131 all_whitespace = false;
132 pos = new_pos;
133 } else {
134 out << c;
135 if (!isspace(c))
136 all_whitespace = false;
137 indent++;
138 if (c == '\n') {
139 indent = 0;
140 all_whitespace = true;
141 }
142 pos++;
143 }
144 }
145 return out.str();
146 }
147
148 private:
149 using string_list = std::vector<std::string>;
150 char charAt(size_t p) const {
151 if (p >= template_text.size())
152 throw std::logic_error("EOS found in key");
153 return template_text[p];
154 }
155 size_t parseKey(
156 size_t pos,
157 std::ostream& k,
158 bool& comma_before,
159 bool& comma_after) const {
160 comma_before = false;
161 comma_after = false;
162 pos++;
163 if (charAt(pos) == '{') {
164 pos++;
165 if (charAt(pos) == ',') {
166 comma_before = true;
167 pos++;
168 }
169 pos = parseIdent(pos, k);
170 if (charAt(pos) == ',') {
171 comma_after = true;
172 pos++;
173 }
174 if (charAt(pos) != '}')
175 throw std::logic_error("missing terminating '}'");
176 pos++;
177 return pos;
178 } else {
179 return parseIdent(pos, k);
180 }
181 }
182 size_t parseIdent(size_t pos, std::ostream& k) const {
183 while (pos < template_text.size() &&
184 (isalnum(template_text[pos]) || template_text[pos] == '_')) {
185 k << template_text[pos];
186 pos++;
187 }
188 return pos;
189 }
190 void emitCommaSeparatedList(
191 std::ostream& out,
192 const string_list& strings,
193 bool comma_before,
194 bool comma_after) const {
195 if (comma_before && !strings.empty())
196 out << ", ";
197 for (const auto i : c10::irange(strings.size())) {
198 if (i > 0)
199 out << ", ";
200 out << strings[i];
201 }
202 if (comma_after && !strings.empty())
203 out << ", ";
204 }
205 // These indentation functions follow the convention that they never emit
206 // leading or trailing newlines when the input string does not have leading
207 // or trailing newlines. It's the responsibility of the calling function
208 // to indent correctly in the context.
209 void emitIndent(std::ostream& out, size_t indent) const {
210 for (const auto i : c10::irange(indent)) {
211 (void)i; // Suppress unused variable warning
212 out << " ";
213 }
214 }
215 void emitStringWithIndents(
216 std::ostream& out,
217 size_t indent,
218 const std::string& str) const {
219 for (auto c : str) {
220 out << c;
221 if (c == '\n') {
222 emitIndent(out, indent);
223 }
224 }
225 }
226 void emitLinesIndented(
227 std::stringstream& out,
228 size_t indent,
229 const string_list& strings) const {
230 for (const auto i : c10::irange(strings.size())) {
231 if (i > 0)
232 emitIndent(out, indent);
233 emitStringWithIndents(out, indent, strings[i]);
234 if (i + 1 != strings.size())
235 out << "\n";
236 }
237 }
238 std::string template_text;
239};
240
241static inline std::string format(const std::string& fmt, TemplateEnv& env) {
242 return CodeTemplate(fmt).format(env);
243}
244
245} // namespace jit
246} // namespace at
247