import argparse import ast import collections import contextlib from dataclasses import dataclass import importlib import inspect import io import re import struct import subprocess import sys import lark asmod = importlib.import_module("as") GRAMMAR_FILE = '/home/paulmathieu/vhdl/tools/cc.ebnf' CPP = ('cpp', '-P') class Scope: def __init__(self): self.symbols = {} self.parent = None self.structs = {} class Variable: def __init__(self, type, name, addr_reg=None): self.type = type self.name = name self.addr_reg = addr_reg def __repr__(self): return f'' def __str__(self): return self.name @classmethod def from_def(cls, tree): type = tree.children[0] name = tree.children[1] return cls(type, name) @classmethod def from_dereference(cls, reg, tree): return cls(tree.type.pointed, 'deref', addr_reg=reg) # - all registers pointing to unwritten stuff can be dropped as soon as we're # done with them: # - already fed them into the operations # - end of statement # - maybe should have a separate storeage for special registers? # need ways to: # - assign a list of registers into r0, ... rn for function call # - ... and run all callbacks # - get the register for an identifier # - dereference an expression: # - essentially turn a temp into an lvalue-address # - if read, need a second reg for its value # - mark a variable to have its address in a register # - delay reading the identifier if it's lhs # - store the value to memory when registers are claimed # - if it was modified through: # - assignment # - pre-post *crement # - retrieve the memory type: # - stack # - absolute # - register # - or just store the value when the variable is assigned # - get a temporary register # - return it # - I guess dereferencing is an upgrade def type(self, tree): print(tree) def allregs(): return (f'r{i}' for i in range(12)) class RegBank: def __init__(self, logger=None): self.reset() self.log = logger or print def reset(self): self.available = list(allregs()) self.vars = {} self.varregs = {} self.cleanup = collections.defaultdict(list) self.taken = {} def __str__(self): return f'regs(avail={self.available}, vars={self.vars})' def take(self, reg=None, orwhatever=False): out = self._take(reg=reg, orwhatever=orwhatever) self.taken[out] = inspect.stack()[1] return out def _take(self, reg=None, orwhatever=False): if reg is not None: if reg not in self.available: if reg in self.varregs: self.evict(self.var(reg)) elif orwhatever: return self._take() else: raise RuntimeError( f'trying to overtake {reg}, ' f'varregs: {self.varregs}, vars: {self.vars}, ' f'taken in {self.taken[reg]}') if reg in self.available: self.available.remove(reg) return reg if not self.available: assert self.vars, "nothing to clean, no more regs :/" # storing one random var var = list(self.vars.keys())[0] self.evict(var) return self.available.pop(0) def give(self, reg): assert reg is not None if reg in self.varregs: # Need to call evict() with the var to free it. return self.available.insert(0, reg) def loaded(self, var, reg, cleanup=None): """Tells the regbank some variable was loaded into the given register.""" assert reg is not None assert var is not None self.vars[var] = reg self.varregs[reg] = var self.take(reg) if cleanup is not None: self.log(f'recording cleanup for {reg}({var})') self.cleanup[reg].append(cleanup) def stored(self, var): """Tells the regbank the given var was stored to memory, register can be freed.""" assert var in self.vars reg = self.vars.pop(var) del self.varregs[reg] self.give(reg) def load(self, var): """Returns the reg associated with the var, or a new reg if none was, and True if the var was created, False otherwise.""" self.log(f'vars: {self.vars}, varregs: {self.varregs}') if var not in self.vars: reg = self.take() self.vars[var] = reg self.varregs[reg] = var return reg, True return self.vars[var], False def assign(self, var, reg, cleanup=None): """Assign a previously-used register to a variable.""" assert reg is not None if var in self.vars: self.stored(var) if reg in self.varregs: for cb in self.cleanup.pop(reg, []): cb(reg) self.stored(self.varregs[reg]) self.vars[var] = reg self.varregs[reg] = var if cleanup is not None: self.cleanup[reg].append(cleanup) def var(self, reg): return self.varregs.get(reg, None) def evict_all(self): for var in list(self.vars): self.evict(var) def evict(self, var): """Runs var callbacks & frees the register.""" assert var in self.vars, f'trying to evict {var}' self.log(f'evicting {var}') if var not in self.vars: return reg = self.vars[var] for cb in self.cleanup.pop(reg, []): cb(reg) self.stored(var) def drop_temps(self): for reg in allregs(): if reg not in self.available and reg not in self.varregs: self.give(reg) def swapped(self, reg0, reg1): var0 = self.varregs.get(reg0, None) var1 = self.varregs.get(reg1, None) if var0 is not None: self.stored(var0) elif reg0 not in self.available: self.give(reg0) if var1 is not None: self.stored(var1) elif reg1 not in self.available: self.give(reg1) if var0 is not None: self.loaded(var0, reg1) if var1 is not None: self.loaded(var1, reg0) def snap(self): return dict(self.vars), dict(self.cleanup) def restore(self, snap): self.reset() vardict, cleanup = snap for var, reg in vardict.items(): assert reg is not None, f'bad snapshot: {snap}' self.loaded(var, reg) self.cleanup[reg] = cleanup.get(reg, []) @contextlib.contextmanager def borrow(self, howmany): regs = [self.take() for i in range(howmany)] yield regs for reg in regs: self.give(reg) class FunctionSpec: def __init__(self, fun_prot): self.return_type = fun_prot.children[0] self.name = fun_prot.children[1] self.param_types = [x.children[0] for x in fun_prot.children[2:]] def __repr__(self): params = ', '.join(self.param_types) return f'' @property def type(self): return FunType(ret=self.return_type, params=self.param_types) class Function: def __init__(self, fun_prot): self.locals = {} self.spec = FunctionSpec(fun_prot) self.params = [Variable(*x.children) for x in fun_prot.children[2:]] self.ret = None self.nextstack = 0 self.topstack = 0 self.statements = [] self.regs = RegBank(logger=self.log) self.deferred_ops = [] self.fun_calls = 0 self.ops = [] def log(self, line): self.ops.append(lambda: [f'// {line}']) @property def stack_usage(self): top_usage = max(self.topstack, self.nextstack) if self.fun_calls > 0: return top_usage + 2 return top_usage def get_stack(self, size=2): stk = self.nextstack self.nextstack += size return stk @contextlib.contextmanager def tempstack(self): stk = self.nextstack yield if self.nextstack > self.topstack: self.topstack = self.nextstack self.nextstack = stk def param_dict(self): return {p.name: p for p in self.params} def __repr__(self): return repr(self.spec) def synth(self): ops = [] if self.fun_calls > 0: ops += [f'store lr, [sp, -2]'] if self.stack_usage > 0: ops += [f'set r4, {self.stack_usage}', f'sub sp, sp, r4'] for op in self.ops: ops += op() indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops] return [f'.global {self.spec.name}', f'{self.spec.name}:'] + indented class Struct: def __init__(self, struct_def): self.name = struct_def.children[0] self.fields = [Variable(*x.children) for x in struct_def.children[1].children] def offset(self, field): return 2 * [x.name for x in self.fields].index(field) def field_type(self, field): for f in self.fields: if f.name == field: return f.type return None @dataclass class CType: volatile: bool = False const: bool = False @dataclass class StructType(CType): struct: str = '' @dataclass class PointerType(CType): pointed: CType = None @dataclass class FunType(CType): ret: CType = None params: [CType] = None @dataclass class PodType(CType): pod: str = 'int' signed = True @dataclass class VoidType(CType): pass class CcTransform(lark.visitors.Transformer): def _binary_op(litt): @lark.v_args(tree=True) def _f(self, tree): left, right = tree.children if left.data == 'literal' and right.data == 'literal': tree.data = 'literal' tree.children = [litt(left.children[0], right.children[0])] return tree return _f def field(self, children): (field,) = children return field def struct_type(self, children): (name,) = children return StructType(struct=name) def pointer(self, children): volat = 'volatile' in children pointed = children[-1] return PointerType(volatile=volat, pointed=pointed) def funptr_type(self, children): ret, *params = children return FunType(ret=ret, params=params) def comma_expr(self, children): c1, c2 = children if c1.data == 'comma_expr': c1.children.append(c2) return c1 return lark.Tree('comma_expr', [c1, c2]) volatile = lambda *_: 'volatile' const = lambda *_: 'const' def type(self, children): volat = 'volatile' in children const = 'const' in children typ = children[-1] if isinstance(typ, str): if typ == 'void': return VoidType() return PodType(volatile=volat, const=const, pod=typ) else: return typ def array_item(self, children): # transform blarg[foo] into *(blarg + foo) because reasons # TODO: bring this over to the main parser, we need to know the type # of blarg in order to compute the actual offset addop = lark.Tree('add', children) return lark.Tree('dereference', [addop]) # operations on literals can be done by the compiler add = _binary_op(lambda a, b: a+b) sub = _binary_op(lambda a, b: a-b) mul = _binary_op(lambda a, b: a*b) shl = _binary_op(lambda a, b: a< 0: ops += [f'set {sc0}, {stack_usage}', f'add sp, sp, {sc0}'] if self.fun.ret is not None: ret = self.ret_reg ops += [f'or r0, {ret}, {ret}'] if self.fun.fun_calls == 0: ops += [f'or pc, lr, lr'] else: ops += [f'load pc, [sp, -2] // return'] return ops class Load(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.fun = fun self.dest, self.var = ops fun.log(f'8bit load? {self.var.type} {self.is8bit}') @property def is8bit(self): if not isinstance(self.var.type, PodType): return False if self.var.type.pod in ['uint8_t', 'int8_t', 'char']: return True return False def _maybecast8bit(self, reg): if self.is8bit: return [f'shr {reg}, {reg}, 8'] return [] @property def out(self): return self.dest def synth(self, scratches): reg = self.dest if self.var.name in self.fun.locals: src = self.var.stackaddr return [f'load {reg}, [sp, {src}]'] # no 8bit cast to/from stack elif self.var.addr_reg is not None: return [f'load {reg}, [{self.var.addr_reg}]'] + self._maybecast8bit(reg) else: return [f'set {reg}, {self.var.name}', f'nop // in case we load a far global', f'load {reg}, [{reg}]'] # unsure if we need a cast here class Push(AsmOp): def __init__(self, reg, stackaddr): self.reg = reg self.stackaddr = stackaddr def synth(self, scratches): return [f'store {self.reg}, [sp, {self.stackaddr}]'] class Pop(AsmOp): def __init__(self, reg, stackaddr): self.reg = reg self.stackaddr = stackaddr @property def out(self): return self.reg def synth(self, scratches): return [f'load {self.reg}, [sp, {self.stackaddr}]'] class Store(AsmOp): scratch_need = 1 def __init__(self, fun, ops): self.fun = fun self.src, self.var = ops def synth(self, scratches): (sc,) = scratches reg = self.src if self.var.name in self.fun.locals: dst = self.var.stackaddr self.fun.log(f'storing {self.var}({reg}) to [sp, {dst}]') return [f'store {reg}, [sp, {dst}]'] elif self.var.addr_reg is not None: return [f'store {reg}, [{self.var.addr_reg}]'] return [f'set {sc}, {self.var.name}', f'nop // you know, in case blah', f'store {reg}, [{sc}]'] class Assign(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.fun = fun self.src, self.var = ops @property def out(self): return self.var def synth(self, scratches): return [f'or {self.var}, {self.src}, {self.src}'] class SetAddr(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.dest, self.ident = ops @property def out(self): return self.dest def synth(self, scratches): reg = self.dest return [f'set {reg}, {self.ident}', f'nop // placeholder for a long address'] class Set16Imm(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.dest, self.imm16 = ops @property def out(self): return self.dest def synth(self, scratches): reg = self.dest hi = (self.imm16 >> 8) & 0xff lo = (self.imm16 >> 0) & 0xff if hi != 0: return [f'set {reg}, {lo}', f'seth {reg}, {hi}'] else: return [f'set {reg}, {lo}'] class Move(AsmOp): scratch_need = 0 def __init__(self, dst, src): self.dst = dst self.src = src def synth(self, scratches): return [f'or {self.dst}, {self.src}, {self.src}'] class Swap(AsmOp): scratch_need = 1 def __init__(self, a0, a1): self.a0 = a0 self.a1 = a1 def synth(self, scratches): (sc0,) = scratches return [f'or {sc0}, {self.a0}, {self.a0}', f'or {self.a0}, {self.a1}, {self.a1}', f'or {self.a1}, {sc0}, {sc0}'] class IfOp(AsmOp): scratch_need = 1 def __init__(self, fun, op): self.fun = fun self.cond, mark, self.has_else = op self.then_mark = f'_then_{mark}' self.else_mark = f'_else_{mark}' self.endif_mark = f'_endif_{mark}' def synth(self, scratches): sc0 = scratches[0] if self.has_else: return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', # flag if cond == 0 f'beq {self.else_mark}'] else: return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', f'beq {self.endif_mark}'] def synth_else(self): return [f'cmp r0, r0', # trick because beq is better than "or pc, ." f'beq {self.endif_mark}', f'{self.else_mark}:'] def synth_endif(self): return [f'{self.endif_mark}:'] class IfEq(IfOp): scratch_need = 0 def __init__(self, a, b, mark, has_else): self.a = a self.b = b self.has_else = has_else self.then_mark = f'_then_{mark}' self.else_mark = f'_else_{mark}' self.endif_mark = f'_endif_{mark}' def synth(self, scratches): if self.has_else: return [f'cmp {self.a}, {self.b}', f'bneq {self.else_mark}'] else: return [f'cmp {self.a}, {self.b}', f'bneq {self.endif_mark}'] class WhileOp(AsmOp): scratch_need = 1 @staticmethod def synth_loop(mark): loop_mark = f'_loop_{mark}' return [f'{loop_mark}:'] def __init__(self, cond, mark): self.cond = cond self.loop_mark = f'_loop_{mark}' self.endwhile_mark = f'_endwhile_{mark}' def synth(self, scratches): sc0 = scratches[0] return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', f'beq {self.endwhile_mark}'] def synth_endwhile(self): return [f'cmp r0, r0', f'beq {self.loop_mark}', f'{self.endwhile_mark}:'] class Delayed(AsmOp): def __init__(self, out_cb): self.out_cb = out_cb @property def out(self): return self.out_cb() def synth(self): return [] def litt_type(val): if isinstance(val, str): return PointerType(pointed=PodType(pod='char', const=True)) elif isinstance(val, float): return PodType(pod='float', const=True) else: return PodType(pod='int', const=True) def return_type(type): return type.ret def deref_type(type): return type.pointed class CcInterp(lark.visitors.Interpreter): def __init__(self): self.global_scope = Scope() self.cur_scope = self.global_scope self.cur_fun = None self.funs = [] self.next_reg = 0 self.next_marker = 0 self.strings = {} self.next_string_token = 0 def _lookup_symbol(self, s): scope = self.cur_scope while scope is not None: if s in scope.symbols: return scope.symbols[s] scope = scope.parent return None def _get_reg(self): return self.cur_fun.regs.take() def _get_marker(self): mark = self.next_marker self.next_marker += 1 return mark @contextlib.contextmanager def _nosynth(self): old = self._synth self._synth = lambda *_: None yield self._synth = old def _synth(self, op): with self.cur_fun.regs.borrow(op.scratch_need) as scratches: self._log(f'{op.__class__.__name__}') self.cur_fun.ops.append(lambda: op.synth(scratches)) def _load(self, ident, reg=None): s = self._lookup_symbol(ident) assert s is not None, f'unknown identifier {ident}' if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols: reg = self.cur_fun.regs.take(reg=reg, orwhatever=True) return SetAddr(self.cur_fun, [reg, ident]) else: if s.type.volatile: self._log(f'loading volatile {s}') return Load(self.cur_fun, [reg, s]) reg, created = self.cur_fun.regs.load(s) if created: return Load(self.cur_fun, [reg, s]) else: self._log(f'{s} was already in {reg}') return Reg(reg) def identifier(self, tree): # TODO: not actually load the value until it's used in an expression # could have the op.out() function have a side effect that does that # if it's an assignment, we need to make a Variable with the proper # address and assign the register to it. # If it's volatile, we need to flush it. def delayed_load(): self._log(f'delay-loading {tree.children[0]}') reg = getattr(tree, 'request_out', None) op = self._load(tree.children[0], reg=reg) self._synth(op) return op.out tree.op = Delayed(delayed_load) tree.var = self._lookup_symbol(tree.children[0]) assert tree.var is not None, f'undefined identifier: {tree.children[0]}' tree.type = tree.var.type def _make_string(self, s): nexttok = self.next_string_token self.next_string_token += 1 token = f'_str{nexttok}' self.strings[token] = s self.global_scope.symbols[token] = Variable(litt_type(s), token) return token def literal(self, tree): imm = tree.children[0] if isinstance(imm, str): tok = self._make_string(imm) tree.children = [tok] return self.identifier(tree) def delayed(): if hasattr(tree, 'request_out'): reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True) else: reg = self.cur_fun.regs.take() assert reg is not None, f'weird: {self.cur_fun.regs}' op = Set16Imm(self.cur_fun, [reg, imm]) self._synth(op) return op.out tree.op = Delayed(delayed) tree.type = litt_type(tree.children[0]) def _assign(self, left, right): self._log(f'assigning {left} = {right}') self.cur_fun.regs.assign(left, right) op = Store(self.cur_fun, [right, left]) self._synth(op) return op.out def assignment(self, tree): self.visit_children(tree) val = tree.children[1].op.out assert val is not None, f'no output!!! {tree.pretty()}' # left hand side is an lvalue, retrieve it var = tree.children[0].var tree.op = self._assign(var, val) def global_var(self, tree): self.visit_children(tree) var = Variable.from_def(tree) self.global_scope.symbols[var.name] = var val = 0 if len(tree.children) > 2: val = tree.children[2].children[0].value def fun_decl(self, tree): fun = FunctionSpec(tree.children[0]) self.cur_scope.symbols[fun.name] = fun def _reg_shuffle(self, src, dst): def swap(old, new): if old == new: return oldpos = src.index(old) try: newpos = src.index(new) except ValueError: src[oldpos] = new else: src[newpos], src[oldpos] = src[oldpos], src[newpos] self._synth(Swap(old, new)) for i, r in enumerate(dst): if r not in src: #self.cur_fun.regs.take(r) # precious! self._synth(Move(r, src[i])) else: swap(r, src[i]) def _prep_fun_call(self, fn_reg, params): """Move all params to r0-rn.""" target_regs = [f'r{i}' for i in range(len(params) + 1)] self._reg_shuffle(params + [fn_reg], target_regs) return target_regs[-1] @contextlib.contextmanager def _push(self, dontpush=()): pushed = [] self.cur_fun.regs.evict_all() with self.cur_fun.tempstack(): for reg in allregs(): if reg not in self.cur_fun.regs.available and reg not in dontpush: self._log(f'reg in use: {reg}') stk = self.cur_fun.get_stack() pushed.append((reg, stk)) self._synth(Push(reg, stk)) self.cur_fun.regs.give(reg) yield pushed def fun_call(self, tree): self.cur_fun.fun_calls += 1 if len(tree.children) == 1: param_children = [] elif tree.children[1].data == 'comma_expr': param_children = tree.children[1].children else: param_children = [tree.children[1]] if True: # pre-allocate output registers for i, child in enumerate(param_children): self._log(f'requesting r{i} for {child}') child.request_out = f'r{i}' self._log(f'requesting r{len(param_children)} for {tree.children[0]}') tree.children[0].request_out = f'r{len(param_children)}' self.visit_children(tree) fn_reg = tree.children[0].op.out param_regs = [c.op.out for c in param_children] self._log(f'calling {tree.children[0].children[0]}({param_regs})') with self._push(dontpush=param_regs + [fn_reg]) as pushed: self._flush_deferred() # side effects fn_reg = self._prep_fun_call(fn_reg, param_regs) self.cur_fun.regs.reset() self.cur_fun.regs.take(fn_reg) for reg in (f'r{i}' for i in range(len(param_regs))): self.cur_fun.regs.take(reg) tree.op = FnCall(self.cur_fun, [fn_reg, param_regs]) self._synth(tree.op) self.cur_fun.regs.reset() pops = [] for reg, stk in pushed: self.cur_fun.regs.take(reg) pops.append(Pop(reg, stk)) ret = 'r0' if hasattr(tree, 'request_out'): ret = tree.request_out self._synth(Move(ret, 'r0')) nret = self.cur_fun.regs.take(ret, orwhatever=True) if nret != ret: self._synth(Move(nret, ret)) tree.op = type('blah', (), {'out': nret}) for op in pops: self._synth(op) tree.type = return_type(tree.children[0].type) def _flush_deferred(self): if self.cur_fun.deferred_ops: self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}') for op in self.cur_fun.deferred_ops: self._synth(op) self.cur_fun.deferred_ops = [] def statement(self, tree): self.visit_children(tree) self._flush_deferred() self.cur_fun.regs.drop_temps() iter_expression = statement def if_stat(self, tree): mark = self._get_marker() has_else = len(tree.children) > 2 # optimization!!!! if tree.children[0].data == 'eq': self.visit_children(tree.children[0]) a, b = (c.op.out for c in tree.children[0].children) op = IfEq(a, b, mark, has_else) else: self.visit(tree.children[0]) op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) self._synth(op) self.cur_fun.regs.drop_temps() begin_vars = self.cur_fun.regs.snap() valid_vars_set = set(begin_vars[0].items()) self.visit(tree.children[1]) valid_vars_set &= set(self.cur_fun.regs.snap()[0].items()) if has_else: self.cur_fun.regs.restore(begin_vars) self.cur_fun.ops.append(op.synth_else) self.visit(tree.children[2]) valid_vars_set &= set(self.cur_fun.regs.snap()[0].items()) valid_vars = dict(valid_vars_set) valid_snap = (valid_vars, {r: begin_vars[1][r] for r in valid_vars.values()}) self.cur_fun.regs.restore(valid_snap) self.cur_fun.ops.append(op.synth_endif) def _restore_vars(self, begin_vars): curvars = self.cur_fun.regs.vars # 0. do the loop shuffle shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars} dst, src = tuple(zip(*shuffles.items())) or ([], []) self._reg_shuffle(src, dst) # 1. load the rest for v, r in begin_vars.items(): if v in curvars: continue self._log(f'loading missing var {v} into {r}') self._synth(Load(self.cur_fun, [r, v])) def while_stat(self, tree): mark = self._get_marker() self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[0]) postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[0].op.out, mark) self._synth(op) self.cur_fun.regs.drop_temps() self.visit(tree.children[1]) self._restore_vars(begin_vars) self.cur_fun.ops.append(op.synth_endwhile) self.cur_fun.regs.restore(postcond) def for_stat(self, tree): mark = self._get_marker() self.visit(tree.children[0]) # initialization self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[1]) postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[1].op.out, mark) self._synth(op) self.cur_fun.regs.drop_temps() self.visit(tree.children[3]) self.visit(tree.children[2]) # 3rd statement self._restore_vars(begin_vars) self.cur_fun.ops.append(op.synth_endwhile) self.cur_fun.regs.restore(postcond) def _unary_op(op): def _f(self, tree): self.visit_children(tree) operand = tree.children[0].op.out reg = self.cur_fun.regs.take() tree.op = op(self.cur_fun, [reg, operand]) self._synth(tree.op) self.cur_fun.regs.give(operand) return _f def _deref(self, tree): reg = tree.children[0].op.out var = Variable.from_dereference(reg, tree.children[0]) self._log(f'making var {var} from derefing reg {reg}') tree.var = var def delayed_load(): self._log(f'delay-loading {tree.children[0]}') op = Load(self.cur_fun, [reg, var]) self._synth(op) return op.out tree.op = Delayed(delayed_load) tree.type = var.type def dereference(self, tree): self.visit_children(tree) self._deref(tree) def pointer_access(self, tree): with self._nosynth(): self.visit_children(tree) ch_strct, ch_field = tree.children type = ch_strct.type.pointed strct = self.global_scope.structs[type.struct] offs = strct.offset(ch_field) if offs > 0: load_offs = lark.Tree('literal', [offs]) addop = lark.Tree('add', [tree.children[0], load_offs]) self.visit(addop) addop.type = PointerType(pointed=strct.field_type(ch_field)) tree.children = [addop] else: tree.children[0].type = PointerType(pointed=strct.field_type(ch_field)) tree.children = [tree.children[0]] self.visit_children(tree) self._deref(tree) def post_increment(self, tree): self.visit_children(tree) tree.op = Reg(tree.children[0].op.out) var = tree.children[0].var reg = tree.op.out self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) tree.type = var.type def post_decrement(self, tree): self.visit_children(tree) tree.op = Reg(tree.children[0].op.out) var = tree.children[0].var reg = tree.op.out self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) tree.type = var.type def pre_increment(self, tree): self.visit_children(tree) reg = tree.children[0].op.out tree.op = Incr(self.cur_fun, [reg, reg]) self._synth(tree.op) cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) self.cur_fun.regs.cleanup[reg].append(cleanup) bool_not = _unary_op(BoolNot) def _binary_op(op): def _f(self, tree): self.visit_children(tree) left, right = (x.op.out for x in tree.children) dest = self.cur_fun.regs.take() tree.op = op(self.cur_fun, [dest, left, right]) self._synth(tree.op) self.cur_fun.regs.give(left) self.cur_fun.regs.give(right) tree.type = tree.children[0].type # because uhm reasons return _f def _combo(uop, bop): def _f(self, tree): ftree = lark.Tree(bop, tree.children) tree.children = [ftree] uop.__get__(self)(tree) return _f shl = _binary_op(ShlOp) add = _binary_op(AddOp) sub = _binary_op(SubOp) mul = _binary_op(MulOp) _and = _binary_op(AndOp) _or = _binary_op(OrOp) # ... gt = _binary_op(GtOp) lt = _binary_op(LtOp) neq = _binary_op(NeqOp) eq = _combo(bool_not, 'neq') def shr(self, tree): self.visit(tree.children[0]) left = tree.children[0].op.out assert tree.children[1].data == 'literal' right = tree.children[1].children[0] dest = self.cur_fun.regs.take() tree.op = ShrOp(self.cur_fun, [dest, left, right]) self._synth(tree.op) self.cur_fun.regs.give(left) tree.type = tree.children[0].type # because uhm reasons def _assign(self, left, right): self._log(f'assigning {left} = {right}') self.cur_fun.regs.assign(left, right) op = Store(self.cur_fun, [right, left]) self._synth(op) return op.out def _assign_op(op): def f(self, tree): opnode = lark.Tree(op, tree.children) self.visit(opnode) val = opnode.op.out # left hand side is an lvalue, retrieve it var = tree.children[0].var tree.op = self._assign(var, val) return f or_ass = _assign_op('_or') add_ass = _assign_op('add') def _forward_op(self, tree): self.visit_children(tree) tree.op = tree.children[0].op def cast(self, tree): if hasattr(tree, 'request_out'): tree.children[1].request_out = tree.request_out self.visit_children(tree) tree.op = tree.children[1].op tree.type = tree.children[0] def _log(self, line): self.cur_fun.log(line) def local_var(self, tree): self.visit_children(tree) assert self.cur_fun is not None assert self.cur_scope is not None var = Variable.from_def(tree) var.stackaddr = self.cur_fun.get_stack() # will have to invert self.cur_scope.symbols[var.name] = var self.cur_fun.locals[var.name] = var if len(tree.children) > 2: initval = tree.children[2].children[0].op.out cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) self.cur_fun.regs.assign(var, initval, cleanup=cleanup) self._log(f'assigning {var} = {initval}') def fun_def(self, tree): prot, body = tree.children fun = Function(prot) assert self.cur_fun is None self.cur_fun = fun self.cur_scope.symbols[fun.spec.name] = fun.spec params = fun.param_dict() def getcleanup(var): def cl(reg): self._log(f'cleanup for {var} in {reg}') self._synth(Store(self.cur_fun, [reg, var])) return cl for i, param in enumerate(fun.params): fun.locals[param.name] = param param.stackaddr = fun.get_stack() self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}') cleanup = getcleanup(param) fun.regs.loaded(param, f'r{i}', cleanup=cleanup) fun_scope = Scope() fun_scope.parent = self.cur_scope fun_scope.symbols.update(params) body.scope = fun_scope self.visit_children(tree) if fun.ret is None: if fun.spec.name == 'main': self._synth(Set16Imm(fun, ['r0', 0])) self._synth(ReturnReg(fun, ['r0'])) elif fun.spec.return_type == VoidType(): self._synth(ReturnReg(fun, ['r0'])) else: assert fun.ret is not None self.cur_fun = None self.funs.append(fun) def body(self, tree): bscope = getattr(tree, 'scope', Scope()) bscope.parent = self.cur_scope self.cur_scope = bscope self.visit_children(tree) self.cur_scope = bscope.parent def return_stat(self, tree): assert self.cur_fun is not None self.cur_fun.ret = True self.visit_children(tree) expr_reg = tree.children[0].op.out tree.op = ReturnReg(self.cur_fun, [expr_reg]) self._synth(tree.op) def struct_def(self, tree): self.visit_children(tree) strct = Struct(tree) self.cur_scope.structs[strct.name] = strct preamble = [f'_start:', f'xor r0, r0, r0', f'xor r1, r1, r1', f'set sp, 0', f'seth sp, {0x11}', # 256 bytes of stack ought to be enough f'set r2, main', f'set r3, 0', f'add lr, pc, r3', f'or pc, r2, r2', f'cmp r0, r0', f'beq [pc, -4] // loop forever', ] def filter_dupes(ops): dupe_re = re.compile(r'or (r\d+), \1, \1') for op in ops: if dupe_re.search(op): continue yield op def parse_tree(tree, debug=False): tr = CcTransform() tree = tr.transform(tree) if debug: print(tree.pretty()) inte = CcInterp() inte.visit(tree) out = [] for fun in inte.funs: out += fun.synth() out.append('') for sym, dat in inte.strings.items(): out += [f'.global {sym}', f'{sym}:'] dat += '\0' if len(dat) % 2 != 0: dat += '\0' dat = dat.encode() nwords = len(dat) // 2 out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)] out.append('') return '\n'.join(filter_dupes(out)) def larkparse(f, debug=False): with open(GRAMMAR_FILE) as g: asparser = lark.Lark(g.read()) data = f.read() if isinstance(data, bytes): data = data.decode() tree = asparser.parse(data) return parse_tree(tree, debug=debug) def parse_args(): parser = argparse.ArgumentParser(description='Compile.') parser.add_argument('--debug', action='store_true', help='print the AST') parser.add_argument('--assembly', '-S', action='store_true', help='output assembly') parser.add_argument('--compile', '-c', action='store_true', help='compile a single file') parser.add_argument('input', type=argparse.FileType('r'), default=sys.stdin, help='input file (default: stdin)') parser.add_argument('--output', '-o', type=argparse.FileType('wb'), default=sys.stdout.buffer, help='output file') parser.add_argument('-I', dest='include_dirs', action='append', default=[], help='include dirs') return parser.parse_args() def preprocess(fin, include_dirs): cmd = list(CPP) + [f'-I{x}' for x in include_dirs] p = subprocess.Popen(cmd, stdin=fin, stdout=subprocess.PIPE) out, _ = p.communicate() if p.returncode != 0: raise RuntimeError(f'preprocessor error') return io.StringIO(out.decode()) def assemble(text, fout): fin = io.StringIO(text) asmod.write_obj(fout, *asmod.larkparse(fin)) def main(): args = parse_args() assy = larkparse(preprocess(args.input, args.include_dirs), debug=args.debug) if args.assembly: args.output.write(assy.encode() + b'\n') else: assemble(assy, args.output) if __name__ == "__main__": main()