1/*
2 * Copyright (c) 2021, Redis Ltd.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * * Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 * * Redistributions in binary form must reproduce the above copyright
11 * notice, this list of conditions and the following disclaimer in the
12 * documentation and/or other materials provided with the distribution.
13 * * Neither the name of Redis nor the names of its contributors may be used
14 * to endorse or promote products derived from this software without
15 * specific prior written permission.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 * POSSIBILITY OF SUCH DAMAGE.
28 */
29
30/*
31 * function_lua.c unit provides the Lua engine functionality.
32 * Including registering the engine and implementing the engine
33 * callbacks:
34 * * Create a function from blob (usually text)
35 * * Invoke a function
36 * * Free function memory
37 * * Get memory usage
38 *
39 * Uses script_lua.c to run the Lua code.
40 */
41
42#include "functions.h"
43#include "script_lua.h"
44#include <lua.h>
45#include <lauxlib.h>
46#include <lualib.h>
47
48#define LUA_ENGINE_NAME "LUA"
49#define REGISTRY_ENGINE_CTX_NAME "__ENGINE_CTX__"
50#define REGISTRY_ERROR_HANDLER_NAME "__ERROR_HANDLER__"
51#define REGISTRY_LOAD_CTX_NAME "__LIBRARY_CTX__"
52#define LIBRARY_API_NAME "__LIBRARY_API__"
53#define GLOBALS_API_NAME "__GLOBALS_API__"
54#define LOAD_TIMEOUT_MS 500
55
56/* Lua engine ctx */
57typedef struct luaEngineCtx {
58 lua_State *lua;
59} luaEngineCtx;
60
61/* Lua function ctx */
62typedef struct luaFunctionCtx {
63 /* Special ID that allows getting the Lua function object from the Lua registry */
64 int lua_function_ref;
65} luaFunctionCtx;
66
67typedef struct loadCtx {
68 functionLibInfo *li;
69 monotime start_time;
70} loadCtx;
71
72typedef struct registerFunctionArgs {
73 sds name;
74 sds desc;
75 luaFunctionCtx *lua_f_ctx;
76 uint64_t f_flags;
77} registerFunctionArgs;
78
79/* Hook for FUNCTION LOAD execution.
80 * Used to cancel the execution in case of a timeout (500ms).
81 * This execution should be fast and should only register
82 * functions so 500ms should be more than enough. */
83static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) {
84 UNUSED(ar);
85 loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME);
86 uint64_t duration = elapsedMs(load_ctx->start_time);
87 if (duration > LOAD_TIMEOUT_MS) {
88 lua_sethook(lua, luaEngineLoadHook, LUA_MASKLINE, 0);
89
90 luaPushError(lua,"FUNCTION LOAD timeout");
91 luaError(lua);
92 }
93}
94
95/*
96 * Compile a given blob and save it on the registry.
97 * Return a function ctx with Lua ref that allows to later retrieve the
98 * function from the registry.
99 *
100 * Return NULL on compilation error and set the error to the err variable
101 */
102static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, sds *err) {
103 int ret = C_ERR;
104 luaEngineCtx *lua_engine_ctx = engine_ctx;
105 lua_State *lua = lua_engine_ctx->lua;
106
107 /* set load library globals */
108 lua_getmetatable(lua, LUA_GLOBALSINDEX);
109 lua_enablereadonlytable(lua, -1, 0); /* disable global protection */
110 lua_getfield(lua, LUA_REGISTRYINDEX, LIBRARY_API_NAME);
111 lua_setfield(lua, -2, "__index");
112 lua_enablereadonlytable(lua, LUA_GLOBALSINDEX, 1); /* enable global protection */
113 lua_pop(lua, 1); /* pop the metatable */
114
115 /* compile the code */
116 if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) {
117 *err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1));
118 lua_pop(lua, 1); /* pops the error */
119 goto done;
120 }
121 serverAssert(lua_isfunction(lua, -1));
122
123 loadCtx load_ctx = {
124 .li = li,
125 .start_time = getMonotonicUs(),
126 };
127 luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, &load_ctx);
128
129 lua_sethook(lua,luaEngineLoadHook,LUA_MASKCOUNT,100000);
130 /* Run the compiled code to allow it to register functions */
131 if (lua_pcall(lua,0,0,0)) {
132 errorInfo err_info = {0};
133 luaExtractErrorInformation(lua, &err_info);
134 *err = sdscatprintf(sdsempty(), "Error registering functions: %s", err_info.msg);
135 lua_pop(lua, 1); /* pops the error */
136 luaErrorInformationDiscard(&err_info);
137 goto done;
138 }
139
140 ret = C_OK;
141
142done:
143 /* restore original globals */
144 lua_getmetatable(lua, LUA_GLOBALSINDEX);
145 lua_enablereadonlytable(lua, -1, 0); /* disable global protection */
146 lua_getfield(lua, LUA_REGISTRYINDEX, GLOBALS_API_NAME);
147 lua_setfield(lua, -2, "__index");
148 lua_enablereadonlytable(lua, LUA_GLOBALSINDEX, 1); /* enable global protection */
149 lua_pop(lua, 1); /* pop the metatable */
150
151 lua_sethook(lua,NULL,0,0); /* Disable hook */
152 luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, NULL);
153 return ret;
154}
155
156/*
157 * Invole the give function with the given keys and args
158 */
159static void luaEngineCall(scriptRunCtx *run_ctx,
160 void *engine_ctx,
161 void *compiled_function,
162 robj **keys,
163 size_t nkeys,
164 robj **args,
165 size_t nargs)
166{
167 luaEngineCtx *lua_engine_ctx = engine_ctx;
168 lua_State *lua = lua_engine_ctx->lua;
169 luaFunctionCtx *f_ctx = compiled_function;
170
171 /* Push error handler */
172 lua_pushstring(lua, REGISTRY_ERROR_HANDLER_NAME);
173 lua_gettable(lua, LUA_REGISTRYINDEX);
174
175 lua_rawgeti(lua, LUA_REGISTRYINDEX, f_ctx->lua_function_ref);
176
177 serverAssert(lua_isfunction(lua, -1));
178
179 luaCallFunction(run_ctx, lua, keys, nkeys, args, nargs, 0);
180 lua_pop(lua, 1); /* Pop error handler */
181}
182
183static size_t luaEngineGetUsedMemoy(void *engine_ctx) {
184 luaEngineCtx *lua_engine_ctx = engine_ctx;
185 return luaMemory(lua_engine_ctx->lua);
186}
187
188static size_t luaEngineFunctionMemoryOverhead(void *compiled_function) {
189 return zmalloc_size(compiled_function);
190}
191
192static size_t luaEngineMemoryOverhead(void *engine_ctx) {
193 luaEngineCtx *lua_engine_ctx = engine_ctx;
194 return zmalloc_size(lua_engine_ctx);
195}
196
197static void luaEngineFreeFunction(void *engine_ctx, void *compiled_function) {
198 luaEngineCtx *lua_engine_ctx = engine_ctx;
199 lua_State *lua = lua_engine_ctx->lua;
200 luaFunctionCtx *f_ctx = compiled_function;
201 lua_unref(lua, f_ctx->lua_function_ref);
202 zfree(f_ctx);
203}
204
205static void luaRegisterFunctionArgsInitialize(registerFunctionArgs *register_f_args,
206 sds name,
207 sds desc,
208 luaFunctionCtx *lua_f_ctx,
209 uint64_t flags)
210{
211 *register_f_args = (registerFunctionArgs){
212 .name = name,
213 .desc = desc,
214 .lua_f_ctx = lua_f_ctx,
215 .f_flags = flags,
216 };
217}
218
219static void luaRegisterFunctionArgsDispose(lua_State *lua, registerFunctionArgs *register_f_args) {
220 sdsfree(register_f_args->name);
221 if (register_f_args->desc) sdsfree(register_f_args->desc);
222 lua_unref(lua, register_f_args->lua_f_ctx->lua_function_ref);
223 zfree(register_f_args->lua_f_ctx);
224}
225
226/* Read function flags located on the top of the Lua stack.
227 * On success, return C_OK and set the flags to 'flags' out parameter
228 * Return C_ERR if encounter an unknown flag. */
229static int luaRegisterFunctionReadFlags(lua_State *lua, uint64_t *flags) {
230 int j = 1;
231 int ret = C_ERR;
232 int f_flags = 0;
233 while(1) {
234 lua_pushnumber(lua,j++);
235 lua_gettable(lua,-2);
236 int t = lua_type(lua,-1);
237 if (t == LUA_TNIL) {
238 lua_pop(lua,1);
239 break;
240 }
241 if (!lua_isstring(lua, -1)) {
242 lua_pop(lua,1);
243 goto done;
244 }
245
246 const char *flag_str = lua_tostring(lua, -1);
247 int found = 0;
248 for (scriptFlag *flag = scripts_flags_def; flag->str ; ++flag) {
249 if (!strcasecmp(flag->str, flag_str)) {
250 f_flags |= flag->flag;
251 found = 1;
252 break;
253 }
254 }
255 /* pops the value to continue the iteration */
256 lua_pop(lua,1);
257 if (!found) {
258 /* flag not found */
259 goto done;
260 }
261 }
262
263 *flags = f_flags;
264 ret = C_OK;
265
266done:
267 return ret;
268}
269
270static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
271 char *err = NULL;
272 sds name = NULL;
273 sds desc = NULL;
274 luaFunctionCtx *lua_f_ctx = NULL;
275 uint64_t flags = 0;
276 if (!lua_istable(lua, 1)) {
277 err = "calling redis.register_function with a single argument is only applicable to Lua table (representing named arguments).";
278 goto error;
279 }
280
281 /* Iterating on all the named arguments */
282 lua_pushnil(lua);
283 while (lua_next(lua, -2)) {
284 /* Stack now: table, key, value */
285 if (!lua_isstring(lua, -2)) {
286 err = "named argument key given to redis.register_function is not a string";
287 goto error;
288 }
289 const char *key = lua_tostring(lua, -2);
290 if (!strcasecmp(key, "function_name")) {
291 if (!(name = luaGetStringSds(lua, -1))) {
292 err = "function_name argument given to redis.register_function must be a string";
293 goto error;
294 }
295 } else if (!strcasecmp(key, "description")) {
296 if (!(desc = luaGetStringSds(lua, -1))) {
297 err = "description argument given to redis.register_function must be a string";
298 goto error;
299 }
300 } else if (!strcasecmp(key, "callback")) {
301 if (!lua_isfunction(lua, -1)) {
302 err = "callback argument given to redis.register_function must be a function";
303 goto error;
304 }
305 int lua_function_ref = luaL_ref(lua, LUA_REGISTRYINDEX);
306
307 lua_f_ctx = zmalloc(sizeof(*lua_f_ctx));
308 lua_f_ctx->lua_function_ref = lua_function_ref;
309 continue; /* value was already popped, so no need to pop it out. */
310 } else if (!strcasecmp(key, "flags")) {
311 if (!lua_istable(lua, -1)) {
312 err = "flags argument to redis.register_function must be a table representing function flags";
313 goto error;
314 }
315 if (luaRegisterFunctionReadFlags(lua, &flags) != C_OK) {
316 err = "unknown flag given";
317 goto error;
318 }
319 } else {
320 /* unknown argument was given, raise an error */
321 err = "unknown argument given to redis.register_function";
322 goto error;
323 }
324 lua_pop(lua, 1); /* pop the value to continue the iteration */
325 }
326
327 if (!name) {
328 err = "redis.register_function must get a function name argument";
329 goto error;
330 }
331
332 if (!lua_f_ctx) {
333 err = "redis.register_function must get a callback argument";
334 goto error;
335 }
336
337 luaRegisterFunctionArgsInitialize(register_f_args, name, desc, lua_f_ctx, flags);
338
339 return C_OK;
340
341error:
342 if (name) sdsfree(name);
343 if (desc) sdsfree(desc);
344 if (lua_f_ctx) {
345 lua_unref(lua, lua_f_ctx->lua_function_ref);
346 zfree(lua_f_ctx);
347 }
348 luaPushError(lua, err);
349 return C_ERR;
350}
351
352static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
353 char *err = NULL;
354 sds name = NULL;
355 sds desc = NULL;
356 luaFunctionCtx *lua_f_ctx = NULL;
357 if (!(name = luaGetStringSds(lua, 1))) {
358 err = "first argument to redis.register_function must be a string";
359 goto error;
360 }
361
362 if (!lua_isfunction(lua, 2)) {
363 err = "second argument to redis.register_function must be a function";
364 goto error;
365 }
366
367 int lua_function_ref = luaL_ref(lua, LUA_REGISTRYINDEX);
368
369 lua_f_ctx = zmalloc(sizeof(*lua_f_ctx));
370 lua_f_ctx->lua_function_ref = lua_function_ref;
371
372 luaRegisterFunctionArgsInitialize(register_f_args, name, NULL, lua_f_ctx, 0);
373
374 return C_OK;
375
376error:
377 if (name) sdsfree(name);
378 if (desc) sdsfree(desc);
379 luaPushError(lua, err);
380 return C_ERR;
381}
382
383static int luaRegisterFunctionReadArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
384 int argc = lua_gettop(lua);
385 if (argc < 1 || argc > 2) {
386 luaPushError(lua, "wrong number of arguments to redis.register_function");
387 return C_ERR;
388 }
389
390 if (argc == 1) {
391 return luaRegisterFunctionReadNamedArgs(lua, register_f_args);
392 } else {
393 return luaRegisterFunctionReadPositionalArgs(lua, register_f_args);
394 }
395}
396
397static int luaRegisterFunction(lua_State *lua) {
398 registerFunctionArgs register_f_args = {0};
399
400 loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME);
401 if (!load_ctx) {
402 luaPushError(lua, "redis.register_function can only be called on FUNCTION LOAD command");
403 return luaError(lua);
404 }
405
406 if (luaRegisterFunctionReadArgs(lua, &register_f_args) != C_OK) {
407 return luaError(lua);
408 }
409
410 sds err = NULL;
411 if (functionLibCreateFunction(register_f_args.name, register_f_args.lua_f_ctx, load_ctx->li, register_f_args.desc, register_f_args.f_flags, &err) != C_OK) {
412 luaRegisterFunctionArgsDispose(lua, &register_f_args);
413 luaPushError(lua, err);
414 sdsfree(err);
415 return luaError(lua);
416 }
417
418 return 0;
419}
420
421/* Initialize Lua engine, should be called once on start. */
422int luaEngineInitEngine() {
423 luaEngineCtx *lua_engine_ctx = zmalloc(sizeof(*lua_engine_ctx));
424 lua_engine_ctx->lua = lua_open();
425
426 luaRegisterRedisAPI(lua_engine_ctx->lua);
427
428 /* Register the library commands table and fields and store it to registry */
429 lua_newtable(lua_engine_ctx->lua); /* load library globals */
430 lua_newtable(lua_engine_ctx->lua); /* load library `redis` table */
431
432 lua_pushstring(lua_engine_ctx->lua, "register_function");
433 lua_pushcfunction(lua_engine_ctx->lua, luaRegisterFunction);
434 lua_settable(lua_engine_ctx->lua, -3);
435
436 luaRegisterLogFunction(lua_engine_ctx->lua);
437 luaRegisterVersion(lua_engine_ctx->lua);
438
439 luaSetErrorMetatable(lua_engine_ctx->lua);
440 lua_setfield(lua_engine_ctx->lua, -2, REDIS_API_NAME);
441
442 luaSetErrorMetatable(lua_engine_ctx->lua);
443 luaSetTableProtectionRecursively(lua_engine_ctx->lua); /* protect load library globals */
444 lua_setfield(lua_engine_ctx->lua, LUA_REGISTRYINDEX, LIBRARY_API_NAME);
445
446 /* Save error handler to registry */
447 lua_pushstring(lua_engine_ctx->lua, REGISTRY_ERROR_HANDLER_NAME);
448 char *errh_func = "local dbg = debug\n"
449 "debug = nil\n"
450 "local error_handler = function (err)\n"
451 " local i = dbg.getinfo(2,'nSl')\n"
452 " if i and i.what == 'C' then\n"
453 " i = dbg.getinfo(3,'nSl')\n"
454 " end\n"
455 " if type(err) ~= 'table' then\n"
456 " err = {err='ERR ' .. tostring(err)}"
457 " end"
458 " if i then\n"
459 " err['source'] = i.source\n"
460 " err['line'] = i.currentline\n"
461 " end"
462 " return err\n"
463 "end\n"
464 "return error_handler";
465 luaL_loadbuffer(lua_engine_ctx->lua, errh_func, strlen(errh_func), "@err_handler_def");
466 lua_pcall(lua_engine_ctx->lua,0,1,0);
467 lua_settable(lua_engine_ctx->lua, LUA_REGISTRYINDEX);
468
469 lua_pushvalue(lua_engine_ctx->lua, LUA_GLOBALSINDEX);
470 luaSetErrorMetatable(lua_engine_ctx->lua);
471 luaSetTableProtectionRecursively(lua_engine_ctx->lua); /* protect globals */
472 lua_pop(lua_engine_ctx->lua, 1);
473
474 /* Save default globals to registry */
475 lua_pushvalue(lua_engine_ctx->lua, LUA_GLOBALSINDEX);
476 lua_setfield(lua_engine_ctx->lua, LUA_REGISTRYINDEX, GLOBALS_API_NAME);
477
478 /* save the engine_ctx on the registry so we can get it from the Lua interpreter */
479 luaSaveOnRegistry(lua_engine_ctx->lua, REGISTRY_ENGINE_CTX_NAME, lua_engine_ctx);
480
481 /* Create new empty table to be the new globals, we will be able to control the real globals
482 * using metatable */
483 lua_newtable(lua_engine_ctx->lua); /* new globals */
484 lua_newtable(lua_engine_ctx->lua); /* new globals metatable */
485 lua_pushvalue(lua_engine_ctx->lua, LUA_GLOBALSINDEX);
486 lua_setfield(lua_engine_ctx->lua, -2, "__index");
487 lua_enablereadonlytable(lua_engine_ctx->lua, -1, 1); /* protect the metatable */
488 lua_setmetatable(lua_engine_ctx->lua, -2);
489 lua_enablereadonlytable(lua_engine_ctx->lua, -1, 1); /* protect the new global table */
490 lua_replace(lua_engine_ctx->lua, LUA_GLOBALSINDEX); /* set new global table as the new globals */
491
492
493 engine *lua_engine = zmalloc(sizeof(*lua_engine));
494 *lua_engine = (engine) {
495 .engine_ctx = lua_engine_ctx,
496 .create = luaEngineCreate,
497 .call = luaEngineCall,
498 .get_used_memory = luaEngineGetUsedMemoy,
499 .get_function_memory_overhead = luaEngineFunctionMemoryOverhead,
500 .get_engine_memory_overhead = luaEngineMemoryOverhead,
501 .free_function = luaEngineFreeFunction,
502 };
503 return functionsRegisterEngine(LUA_ENGINE_NAME, lua_engine);
504}
505