diff --git a/tools/cc.py b/tools/cc.py index 1d2b4c3..989344c 100644 --- a/tools/cc.py +++ b/tools/cc.py @@ -4,6 +4,7 @@ import collections import contextlib from dataclasses import dataclass import importlib +import inspect import io import re import struct @@ -34,6 +35,9 @@ class Variable: def __repr__(self): return f'' + def __str__(self): + return self.name + @classmethod def from_def(cls, tree): type = tree.children[0] @@ -76,21 +80,42 @@ class Variable: print(tree) +def allregs(): + return (f'r{i}' for i in range(12)) + + class RegBank: def __init__(self, logger=None): self.reset() self.log = logger or print def reset(self): - self.available = [f'r{i}' for i in range(12)] + self.available = list(allregs()) self.vars = {} self.varregs = {} self.cleanup = collections.defaultdict(list) + self.taken = {} - def take(self, reg=None): + def __str__(self): + return f'regs(avail={self.available}, vars={self.vars})' + + def take(self, reg=None, orwhatever=False): + out = self._take(reg=reg, orwhatever=orwhatever) + self.taken[out] = inspect.stack()[1] + return out + + def _take(self, reg=None, orwhatever=False): if reg is not None: if reg not in self.available: - self.evict(self.var(reg)) + if reg in self.varregs: + self.evict(self.var(reg)) + elif orwhatever: + return self._take() + else: + raise RuntimeError( + f'trying to overtake {reg}, ' + f'varregs: {self.varregs}, vars: {self.vars}, ' + f'taken in {self.taken[reg]}') if reg in self.available: self.available.remove(reg) return reg @@ -102,6 +127,7 @@ class RegBank: return self.available.pop(0) def give(self, reg): + assert reg is not None if reg in self.varregs: # Need to call evict() with the var to free it. return @@ -109,11 +135,13 @@ class RegBank: def loaded(self, var, reg, cleanup=None): """Tells the regbank some variable was loaded into the given register.""" + assert reg is not None + assert var is not None self.vars[var] = reg self.varregs[reg] = var self.take(reg) if cleanup is not None: - self.log(f'recording cleanup for {reg}({var.name})') + self.log(f'recording cleanup for {reg}({var})') self.cleanup[reg].append(cleanup) def stored(self, var): @@ -136,6 +164,7 @@ class RegBank: def assign(self, var, reg, cleanup=None): """Assign a previously-used register to a variable.""" + assert reg is not None if var in self.vars: self.stored(var) if reg in self.varregs: @@ -150,21 +179,25 @@ class RegBank: def var(self, reg): return self.varregs.get(reg, None) + def evict_all(self): + for var in list(self.vars): + self.evict(var) + def evict(self, var): """Runs var callbacks & frees the register.""" + assert var in self.vars, f'trying to evict {var}' + self.log(f'evicting {var}') if var not in self.vars: return reg = self.vars[var] - for cb in self.cleanup.pop(var, []): + for cb in self.cleanup.pop(reg, []): cb(reg) self.stored(var) - def flush_all(self): - for reg in list(self.cleanup): - self.log(f'flushing {reg}({self.varregs[reg].name})') - for cb in self.cleanup.pop(reg): - cb(reg) - self.reset() + def drop_temps(self): + for reg in allregs(): + if reg not in self.available and reg not in self.varregs: + self.give(reg) def swapped(self, reg0, reg1): var0 = self.varregs.get(reg0, None) @@ -192,6 +225,7 @@ class RegBank: self.reset() vardict, cleanup = snap for var, reg in vardict.items(): + assert reg is not None, f'bad snapshot: {snap}' self.loaded(var, reg) self.cleanup[reg] = cleanup.get(reg, []) @@ -225,6 +259,7 @@ class Function: self.params = [Variable(*x.children) for x in fun_prot.children[2:]] self.ret = None self.nextstack = 0 + self.topstack = 0 self.statements = [] self.regs = RegBank(logger=self.log) self.deferred_ops = [] @@ -236,13 +271,24 @@ class Function: @property def stack_usage(self): - return self.nextstack + 2 + top_usage = max(self.topstack, self.nextstack) + if self.fun_calls > 0: + return top_usage + 2 + return top_usage def get_stack(self, size=2): stk = self.nextstack self.nextstack += size return stk + @contextlib.contextmanager + def tempstack(self): + stk = self.nextstack + yield + if self.nextstack > self.topstack: + self.topstack = self.nextstack + self.nextstack = stk + def param_dict(self): return {p.name: p for p in self.params} @@ -250,14 +296,12 @@ class Function: return repr(self.spec) def synth(self): + ops = [] if self.fun_calls > 0: - preamble = [f'store lr, [sp, -2]', - f'set r4, {self.stack_usage}', - f'sub sp, sp, r4'] - else: - preamble = [] - - ops = preamble + ops += [f'store lr, [sp, -2]'] + if self.stack_usage > 0: + ops += [f'set r4, {self.stack_usage}', + f'sub sp, sp, r4'] for op in self.ops: ops += op() indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops] @@ -524,17 +568,22 @@ class ReturnReg(AsmOp): (self.ret_reg,) = ops def synth(self, scratches): - if self.fun.fun_calls == 0: - return [f'or r0, {self.ret_reg}, {self.ret_reg}', - f'or pc, lr, lr'] - sc0 = scratches[0] + ops = [] stack_usage = self.fun.stack_usage - ret = self.ret_reg - assert stack_usage < 255 - return [f'set {sc0}, {stack_usage}', - f'add sp, sp, {sc0}', - f'or r0, {ret}, {ret}', - f'load pc, [sp, -2] // return'] + (sc0,) = scratches + if stack_usage > 0: + ops += [f'set {sc0}, {stack_usage}', + f'add sp, sp, {sc0}'] + if self.fun.ret is not None: + ret = self.ret_reg + ops += [f'or r0, {ret}, {ret}'] + + if self.fun.fun_calls == 0: + ops += [f'or pc, lr, lr'] + else: + ops += [f'load pc, [sp, -2] // return'] + + return ops class Load(AsmOp): @@ -574,6 +623,29 @@ class Load(AsmOp): f'nop // in case we load a far global', f'load {reg}, [{reg}]'] # unsure if we need a cast here + +class Push(AsmOp): + def __init__(self, reg, stackaddr): + self.reg = reg + self.stackaddr = stackaddr + + def synth(self, scratches): + return [f'store {self.reg}, [sp, {self.stackaddr}]'] + + +class Pop(AsmOp): + def __init__(self, reg, stackaddr): + self.reg = reg + self.stackaddr = stackaddr + + @property + def out(self): + return self.reg + + def synth(self, scratches): + return [f'load {self.reg}, [sp, {self.stackaddr}]'] + + class Store(AsmOp): scratch_need = 1 @@ -581,10 +653,6 @@ class Store(AsmOp): self.fun = fun self.src, self.var = ops - @property - def out(self): - return None - def synth(self, scratches): (sc,) = scratches reg = self.src @@ -820,11 +888,11 @@ class CcInterp(lark.visitors.Interpreter): self._log(f'{op.__class__.__name__}') self.cur_fun.ops.append(lambda: op.synth(scratches)) - def _load(self, ident): + def _load(self, ident, reg=None): s = self._lookup_symbol(ident) assert s is not None, f'unknown identifier {ident}' if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols: - reg = self._get_reg() + reg = reg or self._get_reg() return SetAddr(self.cur_fun, [reg, ident]) else: if s.type.volatile: @@ -845,7 +913,8 @@ class CcInterp(lark.visitors.Interpreter): # If it's volatile, we need to flush it. def delayed_load(): self._log(f'delay-loading {tree.children[0]}') - op = self._load(tree.children[0]) + reg = getattr(tree, 'request_out', None) + op = self._load(tree.children[0], reg=reg) self._synth(op) return op.out @@ -868,10 +937,17 @@ class CcInterp(lark.visitors.Interpreter): tok = self._make_string(imm) tree.children = [tok] return self.identifier(tree) - reg = self._get_reg() - assert self.cur_fun is not None - tree.op = Set16Imm(self.cur_fun, [reg, imm]) - self._synth(tree.op) + def delayed(): + if hasattr(tree, 'request_out'): + reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True) + else: + reg = self.cur_fun.regs.take() + assert reg is not None, f'weird: {self.cur_fun.regs}' + op = Set16Imm(self.cur_fun, [reg, imm]) + self._synth(op) + return op.out + + tree.op = Delayed(delayed) tree.type = litt_type(tree.children[0]) def _assign(self, left, right): @@ -884,6 +960,7 @@ class CcInterp(lark.visitors.Interpreter): def assignment(self, tree): self.visit_children(tree) val = tree.children[1].op.out + assert val is not None, f'no output!!! {tree.pretty()}' # left hand side is an lvalue, retrieve it var = tree.children[0].var @@ -916,7 +993,7 @@ class CcInterp(lark.visitors.Interpreter): for i, r in enumerate(dst): if r not in src: - self.cur_fun.regs.take(r) # precious! + #self.cur_fun.regs.take(r) # precious! self._synth(Move(r, src[i])) else: swap(r, src[i]) @@ -924,46 +1001,92 @@ class CcInterp(lark.visitors.Interpreter): def _prep_fun_call(self, fn_reg, params): """Move all params to r0-rn.""" - target_regs = [f'r{i}' for i in range(len(params))] - new_fn = fn_reg - while new_fn in target_regs or new_fn in params: - new_fn = self.cur_fun.regs.take() - if new_fn != fn_reg: - self._synth(Move(new_fn, fn_reg)) - fn_reg = new_fn - self._reg_shuffle(params, target_regs) - return fn_reg + target_regs = [f'r{i}' for i in range(len(params) + 1)] + self._reg_shuffle(params + [fn_reg], target_regs) + return target_regs[-1] + + @contextlib.contextmanager + def _push(self, dontpush=()): + pushed = [] + self.cur_fun.regs.evict_all() + with self.cur_fun.tempstack(): + for reg in allregs(): + if reg not in self.cur_fun.regs.available and reg not in dontpush: + self._log(f'reg in use: {reg}') + stk = self.cur_fun.get_stack() + pushed.append((reg, stk)) + self._synth(Push(reg, stk)) + self.cur_fun.regs.give(reg) + yield pushed def fun_call(self, tree): self.cur_fun.fun_calls += 1 - self.visit_children(tree) - fn_reg = tree.children[0].op.out if len(tree.children) == 1: - param_regs = [] + param_children = [] elif tree.children[1].data == 'comma_expr': - param_regs = [c.op.out for c in tree.children[1].children] + param_children = tree.children[1].children else: - param_regs = [tree.children[1].op.out] - self._log(f'calling {tree.children[0].children[0]}({param_regs})') - self.cur_fun.regs.flush_all() - for i in range(len(param_regs)): - self.cur_fun.regs.take(f'r{i}') - fn_reg = self._prep_fun_call(fn_reg, param_regs) - self.cur_fun.regs.take(fn_reg) - tree.op = FnCall(self.cur_fun, [fn_reg, param_regs]) - self._synth(tree.op) - self.cur_fun.regs.reset() - self.cur_fun.regs.take('r0') + param_children = [tree.children[1]] + + + if True: + # pre-allocate output registers + for i, child in enumerate(param_children): + self._log(f'requesting r{i} for {child}') + child.request_out = f'r{i}' + self._log(f'requesting r{len(param_children)} for {tree.children[0]}') + tree.children[0].request_out = f'r{len(param_children)}' + self.visit_children(tree) + fn_reg = tree.children[0].op.out + param_regs = [c.op.out for c in param_children] + self._log(f'calling {tree.children[0].children[0]}({param_regs})') + + with self._push(dontpush=param_regs + [fn_reg]) as pushed: + + self._flush_deferred() # side effects + fn_reg = self._prep_fun_call(fn_reg, param_regs) + self.cur_fun.regs.reset() + self.cur_fun.regs.take(fn_reg) + for reg in (f'r{i}' for i in range(len(param_regs))): + self.cur_fun.regs.take(reg) + + tree.op = FnCall(self.cur_fun, [fn_reg, param_regs]) + self._synth(tree.op) + + self.cur_fun.regs.reset() + + pops = [] + for reg, stk in pushed: + self.cur_fun.regs.take(reg) + pops.append(Pop(reg, stk)) + + ret = 'r0' + if hasattr(tree, 'request_out'): + ret = tree.request_out + self._synth(Move(ret, 'r0')) + + nret = self.cur_fun.regs.take(ret, orwhatever=True) + if nret != ret: + self._synth(Move(nret, ret)) + tree.op = type('blah', (), {'out': nret}) + + for op in pops: + self._synth(op) + tree.type = return_type(tree.children[0].type) - def statement(self, tree): - self.visit_children(tree) + def _flush_deferred(self): if self.cur_fun.deferred_ops: self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}') for op in self.cur_fun.deferred_ops: self._synth(op) self.cur_fun.deferred_ops = [] + def statement(self, tree): + self.visit_children(tree) + self._flush_deferred() + self.cur_fun.regs.drop_temps() + iter_expression = statement def if_stat(self, tree): @@ -978,6 +1101,7 @@ class CcInterp(lark.visitors.Interpreter): self.visit(tree.children[0]) op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) self._synth(op) + self.cur_fun.regs.drop_temps() begin_vars = self.cur_fun.regs.snap() self.visit(tree.children[1]) if has_else: @@ -1009,6 +1133,7 @@ class CcInterp(lark.visitors.Interpreter): postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[0].op.out, mark) self._synth(op) + self.cur_fun.regs.drop_temps() self.visit(tree.children[1]) self._restore_vars(begin_vars) self.cur_fun.ops.append(op.synth_endwhile) @@ -1023,6 +1148,7 @@ class CcInterp(lark.visitors.Interpreter): postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[1].op.out, mark) self._synth(op) + self.cur_fun.regs.drop_temps() self.visit(tree.children[3]) self.visit(tree.children[2]) # 3rd statement self._restore_vars(begin_vars) @@ -1093,7 +1219,14 @@ class CcInterp(lark.visitors.Interpreter): self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) tree.type = var.type - pre_increment = _unary_op(Incr) + def pre_increment(self, tree): + self.visit_children(tree) + reg = tree.children[0].op.out + tree.op = Incr(self.cur_fun, [reg, reg]) + self._synth(tree.op) + cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) + self.cur_fun.regs.cleanup[reg].append(cleanup) + bool_not = _unary_op(BoolNot) def _binary_op(op): @@ -1152,6 +1285,8 @@ class CcInterp(lark.visitors.Interpreter): tree.op = tree.children[0].op def cast(self, tree): + if hasattr(tree, 'request_out'): + tree.children[1].request_out = tree.request_out self.visit_children(tree) tree.op = tree.children[1].op tree.type = tree.children[0] @@ -1184,7 +1319,10 @@ class CcInterp(lark.visitors.Interpreter): params = fun.param_dict() def getcleanup(var): - return lambda reg: self._synth(Store(self.cur_fun, [reg, var])) + def cl(reg): + self._log(f'cleanup for {var} in {reg}') + self._synth(Store(self.cur_fun, [reg, var])) + return cl for i, param in enumerate(fun.params): fun.locals[param.name] = param