edocgen/template.lua

255 lines
6.1 KiB
Lua

local M = {}
local is_nl = function(c)
return (c == "\n") or (c == "\r")
end
local wses = {
[" "] = true, ["\t"] = true, ["\v"] = true, ["\f"] = true
}
local is_ws = function(c)
return not not wses[c]
end
local next_char = function(ls)
local oc = ls.current
ls.current = ls.stream()
return oc
end
local next_line = function(ls)
local pc = next_char(ls)
assert(is_nl(pc))
if is_nl(ls.current) and (ls.current ~= pc) then
local nc = next_char(ls)
return "\n"
end
ls.lnum = ls.lnum + 1
return "\n"
end
local lex_new = function(s)
local ret = {
stream = s,
lnum = 1
}
next_char(ret)
return ret
end
local lex_error = function(ls, msg)
error(("input:%d: %s near '%s'"):format(ls.lnum, msg, ls.current or "EOF"))
end
local pair_chars = {
["*"] = true, ["%"] = true, ["-"] = true
}
local lex_token
lex_token = function(ls)
-- keep track of line numbers but consider newlines
-- tokens in order to preserve format exactly as it was
-- other whitespace needs no tracking
if is_nl(ls.current) then
return next_line(ls)
end
if ls.current == "\\" then
next_char(ls)
local c = ls.current
if not c then
lex_error(ls, "unfinished escape")
end
next_char(ls)
return c
end
if ls.current == "{" then
next_char(ls)
-- comment
if ls.current == "#" then
next_char(ls)
while true do
if not ls.current then
lex_error(ls, "unfinished comment")
end
if ls.current == "\\" then
next_char(ls)
next_char(ls)
end
local c = ls.current
next_char(ls)
if c == "#" and ls.current == "}" then
next_char(ls)
-- eliminated tail call
return lex_token(ls)
end
end
end
local c = ls.current
if c == "{" or c == "[" or pair_chars[c] then
next_char(ls)
return "{" .. c
end
return c
end
local c = ls.current
next_char(ls)
if ls.current == "}" then
if c == "}" or c == "]" or pair_chars[c] then
local nc = ls.current
next_char(ls)
return c .. nc
end
end
return c
end
local lex = function(ls)
local ret = ls.lookahead
if ret then
ls.lookahead = nil
return ret
end
return lex_token(ls)
end
local lookahead = function(ls)
local ret = lex_token(ls)
ls.lookahead = ret
return ret
end
local get_until = function(ls, etok)
local tok = lex(ls)
local buf = {}
while tok ~= etok do
buf[#buf + 1] = tok
tok = lex(ls)
if not tok then
lex_error(ls, ("'%s' expected"):format(etok))
end
end
return table.concat(buf)
end
local save_exp = function(cbuf, exp)
cbuf[#cbuf + 1] = " write(" .. exp .. ") "
end
local save_code = function(cbuf, code)
cbuf[#cbuf + 1] = " " .. code .. " "
end
local save_acc = function(cbuf, acc)
save_exp(cbuf, ("%q"):format(table.concat(acc)))
return {}
end
local parse = function(ls)
local acc, cbuf = {}, {"local context, escape = ... "}
local tok = lex(ls)
local was_nl = false
while tok ~= nil do
local cnl, skip_tok = false, false
if tok == "{{" then
acc = save_acc(cbuf, acc)
save_exp(cbuf, "escape(" .. get_until(ls, "}}") .. ")")
elseif tok == "{*" then
acc = save_acc(cbuf, acc)
save_exp(cbuf, get_until(ls, "*}"))
elseif tok == "{%" then
local code = get_until(ls, "%}")
local lah = lookahead(ls)
skip_tok = true
if not was_nl or (lah ~= "\n") then
acc = save_acc(cbuf, acc)
else
acc[#acc] = nil
end
save_code(cbuf, code)
elseif was_nl and is_ws(tok) then
acc[#acc] = acc[#acc] .. tok
cnl = true
else
cnl = (tok == "\n")
if cnl then
acc = save_acc(cbuf, acc)
end
acc[#acc + 1] = tok
end
was_nl = cnl
tok = lex(ls)
end
save_acc(cbuf, acc)
return table.concat(cbuf)
end
-- turn input into a character stream
local make_stream = function(str)
local f
-- if a string, test if it's a file path, otherwise consider a stream
if type(str) == "string" then
f = io.open(str)
else
f = str
end
-- a string and cannot be opened, treat as it is
if not f then
return str:gmatch(".")
end
-- otherwise turn into a special stream
-- but if we can't read nything, return empty
local chunksize = 64
local ip = f:read(chunksize)
if not ip or #ip == 0 then
return function() end
end
local ss = ip:gmatch(".")
return function()
local c = ss()
if not c then
ip = f:read(chunksize)
if not ip or #ip == 0 then
return nil
end
c = ss()
end
return c
end
end
local default_esc = function(str)
return str
end
local default_mt = {
-- not invoked for actual fields of the environment, but when invoked,
-- look up in context first (that way we can access context variables
-- without prefixing every time) and if that fails, global environment
__index = function(self, n)
local cf = self.context[n]
if cf == nil then
return getfenv(0)[n]
end
return cf
end
}
M.compile = function(str)
-- environment is nicely encapsulated...
local denv = setmetatable({ template = M }, default_mt)
local f = setfenv(
assert(loadstring(parse(lex_new(make_stream(str))))),
denv
)
return function(ctx, escape, write)
denv.context = ctx
denv.write = write or io.write
return f(ctx, escape or default_esc)
end
end
return M