cc: mostly fixed function calls

This commit is contained in:
Paul Mathieu 2021-03-20 21:27:26 -07:00
parent eb303641d9
commit ef81ec3b12

View File

@ -4,6 +4,7 @@ import collections
import contextlib
from dataclasses import dataclass
import importlib
import inspect
import io
import re
import struct
@ -34,6 +35,9 @@ class Variable:
def __repr__(self):
return f'<Var: {self.type} {self.name}>'
def __str__(self):
return self.name
@classmethod
def from_def(cls, tree):
type = tree.children[0]
@ -76,21 +80,42 @@ class Variable:
print(tree)
def allregs():
return (f'r{i}' for i in range(12))
class RegBank:
def __init__(self, logger=None):
self.reset()
self.log = logger or print
def reset(self):
self.available = [f'r{i}' for i in range(12)]
self.available = list(allregs())
self.vars = {}
self.varregs = {}
self.cleanup = collections.defaultdict(list)
self.taken = {}
def take(self, reg=None):
def __str__(self):
return f'regs(avail={self.available}, vars={self.vars})'
def take(self, reg=None, orwhatever=False):
out = self._take(reg=reg, orwhatever=orwhatever)
self.taken[out] = inspect.stack()[1]
return out
def _take(self, reg=None, orwhatever=False):
if reg is not None:
if reg not in self.available:
if reg in self.varregs:
self.evict(self.var(reg))
elif orwhatever:
return self._take()
else:
raise RuntimeError(
f'trying to overtake {reg}, '
f'varregs: {self.varregs}, vars: {self.vars}, '
f'taken in {self.taken[reg]}')
if reg in self.available:
self.available.remove(reg)
return reg
@ -102,6 +127,7 @@ class RegBank:
return self.available.pop(0)
def give(self, reg):
assert reg is not None
if reg in self.varregs:
# Need to call evict() with the var to free it.
return
@ -109,11 +135,13 @@ class RegBank:
def loaded(self, var, reg, cleanup=None):
"""Tells the regbank some variable was loaded into the given register."""
assert reg is not None
assert var is not None
self.vars[var] = reg
self.varregs[reg] = var
self.take(reg)
if cleanup is not None:
self.log(f'recording cleanup for {reg}({var.name})')
self.log(f'recording cleanup for {reg}({var})')
self.cleanup[reg].append(cleanup)
def stored(self, var):
@ -136,6 +164,7 @@ class RegBank:
def assign(self, var, reg, cleanup=None):
"""Assign a previously-used register to a variable."""
assert reg is not None
if var in self.vars:
self.stored(var)
if reg in self.varregs:
@ -150,21 +179,25 @@ class RegBank:
def var(self, reg):
return self.varregs.get(reg, None)
def evict_all(self):
for var in list(self.vars):
self.evict(var)
def evict(self, var):
"""Runs var callbacks & frees the register."""
assert var in self.vars, f'trying to evict {var}'
self.log(f'evicting {var}')
if var not in self.vars:
return
reg = self.vars[var]
for cb in self.cleanup.pop(var, []):
for cb in self.cleanup.pop(reg, []):
cb(reg)
self.stored(var)
def flush_all(self):
for reg in list(self.cleanup):
self.log(f'flushing {reg}({self.varregs[reg].name})')
for cb in self.cleanup.pop(reg):
cb(reg)
self.reset()
def drop_temps(self):
for reg in allregs():
if reg not in self.available and reg not in self.varregs:
self.give(reg)
def swapped(self, reg0, reg1):
var0 = self.varregs.get(reg0, None)
@ -192,6 +225,7 @@ class RegBank:
self.reset()
vardict, cleanup = snap
for var, reg in vardict.items():
assert reg is not None, f'bad snapshot: {snap}'
self.loaded(var, reg)
self.cleanup[reg] = cleanup.get(reg, [])
@ -225,6 +259,7 @@ class Function:
self.params = [Variable(*x.children) for x in fun_prot.children[2:]]
self.ret = None
self.nextstack = 0
self.topstack = 0
self.statements = []
self.regs = RegBank(logger=self.log)
self.deferred_ops = []
@ -236,13 +271,24 @@ class Function:
@property
def stack_usage(self):
return self.nextstack + 2
top_usage = max(self.topstack, self.nextstack)
if self.fun_calls > 0:
return top_usage + 2
return top_usage
def get_stack(self, size=2):
stk = self.nextstack
self.nextstack += size
return stk
@contextlib.contextmanager
def tempstack(self):
stk = self.nextstack
yield
if self.nextstack > self.topstack:
self.topstack = self.nextstack
self.nextstack = stk
def param_dict(self):
return {p.name: p for p in self.params}
@ -250,14 +296,12 @@ class Function:
return repr(self.spec)
def synth(self):
ops = []
if self.fun_calls > 0:
preamble = [f'store lr, [sp, -2]',
f'set r4, {self.stack_usage}',
ops += [f'store lr, [sp, -2]']
if self.stack_usage > 0:
ops += [f'set r4, {self.stack_usage}',
f'sub sp, sp, r4']
else:
preamble = []
ops = preamble
for op in self.ops:
ops += op()
indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops]
@ -524,17 +568,22 @@ class ReturnReg(AsmOp):
(self.ret_reg,) = ops
def synth(self, scratches):
if self.fun.fun_calls == 0:
return [f'or r0, {self.ret_reg}, {self.ret_reg}',
f'or pc, lr, lr']
sc0 = scratches[0]
ops = []
stack_usage = self.fun.stack_usage
(sc0,) = scratches
if stack_usage > 0:
ops += [f'set {sc0}, {stack_usage}',
f'add sp, sp, {sc0}']
if self.fun.ret is not None:
ret = self.ret_reg
assert stack_usage < 255
return [f'set {sc0}, {stack_usage}',
f'add sp, sp, {sc0}',
f'or r0, {ret}, {ret}',
f'load pc, [sp, -2] // return']
ops += [f'or r0, {ret}, {ret}']
if self.fun.fun_calls == 0:
ops += [f'or pc, lr, lr']
else:
ops += [f'load pc, [sp, -2] // return']
return ops
class Load(AsmOp):
@ -574,6 +623,29 @@ class Load(AsmOp):
f'nop // in case we load a far global',
f'load {reg}, [{reg}]'] # unsure if we need a cast here
class Push(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
def synth(self, scratches):
return [f'store {self.reg}, [sp, {self.stackaddr}]']
class Pop(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
@property
def out(self):
return self.reg
def synth(self, scratches):
return [f'load {self.reg}, [sp, {self.stackaddr}]']
class Store(AsmOp):
scratch_need = 1
@ -581,10 +653,6 @@ class Store(AsmOp):
self.fun = fun
self.src, self.var = ops
@property
def out(self):
return None
def synth(self, scratches):
(sc,) = scratches
reg = self.src
@ -820,11 +888,11 @@ class CcInterp(lark.visitors.Interpreter):
self._log(f'{op.__class__.__name__}')
self.cur_fun.ops.append(lambda: op.synth(scratches))
def _load(self, ident):
def _load(self, ident, reg=None):
s = self._lookup_symbol(ident)
assert s is not None, f'unknown identifier {ident}'
if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
reg = self._get_reg()
reg = reg or self._get_reg()
return SetAddr(self.cur_fun, [reg, ident])
else:
if s.type.volatile:
@ -845,7 +913,8 @@ class CcInterp(lark.visitors.Interpreter):
# If it's volatile, we need to flush it.
def delayed_load():
self._log(f'delay-loading {tree.children[0]}')
op = self._load(tree.children[0])
reg = getattr(tree, 'request_out', None)
op = self._load(tree.children[0], reg=reg)
self._synth(op)
return op.out
@ -868,10 +937,17 @@ class CcInterp(lark.visitors.Interpreter):
tok = self._make_string(imm)
tree.children = [tok]
return self.identifier(tree)
reg = self._get_reg()
assert self.cur_fun is not None
tree.op = Set16Imm(self.cur_fun, [reg, imm])
self._synth(tree.op)
def delayed():
if hasattr(tree, 'request_out'):
reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True)
else:
reg = self.cur_fun.regs.take()
assert reg is not None, f'weird: {self.cur_fun.regs}'
op = Set16Imm(self.cur_fun, [reg, imm])
self._synth(op)
return op.out
tree.op = Delayed(delayed)
tree.type = litt_type(tree.children[0])
def _assign(self, left, right):
@ -884,6 +960,7 @@ class CcInterp(lark.visitors.Interpreter):
def assignment(self, tree):
self.visit_children(tree)
val = tree.children[1].op.out
assert val is not None, f'no output!!! {tree.pretty()}'
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
@ -916,7 +993,7 @@ class CcInterp(lark.visitors.Interpreter):
for i, r in enumerate(dst):
if r not in src:
self.cur_fun.regs.take(r) # precious!
#self.cur_fun.regs.take(r) # precious!
self._synth(Move(r, src[i]))
else:
swap(r, src[i])
@ -924,46 +1001,92 @@ class CcInterp(lark.visitors.Interpreter):
def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn."""
target_regs = [f'r{i}' for i in range(len(params))]
new_fn = fn_reg
while new_fn in target_regs or new_fn in params:
new_fn = self.cur_fun.regs.take()
if new_fn != fn_reg:
self._synth(Move(new_fn, fn_reg))
fn_reg = new_fn
self._reg_shuffle(params, target_regs)
return fn_reg
target_regs = [f'r{i}' for i in range(len(params) + 1)]
self._reg_shuffle(params + [fn_reg], target_regs)
return target_regs[-1]
@contextlib.contextmanager
def _push(self, dontpush=()):
pushed = []
self.cur_fun.regs.evict_all()
with self.cur_fun.tempstack():
for reg in allregs():
if reg not in self.cur_fun.regs.available and reg not in dontpush:
self._log(f'reg in use: {reg}')
stk = self.cur_fun.get_stack()
pushed.append((reg, stk))
self._synth(Push(reg, stk))
self.cur_fun.regs.give(reg)
yield pushed
def fun_call(self, tree):
self.cur_fun.fun_calls += 1
if len(tree.children) == 1:
param_children = []
elif tree.children[1].data == 'comma_expr':
param_children = tree.children[1].children
else:
param_children = [tree.children[1]]
if True:
# pre-allocate output registers
for i, child in enumerate(param_children):
self._log(f'requesting r{i} for {child}')
child.request_out = f'r{i}'
self._log(f'requesting r{len(param_children)} for {tree.children[0]}')
tree.children[0].request_out = f'r{len(param_children)}'
self.visit_children(tree)
fn_reg = tree.children[0].op.out
if len(tree.children) == 1:
param_regs = []
elif tree.children[1].data == 'comma_expr':
param_regs = [c.op.out for c in tree.children[1].children]
else:
param_regs = [tree.children[1].op.out]
param_regs = [c.op.out for c in param_children]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
self.cur_fun.regs.flush_all()
for i in range(len(param_regs)):
self.cur_fun.regs.take(f'r{i}')
with self._push(dontpush=param_regs + [fn_reg]) as pushed:
self._flush_deferred() # side effects
fn_reg = self._prep_fun_call(fn_reg, param_regs)
self.cur_fun.regs.reset()
self.cur_fun.regs.take(fn_reg)
for reg in (f'r{i}' for i in range(len(param_regs))):
self.cur_fun.regs.take(reg)
tree.op = FnCall(self.cur_fun, [fn_reg, param_regs])
self._synth(tree.op)
self.cur_fun.regs.reset()
self.cur_fun.regs.take('r0')
pops = []
for reg, stk in pushed:
self.cur_fun.regs.take(reg)
pops.append(Pop(reg, stk))
ret = 'r0'
if hasattr(tree, 'request_out'):
ret = tree.request_out
self._synth(Move(ret, 'r0'))
nret = self.cur_fun.regs.take(ret, orwhatever=True)
if nret != ret:
self._synth(Move(nret, ret))
tree.op = type('blah', (), {'out': nret})
for op in pops:
self._synth(op)
tree.type = return_type(tree.children[0].type)
def statement(self, tree):
self.visit_children(tree)
def _flush_deferred(self):
if self.cur_fun.deferred_ops:
self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}')
for op in self.cur_fun.deferred_ops:
self._synth(op)
self.cur_fun.deferred_ops = []
def statement(self, tree):
self.visit_children(tree)
self._flush_deferred()
self.cur_fun.regs.drop_temps()
iter_expression = statement
def if_stat(self, tree):
@ -978,6 +1101,7 @@ class CcInterp(lark.visitors.Interpreter):
self.visit(tree.children[0])
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
self._synth(op)
self.cur_fun.regs.drop_temps()
begin_vars = self.cur_fun.regs.snap()
self.visit(tree.children[1])
if has_else:
@ -1009,6 +1133,7 @@ class CcInterp(lark.visitors.Interpreter):
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[0].op.out, mark)
self._synth(op)
self.cur_fun.regs.drop_temps()
self.visit(tree.children[1])
self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile)
@ -1023,6 +1148,7 @@ class CcInterp(lark.visitors.Interpreter):
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[1].op.out, mark)
self._synth(op)
self.cur_fun.regs.drop_temps()
self.visit(tree.children[3])
self.visit(tree.children[2]) # 3rd statement
self._restore_vars(begin_vars)
@ -1093,7 +1219,14 @@ class CcInterp(lark.visitors.Interpreter):
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
pre_increment = _unary_op(Incr)
def pre_increment(self, tree):
self.visit_children(tree)
reg = tree.children[0].op.out
tree.op = Incr(self.cur_fun, [reg, reg])
self._synth(tree.op)
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.cleanup[reg].append(cleanup)
bool_not = _unary_op(BoolNot)
def _binary_op(op):
@ -1152,6 +1285,8 @@ class CcInterp(lark.visitors.Interpreter):
tree.op = tree.children[0].op
def cast(self, tree):
if hasattr(tree, 'request_out'):
tree.children[1].request_out = tree.request_out
self.visit_children(tree)
tree.op = tree.children[1].op
tree.type = tree.children[0]
@ -1184,7 +1319,10 @@ class CcInterp(lark.visitors.Interpreter):
params = fun.param_dict()
def getcleanup(var):
return lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
def cl(reg):
self._log(f'cleanup for {var} in {reg}')
self._synth(Store(self.cur_fun, [reg, var]))
return cl
for i, param in enumerate(fun.params):
fun.locals[param.name] = param