edocgen/template.lua

367 lines
8.7 KiB
Lua

local M = {}
local is_nl = function(c)
return (c == "\n") or (c == "\r")
end
local wses = {
[" "] = true, ["\t"] = true, ["\v"] = true, ["\f"] = true
}
local is_ws = function(c)
return not not wses[c]
end
local is_bname = function(c)
return not not c:match("[%w_]")
end
local next_char = function(ls)
local oc = ls.current
ls.current = ls.stream()
return oc
end
local next_line = function(ls)
local pc = next_char(ls)
assert(is_nl(pc))
if is_nl(ls.current) and (ls.current ~= pc) then
local nc = next_char(ls)
return "\n"
end
ls.lnum = ls.lnum + 1
return "\n"
end
local lex_new = function(s)
local ret = {
stream = s,
lnum = 1
}
next_char(ret)
return ret
end
local lex_error = function(ls, msg)
error(("input:%d: %s near '%s'"):format(ls.lnum, msg, ls.current or "EOF"))
end
local pair_chars = {
["*"] = true, ["%"] = true, ["-"] = true
}
local lex_token
lex_token = function(ls)
-- keep track of line numbers but consider newlines
-- tokens in order to preserve format exactly as it was
-- other whitespace needs no tracking
if is_nl(ls.current) then
return next_line(ls)
end
if ls.current == "\\" then
next_char(ls)
local c = ls.current
if not c then
lex_error(ls, "unfinished escape")
end
next_char(ls)
return c
end
if ls.current == "{" then
next_char(ls)
-- comment
if ls.current == "#" then
next_char(ls)
while true do
if not ls.current then
lex_error(ls, "unfinished comment")
end
if ls.current == "\\" then
next_char(ls)
next_char(ls)
end
local c = ls.current
next_char(ls)
if c == "#" and ls.current == "}" then
next_char(ls)
-- eliminated tail call
return lex_token(ls)
end
end
end
local c = ls.current
if c == "{" or c == "[" or pair_chars[c] then
next_char(ls)
return "{" .. c
end
return "{"
end
local c = ls.current
next_char(ls)
if ls.current == "}" then
if c == "}" or c == "]" or pair_chars[c] then
local nc = ls.current
next_char(ls)
return c .. nc
end
end
return c
end
local lex = function(ls)
local ret = ls.lookahead
if ret then
ls.lookahead = nil
return ret
end
return lex_token(ls)
end
local lookahead = function(ls)
local ret = lex_token(ls)
ls.lookahead = ret
return ret
end
local is_toks = function(ls, tok, tok1, tok2)
if tok2 then
return (tok == tok1) or (tok == tok2)
end
return (tok == tok1)
end
local get_until = function(ls, etok1, etok2)
local tok = lex(ls)
local buf = {}
while not is_toks(ls, tok, etok1, etok2) do
buf[#buf + 1] = tok
tok = lex(ls)
if not tok then
lex_error(ls, ("'%s' expected"):format(etok1))
end
end
return table.concat(buf), tok
end
local get_bname = function(ls)
local tok = lex(ls)
local buf = {}
while is_bname(tok) do
buf[#buf + 1] = tok
tok = lex(ls)
end
if tok ~= "-}" then
lex_error(ls, "'-}' expected")
end
return table.concat(buf)
end
local get_block = function(ls)
local bname = get_bname(ls)
local tok = lex(ls)
local buf = {}
while true do
if not tok then
lex_error(ls, "unfinished block")
end
if tok == "{-" then
local ebname = get_bname(ls)
if ebname == bname then
return table.concat(buf), bname
else
buf[#buf + 1] = "{-" .. ebname .. "-}"
end
else
buf[#buf + 1] = tok
end
tok = lex(ls)
end
end
local save_exp = function(cbuf, exp)
cbuf[#cbuf + 1] = " write(" .. exp .. ") "
end
local save_code = function(cbuf, code)
cbuf[#cbuf + 1] = " " .. code .. " "
end
local save_acc = function(cbuf, acc)
if #acc == 0 then
return acc
end
save_exp(cbuf, ("%q"):format(table.concat(acc)))
return {}
end
local parse = function(ls)
local acc, cbuf = {}, {"local context = ... "}
local tok = lex(ls)
local was_nl = false
while tok ~= nil do
local cnl, skip_tok = false, false
if tok == "{{" then
acc = save_acc(cbuf, acc)
save_exp(cbuf, "escape.default(" .. get_until(ls, "}}") .. ")")
elseif tok == "{*" then
acc = save_acc(cbuf, acc)
save_exp(cbuf, get_until(ls, "*}"))
elseif tok == "{%" then
local code = get_until(ls, "%}")
local lah = lookahead(ls)
skip_tok = true
if not was_nl or (lah ~= "\n") then
acc = save_acc(cbuf, acc)
else
acc[#acc] = nil
end
save_code(cbuf, code)
elseif tok == "{[" then
acc = save_acc(cbuf, acc)
-- TODO: not ideal, allow commas in expr
local exp, ftok = get_until(ls, "]}", ",")
-- inherit context by default
local ctx = "context"
if ftok == "," then
ctx = get_until(ls, "]}")
end
save_code(
cbuf,
"template.compile(" .. exp .. ")(" .. ctx .. ", write)"
)
elseif tok == "{-" then
acc = save_acc(cbuf, acc)
local block, bname = get_block(ls)
if bname == "raw" or bname == "verbatim" then
save_exp(cbuf, ("%q"):format(block))
else
save_code(cbuf, (
"blocks[\"%s\"] = template.compile(%q)"
.. "(context, true)"
):format(bname, block))
end
elseif was_nl and is_ws(tok) then
acc[#acc] = acc[#acc] .. tok
cnl = true
else
cnl = (tok == "\n")
if cnl then
acc = save_acc(cbuf, acc)
end
acc[#acc + 1] = tok
end
was_nl = cnl
tok = lex(ls)
end
save_acc(cbuf, acc)
return table.concat(cbuf)
end
-- turn input into a character stream
local make_stream = function(str)
local f
-- if a string, test if it's a file path, otherwise consider a stream
if type(str) == "string" then
f = io.open(str)
else
f = str
end
-- a string and cannot be opened, treat as it is
if not f then
return str:gmatch(".")
end
-- otherwise turn into a special stream
-- but if we can't read nything, return empty
local chunksize = 64
local ip = f:read(chunksize)
if not ip or #ip == 0 then
return function() end
end
local ss = ip:gmatch(".")
return function()
local c = ss()
if not c then
ip = f:read(chunksize)
if not ip or #ip == 0 then
return nil
end
ss = ip:gmatch(".")
c = ss()
end
return c
end
end
local esc_def = function(str)
return str
end
local escapes = {
["default"] = esc_def
}
M.set_escape = function(style, func)
local old = escapes[style]
if type(func) == "string" then
func = escapes[func]
end
if not func then
func = esc_def
end
escapes[style] = func
return old
end
M.get_escape = function(style)
return escapes[style]
end
local default_mt = {
-- not invoked for actual fields of the environment, but when invoked,
-- look up in context first (that way we can access context variables
-- without prefixing every time) and if that fails, global environment
__index = function(self, n)
local cf = self.context[n]
if cf == nil then
return getfenv(0)[n]
end
return cf
end
}
local dump_ret = function(acc)
if not acc then
return nil
end
return table.concat(acc)
end
M.compile = function(str)
-- environment is nicely encapsulated...
local denv = setmetatable({
template = M, blocks = {}, escape = escapes
}, default_mt)
local f = setfenv(
assert(loadstring(parse(lex_new(make_stream(str))))),
denv
)
return function(ctx, write)
denv.context = ctx
local acc
if write == true then
acc = {}
denv.write = function(s)
acc[#acc + 1] = s
end
else
denv.write = write or io.write
end
f(ctx)
return dump_ret(acc)
end
end
return M