1/*
2 * Copyright 2021 Google Inc. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "bfbs_gen_lua.h"
18
19#include <cstdint>
20#include <map>
21#include <memory>
22#include <string>
23#include <unordered_set>
24#include <vector>
25
26// Ensure no includes to flatc internals. bfbs_gen.h and generator.h are OK.
27#include "bfbs_gen.h"
28#include "flatbuffers/bfbs_generator.h"
29
30// The intermediate representation schema.
31#include "flatbuffers/reflection_generated.h"
32
33namespace flatbuffers {
34namespace {
35
36// To reduce typing
37namespace r = ::reflection;
38
39class LuaBfbsGenerator : public BaseBfbsGenerator {
40 public:
41 explicit LuaBfbsGenerator(const std::string &flatc_version)
42 : BaseBfbsGenerator(),
43 keywords_(),
44 requires_(),
45 current_obj_(nullptr),
46 current_enum_(nullptr),
47 flatc_version_(flatc_version) {
48 static const char *const keywords[] = {
49 "and", "break", "do", "else", "elseif", "end", "false", "for",
50 "function", "goto", "if", "in", "local", "nil", "not", "or",
51 "repeat", "return", "then", "true", "until", "while"
52 };
53 keywords_.insert(std::begin(keywords), std::end(keywords));
54 }
55
56 GeneratorStatus GenerateFromSchema(const r::Schema *schema)
57 FLATBUFFERS_OVERRIDE {
58 if (!GenerateEnums(schema->enums())) { return FAILED; }
59 if (!GenerateObjects(schema->objects(), schema->root_table())) {
60 return FAILED;
61 }
62 return OK;
63 }
64
65 uint64_t SupportedAdvancedFeatures() const FLATBUFFERS_OVERRIDE {
66 return 0xF;
67 }
68
69 protected:
70 bool GenerateEnums(
71 const flatbuffers::Vector<flatbuffers::Offset<r::Enum>> *enums) {
72 ForAllEnums(enums, [&](const r::Enum *enum_def) {
73 std::string code;
74
75 StartCodeBlock(enum_def);
76
77 std::string ns;
78 const std::string enum_name =
79 NormalizeName(Denamespace(enum_def->name(), ns));
80
81 GenerateDocumentation(enum_def->documentation(), "", code);
82 code += "local " + enum_name + " = {\n";
83
84 ForAllEnumValues(enum_def, [&](const reflection::EnumVal *enum_val) {
85 GenerateDocumentation(enum_val->documentation(), " ", code);
86 code += " " + NormalizeName(enum_val->name()) + " = " +
87 NumToString(enum_val->value()) + ",\n";
88 });
89 code += "}\n";
90 code += "\n";
91
92 EmitCodeBlock(code, enum_name, ns, enum_def->declaration_file()->str());
93 });
94 return true;
95 }
96
97 bool GenerateObjects(
98 const flatbuffers::Vector<flatbuffers::Offset<r::Object>> *objects,
99 const r::Object *root_object) {
100 ForAllObjects(objects, [&](const r::Object *object) {
101 std::string code;
102
103 StartCodeBlock(object);
104
105 // Register the main flatbuffers module.
106 RegisterRequires("flatbuffers", "flatbuffers");
107
108 std::string ns;
109 const std::string object_name =
110 NormalizeName(Denamespace(object->name(), ns));
111
112 GenerateDocumentation(object->documentation(), "", code);
113
114 code += "local " + object_name + " = {}\n";
115 code += "local mt = {}\n";
116 code += "\n";
117 code += "function " + object_name + ".New()\n";
118 code += " local o = {}\n";
119 code += " setmetatable(o, {__index = mt})\n";
120 code += " return o\n";
121 code += "end\n";
122 code += "\n";
123
124 if (object == root_object) {
125 code += "function " + object_name + ".GetRootAs" + object_name +
126 "(buf, offset)\n";
127 code += " if type(buf) == \"string\" then\n";
128 code += " buf = flatbuffers.binaryArray.New(buf)\n";
129 code += " end\n";
130 code += "\n";
131 code += " local n = flatbuffers.N.UOffsetT:Unpack(buf, offset)\n";
132 code += " local o = " + object_name + ".New()\n";
133 code += " o:Init(buf, n + offset)\n";
134 code += " return o\n";
135 code += "end\n";
136 code += "\n";
137 }
138
139 // Generates a init method that receives a pre-existing accessor object,
140 // so that objects can be reused.
141
142 code += "function mt:Init(buf, pos)\n";
143 code += " self.view = flatbuffers.view.New(buf, pos)\n";
144 code += "end\n";
145 code += "\n";
146
147 // Create all the field accessors.
148 ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
149 // Skip writing deprecated fields altogether.
150 if (field->deprecated()) { return; }
151
152 const std::string field_name = NormalizeName(field->name());
153 const std::string field_name_camel_case = MakeCamelCase(field_name);
154 const r::BaseType base_type = field->type()->base_type();
155
156 // Generate some fixed strings so we don't repeat outselves later.
157 const std::string getter_signature =
158 "function mt:" + field_name_camel_case + "()\n";
159 const std::string offset_prefix = "local o = self.view:Offset(" +
160 NumToString(field->offset()) + ")\n";
161 const std::string offset_prefix_2 = "if o ~= 0 then\n";
162
163 GenerateDocumentation(field->documentation(), "", code);
164
165 if (IsScalar(base_type)) {
166 code += getter_signature;
167
168 if (object->is_struct()) {
169 // TODO(derekbailey): it would be nice to modify the view:Get to
170 // just pass in the offset and not have to add it its own
171 // self.view.pos.
172 code += " return " + GenerateGetter(field->type()) +
173 "self.view.pos + " + NumToString(field->offset()) + ")\n";
174 } else {
175 // Table accessors
176 code += " " + offset_prefix;
177 code += " " + offset_prefix_2;
178
179 std::string getter =
180 GenerateGetter(field->type()) + "self.view.pos + o)";
181 if (IsBool(base_type)) { getter = "(" + getter + " ~=0)"; }
182 code += " return " + getter + "\n";
183 code += " end\n";
184 code += " return " + DefaultValue(field) + "\n";
185 }
186 code += "end\n";
187 code += "\n";
188 } else {
189 switch (base_type) {
190 case r::String: {
191 code += getter_signature;
192 code += " " + offset_prefix;
193 code += " " + offset_prefix_2;
194 code += " return " + GenerateGetter(field->type()) +
195 "self.view.pos + o)\n";
196 code += " end\n";
197 code += "end\n";
198 code += "\n";
199 break;
200 }
201 case r::Obj: {
202 if (object->is_struct()) {
203 code += "function mt:" + field_name_camel_case + "(obj)\n";
204 code += " obj:Init(self.view.bytes, self.view.pos + " +
205 NumToString(field->offset()) + ")\n";
206 code += " return obj\n";
207 code += "end\n";
208 code += "\n";
209 } else {
210 code += getter_signature;
211 code += " " + offset_prefix;
212 code += " " + offset_prefix_2;
213
214 const r::Object *field_object = GetObject(field->type());
215 if (!field_object) {
216 // TODO(derekbailey): this is an error condition. we
217 // should report it better.
218 return;
219 }
220 code += " local x = " +
221 std::string(
222 field_object->is_struct()
223 ? "self.view.pos + o\n"
224 : "self.view:Indirect(self.view.pos + o)\n");
225 const std::string require_name = RegisterRequires(field);
226 code += " local obj = " + require_name + ".New()\n";
227 code += " obj:Init(self.view.bytes, x)\n";
228 code += " return obj\n";
229 code += " end\n";
230 code += "end\n";
231 code += "\n";
232 }
233 break;
234 }
235 case r::Union: {
236 code += getter_signature;
237 code += " " + offset_prefix;
238 code += " " + offset_prefix_2;
239 code +=
240 " local obj = "
241 "flatbuffers.view.New(flatbuffers.binaryArray.New("
242 "0), 0)\n";
243 code += " " + GenerateGetter(field->type()) + "obj, o)\n";
244 code += " return obj\n";
245 code += " end\n";
246 code += "end\n";
247 code += "\n";
248 break;
249 }
250 case r::Array:
251 case r::Vector: {
252 const r::BaseType vector_base_type = field->type()->element();
253 int32_t element_size = field->type()->element_size();
254 code += "function mt:" + field_name_camel_case + "(j)\n";
255 code += " " + offset_prefix;
256 code += " " + offset_prefix_2;
257
258 if (IsStructOrTable(vector_base_type)) {
259 code += " local x = self.view:Vector(o)\n";
260 code +=
261 " x = x + ((j-1) * " + NumToString(element_size) + ")\n";
262 if (IsTable(field->type(), /*use_element=*/true)) {
263 code += " x = self.view:Indirect(x)\n";
264 } else {
265 // Vector of structs are inline, so we need to query the
266 // size of the struct.
267 const reflection::Object *obj =
268 GetObjectByIndex(field->type()->index());
269 element_size = obj->bytesize();
270 }
271
272 // Include the referenced type, thus we need to make sure
273 // we set `use_element` to true.
274 const std::string require_name =
275 RegisterRequires(field, /*use_element=*/true);
276 code += " local obj = " + require_name + ".New()\n";
277 code += " obj:Init(self.view.bytes, x)\n";
278 code += " return obj\n";
279 } else {
280 code += " local a = self.view:Vector(o)\n";
281 code += " return " + GenerateGetter(field->type()) +
282 "a + ((j-1) * " + NumToString(element_size) + "))\n";
283 }
284 code += " end\n";
285 // Only generate a default value for those types that are
286 // supported.
287 if (!IsStructOrTable(vector_base_type)) {
288 code +=
289 " return " +
290 std::string(vector_base_type == r::String ? "''\n" : "0\n");
291 }
292 code += "end\n";
293 code += "\n";
294
295 // If the vector is composed of single byte values, we
296 // generate a helper function to get it as a byte string in
297 // Lua.
298 if (IsSingleByte(vector_base_type)) {
299 code += "function mt:" + field_name_camel_case +
300 "AsString(start, stop)\n";
301 code += " return self.view:VectorAsString(" +
302 NumToString(field->offset()) + ", start, stop)\n";
303 code += "end\n";
304 code += "\n";
305 }
306
307 // We also make a new accessor to query just the length of the
308 // vector.
309 code += "function mt:" + field_name_camel_case + "Length()\n";
310 code += " " + offset_prefix;
311 code += " " + offset_prefix_2;
312 code += " return self.view:VectorLen(o)\n";
313 code += " end\n";
314 code += " return 0\n";
315 code += "end\n";
316 code += "\n";
317 break;
318 }
319 default: {
320 return;
321 }
322 }
323 }
324 return;
325 });
326
327 // Create all the builders
328 if (object->is_struct()) {
329 code += "function " + object_name + ".Create" + object_name +
330 "(builder" + GenerateStructBuilderArgs(object) + ")\n";
331 code += AppendStructBuilderBody(object);
332 code += " return builder:Offset()\n";
333 code += "end\n";
334 code += "\n";
335 } else {
336 // Table builders
337 code += "function " + object_name + ".Start(builder)\n";
338 code += " builder:StartObject(" +
339 NumToString(object->fields()->size()) + ")\n";
340 code += "end\n";
341 code += "\n";
342
343 ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
344 if (field->deprecated()) { return; }
345
346 const std::string field_name = NormalizeName(field->name());
347
348 code += "function " + object_name + ".Add" +
349 MakeCamelCase(field_name) + "(builder, " +
350 MakeCamelCase(field_name, false) + ")\n";
351 code += " builder:Prepend" + GenerateMethod(field) + "Slot(" +
352 NumToString(field->id()) + ", " +
353 MakeCamelCase(field_name, false) + ", " +
354 DefaultValue(field) + ")\n";
355 code += "end\n";
356 code += "\n";
357
358 if (IsVector(field->type()->base_type())) {
359 code += "function " + object_name + ".Start" +
360 MakeCamelCase(field_name) + "Vector(builder, numElems)\n";
361
362 const int32_t element_size = field->type()->element_size();
363 int32_t alignment = 0;
364 if (IsStruct(field->type(), /*use_element=*/true)) {
365 alignment = GetObjectByIndex(field->type()->index())->minalign();
366 } else {
367 alignment = element_size;
368 }
369
370 code += " return builder:StartVector(" +
371 NumToString(element_size) + ", numElems, " +
372 NumToString(alignment) + ")\n";
373 code += "end\n";
374 code += "\n";
375 }
376 });
377
378 code += "function " + object_name + ".End(builder)\n";
379 code += " return builder:EndObject()\n";
380 code += "end\n";
381 code += "\n";
382 }
383
384 EmitCodeBlock(code, object_name, ns, object->declaration_file()->str());
385 });
386 return true;
387 }
388
389 private:
390 void GenerateDocumentation(
391 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>
392 *documentation,
393 std::string indent, std::string &code) const {
394 flatbuffers::ForAllDocumentation(
395 documentation, [&](const flatbuffers::String *str) {
396 code += indent + "--" + str->str() + "\n";
397 });
398 }
399
400 std::string GenerateStructBuilderArgs(const r::Object *object,
401 std::string prefix = "") const {
402 std::string signature;
403 ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
404 if (IsStructOrTable(field->type()->base_type())) {
405 const r::Object *field_object = GetObject(field->type());
406 signature += GenerateStructBuilderArgs(
407 field_object, prefix + NormalizeName(field->name()) + "_");
408 } else {
409 signature +=
410 ", " + prefix + MakeCamelCase(NormalizeName(field->name()), false);
411 }
412 });
413 return signature;
414 }
415
416 std::string AppendStructBuilderBody(const r::Object *object,
417 std::string prefix = "") const {
418 std::string code;
419 code += " builder:Prep(" + NumToString(object->minalign()) + ", " +
420 NumToString(object->bytesize()) + ")\n";
421
422 // We need to reverse the order we iterate over, since we build the
423 // buffer backwards.
424 ForAllFields(object, /*reverse=*/true, [&](const r::Field *field) {
425 const int32_t num_padding_bytes = field->padding();
426 if (num_padding_bytes) {
427 code += " builder:Pad(" + NumToString(num_padding_bytes) + ")\n";
428 }
429 if (IsStructOrTable(field->type()->base_type())) {
430 const r::Object *field_object = GetObject(field->type());
431 code += AppendStructBuilderBody(
432 field_object, prefix + NormalizeName(field->name()) + "_");
433 } else {
434 code += " builder:Prepend" + GenerateMethod(field) + "(" + prefix +
435 MakeCamelCase(NormalizeName(field->name()), false) + ")\n";
436 }
437 });
438
439 return code;
440 }
441
442 std::string GenerateMethod(const r::Field *field) const {
443 const r::BaseType base_type = field->type()->base_type();
444 if (IsScalar(base_type)) { return MakeCamelCase(GenerateType(base_type)); }
445 if (IsStructOrTable(base_type)) { return "Struct"; }
446 return "UOffsetTRelative";
447 }
448
449 std::string GenerateGetter(const r::Type *type,
450 bool element_type = false) const {
451 switch (element_type ? type->element() : type->base_type()) {
452 case r::String: return "self.view:String(";
453 case r::Union: return "self.view:Union(";
454 case r::Vector: return GenerateGetter(type, true);
455 default:
456 return "self.view:Get(flatbuffers.N." +
457 MakeCamelCase(GenerateType(type, element_type)) + ", ";
458 }
459 }
460
461 std::string GenerateType(const r::Type *type,
462 bool element_type = false) const {
463 const r::BaseType base_type =
464 element_type ? type->element() : type->base_type();
465 if (IsScalar(base_type)) { return GenerateType(base_type); }
466 switch (base_type) {
467 case r::String: return "string";
468 case r::Vector: return GenerateGetter(type, true);
469 case r::Obj: {
470 const r::Object *obj = GetObject(type);
471 return NormalizeName(Denamespace(obj->name()));
472 };
473 default: return "*flatbuffers.Table";
474 }
475 }
476
477 std::string GenerateType(const r::BaseType base_type) const {
478 // Need to override the default naming to match the Lua runtime libraries.
479 // TODO(derekbailey): make overloads in the runtime libraries to avoid this.
480 switch (base_type) {
481 case r::None: return "uint8";
482 case r::UType: return "uint8";
483 case r::Byte: return "int8";
484 case r::UByte: return "uint8";
485 case r::Short: return "int16";
486 case r::UShort: return "uint16";
487 case r::Int: return "int32";
488 case r::UInt: return "uint32";
489 case r::Long: return "int64";
490 case r::ULong: return "uint64";
491 case r::Float: return "Float32";
492 case r::Double: return "Float64";
493 default: return r::EnumNameBaseType(base_type);
494 }
495 }
496
497 std::string DefaultValue(const r::Field *field) const {
498 const r::BaseType base_type = field->type()->base_type();
499 if (IsFloatingPoint(base_type)) {
500 return NumToString(field->default_real());
501 }
502 if (IsBool(base_type)) {
503 return field->default_integer() ? "true" : "false";
504 }
505 if (IsScalar(base_type)) { return NumToString((field->default_integer())); }
506 // represents offsets
507 return "0";
508 }
509
510 std::string NormalizeName(const std::string name) const {
511 return keywords_.find(name) == keywords_.end() ? name : "_" + name;
512 }
513
514 std::string NormalizeName(const flatbuffers::String *name) const {
515 return NormalizeName(name->str());
516 }
517
518 void StartCodeBlock(const reflection::Enum *enum_def) {
519 current_enum_ = enum_def;
520 current_obj_ = nullptr;
521 requires_.clear();
522 }
523
524 void StartCodeBlock(const reflection::Object *object) {
525 current_obj_ = object;
526 current_enum_ = nullptr;
527 requires_.clear();
528 }
529
530 std::string RegisterRequires(const r::Field *field,
531 bool use_element = false) {
532 std::string type_name;
533
534 const r::BaseType type =
535 use_element ? field->type()->element() : field->type()->base_type();
536
537 if (IsStructOrTable(type)) {
538 const r::Object *object = GetObjectByIndex(field->type()->index());
539 if (object == current_obj_) { return Denamespace(object->name()); }
540 type_name = object->name()->str();
541 } else {
542 const r::Enum *enum_def = GetEnumByIndex(field->type()->index());
543 if (enum_def == current_enum_) { return Denamespace(enum_def->name()); }
544 type_name = enum_def->name()->str();
545 }
546
547 // Prefix with double __ to avoid name clashing, since these are defined
548 // at the top of the file and have lexical scoping. Replace '.' with '_'
549 // so it can be a legal identifier.
550 std::string name = "__" + type_name;
551 std::replace(name.begin(), name.end(), '.', '_');
552
553 return RegisterRequires(name, type_name);
554 }
555
556 std::string RegisterRequires(const std::string &local_name,
557 const std::string &requires_name) {
558 requires_[local_name] = requires_name;
559 return local_name;
560 }
561
562 void EmitCodeBlock(const std::string &code_block, const std::string &name,
563 const std::string &ns,
564 const std::string &declaring_file) const {
565 const std::string root_type = schema_->root_table()->name()->str();
566 const std::string root_file =
567 schema_->root_table()->declaration_file()->str();
568 const std::string full_qualified_name = ns.empty() ? name : ns + "." + name;
569
570 std::string code = "--[[ " + full_qualified_name + "\n\n";
571 code +=
572 " Automatically generated by the FlatBuffers compiler, do not "
573 "modify.\n";
574 code += " Or modify. I'm a message, not a cop.\n";
575 code += "\n";
576 code += " flatc version: " + flatc_version_ + "\n";
577 code += "\n";
578 code += " Declared by : " + declaring_file + "\n";
579 code += " Rooting type : " + root_type + " (" + root_file + ")\n";
580 code += "\n--]]\n\n";
581
582 if (!requires_.empty()) {
583 for (auto it = requires_.cbegin(); it != requires_.cend(); ++it) {
584 code += "local " + it->first + " = require('" + it->second + "')\n";
585 }
586 code += "\n";
587 }
588
589 code += code_block;
590 code += "return " + name;
591
592 // Namespaces are '.' deliminted, so replace it with the path separator.
593 std::string path = ns;
594
595 if (path.empty()) {
596 path = ".";
597 } else {
598 std::replace(path.begin(), path.end(), '.', '/');
599 }
600
601 // TODO(derekbailey): figure out a save file without depending on util.h
602 EnsureDirExists(path);
603 const std::string file_name = path + "/" + name + ".lua";
604 SaveFile(file_name.c_str(), code, false);
605 }
606
607 std::unordered_set<std::string> keywords_;
608 std::map<std::string, std::string> requires_;
609 const r::Object *current_obj_;
610 const r::Enum *current_enum_;
611 const std::string flatc_version_;
612};
613} // namespace
614
615std::unique_ptr<BfbsGenerator> NewLuaBfbsGenerator(
616 const std::string &flatc_version) {
617 return std::unique_ptr<LuaBfbsGenerator>(new LuaBfbsGenerator(flatc_version));
618}
619
620} // namespace flatbuffers