summaryrefslogblamecommitdiff
path: root/template.lua
blob: eb9a288ca527076abee74ebefabe7383089a604f (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14













                                                             



                                   














































































                                                                               
                  



























                                                     







                                             

                       
                                               


                           
                                                          

           
                                 

   


































                                                      








                                              


                     





                                                    
                                                  





                                          
                                                                           












                                               








                                                      
                      
                     
                                                                       
             


                                              




                                                           
                                            

                                       












































                                                                         
                               





                    



                             
                 
                         



                                    





                                  





                              








                                                                          
                                 




                 






                              
           
                   








                                                    

                                            
                               

                                                    
                  
                                                           
                               
                          








                                          
              
                            



        
local M = {}

local is_nl = function(c)
    return (c == "\n") or (c == "\r")
end

local wses = {
    [" "] = true, ["\t"] = true, ["\v"] = true, ["\f"] = true
}

local is_ws = function(c)
    return not not wses[c]
end

local is_bname = function(c)
    return not not c:match("[%w_]")
end

local next_char = function(ls)
    local oc = ls.current
    ls.current = ls.stream()
    return oc
end

local next_line = function(ls)
    local pc = next_char(ls)
    assert(is_nl(pc))
    if is_nl(ls.current) and (ls.current ~= pc) then
        local nc = next_char(ls)
        return "\n"
    end
    ls.lnum = ls.lnum + 1
    return "\n"
end

local lex_new = function(s)
    local ret = {
        stream = s,
        lnum = 1
    }
    next_char(ret)
    return ret
end

local lex_error = function(ls, msg)
    error(("input:%d: %s near '%s'"):format(ls.lnum, msg, ls.current or "EOF"))
end

local pair_chars = {
    ["*"] = true, ["%"] = true, ["-"] = true
}

local lex_token

lex_token = function(ls)
    -- keep track of line numbers but consider newlines
    -- tokens in order to preserve format exactly as it was
    -- other whitespace needs no tracking
    if is_nl(ls.current) then
        return next_line(ls)
    end
    if ls.current == "\\" then
        next_char(ls)
        local c = ls.current
        if not c then
            lex_error(ls, "unfinished escape")
        end
        next_char(ls)
        return c
    end
    if ls.current == "{" then
        next_char(ls)
        -- comment
        if ls.current == "#" then
            next_char(ls)
            while true do
                if not ls.current then
                    lex_error(ls, "unfinished comment")
                end
                if ls.current == "\\" then
                    next_char(ls)
                    next_char(ls)
                end
                local c = ls.current
                next_char(ls)
                if c == "#" and ls.current == "}" then
                    next_char(ls)
                    -- eliminated tail call
                    return lex_token(ls)
                end
            end
        end
        local c = ls.current
        if c == "{" or c == "[" or pair_chars[c] then
            next_char(ls)
            return "{" .. c
        end
        return "{"
    end
    local c = ls.current
    next_char(ls)
    if ls.current == "}" then
        if c == "}" or c == "]" or pair_chars[c] then
            local nc = ls.current
            next_char(ls)
            return c .. nc
        end
    end
    return c
end

local lex = function(ls)
    local ret = ls.lookahead
    if ret then
        ls.lookahead = nil
        return ret
    end
    return lex_token(ls)
end

local lookahead = function(ls)
    local ret = lex_token(ls)
    ls.lookahead = ret
    return ret
end

local is_toks = function(ls, tok, tok1, tok2)
    if tok2 then
        return (tok == tok1) or (tok == tok2)
    end
    return (tok == tok1)
end

local get_until = function(ls, etok1, etok2)
    local tok = lex(ls)
    local buf = {}
    while not is_toks(ls, tok, etok1, etok2) do
        buf[#buf + 1] = tok
        tok = lex(ls)
        if not tok then
            lex_error(ls, ("'%s' expected"):format(etok1))
        end
    end
    return table.concat(buf), tok
end

local get_bname = function(ls)
    local tok = lex(ls)
    local buf = {}
    while is_bname(tok) do
        buf[#buf + 1] = tok
        tok = lex(ls)
    end
    if tok ~= "-}" then
        lex_error(ls, "'-}' expected")
    end
    return table.concat(buf)
end

local get_block = function(ls)
    local bname = get_bname(ls)
    local tok = lex(ls)
    local buf = {}
    while true do
        if not tok then
            lex_error(ls, "unfinished block")
        end
        if tok == "{-" then
            local ebname = get_bname(ls)
            if ebname == bname then
                return table.concat(buf), bname
            else
                buf[#buf + 1] = "{-" .. ebname .. "-}"
            end
        else
            buf[#buf + 1] = tok
        end
        tok = lex(ls)
    end
end

local save_exp = function(cbuf, exp)
    cbuf[#cbuf + 1] = " write(" .. exp .. ") "
end

local save_code = function(cbuf, code)
    cbuf[#cbuf + 1] = " " .. code .. " "
end

local save_acc = function(cbuf, acc)
    if #acc == 0 then
        return acc
    end
    save_exp(cbuf, ("%q"):format(table.concat(acc)))
    return {}
end


local parse = function(ls)
    local acc, cbuf = {}, {"local context = ... "}
    local tok = lex(ls)
    local was_nl = false
    while tok ~= nil do
        local cnl, skip_tok = false, false
        if tok == "{{" then
            acc = save_acc(cbuf, acc)
            save_exp(cbuf, "escape.default(" .. get_until(ls, "}}") .. ")")
        elseif tok == "{*" then
            acc = save_acc(cbuf, acc)
            save_exp(cbuf, get_until(ls, "*}"))
        elseif tok  == "{%" then
            local code = get_until(ls, "%}")
            local lah = lookahead(ls)
            skip_tok = true
            if not was_nl or (lah ~= "\n") then
                acc = save_acc(cbuf, acc)
            else
                acc[#acc] = nil
            end
            save_code(cbuf, code)
        elseif tok == "{[" then
            acc = save_acc(cbuf, acc)
            -- TODO: not ideal, allow commas in expr
            local exp, ftok = get_until(ls, "]}", ",")
            -- inherit context by default
            local ctx = "context"
            if ftok == "," then
                ctx = get_until(ls, "]}")
            end
            save_code(
                cbuf,
                "template.compile(" .. exp .. ")(" .. ctx .. ", write)"
            )
        elseif tok == "{-" then
            acc = save_acc(cbuf, acc)
            local block, bname = get_block(ls)
            if bname == "raw" or bname == "verbatim" then
                save_exp(cbuf, ("%q"):format(block))
            else
                save_code(cbuf, (
                    "blocks[\"%s\"] = template.compile(%q)"
                        .. "(context, true)"
                ):format(bname, block))
            end
        elseif was_nl and is_ws(tok) then
            acc[#acc] = acc[#acc] .. tok
            cnl = true
        else
            cnl = (tok == "\n")
            if cnl then
                acc = save_acc(cbuf, acc)
            end
            acc[#acc + 1] = tok
        end
        was_nl = cnl
        tok = lex(ls)
    end
    save_acc(cbuf, acc)
    return table.concat(cbuf)
end

-- turn input into a character stream
local make_stream = function(str)
    local f
    -- if a string, test if it's a file path, otherwise consider a stream
    if type(str) == "string" then
        f = io.open(str)
    else
        f = str
    end
    -- a string and cannot be opened, treat as it is
    if not f then
        return str:gmatch(".")
    end
    -- otherwise turn into a special stream
    -- but if we can't read nything, return empty
    local chunksize = 64
    local ip = f:read(chunksize)
    if not ip or #ip == 0 then
        return function() end
    end
    local ss = ip:gmatch(".")
    return function()
        local c = ss()
        if not c then
            ip = f:read(chunksize)
            if not ip or #ip == 0 then
                return nil
            end
            ss = ip:gmatch(".")
            c = ss()
        end
        return c
    end
end

local esc_def = function(str)
    return str
end

local escapes = {
    ["default"] = esc_def
}

M.set_escape = function(style, func)
    local old = escapes[style]
    if type(func) == "string" then
        func = escapes[func]
    end
    if not func then
        func = esc_def
    end
    escapes[style] = func
    return old
end

M.get_escape = function(style)
    return escapes[style]
end

local default_mt = {
    -- not invoked for actual fields of the environment, but when invoked,
    -- look up in context first (that way we can access context variables
    -- without prefixing every time) and if that fails, global environment
    __index = function(self, n)
        local cf = self.context[n]
        if cf == nil then
            return self.__genv[n]
        end
        return cf
    end
}

local dump_ret = function(acc)
    if not acc then
        return nil
    end
    return table.concat(acc)
end

local loads
if not setfenv then
    loads = function(str, env)
        return assert(load(str, str, "t", env))
    end
else
    loads = function(str, env)
        return setfenv(assert(loadstring(str)), env)
    end
end

M.compile = function(str)
    -- environment is nicely encapsulated...
    local denv = setmetatable({
        template = M, blocks = {}, escape = escapes,
        __genv = getfenv and getfenv(1) or _ENV
    }, default_mt)
    local f = loads(parse(lex_new(make_stream(str))), denv)
    return function(ctx, write)
        denv.context = ctx
        local acc
        if write == true then
            acc = {}
            denv.write = function(s)
                acc[#acc + 1] = s
            end
        else
            denv.write = write or io.write
        end
        f(ctx)
        return dump_ret(acc)
    end
end

return M