import argparse import ast import collections import contextlib from dataclasses import dataclass import importlib import inspect import io import os import re import struct import subprocess import sys import lark asmod = importlib.import_module("as") _HERE = os.path.dirname(__file__) GRAMMAR_FILE = os.path.join(_HERE, 'cc.ebnf') CPP = ('cpp', '-P') class Scope: def __init__(self): self.symbols = {} self.parent = None self.structs = {} class Variable: def __init__(self, type, name, addr_reg=None): self.type = type self.name = name self.addr_reg = addr_reg def __repr__(self): return f'' def __str__(self): return self.name @classmethod def from_def(cls, tree): type = tree.children[0] name = tree.children[1] return cls(type, name) @classmethod def from_dereference(cls, reg, tree): return cls(tree.type.pointed, 'deref', addr_reg=reg) # - all registers pointing to unwritten stuff can be dropped as soon as we're # done with them: # - already fed them into the operations # - end of statement # - maybe should have a separate storeage for special registers? # need ways to: # - assign a list of registers into r0, ... rn for function call # - ... and run all callbacks # - get the register for an identifier # - dereference an expression: # - essentially turn a temp into an lvalue-address # - if read, need a second reg for its value # - mark a variable to have its address in a register # - delay reading the identifier if it's lhs # - store the value to memory when registers are claimed # - if it was modified through: # - assignment # - pre-post *crement # - retrieve the memory type: # - stack # - absolute # - register # - or just store the value when the variable is assigned # - get a temporary register # - return it # - I guess dereferencing is an upgrade def type(self, tree): print(tree) def allregs(): return (f'r{i}' for i in range(12)) class RegBank: def __init__(self, logger=None): self.reset() self.log = logger or print def reset(self): self.available = list(allregs()) self.vars = {} self.varregs = {} self.cleanup = collections.defaultdict(list) self.taken = {} def __str__(self): return f'regs(avail={self.available}, vars={self.vars})' def take(self, reg=None, orwhatever=False): out = self._take(reg=reg, orwhatever=orwhatever) self.taken[out] = inspect.stack()[1] return out def _take(self, reg=None, orwhatever=False): if reg is not None: if reg not in self.available: if reg in self.varregs: self.evict(self.var(reg)) elif orwhatever: return self._take() else: raise RuntimeError( f'trying to overtake {reg}, ' f'varregs: {self.varregs}, vars: {self.vars}, ' f'taken in {self.taken[reg]}') if reg in self.available: self.available.remove(reg) return reg if not self.available: assert self.vars, "nothing to clean, no more regs :/" # storing one random var var = list(self.vars.keys())[0] self.evict(var) return self.available.pop(0) def give(self, reg): assert reg is not None if reg in self.varregs: # Need to call evict() with the var to free it. return self.available.insert(0, reg) def loaded(self, var, reg, cleanup=None): """Tells the regbank some variable was loaded into the given register.""" assert reg is not None assert var is not None self.vars[var] = reg self.varregs[reg] = var self.take(reg) if cleanup is not None: self.log(f'recording cleanup for {reg}({var})') self.cleanup[reg].append(cleanup) def stored(self, var): """Tells the regbank the given var was stored to memory, register can be freed.""" assert var in self.vars reg = self.vars.pop(var) del self.varregs[reg] self.give(reg) def load(self, var): """Returns the reg associated with the var, or a new reg if none was, and True if the var was created, False otherwise.""" self.log(f'vars: {self.vars}, varregs: {self.varregs}') if var not in self.vars: reg = self.take() self.vars[var] = reg self.varregs[reg] = var return reg, True return self.vars[var], False def assign(self, var, reg, cleanup=None): """Assign a previously-used register to a variable.""" assert reg is not None if var in self.vars: self.stored(var) if reg in self.varregs: for cb in self.cleanup.pop(reg, []): cb(reg) self.stored(self.varregs[reg]) self.vars[var] = reg self.varregs[reg] = var if cleanup is not None: self.cleanup[reg].append(cleanup) def var(self, reg): return self.varregs.get(reg, None) def evict_all(self): for var in list(self.vars): self.evict(var) def evict(self, var): """Runs var callbacks & frees the register.""" assert var in self.vars, f'trying to evict {var}' self.log(f'evicting {var}') if var not in self.vars: return reg = self.vars[var] for cb in self.cleanup.pop(reg, []): cb(reg) self.stored(var) def drop_temps(self): for reg in allregs(): if reg not in self.available and reg not in self.varregs: self.give(reg) def swapped(self, reg0, reg1): var0 = self.varregs.get(reg0, None) var1 = self.varregs.get(reg1, None) if var0 is not None: self.stored(var0) elif reg0 not in self.available: self.give(reg0) if var1 is not None: self.stored(var1) elif reg1 not in self.available: self.give(reg1) if var0 is not None: self.loaded(var0, reg1) if var1 is not None: self.loaded(var1, reg0) def snap(self): cleanup = collections.defaultdict(list) cleanup.update({r: list(c) for r, c in self.cleanup.items()}) return dict(self.vars), cleanup def restore(self, snap): self.reset() vardict, cleanup = snap for var, reg in vardict.items(): assert reg is not None, f'bad snapshot: {snap}' self.loaded(var, reg) self.cleanup[reg] = cleanup.get(reg, []) @contextlib.contextmanager def borrow(self, howmany): regs = [self.take() for i in range(howmany)] yield regs for reg in regs: self.give(reg) class FunctionSpec: def __init__(self, fun_prot): self.return_type = fun_prot.children[0] self.name = fun_prot.children[1] self.param_types = [x.children[0] for x in fun_prot.children[2:]] def __repr__(self): params = ', '.join(self.param_types) return f'' @property def type(self): return FunType(ret=self.return_type, params=self.param_types) class Function: def __init__(self, fun_prot): self.locals = {} self.spec = FunctionSpec(fun_prot) self.params = [Variable(*x.children) for x in fun_prot.children[2:]] self.ret = None self.nextstack = 0 self.topstack = 0 self.statements = [] self.regs = RegBank(logger=self.log) self.deferred_ops = [] self.fun_calls = 0 self.ops = [] def log(self, line): self.ops.append(lambda: [f'// {line}']) @property def stack_usage(self): top_usage = max(self.topstack, self.nextstack) if self.fun_calls > 0: return top_usage + 2 return top_usage def get_stack(self, size=2): stk = self.nextstack self.nextstack += size return stk @contextlib.contextmanager def tempstack(self): stk = self.nextstack yield if self.nextstack > self.topstack: self.topstack = self.nextstack self.nextstack = stk def param_dict(self): return {p.name: p for p in self.params} def __repr__(self): return repr(self.spec) def synth(self): ops = [] if self.fun_calls > 0: ops += [f'store lr, [sp, -2]'] if self.stack_usage > 0: ops += [f'set r4, {self.stack_usage}', f'sub sp, sp, r4'] for op in self.ops: ops += op() indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops] return [f'.global {self.spec.name}', f'{self.spec.name}:'] + indented class Struct: def __init__(self, struct_def): self.name = struct_def.children[0] self.fields = [Variable(*x.children) for x in struct_def.children[1].children] def offset(self, field): return 2 * [x.name for x in self.fields].index(field) def field_type(self, field): for f in self.fields: if f.name == field: return f.type return None @dataclass class CType: volatile: bool = False const: bool = False size: int = 2 @dataclass class StructType(CType): struct: str = '' @dataclass class PointerType(CType): pointed: CType = None @dataclass class FunType(CType): ret: CType = None params: [CType] = None @dataclass class PodType(CType): pod: str = 'int' signed = True @dataclass class VoidType(CType): pass class CcTransform(lark.visitors.Transformer): def _binary_op(litt): @lark.v_args(tree=True) def _f(self, tree): left, right = tree.children if left.data == 'literal' and right.data == 'literal': tree.data = 'literal' tree.children = [litt(left.children[0], right.children[0])] return tree return _f def field(self, children): (field,) = children return field def struct_type(self, children): (name,) = children return StructType(struct=name) def pointer(self, children): volat = 'volatile' in children pointed = children[-1] return PointerType(volatile=volat, pointed=pointed) def funptr_type(self, children): ret, *params = children return FunType(ret=ret, params=params) def comma_expr(self, children): c1, c2 = children if c1.data == 'comma_expr': c1.children.append(c2) return c1 return lark.Tree('comma_expr', [c1, c2]) volatile = lambda *_: 'volatile' const = lambda *_: 'const' def type(self, children): volat = 'volatile' in children const = 'const' in children typ = children[-1] if isinstance(typ, str): if typ == 'void': return VoidType() size = 1 if typ == 'char' else 2 return PodType(volatile=volat, const=const, pod=typ, size=size) else: return typ # operations on literals can be done by the compiler add = _binary_op(lambda a, b: a+b) sub = _binary_op(lambda a, b: a-b) mul = _binary_op(lambda a, b: a*b) shl = _binary_op(lambda a, b: a< 0: ops += [f'set {sc0}, {stack_usage}', f'add sp, sp, {sc0}'] if self.fun.ret is not None: ret = self.ret_reg ops += [f'or r0, {ret}, {ret}'] if self.fun.fun_calls == 0: ops += [f'or pc, lr, lr'] else: ops += [f'load pc, [sp, -2] // return'] return ops class Load(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.fun = fun self.dest, self.var = ops fun.log(f'8bit load? {self.var.type} {self.is8bit}') @property def is8bit(self): if not isinstance(self.var.type, PodType): return False if self.var.type.pod in ['uint8_t', 'int8_t', 'char']: return True return False def _maybecast8bit(self, reg): if self.is8bit: return [f'shr {reg}, {reg}, 8'] return [] @property def out(self): return self.dest def synth(self, scratches): reg = self.dest if self.var.name in self.fun.locals: src = self.var.stackaddr return [f'load {reg}, [sp, {src}]'] # no 8bit cast to/from stack elif self.var.addr_reg is not None: return [f'load {reg}, [{self.var.addr_reg}]'] + self._maybecast8bit(reg) else: return [f'set {reg}, {self.var.name}', f'nop // in case we load a far global', f'load {reg}, [{reg}]'] # unsure if we need a cast here class Push(AsmOp): def __init__(self, reg, stackaddr): self.reg = reg self.stackaddr = stackaddr def synth(self, scratches): return [f'store {self.reg}, [sp, {self.stackaddr}]'] class Pop(AsmOp): def __init__(self, reg, stackaddr): self.reg = reg self.stackaddr = stackaddr @property def out(self): return self.reg def synth(self, scratches): return [f'load {self.reg}, [sp, {self.stackaddr}]'] class Store(AsmOp): scratch_need = 1 def __init__(self, fun, ops): self.fun = fun self.src, self.var = ops def synth(self, scratches): (sc,) = scratches reg = self.src if self.var.name in self.fun.locals: dst = self.var.stackaddr self.fun.log(f'storing {self.var}({reg}) to [sp, {dst}]') return [f'store {reg}, [sp, {dst}]'] elif self.var.addr_reg is not None: return [f'store {reg}, [{self.var.addr_reg}]'] return [f'set {sc}, {self.var.name}', f'nop // you know, in case blah', f'store {reg}, [{sc}]'] class Assign(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.fun = fun self.src, self.var = ops @property def out(self): return self.var def synth(self, scratches): return [f'or {self.var}, {self.src}, {self.src}'] class SetAddr(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.dest, self.ident = ops @property def out(self): return self.dest def synth(self, scratches): reg = self.dest return [f'set {reg}, {self.ident}', f'nop // placeholder for a long address'] class Set16Imm(AsmOp): scratch_need = 0 def __init__(self, fun, ops): self.dest, self.imm16 = ops @property def out(self): return self.dest def synth(self, scratches): reg = self.dest hi = (self.imm16 >> 8) & 0xff lo = (self.imm16 >> 0) & 0xff if hi != 0: return [f'set {reg}, {lo}', f'seth {reg}, {hi}'] else: return [f'set {reg}, {lo}'] class Move(AsmOp): scratch_need = 0 def __init__(self, dst, src): self.dst = dst self.src = src def synth(self, scratches): return [f'or {self.dst}, {self.src}, {self.src}'] class Swap(AsmOp): scratch_need = 1 def __init__(self, a0, a1): self.a0 = a0 self.a1 = a1 def synth(self, scratches): (sc0,) = scratches return [f'or {sc0}, {self.a0}, {self.a0}', f'or {self.a0}, {self.a1}, {self.a1}', f'or {self.a1}, {sc0}, {sc0}'] class IfOp(AsmOp): scratch_need = 1 def __init__(self, fun, op): self.fun = fun self.cond, mark, self.has_else = op self.then_mark = f'_then_{mark}' self.else_mark = f'_else_{mark}' self.endif_mark = f'_endif_{mark}' def synth(self, scratches): sc0 = scratches[0] if self.has_else: return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', # flag if cond == 0 f'beq {self.else_mark}'] else: return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', f'beq {self.endif_mark}'] def synth_else(self): return [f'cmp r0, r0', # trick because beq is better than "or pc, ." f'beq {self.endif_mark}', f'{self.else_mark}:'] def synth_endif(self): return [f'{self.endif_mark}:'] class IfEq(IfOp): scratch_need = 0 def __init__(self, a, b, mark, has_else): self.a = a self.b = b self.has_else = has_else self.then_mark = f'_then_{mark}' self.else_mark = f'_else_{mark}' self.endif_mark = f'_endif_{mark}' def synth(self, scratches): if self.has_else: return [f'cmp {self.a}, {self.b}', f'bneq {self.else_mark}'] else: return [f'cmp {self.a}, {self.b}', f'bneq {self.endif_mark}'] class WhileOp(AsmOp): scratch_need = 1 @staticmethod def synth_loop(mark): loop_mark = f'_loop_{mark}' return [f'{loop_mark}:'] def __init__(self, cond, mark): self.cond = cond self.loop_mark = f'_loop_{mark}' self.endwhile_mark = f'_endwhile_{mark}' def synth(self, scratches): sc0 = scratches[0] return [f'set {sc0}, 0', f'cmp {sc0}, {self.cond}', f'beq {self.endwhile_mark}'] def synth_endwhile(self): return [f'cmp r0, r0', f'beq {self.loop_mark}', f'{self.endwhile_mark}:'] class Delayed(AsmOp): def __init__(self, out_cb): self.out_cb = out_cb @property def out(self): return self.out_cb() def synth(self): return [] def litt_type(val): if isinstance(val, str): return PointerType(pointed=PodType(pod='char', size=1, const=True)) elif isinstance(val, float): return PodType(pod='float', const=True) else: return PodType(pod='int', const=True) def return_type(type): return type.ret def deref_type(type): return type.pointed class CcInterp(lark.visitors.Interpreter): def __init__(self): self.global_scope = Scope() self.cur_scope = self.global_scope self.cur_fun = None self.funs = [] self.next_reg = 0 self.next_marker = 0 self.strings = {} self.next_string_token = 0 self.rodata = {} def _lookup_symbol(self, s): scope = self.cur_scope while scope is not None: if s in scope.symbols: return scope.symbols[s] scope = scope.parent return None def _get_reg(self): return self.cur_fun.regs.take() def _get_marker(self): mark = self.next_marker self.next_marker += 1 return mark @contextlib.contextmanager def _nosynth(self): old = self._synth self._synth = lambda *_: None yield self._synth = old def _synth(self, op): with self.cur_fun.regs.borrow(op.scratch_need) as scratches: self._log(f'{op.__class__.__name__}') self.cur_fun.ops.append(lambda: op.synth(scratches)) def _load(self, ident, reg=None): s = self._lookup_symbol(ident) assert s is not None, f'unknown identifier {ident}' if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols: reg = self.cur_fun.regs.take(reg=reg, orwhatever=True) return SetAddr(self.cur_fun, [reg, s.name]) else: if s.type.volatile: self._log(f'loading volatile {s}') return Load(self.cur_fun, [reg, s]) reg, created = self.cur_fun.regs.load(s) if created: return Load(self.cur_fun, [reg, s]) else: self._log(f'{s} was already in {reg}') return Reg(reg) def identifier(self, tree): # TODO: not actually load the value until it's used in an expression # could have the op.out() function have a side effect that does that # if it's an assignment, we need to make a Variable with the proper # address and assign the register to it. # If it's volatile, we need to flush it. def delayed_load(): self._log(f'delay-loading {tree.children[0]}') reg = getattr(tree, 'request_out', None) op = self._load(tree.children[0], reg=reg) self._synth(op) return op.out tree.op = Delayed(delayed_load) tree.var = self._lookup_symbol(tree.children[0]) assert tree.var is not None, f'undefined identifier: {tree.children[0]}' tree.type = tree.var.type def _make_string(self, s): nexttok = self.next_string_token self.next_string_token += 1 token = f'_str{nexttok}' self.strings[token] = s self.global_scope.symbols[token] = Variable(litt_type(s), token) return token def literal(self, tree): imm = tree.children[0] if isinstance(imm, str): tok = self._make_string(imm) tree.children = [tok] return self.identifier(tree) def delayed(): if hasattr(tree, 'request_out'): reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True) else: reg = self.cur_fun.regs.take() assert reg is not None, f'weird: {self.cur_fun.regs}' op = Set16Imm(self.cur_fun, [reg, imm]) self._synth(op) return op.out tree.op = Delayed(delayed) tree.type = litt_type(tree.children[0]) def array_item(self, tree): # transform blarg[foo] into *(blarg + foo) because reasons # TODO: bring this over to the main parser, we need to know the type # of blarg in order to compute the actual offset s = self._lookup_symbol(tree.children[0].children[0]) if s.type.pointed.size == 1: # easy deref + add tree.children = [lark.Tree('add', tree.children)] return self.dereference(tree) # everything else is 16 bit, do something like: # array[index] === *(array + (index << 1)) add1 = lark.Tree('shl', [tree.children[1], lark.Tree('literal', [1])]) add2 = lark.Tree('add', [tree.children[0], add1]) tree.children = [add2] return self.dereference(tree) def _assign(self, left, right): self._log(f'assigning {left} = {right}') self.cur_fun.regs.assign(left, right) op = Store(self.cur_fun, [right, left]) self._synth(op) return op.out def assignment(self, tree): self.visit_children(tree) val = tree.children[1].op.out assert val is not None, f'no output!!! {tree.pretty()}' # left hand side is an lvalue, retrieve it var = tree.children[0].var tree.op = self._assign(var, val) def global_var(self, tree): self.visit_children(tree) var = Variable.from_def(tree) self.global_scope.symbols[var.name] = var val = 0 if len(tree.children) > 2: val = tree.children[2].children[0].value def fun_decl(self, tree): fun = FunctionSpec(tree.children[0]) self.cur_scope.symbols[fun.name] = fun def _reg_shuffle(self, src, dst): def swap(old, new): if old == new: return oldpos = src.index(old) try: newpos = src.index(new) except ValueError: src[oldpos] = new else: src[newpos], src[oldpos] = src[oldpos], src[newpos] self._synth(Swap(old, new)) for i, r in enumerate(dst): if r not in src: #self.cur_fun.regs.take(r) # precious! self._synth(Move(r, src[i])) else: swap(r, src[i]) def _prep_fun_call(self, fn_reg, params): """Move all params to r0-rn.""" target_regs = [f'r{i}' for i in range(len(params) + 1)] self._reg_shuffle(params + [fn_reg], target_regs) return target_regs[-1] @contextlib.contextmanager def _push(self, dontpush=()): pushed = [] self.cur_fun.regs.evict_all() with self.cur_fun.tempstack(): for reg in allregs(): if reg not in self.cur_fun.regs.available and reg not in dontpush: self._log(f'reg in use: {reg}') stk = self.cur_fun.get_stack() pushed.append((reg, stk)) self._synth(Push(reg, stk)) self.cur_fun.regs.give(reg) yield pushed def fun_call(self, tree): self.cur_fun.fun_calls += 1 if len(tree.children) == 1: param_children = [] elif tree.children[1].data == 'comma_expr': param_children = tree.children[1].children else: param_children = [tree.children[1]] if True: # pre-allocate output registers for i, child in enumerate(param_children): self._log(f'requesting r{i} for {child}') child.request_out = f'r{i}' self._log(f'requesting r{len(param_children)} for {tree.children[0]}') tree.children[0].request_out = f'r{len(param_children)}' self.visit_children(tree) fn_reg = tree.children[0].op.out param_regs = [c.op.out for c in param_children] self._log(f'calling {tree.children[0].children[0]}({param_regs})') with self._push(dontpush=param_regs + [fn_reg]) as pushed: self._flush_deferred() # side effects fn_reg = self._prep_fun_call(fn_reg, param_regs) self.cur_fun.regs.reset() self.cur_fun.regs.take(fn_reg) for reg in (f'r{i}' for i in range(len(param_regs))): self.cur_fun.regs.take(reg) tree.op = FnCall(self.cur_fun, [fn_reg, param_regs]) self._synth(tree.op) self.cur_fun.regs.reset() pops = [] for reg, stk in pushed: self.cur_fun.regs.take(reg) pops.append(Pop(reg, stk)) ret = 'r0' if hasattr(tree, 'request_out'): ret = tree.request_out self._synth(Move(ret, 'r0')) nret = self.cur_fun.regs.take(ret, orwhatever=True) if nret != ret: self._synth(Move(nret, ret)) tree.op = type('blah', (), {'out': nret}) for op in pops: self._synth(op) tree.type = return_type(tree.children[0].type) def _flush_deferred(self): if self.cur_fun.deferred_ops: self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}') for op in self.cur_fun.deferred_ops: self._synth(op) self.cur_fun.deferred_ops = [] def statement(self, tree): self.visit_children(tree) self._flush_deferred() self.cur_fun.regs.drop_temps() iter_expression = statement def if_stat(self, tree): mark = self._get_marker() has_else = len(tree.children) > 2 # optimization!!!! if tree.children[0].data == 'eq': self.visit_children(tree.children[0]) a, b = (c.op.out for c in tree.children[0].children) op = IfEq(a, b, mark, has_else) else: self.visit(tree.children[0]) op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) self._synth(op) self.cur_fun.regs.drop_temps() begin_vars = self.cur_fun.regs.snap() valid_vars_set = set(begin_vars[0].items()) self.visit(tree.children[1]) valid_vars_set &= set(self.cur_fun.regs.snap()[0].items()) if has_else: self.cur_fun.regs.restore(begin_vars) self.cur_fun.ops.append(op.synth_else) self.visit(tree.children[2]) valid_vars_set &= set(self.cur_fun.regs.snap()[0].items()) valid_vars = dict(valid_vars_set) valid_snap = (valid_vars, {r: begin_vars[1][r] for r in valid_vars.values()}) self.cur_fun.regs.restore(valid_snap) self.cur_fun.ops.append(op.synth_endif) def _restore_vars(self, begin_vars): curvars = self.cur_fun.regs.vars # 0. do the loop shuffle shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars} dst, src = tuple(zip(*shuffles.items())) or ([], []) self._reg_shuffle(src, dst) # 1. load the rest for v, r in begin_vars.items(): if v in curvars: continue self._log(f'loading missing var {v} into {r}') self._synth(Load(self.cur_fun, [r, v])) def while_stat(self, tree): mark = self._get_marker() self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[0]) postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[0].op.out, mark) self._synth(op) self.cur_fun.regs.drop_temps() self.visit(tree.children[1]) self._restore_vars(begin_vars) self.cur_fun.ops.append(op.synth_endwhile) self.cur_fun.regs.restore(postcond) def for_stat(self, tree): mark = self._get_marker() self.visit(tree.children[0]) # initialization self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[1]) postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[1].op.out, mark) self._synth(op) self.cur_fun.regs.drop_temps() self.visit(tree.children[3]) self.visit(tree.children[2]) # 3rd statement self._restore_vars(begin_vars) self.cur_fun.ops.append(op.synth_endwhile) self.cur_fun.regs.restore(postcond) def _unary_op(op): def _f(self, tree): self.visit_children(tree) operand = tree.children[0].op.out reg = self.cur_fun.regs.take() tree.op = op(self.cur_fun, [reg, operand]) self._synth(tree.op) self.cur_fun.regs.give(operand) return _f def _deref(self, tree): reg = tree.children[0].op.out var = Variable.from_dereference(reg, tree.children[0]) self._log(f'making var {var} from derefing reg {reg}') tree.var = var def delayed_load(): self._log(f'delay-loading {tree.children[0]}') op = Load(self.cur_fun, [reg, var]) self._synth(op) return op.out tree.op = Delayed(delayed_load) tree.type = var.type def dereference(self, tree): self.visit_children(tree) self._deref(tree) def pointer_access(self, tree): with self._nosynth(): self.visit_children(tree) ch_strct, ch_field = tree.children type = ch_strct.type.pointed strct = self.global_scope.structs[type.struct] offs = strct.offset(ch_field) if offs > 0: load_offs = lark.Tree('literal', [offs]) addop = lark.Tree('add', [tree.children[0], load_offs]) self.visit(addop) addop.type = PointerType(pointed=strct.field_type(ch_field)) tree.children = [addop] else: tree.children[0].type = PointerType(pointed=strct.field_type(ch_field)) tree.children = [tree.children[0]] self.visit_children(tree) self._deref(tree) def post_increment(self, tree): self.visit_children(tree) tree.op = Reg(tree.children[0].op.out) var = tree.children[0].var reg = tree.op.out self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) tree.type = var.type def post_decrement(self, tree): self.visit_children(tree) tree.op = Reg(tree.children[0].op.out) var = tree.children[0].var reg = tree.op.out self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) tree.type = var.type def pre_increment(self, tree): self.visit_children(tree) reg = tree.children[0].op.out tree.op = Incr(self.cur_fun, [reg, reg]) self._synth(tree.op) cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) self.cur_fun.regs.cleanup[reg].append(cleanup) bool_not = _unary_op(BoolNot) def _binary_op(op): def _f(self, tree): self.visit_children(tree) left, right = (x.op.out for x in tree.children) dest = self.cur_fun.regs.take() tree.op = op(self.cur_fun, [dest, left, right]) self._synth(tree.op) self.cur_fun.regs.give(left) self.cur_fun.regs.give(right) tree.type = tree.children[0].type # because uhm reasons return _f def _combo(uop, bop): def _f(self, tree): ftree = lark.Tree(bop, tree.children) tree.children = [ftree] uop.__get__(self)(tree) return _f shl = _binary_op(ShlOp) add = _binary_op(AddOp) sub = _binary_op(SubOp) mul = _binary_op(MulOp) _and = _binary_op(AndOp) _or = _binary_op(OrOp) # ... gt = _binary_op(GtOp) lt = _binary_op(LtOp) neq = _binary_op(NeqOp) eq = _combo(bool_not, 'neq') def shr(self, tree): self.visit(tree.children[0]) left = tree.children[0].op.out assert tree.children[1].data == 'literal' right = tree.children[1].children[0] dest = self.cur_fun.regs.take() tree.op = ShrOp(self.cur_fun, [dest, left, right]) self._synth(tree.op) self.cur_fun.regs.give(left) tree.type = tree.children[0].type # because uhm reasons def _assign(self, left, right): self._log(f'assigning {left} = {right}') self.cur_fun.regs.assign(left, right) op = Store(self.cur_fun, [right, left]) self._synth(op) return op.out def _assign_op(op): def f(self, tree): opnode = lark.Tree(op, tree.children) self.visit(opnode) val = opnode.op.out # left hand side is an lvalue, retrieve it var = tree.children[0].var tree.op = self._assign(var, val) return f or_ass = _assign_op('_or') add_ass = _assign_op('add') def _forward_op(self, tree): self.visit_children(tree) tree.op = tree.children[0].op def cast(self, tree): if hasattr(tree, 'request_out'): tree.children[1].request_out = tree.request_out self.visit_children(tree) tree.op = tree.children[1].op tree.type = tree.children[0] def _log(self, line): self.cur_fun.log(line) def local_var(self, tree): self.visit_children(tree) assert self.cur_fun is not None assert self.cur_scope is not None var = Variable.from_def(tree) var.stackaddr = self.cur_fun.get_stack() # will have to invert self.cur_scope.symbols[var.name] = var self.cur_fun.locals[var.name] = var if len(tree.children) > 2: initval = tree.children[2].op.out cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) self.cur_fun.regs.assign(var, initval, cleanup=cleanup) self._log(f'assigning {var} = {initval}') def local_array(self, tree): self.visit_children(tree) assert self.cur_fun is not None assert self.cur_scope is not None var = Variable.from_def(tree) if var.type.const: # storing in rodata assert len(tree.children) > 3 tok = self.next_string_token self.next_string_token += 1 tok = f'_dat{tok}' assert len(tree.children[3].children) > 0 fields =[c.children[0] for c in tree.children[3].children] if tree.children[2].data == 'empty_array': fieldcount = len(fields) else: fieldcount = tree.children[2].children[0] if len(fields) < fieldcount: fields += [0] * (fieldcount - len(fields)) name = tree.children[1] self.rodata[tok] = (tree.children[0], fields) var = Variable(PointerType(pointed=litt_type(fields[0])), tok) self.global_scope.symbols[name] = var def fun_def(self, tree): prot, body = tree.children fun = Function(prot) assert self.cur_fun is None self.cur_fun = fun self.cur_scope.symbols[fun.spec.name] = fun.spec params = fun.param_dict() def getcleanup(var): def cl(reg): self._log(f'cleanup for {var} in {reg}') self._synth(Store(self.cur_fun, [reg, var])) return cl for i, param in enumerate(fun.params): fun.locals[param.name] = param param.stackaddr = fun.get_stack() self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}') cleanup = getcleanup(param) fun.regs.loaded(param, f'r{i}', cleanup=cleanup) fun_scope = Scope() fun_scope.parent = self.cur_scope fun_scope.symbols.update(params) body.scope = fun_scope self.visit_children(tree) if fun.ret is None: if fun.spec.name == 'main': self._synth(Set16Imm(fun, ['r0', 0])) self._synth(ReturnReg(fun, ['r0'])) elif fun.spec.return_type == VoidType(): self._synth(ReturnReg(fun, ['r0'])) else: assert fun.ret is not None self.cur_fun = None self.funs.append(fun) def body(self, tree): bscope = getattr(tree, 'scope', Scope()) bscope.parent = self.cur_scope self.cur_scope = bscope self.visit_children(tree) self.cur_scope = bscope.parent def return_stat(self, tree): assert self.cur_fun is not None self.cur_fun.ret = True self.visit_children(tree) expr_reg = tree.children[0].op.out tree.op = ReturnReg(self.cur_fun, [expr_reg]) self._synth(tree.op) def struct_def(self, tree): self.visit_children(tree) strct = Struct(tree) self.cur_scope.structs[strct.name] = strct preamble = [f'_start:', f'xor r0, r0, r0', f'xor r1, r1, r1', f'set sp, 0', f'seth sp, {0x11}', # 256 bytes of stack ought to be enough f'set r2, main', f'set r3, 0', f'add lr, pc, r3', f'or pc, r2, r2', f'cmp r0, r0', f'beq [pc, -4] // loop forever', ] def filter_dupes(ops): dupe_re = re.compile(r'or (r\d+), \1, \1') for op in ops: if dupe_re.search(op): continue yield op def parse_tree(tree, debug=False): tr = CcTransform() tree = tr.transform(tree) if debug: print(tree.pretty()) inte = CcInterp() inte.visit(tree) out = [] for fun in inte.funs: out += fun.synth() out.append('') for sym, dat in inte.strings.items(): out += [f'.global {sym}', f'{sym}:'] dat += '\0' if len(dat) % 2 != 0: dat += '\0' dat = dat.encode() nwords = len(dat) // 2 out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)] out.append('') for sym, (type, fields) in inte.rodata.items(): out += [f'.global {sym}', f'{sym}:'] out += [f'.word 0x{d:04x}' for d in fields] out.append('') return '\n'.join(filter_dupes(out)) def larkparse(f, debug=False): with open(GRAMMAR_FILE) as g: asparser = lark.Lark(g.read()) data = f.read() if isinstance(data, bytes): data = data.decode() tree = asparser.parse(data) return parse_tree(tree, debug=debug) def parse_args(): parser = argparse.ArgumentParser(description='Compile.') parser.add_argument('--debug', action='store_true', help='print the AST') parser.add_argument('--assembly', '-S', action='store_true', help='output assembly') parser.add_argument('--compile', '-c', action='store_true', help='compile a single file') parser.add_argument('input', type=argparse.FileType('r'), default=sys.stdin, help='input file (default: stdin)') parser.add_argument('--output', '-o', type=argparse.FileType('wb'), default=sys.stdout.buffer, help='output file') parser.add_argument('-I', dest='include_dirs', action='append', default=[], help='include dirs') return parser.parse_args() def preprocess(fin, include_dirs): cmd = list(CPP) + [f'-I{x}' for x in include_dirs] p = subprocess.Popen(cmd, stdin=fin, stdout=subprocess.PIPE) out, _ = p.communicate() if p.returncode != 0: raise RuntimeError(f'preprocessor error') return io.StringIO(out.decode()) def assemble(text, fout): fin = io.StringIO(text) asmod.write_obj(fout, *asmod.larkparse(fin)) def main(): args = parse_args() assy = larkparse(preprocess(args.input, args.include_dirs), debug=args.debug) if args.assembly: args.output.write(assy.encode() + b'\n') else: assemble(assy, args.output) if __name__ == "__main__": main()