import argparse import collections import contextlib import importlib import io import re import subprocess import sys import lark asmod = importlib.import_module("as") GRAMMAR_FILE = '/home/paulmathieu/vhdl/tools/cc.ebnf' CPP = ('cpp', '-P') class Scope: def __init__(self): self.symbols = {} self.parent = None class Variable: def __init__(self, type, name, volatile=False, addr_reg=None): self.type = type self.name = name self.volatile = volatile self.addr_reg = addr_reg def __repr__(self): return f'' @classmethod def from_def(cls, tree): volatile = False type = tree.children[0] if isinstance(type, lark.Tree): for c in type.children: if c == "volatile": volatile = True name = tree.children[1] return cls(type, name, volatile=volatile) @classmethod def from_dereference(cls, reg): return cls('deref', 'deref', addr_reg=reg) # - all registers pointing to unwritten stuff can be dropped as soon as we're # done with them: # - already fed them into the operations # - end of statement # - maybe should have a separate storeage for special registers? # need ways to: # - assign a list of registers into r0, ... rn for function call # - ... and run all callbacks # - get the register for an identifier # - dereference an expression: # - essentially turn a temp into an lvalue-address # - if read, need a second reg for its value # - mark a variable to have its address in a register # - delay reading the identifier if it's lhs # - store the value to memory when registers are claimed # - if it was modified through: # - assignment # - pre-post *crement # - retrieve the memory type: # - stack # - absolute # - register # - or just store the value when the variable is assigned # - get a temporary register # - return it # - I guess dereferencing is an upgrade def type(self, tree): print(tree) class RegBank: def __init__(self, logger=None): self.reset() self.log = logger or print def reset(self): self.available = [f'r{i}' for i in range(12)] self.vars = {} self.varregs = {} self.cleanup = collections.defaultdict(list) def take(self, reg=None): if reg is not None: if reg not in self.available: self.evict(self.var(reg)) self.available.remove(reg) return reg if not self.available: assert self.vars, "nothing to clean, no more regs :/" # storing one random var var = list(self.vars.keys())[0] self.evict(var) return self.available.pop(0) def give(self, reg): if reg in self.varregs: # Need to call evict() with the var to free it. return self.available.insert(0, reg) def loaded(self, var, reg, cleanup=None): """Tells the regbank some variable was loaded into the given register.""" self.vars[var] = reg self.varregs[reg] = var self.take(reg) if cleanup is not None: self.log(f'recording cleanup for {reg}({var.name})') self.cleanup[reg].append(cleanup) def stored(self, var): """Tells the regbank the given var was stored to memory, register can be freed.""" assert var in self.vars reg = self.vars.pop(var) del self.varregs[reg] self.give(reg) def load(self, var): """Returns the reg associated with the var, or a new reg if none was, and True if the var was created, False otherwise.""" self.log(f'vars: {self.vars}, varregs: {self.varregs}') if var not in self.vars: reg = self.take() self.vars[var] = reg self.varregs[reg] = var return reg, True return self.vars[var], False def assign(self, var, reg, cleanup=None): """Assign a previously-used register to a variable.""" if var in self.vars: self.stored(var) if reg in self.varregs: for cb in self.cleanup.pop(reg, []): cb(reg) self.stored(self.varregs[reg]) self.vars[var] = reg self.varregs[reg] = var if cleanup is not None: self.cleanup[reg].append(cleanup) def var(self, reg): return self.varregs.get(reg, None) def evict(self, var): """Runs var callbacks & frees the register.""" if var not in self.vars: return reg = self.vars[var] for cb in self.cleanup.pop(var, []): cb(reg) self.stored(var) def flush_all(self): for reg in list(self.cleanup): self.log(f'flushing {reg}({self.varregs[reg].name})') for cb in self.cleanup.pop(reg): cb(reg) self.reset() def swapped(self, reg0, reg1): var0 = self.varregs.get(reg0, None) var1 = self.varregs.get(reg1, None) if var0 is not None: self.stored(var0) elif reg0 not in self.available: self.give(reg0) if var1 is not None: self.stored(var1) elif reg1 not in self.available: self.give(reg1) if var0 is not None: self.loaded(var0, reg1) if var1 is not None: self.loaded(var1, reg0) @contextlib.contextmanager def borrow(self, howmany): regs = [self.take() for i in range(howmany)] yield regs for reg in regs: self.give(reg) class FunctionSpec: def __init__(self, fun_prot): self.return_type = fun_prot.children[0] self.name = fun_prot.children[1] self.param_types = [x.children[0] for x in fun_prot.children[2:]] def __repr__(self): params = ', '.join(self.param_types) return f'' class Function: def __init__(self, fun_prot): self.locals = {} self.spec = FunctionSpec(fun_prot) self.params = [Variable(*x.children) for x in fun_prot.children[2:]] self.ret = None self.nextstack = 0 self.statements = [] self.regs = RegBank(logger=self.log) self.deferred_ops = [] self.fun_calls = 0 self.ops = [] def log(self, line): self.ops.append(lambda: [f'// {line}']) @property def stack_usage(self): return self.nextstack + 2 def get_stack(self, size=2): stk = self.nextstack self.nextstack += size return stk def param_dict(self): return {p.name: p for p in self.params} def __repr__(self): return repr(self.spec) def synth(self): if self.fun_calls > 0: preamble = [f'store lr, [sp, -2]', f'set r4, {self.stack_usage}', f'sub sp, sp, r4'] else: preamble = [] ops = preamble for op in self.ops: ops += op() indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops] return [f'.global {self.spec.name}', f'{self.spec.name}:'] + indented class CcTransform(lark.visitors.Transformer): def _binary_op(litt): @lark.v_args(tree=True) def _f(self, tree): left, right = tree.children if left.data == 'litteral' and right.data == 'litteral': tree.data = 'litteral' tree.children = [litt(left.children[0], right.children[0])] return tree return _f def array_item(self, children): # transform blarg[foo] into *(blarg + foo) because reasons addop = lark.Tree('add', children) return lark.Tree('dereference', [addop]) # operations on litterals can be done by the compiler add = _binary_op(lambda a, b: a+b) sub = _binary_op(lambda a, b: a-b) mul = _binary_op(lambda a, b: a*b) shl = _binary_op(lambda a, b: a<> 8) & 0xff lo = (self.imm16 >> 0) & 0xff if hi != 0: return [f'set {reg}, {lo}', f'seth {reg}, {hi}'] else: return [f'set {reg}, {lo}'] class Swap(AsmOp): scratch_need = 1 def __init__(self, a0, a1): self.a0 = a0 self.a1 = a1 def synth(self, scratches): (sc0,) = scratches return [f'or {sc0}, {self.a0}, {self.a0}', f'or {self.a0}, {self.a1}, {self.a1}', f'or {self.a1}, {sc0}, {sc0}'] class IfOp(AsmOp): scratch_need = 1 def __init__(self, fun, op): self.fun = fun self.cond, mark, self.has_else = op self.then_mark = f'_then_{mark}' self.else_mark = f'_else_{mark}' self.endif_mark = f'_endif_{mark}' def synth(self, scratches): sc0 = scratches[0] if self.has_else: return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', # flag if cond == 0 f'beq {self.else_mark}'] else: return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', f'beq {self.endif_mark}'] def synth_else(self): return [f'cmp r0, r0', # trick because beq is better than "or pc, ." f'beq {self.endif_mark}', f'{self.else_mark}:'] def synth_endif(self): return [f'{self.endif_mark}:'] class WhileOp(AsmOp): scratch_need = 1 @staticmethod def synth_loop(mark): loop_mark = f'_loop_{mark}' return [f'{loop_mark}:'] def __init__(self, cond, mark): self.cond = cond self.loop_mark = f'_loop_{mark}' self.endwhile_mark = f'_endwhile_{mark}' def synth(self, scratches): sc0 = scratches[0] return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', f'beq {self.endwhile_mark}'] def synth_endwhile(self): return [f'cmp r0, r0', f'beq {self.loop_mark}', f'{self.endwhile_mark}:'] class Delayed(AsmOp): def __init__(self, out_cb): self.out_cb = out_cb @property def out(self): return self.out_cb() def synth(self): return [] class CcInterp(lark.visitors.Interpreter): def __init__(self): self.global_scope = Scope() self.cur_scope = self.global_scope self.cur_fun = None self.funs = [] self.next_reg = 0 self.next_marker = 0 def _lookup_symbol(self, s): scope = self.cur_scope while scope is not None: if s in scope.symbols: return scope.symbols[s] scope = scope.parent return None def _get_reg(self): return self.cur_fun.regs.take() def _get_marker(self): mark = self.next_marker self.next_marker += 1 return mark def _synth(self, op): with self.cur_fun.regs.borrow(op.scratch_need) as scratches: self._log(f'{op.__class__.__name__}') self.cur_fun.ops.append(lambda: op.synth(scratches)) def _load(self, ident): s = self._lookup_symbol(ident) assert s is not None, f'unknown identifier {ident}' if isinstance(s, FunctionSpec) or s in self.global_scope.symbols: reg = self._get_reg() return SetAddr(self.cur_fun, [reg, ident]) else: if s.volatile: self._log(f'loading volatile {s}') return Load(self.cur_fun, [reg, s]) reg, created = self.cur_fun.regs.load(s) if created: return Load(self.cur_fun, [reg, s]) else: self._log(f'{s} was already in {reg}') return Reg(reg) def identifier(self, tree): # TODO: not actually load the value until it's used in an expression # could have the op.out() function have a side effect that does that # if it's an assignment, we need to make a Variable with the proper # address and assign the register to it. # If it's volatile, we need to flush it. def delayed_load(): self._log(f'delay-loading {tree.children[0]}') op = self._load(tree.children[0]) self._synth(op) return op.out tree.op = Delayed(delayed_load) tree.var = self._lookup_symbol(tree.children[0]) def litteral(self, tree): imm = tree.children[0] reg = self._get_reg() assert self.cur_fun is not None tree.op = Set16Imm(self.cur_fun, [reg, imm]) self._synth(tree.op) def _assign(self, left, right): # need to make sure there's a variable and either: # - to store it, or # - to mark a register for it if left.volatile: self._synth(Store(self.cur_fun, [right, left])) self.cur_fun.stored(left) return self._log(f'assigning {left} = {right}') # cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, left])) self.cur_fun.regs.assign(left, right) self._synth(Store(self.cur_fun, [right, left])) def assignment(self, tree): self.visit_children(tree) val = tree.children[1].op.out # left hand side is an lvalue, retrieve it var = tree.children[0].var self._assign(var, val) def global_var(self, tree): self.visit_children(tree) var = Variable.from_def(tree) self.global_scope.symbols[var.name] = var val = 0 if len(tree.children) > 2: val = tree.children[2].children[0].value def fun_decl(self, tree): fun = FunctionSpec(tree.children[0]) self.cur_scope.symbols[fun.name] = fun def _prep_fun_call(self, fn_reg, params): """Move all params to r0-rn.""" def swap(old, new): if old == new: return oldpos = params.index(old) try: newpos = params.index(new) except ValueError: params[oldpos] = new else: params[newpos], params[oldpos] = params[oldpos], params[newpos] self._synth(Swap(old, new)) if fn_reg in [f'r{i}' for i in range(len(params))]: new_fn = f'r{len(params)}' self._synth(Swap(fn_reg, new_fn)) fn_reg = new_fn for i, param in enumerate(params): new = f'r{i}' swap(param, new) return fn_reg def fun_call(self, tree): self.cur_fun.fun_calls += 1 self.visit_children(tree) fn_reg = tree.children[0].op.out param_regs = [c.op.out for c in tree.children[1:]] self.cur_fun.regs.flush_all() for i in range(len(param_regs)): self.cur_fun.regs.take(f'r{i}') fn_reg = self._prep_fun_call(fn_reg, param_regs) self.cur_fun.regs.take(fn_reg) tree.op = FnCall(self.cur_fun, [fn_reg, param_regs]) self._synth(tree.op) self.cur_fun.regs.reset() self.cur_fun.regs.take('r0') def statement(self, tree): self.visit_children(tree) if self.cur_fun.deferred_ops: self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}') for op in self.cur_fun.deferred_ops: self._synth(op) self.cur_fun.deferred_ops = [] iter_expression = statement def if_stat(self, tree): self.visit(tree.children[0]) mark = self._get_marker() has_else = len(tree.children) > 2 op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) self._synth(op) self.visit(tree.children[1]) if has_else: self.cur_fun.ops.append(op.synth_else) self.visit(tree.children[2]) self.cur_fun.ops.append(op.synth_endif) def while_stat(self, tree): mark = self._get_marker() self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[0]) op = WhileOp(tree.children[0].op.out, mark) self._synth(op) self.visit(tree.children[1]) for v, r in begin_vars.items(): rvars = self.cur_fun.regs.vars if v not in rvars or rvars[v] != r: self._log(f'loading missing var {v}') self._synth(Load(self.cur_fun, [r, v])) self.cur_fun.ops.append(op.synth_endwhile) def for_stat(self, tree): mark = self._get_marker() self.visit(tree.children[0]) # initialization self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[1]) op = WhileOp(tree.children[1].op.out, mark) self._synth(op) self.visit(tree.children[3]) self.visit(tree.children[2]) # 3rd statement for v, r in begin_vars.items(): rvars = self.cur_fun.regs.vars if v not in rvars or rvars[v] != r: self._log(f'loading missing var {v}') self._synth(Load(self.cur_fun, [r, v])) self.cur_fun.ops.append(op.synth_endwhile) def _unary_op(op): def _f(self, tree): self.visit_children(tree) operand = tree.children[0].op.out reg = self.cur_fun.regs.take() tree.op = op(self.cur_fun, [reg, operand]) self._synth(tree.op) self.cur_fun.regs.give(operand) return _f def dereference(self, tree): self.visit_children(tree) reg = tree.children[0].op.out var = Variable.from_dereference(reg) self._log(f'making var {var} from derefing reg {reg}') tree.var = var def delayed_load(): self._log(f'delay-loading {tree.children[0]}') op = Load(self.cur_fun, [reg, var]) self._synth(op) return op.out tree.op = Delayed(delayed_load) def post_increment(self, tree): self.visit_children(tree) tree.op = Reg(tree.children[0].op.out) var = tree.children[0].var reg = tree.op.out self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) pre_increment = _unary_op(Incr) bool_not = _unary_op(BoolNot) def _binary_op(op): def _f(self, tree): self.visit_children(tree) left, right = (x.op.out for x in tree.children) dest = self.cur_fun.regs.take() tree.op = op(self.cur_fun, [dest, left, right]) self._synth(tree.op) self.cur_fun.regs.give(left) self.cur_fun.regs.give(right) return _f def _combo(uop, bop): def _f(self, tree): bop.__get__(self)(tree) uop.__get__(self)(tree) return _f shl = _binary_op(ShlOp) add = _binary_op(AddOp) sub = _binary_op(SubOp) mul = _binary_op(MulOp) _and = _binary_op(AndOp) # ... gt = _binary_op(GtOp) lt = _binary_op(LtOp) neq = _binary_op(NeqOp) eq = _combo(bool_not, neq) def _forward_op(self, tree): self.visit_children(tree) tree.op = tree.children[0].op def cast(self, tree): self.visit_children(tree) tree.op = tree.children[1].op def _log(self, line): self.cur_fun.log(line) def local_var(self, tree): self.visit_children(tree) assert self.cur_fun is not None assert self.cur_scope is not None var = Variable.from_def(tree) var.stackaddr = self.cur_fun.get_stack() # will have to invert self.cur_scope.symbols[var.name] = var self.cur_fun.locals[var.name] = var if len(tree.children) > 2: initval = tree.children[2].children[0].op.out cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) self.cur_fun.regs.assign(var, initval, cleanup) self._log(f'assigning {var} = {initval}') def fun_def(self, tree): prot, body = tree.children fun = Function(prot) assert self.cur_fun is None self.cur_fun = fun self.cur_scope.symbols[fun.spec.name] = fun.spec params = fun.param_dict() def getcleanup(var): return lambda reg: self._synth(Store(self.cur_fun, [reg, var])) for i, param in enumerate(fun.params): fun.locals[param.name] = param param.stackaddr = fun.get_stack() self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}') cleanup = getcleanup(param) fun.regs.loaded(param, f'r{i}', cleanup) fun_scope = Scope() fun_scope.parent = self.cur_scope fun_scope.symbols.update(params) body.scope = fun_scope self.visit_children(tree) if fun.ret is None: if fun.spec.name == 'main': self._synth(Set16Imm(fun, ['r0', 0])) self._synth(ReturnReg(fun, ['r0'])) elif fun.spec.return_type == 'void': self._synth(ReturnReg(fun, ['r0'])) else: assert fun.ret is not None self.cur_fun = None self.funs.append(fun) def body(self, tree): bscope = getattr(tree, 'scope', Scope()) bscope.parent = self.cur_scope self.cur_scope = bscope self.visit_children(tree) self.cur_scope = bscope.parent def return_stat(self, tree): assert self.cur_fun is not None self.cur_fun.ret = True self.visit_children(tree) expr_reg = tree.children[0].op.out tree.op = ReturnReg(self.cur_fun, [expr_reg]) self._synth(tree.op) preamble = [f'_start:', f'xor r0, r0, r0', f'xor r1, r1, r1', f'set sp, 0', f'seth sp, {0x11}', # 256 bytes of stack ought to be enough f'set r2, main', f'set r3, 2', f'add lr, pc, r3', f'or pc, r2, r2', f'or pc, pc, pc // loop forever', ] def filter_dupes(ops): dupe_re = re.compile(r'or (r\d+), \1, \1') for op in ops: if dupe_re.search(op): continue yield op def parse_tree(tree, debug=False): tr = CcTransform() tree = tr.transform(tree) if debug: print(tree.pretty()) inte = CcInterp() inte.visit(tree) out = [] for fun in inte.funs: out += fun.synth() out.append('') return '\n'.join(filter_dupes(out)) def larkparse(f, debug=False): with open(GRAMMAR_FILE) as g: asparser = lark.Lark(g.read()) data = f.read() if isinstance(data, bytes): data = data.decode() tree = asparser.parse(data) return parse_tree(tree, debug=debug) def parse_args(): parser = argparse.ArgumentParser(description='Compile.') parser.add_argument('--debug', action='store_true', help='print the AST') parser.add_argument('--assembly', '-S', action='store_true', help='output assembly') parser.add_argument('--compile', '-c', action='store_true', help='compile a single file') parser.add_argument('input', type=argparse.FileType('r'), default=sys.stdin, help='input file (default: stdin)') parser.add_argument('--output', '-o', type=argparse.FileType('wb'), default=sys.stdout.buffer, help='output file') return parser.parse_args() def preprocess(fin): p = subprocess.Popen(CPP, stdin=fin, stdout=subprocess.PIPE) return p.stdout def assemble(text, fout): fin = io.StringIO(text) asmod.write_obj(fout, *asmod.larkparse(fin)) def main(): args = parse_args() assy = larkparse(preprocess(args.input), debug=args.debug) if args.assembly: args.output.write(assy.encode() + b'\n') else: assemble(assy, args.output) if __name__ == "__main__": main()