cc: mostly fixed function calls

This commit is contained in:
Paul Mathieu 2021-03-20 21:27:26 -07:00
parent eb303641d9
commit ef81ec3b12

View File

@ -4,6 +4,7 @@ import collections
import contextlib import contextlib
from dataclasses import dataclass from dataclasses import dataclass
import importlib import importlib
import inspect
import io import io
import re import re
import struct import struct
@ -34,6 +35,9 @@ class Variable:
def __repr__(self): def __repr__(self):
return f'<Var: {self.type} {self.name}>' return f'<Var: {self.type} {self.name}>'
def __str__(self):
return self.name
@classmethod @classmethod
def from_def(cls, tree): def from_def(cls, tree):
type = tree.children[0] type = tree.children[0]
@ -76,21 +80,42 @@ class Variable:
print(tree) print(tree)
def allregs():
return (f'r{i}' for i in range(12))
class RegBank: class RegBank:
def __init__(self, logger=None): def __init__(self, logger=None):
self.reset() self.reset()
self.log = logger or print self.log = logger or print
def reset(self): def reset(self):
self.available = [f'r{i}' for i in range(12)] self.available = list(allregs())
self.vars = {} self.vars = {}
self.varregs = {} self.varregs = {}
self.cleanup = collections.defaultdict(list) self.cleanup = collections.defaultdict(list)
self.taken = {}
def take(self, reg=None): def __str__(self):
return f'regs(avail={self.available}, vars={self.vars})'
def take(self, reg=None, orwhatever=False):
out = self._take(reg=reg, orwhatever=orwhatever)
self.taken[out] = inspect.stack()[1]
return out
def _take(self, reg=None, orwhatever=False):
if reg is not None: if reg is not None:
if reg not in self.available: if reg not in self.available:
self.evict(self.var(reg)) if reg in self.varregs:
self.evict(self.var(reg))
elif orwhatever:
return self._take()
else:
raise RuntimeError(
f'trying to overtake {reg}, '
f'varregs: {self.varregs}, vars: {self.vars}, '
f'taken in {self.taken[reg]}')
if reg in self.available: if reg in self.available:
self.available.remove(reg) self.available.remove(reg)
return reg return reg
@ -102,6 +127,7 @@ class RegBank:
return self.available.pop(0) return self.available.pop(0)
def give(self, reg): def give(self, reg):
assert reg is not None
if reg in self.varregs: if reg in self.varregs:
# Need to call evict() with the var to free it. # Need to call evict() with the var to free it.
return return
@ -109,11 +135,13 @@ class RegBank:
def loaded(self, var, reg, cleanup=None): def loaded(self, var, reg, cleanup=None):
"""Tells the regbank some variable was loaded into the given register.""" """Tells the regbank some variable was loaded into the given register."""
assert reg is not None
assert var is not None
self.vars[var] = reg self.vars[var] = reg
self.varregs[reg] = var self.varregs[reg] = var
self.take(reg) self.take(reg)
if cleanup is not None: if cleanup is not None:
self.log(f'recording cleanup for {reg}({var.name})') self.log(f'recording cleanup for {reg}({var})')
self.cleanup[reg].append(cleanup) self.cleanup[reg].append(cleanup)
def stored(self, var): def stored(self, var):
@ -136,6 +164,7 @@ class RegBank:
def assign(self, var, reg, cleanup=None): def assign(self, var, reg, cleanup=None):
"""Assign a previously-used register to a variable.""" """Assign a previously-used register to a variable."""
assert reg is not None
if var in self.vars: if var in self.vars:
self.stored(var) self.stored(var)
if reg in self.varregs: if reg in self.varregs:
@ -150,21 +179,25 @@ class RegBank:
def var(self, reg): def var(self, reg):
return self.varregs.get(reg, None) return self.varregs.get(reg, None)
def evict_all(self):
for var in list(self.vars):
self.evict(var)
def evict(self, var): def evict(self, var):
"""Runs var callbacks & frees the register.""" """Runs var callbacks & frees the register."""
assert var in self.vars, f'trying to evict {var}'
self.log(f'evicting {var}')
if var not in self.vars: if var not in self.vars:
return return
reg = self.vars[var] reg = self.vars[var]
for cb in self.cleanup.pop(var, []): for cb in self.cleanup.pop(reg, []):
cb(reg) cb(reg)
self.stored(var) self.stored(var)
def flush_all(self): def drop_temps(self):
for reg in list(self.cleanup): for reg in allregs():
self.log(f'flushing {reg}({self.varregs[reg].name})') if reg not in self.available and reg not in self.varregs:
for cb in self.cleanup.pop(reg): self.give(reg)
cb(reg)
self.reset()
def swapped(self, reg0, reg1): def swapped(self, reg0, reg1):
var0 = self.varregs.get(reg0, None) var0 = self.varregs.get(reg0, None)
@ -192,6 +225,7 @@ class RegBank:
self.reset() self.reset()
vardict, cleanup = snap vardict, cleanup = snap
for var, reg in vardict.items(): for var, reg in vardict.items():
assert reg is not None, f'bad snapshot: {snap}'
self.loaded(var, reg) self.loaded(var, reg)
self.cleanup[reg] = cleanup.get(reg, []) self.cleanup[reg] = cleanup.get(reg, [])
@ -225,6 +259,7 @@ class Function:
self.params = [Variable(*x.children) for x in fun_prot.children[2:]] self.params = [Variable(*x.children) for x in fun_prot.children[2:]]
self.ret = None self.ret = None
self.nextstack = 0 self.nextstack = 0
self.topstack = 0
self.statements = [] self.statements = []
self.regs = RegBank(logger=self.log) self.regs = RegBank(logger=self.log)
self.deferred_ops = [] self.deferred_ops = []
@ -236,13 +271,24 @@ class Function:
@property @property
def stack_usage(self): def stack_usage(self):
return self.nextstack + 2 top_usage = max(self.topstack, self.nextstack)
if self.fun_calls > 0:
return top_usage + 2
return top_usage
def get_stack(self, size=2): def get_stack(self, size=2):
stk = self.nextstack stk = self.nextstack
self.nextstack += size self.nextstack += size
return stk return stk
@contextlib.contextmanager
def tempstack(self):
stk = self.nextstack
yield
if self.nextstack > self.topstack:
self.topstack = self.nextstack
self.nextstack = stk
def param_dict(self): def param_dict(self):
return {p.name: p for p in self.params} return {p.name: p for p in self.params}
@ -250,14 +296,12 @@ class Function:
return repr(self.spec) return repr(self.spec)
def synth(self): def synth(self):
ops = []
if self.fun_calls > 0: if self.fun_calls > 0:
preamble = [f'store lr, [sp, -2]', ops += [f'store lr, [sp, -2]']
f'set r4, {self.stack_usage}', if self.stack_usage > 0:
f'sub sp, sp, r4'] ops += [f'set r4, {self.stack_usage}',
else: f'sub sp, sp, r4']
preamble = []
ops = preamble
for op in self.ops: for op in self.ops:
ops += op() ops += op()
indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops] indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops]
@ -524,17 +568,22 @@ class ReturnReg(AsmOp):
(self.ret_reg,) = ops (self.ret_reg,) = ops
def synth(self, scratches): def synth(self, scratches):
if self.fun.fun_calls == 0: ops = []
return [f'or r0, {self.ret_reg}, {self.ret_reg}',
f'or pc, lr, lr']
sc0 = scratches[0]
stack_usage = self.fun.stack_usage stack_usage = self.fun.stack_usage
ret = self.ret_reg (sc0,) = scratches
assert stack_usage < 255 if stack_usage > 0:
return [f'set {sc0}, {stack_usage}', ops += [f'set {sc0}, {stack_usage}',
f'add sp, sp, {sc0}', f'add sp, sp, {sc0}']
f'or r0, {ret}, {ret}', if self.fun.ret is not None:
f'load pc, [sp, -2] // return'] ret = self.ret_reg
ops += [f'or r0, {ret}, {ret}']
if self.fun.fun_calls == 0:
ops += [f'or pc, lr, lr']
else:
ops += [f'load pc, [sp, -2] // return']
return ops
class Load(AsmOp): class Load(AsmOp):
@ -574,6 +623,29 @@ class Load(AsmOp):
f'nop // in case we load a far global', f'nop // in case we load a far global',
f'load {reg}, [{reg}]'] # unsure if we need a cast here f'load {reg}, [{reg}]'] # unsure if we need a cast here
class Push(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
def synth(self, scratches):
return [f'store {self.reg}, [sp, {self.stackaddr}]']
class Pop(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
@property
def out(self):
return self.reg
def synth(self, scratches):
return [f'load {self.reg}, [sp, {self.stackaddr}]']
class Store(AsmOp): class Store(AsmOp):
scratch_need = 1 scratch_need = 1
@ -581,10 +653,6 @@ class Store(AsmOp):
self.fun = fun self.fun = fun
self.src, self.var = ops self.src, self.var = ops
@property
def out(self):
return None
def synth(self, scratches): def synth(self, scratches):
(sc,) = scratches (sc,) = scratches
reg = self.src reg = self.src
@ -820,11 +888,11 @@ class CcInterp(lark.visitors.Interpreter):
self._log(f'{op.__class__.__name__}') self._log(f'{op.__class__.__name__}')
self.cur_fun.ops.append(lambda: op.synth(scratches)) self.cur_fun.ops.append(lambda: op.synth(scratches))
def _load(self, ident): def _load(self, ident, reg=None):
s = self._lookup_symbol(ident) s = self._lookup_symbol(ident)
assert s is not None, f'unknown identifier {ident}' assert s is not None, f'unknown identifier {ident}'
if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols: if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
reg = self._get_reg() reg = reg or self._get_reg()
return SetAddr(self.cur_fun, [reg, ident]) return SetAddr(self.cur_fun, [reg, ident])
else: else:
if s.type.volatile: if s.type.volatile:
@ -845,7 +913,8 @@ class CcInterp(lark.visitors.Interpreter):
# If it's volatile, we need to flush it. # If it's volatile, we need to flush it.
def delayed_load(): def delayed_load():
self._log(f'delay-loading {tree.children[0]}') self._log(f'delay-loading {tree.children[0]}')
op = self._load(tree.children[0]) reg = getattr(tree, 'request_out', None)
op = self._load(tree.children[0], reg=reg)
self._synth(op) self._synth(op)
return op.out return op.out
@ -868,10 +937,17 @@ class CcInterp(lark.visitors.Interpreter):
tok = self._make_string(imm) tok = self._make_string(imm)
tree.children = [tok] tree.children = [tok]
return self.identifier(tree) return self.identifier(tree)
reg = self._get_reg() def delayed():
assert self.cur_fun is not None if hasattr(tree, 'request_out'):
tree.op = Set16Imm(self.cur_fun, [reg, imm]) reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True)
self._synth(tree.op) else:
reg = self.cur_fun.regs.take()
assert reg is not None, f'weird: {self.cur_fun.regs}'
op = Set16Imm(self.cur_fun, [reg, imm])
self._synth(op)
return op.out
tree.op = Delayed(delayed)
tree.type = litt_type(tree.children[0]) tree.type = litt_type(tree.children[0])
def _assign(self, left, right): def _assign(self, left, right):
@ -884,6 +960,7 @@ class CcInterp(lark.visitors.Interpreter):
def assignment(self, tree): def assignment(self, tree):
self.visit_children(tree) self.visit_children(tree)
val = tree.children[1].op.out val = tree.children[1].op.out
assert val is not None, f'no output!!! {tree.pretty()}'
# left hand side is an lvalue, retrieve it # left hand side is an lvalue, retrieve it
var = tree.children[0].var var = tree.children[0].var
@ -916,7 +993,7 @@ class CcInterp(lark.visitors.Interpreter):
for i, r in enumerate(dst): for i, r in enumerate(dst):
if r not in src: if r not in src:
self.cur_fun.regs.take(r) # precious! #self.cur_fun.regs.take(r) # precious!
self._synth(Move(r, src[i])) self._synth(Move(r, src[i]))
else: else:
swap(r, src[i]) swap(r, src[i])
@ -924,46 +1001,92 @@ class CcInterp(lark.visitors.Interpreter):
def _prep_fun_call(self, fn_reg, params): def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn.""" """Move all params to r0-rn."""
target_regs = [f'r{i}' for i in range(len(params))] target_regs = [f'r{i}' for i in range(len(params) + 1)]
new_fn = fn_reg self._reg_shuffle(params + [fn_reg], target_regs)
while new_fn in target_regs or new_fn in params: return target_regs[-1]
new_fn = self.cur_fun.regs.take()
if new_fn != fn_reg: @contextlib.contextmanager
self._synth(Move(new_fn, fn_reg)) def _push(self, dontpush=()):
fn_reg = new_fn pushed = []
self._reg_shuffle(params, target_regs) self.cur_fun.regs.evict_all()
return fn_reg with self.cur_fun.tempstack():
for reg in allregs():
if reg not in self.cur_fun.regs.available and reg not in dontpush:
self._log(f'reg in use: {reg}')
stk = self.cur_fun.get_stack()
pushed.append((reg, stk))
self._synth(Push(reg, stk))
self.cur_fun.regs.give(reg)
yield pushed
def fun_call(self, tree): def fun_call(self, tree):
self.cur_fun.fun_calls += 1 self.cur_fun.fun_calls += 1
self.visit_children(tree)
fn_reg = tree.children[0].op.out
if len(tree.children) == 1: if len(tree.children) == 1:
param_regs = [] param_children = []
elif tree.children[1].data == 'comma_expr': elif tree.children[1].data == 'comma_expr':
param_regs = [c.op.out for c in tree.children[1].children] param_children = tree.children[1].children
else: else:
param_regs = [tree.children[1].op.out] param_children = [tree.children[1]]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
self.cur_fun.regs.flush_all()
for i in range(len(param_regs)): if True:
self.cur_fun.regs.take(f'r{i}') # pre-allocate output registers
fn_reg = self._prep_fun_call(fn_reg, param_regs) for i, child in enumerate(param_children):
self.cur_fun.regs.take(fn_reg) self._log(f'requesting r{i} for {child}')
tree.op = FnCall(self.cur_fun, [fn_reg, param_regs]) child.request_out = f'r{i}'
self._synth(tree.op) self._log(f'requesting r{len(param_children)} for {tree.children[0]}')
self.cur_fun.regs.reset() tree.children[0].request_out = f'r{len(param_children)}'
self.cur_fun.regs.take('r0') self.visit_children(tree)
fn_reg = tree.children[0].op.out
param_regs = [c.op.out for c in param_children]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
with self._push(dontpush=param_regs + [fn_reg]) as pushed:
self._flush_deferred() # side effects
fn_reg = self._prep_fun_call(fn_reg, param_regs)
self.cur_fun.regs.reset()
self.cur_fun.regs.take(fn_reg)
for reg in (f'r{i}' for i in range(len(param_regs))):
self.cur_fun.regs.take(reg)
tree.op = FnCall(self.cur_fun, [fn_reg, param_regs])
self._synth(tree.op)
self.cur_fun.regs.reset()
pops = []
for reg, stk in pushed:
self.cur_fun.regs.take(reg)
pops.append(Pop(reg, stk))
ret = 'r0'
if hasattr(tree, 'request_out'):
ret = tree.request_out
self._synth(Move(ret, 'r0'))
nret = self.cur_fun.regs.take(ret, orwhatever=True)
if nret != ret:
self._synth(Move(nret, ret))
tree.op = type('blah', (), {'out': nret})
for op in pops:
self._synth(op)
tree.type = return_type(tree.children[0].type) tree.type = return_type(tree.children[0].type)
def statement(self, tree): def _flush_deferred(self):
self.visit_children(tree)
if self.cur_fun.deferred_ops: if self.cur_fun.deferred_ops:
self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}') self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}')
for op in self.cur_fun.deferred_ops: for op in self.cur_fun.deferred_ops:
self._synth(op) self._synth(op)
self.cur_fun.deferred_ops = [] self.cur_fun.deferred_ops = []
def statement(self, tree):
self.visit_children(tree)
self._flush_deferred()
self.cur_fun.regs.drop_temps()
iter_expression = statement iter_expression = statement
def if_stat(self, tree): def if_stat(self, tree):
@ -978,6 +1101,7 @@ class CcInterp(lark.visitors.Interpreter):
self.visit(tree.children[0]) self.visit(tree.children[0])
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
self._synth(op) self._synth(op)
self.cur_fun.regs.drop_temps()
begin_vars = self.cur_fun.regs.snap() begin_vars = self.cur_fun.regs.snap()
self.visit(tree.children[1]) self.visit(tree.children[1])
if has_else: if has_else:
@ -1009,6 +1133,7 @@ class CcInterp(lark.visitors.Interpreter):
postcond = self.cur_fun.regs.snap() postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[0].op.out, mark) op = WhileOp(tree.children[0].op.out, mark)
self._synth(op) self._synth(op)
self.cur_fun.regs.drop_temps()
self.visit(tree.children[1]) self.visit(tree.children[1])
self._restore_vars(begin_vars) self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile) self.cur_fun.ops.append(op.synth_endwhile)
@ -1023,6 +1148,7 @@ class CcInterp(lark.visitors.Interpreter):
postcond = self.cur_fun.regs.snap() postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[1].op.out, mark) op = WhileOp(tree.children[1].op.out, mark)
self._synth(op) self._synth(op)
self.cur_fun.regs.drop_temps()
self.visit(tree.children[3]) self.visit(tree.children[3])
self.visit(tree.children[2]) # 3rd statement self.visit(tree.children[2]) # 3rd statement
self._restore_vars(begin_vars) self._restore_vars(begin_vars)
@ -1093,7 +1219,14 @@ class CcInterp(lark.visitors.Interpreter):
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type tree.type = var.type
pre_increment = _unary_op(Incr) def pre_increment(self, tree):
self.visit_children(tree)
reg = tree.children[0].op.out
tree.op = Incr(self.cur_fun, [reg, reg])
self._synth(tree.op)
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.cleanup[reg].append(cleanup)
bool_not = _unary_op(BoolNot) bool_not = _unary_op(BoolNot)
def _binary_op(op): def _binary_op(op):
@ -1152,6 +1285,8 @@ class CcInterp(lark.visitors.Interpreter):
tree.op = tree.children[0].op tree.op = tree.children[0].op
def cast(self, tree): def cast(self, tree):
if hasattr(tree, 'request_out'):
tree.children[1].request_out = tree.request_out
self.visit_children(tree) self.visit_children(tree)
tree.op = tree.children[1].op tree.op = tree.children[1].op
tree.type = tree.children[0] tree.type = tree.children[0]
@ -1184,7 +1319,10 @@ class CcInterp(lark.visitors.Interpreter):
params = fun.param_dict() params = fun.param_dict()
def getcleanup(var): def getcleanup(var):
return lambda reg: self._synth(Store(self.cur_fun, [reg, var])) def cl(reg):
self._log(f'cleanup for {var} in {reg}')
self._synth(Store(self.cur_fun, [reg, var]))
return cl
for i, param in enumerate(fun.params): for i, param in enumerate(fun.params):
fun.locals[param.name] = param fun.locals[param.name] = param