cc: mostly fixed function calls
This commit is contained in:
parent
eb303641d9
commit
ef81ec3b12
260
tools/cc.py
260
tools/cc.py
@ -4,6 +4,7 @@ import collections
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
@ -34,6 +35,9 @@ class Variable:
|
||||
def __repr__(self):
|
||||
return f'<Var: {self.type} {self.name}>'
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@classmethod
|
||||
def from_def(cls, tree):
|
||||
type = tree.children[0]
|
||||
@ -76,21 +80,42 @@ class Variable:
|
||||
print(tree)
|
||||
|
||||
|
||||
def allregs():
|
||||
return (f'r{i}' for i in range(12))
|
||||
|
||||
|
||||
class RegBank:
|
||||
def __init__(self, logger=None):
|
||||
self.reset()
|
||||
self.log = logger or print
|
||||
|
||||
def reset(self):
|
||||
self.available = [f'r{i}' for i in range(12)]
|
||||
self.available = list(allregs())
|
||||
self.vars = {}
|
||||
self.varregs = {}
|
||||
self.cleanup = collections.defaultdict(list)
|
||||
self.taken = {}
|
||||
|
||||
def take(self, reg=None):
|
||||
def __str__(self):
|
||||
return f'regs(avail={self.available}, vars={self.vars})'
|
||||
|
||||
def take(self, reg=None, orwhatever=False):
|
||||
out = self._take(reg=reg, orwhatever=orwhatever)
|
||||
self.taken[out] = inspect.stack()[1]
|
||||
return out
|
||||
|
||||
def _take(self, reg=None, orwhatever=False):
|
||||
if reg is not None:
|
||||
if reg not in self.available:
|
||||
if reg in self.varregs:
|
||||
self.evict(self.var(reg))
|
||||
elif orwhatever:
|
||||
return self._take()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'trying to overtake {reg}, '
|
||||
f'varregs: {self.varregs}, vars: {self.vars}, '
|
||||
f'taken in {self.taken[reg]}')
|
||||
if reg in self.available:
|
||||
self.available.remove(reg)
|
||||
return reg
|
||||
@ -102,6 +127,7 @@ class RegBank:
|
||||
return self.available.pop(0)
|
||||
|
||||
def give(self, reg):
|
||||
assert reg is not None
|
||||
if reg in self.varregs:
|
||||
# Need to call evict() with the var to free it.
|
||||
return
|
||||
@ -109,11 +135,13 @@ class RegBank:
|
||||
|
||||
def loaded(self, var, reg, cleanup=None):
|
||||
"""Tells the regbank some variable was loaded into the given register."""
|
||||
assert reg is not None
|
||||
assert var is not None
|
||||
self.vars[var] = reg
|
||||
self.varregs[reg] = var
|
||||
self.take(reg)
|
||||
if cleanup is not None:
|
||||
self.log(f'recording cleanup for {reg}({var.name})')
|
||||
self.log(f'recording cleanup for {reg}({var})')
|
||||
self.cleanup[reg].append(cleanup)
|
||||
|
||||
def stored(self, var):
|
||||
@ -136,6 +164,7 @@ class RegBank:
|
||||
|
||||
def assign(self, var, reg, cleanup=None):
|
||||
"""Assign a previously-used register to a variable."""
|
||||
assert reg is not None
|
||||
if var in self.vars:
|
||||
self.stored(var)
|
||||
if reg in self.varregs:
|
||||
@ -150,21 +179,25 @@ class RegBank:
|
||||
def var(self, reg):
|
||||
return self.varregs.get(reg, None)
|
||||
|
||||
def evict_all(self):
|
||||
for var in list(self.vars):
|
||||
self.evict(var)
|
||||
|
||||
def evict(self, var):
|
||||
"""Runs var callbacks & frees the register."""
|
||||
assert var in self.vars, f'trying to evict {var}'
|
||||
self.log(f'evicting {var}')
|
||||
if var not in self.vars:
|
||||
return
|
||||
reg = self.vars[var]
|
||||
for cb in self.cleanup.pop(var, []):
|
||||
for cb in self.cleanup.pop(reg, []):
|
||||
cb(reg)
|
||||
self.stored(var)
|
||||
|
||||
def flush_all(self):
|
||||
for reg in list(self.cleanup):
|
||||
self.log(f'flushing {reg}({self.varregs[reg].name})')
|
||||
for cb in self.cleanup.pop(reg):
|
||||
cb(reg)
|
||||
self.reset()
|
||||
def drop_temps(self):
|
||||
for reg in allregs():
|
||||
if reg not in self.available and reg not in self.varregs:
|
||||
self.give(reg)
|
||||
|
||||
def swapped(self, reg0, reg1):
|
||||
var0 = self.varregs.get(reg0, None)
|
||||
@ -192,6 +225,7 @@ class RegBank:
|
||||
self.reset()
|
||||
vardict, cleanup = snap
|
||||
for var, reg in vardict.items():
|
||||
assert reg is not None, f'bad snapshot: {snap}'
|
||||
self.loaded(var, reg)
|
||||
self.cleanup[reg] = cleanup.get(reg, [])
|
||||
|
||||
@ -225,6 +259,7 @@ class Function:
|
||||
self.params = [Variable(*x.children) for x in fun_prot.children[2:]]
|
||||
self.ret = None
|
||||
self.nextstack = 0
|
||||
self.topstack = 0
|
||||
self.statements = []
|
||||
self.regs = RegBank(logger=self.log)
|
||||
self.deferred_ops = []
|
||||
@ -236,13 +271,24 @@ class Function:
|
||||
|
||||
@property
|
||||
def stack_usage(self):
|
||||
return self.nextstack + 2
|
||||
top_usage = max(self.topstack, self.nextstack)
|
||||
if self.fun_calls > 0:
|
||||
return top_usage + 2
|
||||
return top_usage
|
||||
|
||||
def get_stack(self, size=2):
|
||||
stk = self.nextstack
|
||||
self.nextstack += size
|
||||
return stk
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tempstack(self):
|
||||
stk = self.nextstack
|
||||
yield
|
||||
if self.nextstack > self.topstack:
|
||||
self.topstack = self.nextstack
|
||||
self.nextstack = stk
|
||||
|
||||
def param_dict(self):
|
||||
return {p.name: p for p in self.params}
|
||||
|
||||
@ -250,14 +296,12 @@ class Function:
|
||||
return repr(self.spec)
|
||||
|
||||
def synth(self):
|
||||
ops = []
|
||||
if self.fun_calls > 0:
|
||||
preamble = [f'store lr, [sp, -2]',
|
||||
f'set r4, {self.stack_usage}',
|
||||
ops += [f'store lr, [sp, -2]']
|
||||
if self.stack_usage > 0:
|
||||
ops += [f'set r4, {self.stack_usage}',
|
||||
f'sub sp, sp, r4']
|
||||
else:
|
||||
preamble = []
|
||||
|
||||
ops = preamble
|
||||
for op in self.ops:
|
||||
ops += op()
|
||||
indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops]
|
||||
@ -524,17 +568,22 @@ class ReturnReg(AsmOp):
|
||||
(self.ret_reg,) = ops
|
||||
|
||||
def synth(self, scratches):
|
||||
if self.fun.fun_calls == 0:
|
||||
return [f'or r0, {self.ret_reg}, {self.ret_reg}',
|
||||
f'or pc, lr, lr']
|
||||
sc0 = scratches[0]
|
||||
ops = []
|
||||
stack_usage = self.fun.stack_usage
|
||||
(sc0,) = scratches
|
||||
if stack_usage > 0:
|
||||
ops += [f'set {sc0}, {stack_usage}',
|
||||
f'add sp, sp, {sc0}']
|
||||
if self.fun.ret is not None:
|
||||
ret = self.ret_reg
|
||||
assert stack_usage < 255
|
||||
return [f'set {sc0}, {stack_usage}',
|
||||
f'add sp, sp, {sc0}',
|
||||
f'or r0, {ret}, {ret}',
|
||||
f'load pc, [sp, -2] // return']
|
||||
ops += [f'or r0, {ret}, {ret}']
|
||||
|
||||
if self.fun.fun_calls == 0:
|
||||
ops += [f'or pc, lr, lr']
|
||||
else:
|
||||
ops += [f'load pc, [sp, -2] // return']
|
||||
|
||||
return ops
|
||||
|
||||
|
||||
class Load(AsmOp):
|
||||
@ -574,6 +623,29 @@ class Load(AsmOp):
|
||||
f'nop // in case we load a far global',
|
||||
f'load {reg}, [{reg}]'] # unsure if we need a cast here
|
||||
|
||||
|
||||
class Push(AsmOp):
|
||||
def __init__(self, reg, stackaddr):
|
||||
self.reg = reg
|
||||
self.stackaddr = stackaddr
|
||||
|
||||
def synth(self, scratches):
|
||||
return [f'store {self.reg}, [sp, {self.stackaddr}]']
|
||||
|
||||
|
||||
class Pop(AsmOp):
|
||||
def __init__(self, reg, stackaddr):
|
||||
self.reg = reg
|
||||
self.stackaddr = stackaddr
|
||||
|
||||
@property
|
||||
def out(self):
|
||||
return self.reg
|
||||
|
||||
def synth(self, scratches):
|
||||
return [f'load {self.reg}, [sp, {self.stackaddr}]']
|
||||
|
||||
|
||||
class Store(AsmOp):
|
||||
scratch_need = 1
|
||||
|
||||
@ -581,10 +653,6 @@ class Store(AsmOp):
|
||||
self.fun = fun
|
||||
self.src, self.var = ops
|
||||
|
||||
@property
|
||||
def out(self):
|
||||
return None
|
||||
|
||||
def synth(self, scratches):
|
||||
(sc,) = scratches
|
||||
reg = self.src
|
||||
@ -820,11 +888,11 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
self._log(f'{op.__class__.__name__}')
|
||||
self.cur_fun.ops.append(lambda: op.synth(scratches))
|
||||
|
||||
def _load(self, ident):
|
||||
def _load(self, ident, reg=None):
|
||||
s = self._lookup_symbol(ident)
|
||||
assert s is not None, f'unknown identifier {ident}'
|
||||
if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
|
||||
reg = self._get_reg()
|
||||
reg = reg or self._get_reg()
|
||||
return SetAddr(self.cur_fun, [reg, ident])
|
||||
else:
|
||||
if s.type.volatile:
|
||||
@ -845,7 +913,8 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
# If it's volatile, we need to flush it.
|
||||
def delayed_load():
|
||||
self._log(f'delay-loading {tree.children[0]}')
|
||||
op = self._load(tree.children[0])
|
||||
reg = getattr(tree, 'request_out', None)
|
||||
op = self._load(tree.children[0], reg=reg)
|
||||
self._synth(op)
|
||||
return op.out
|
||||
|
||||
@ -868,10 +937,17 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
tok = self._make_string(imm)
|
||||
tree.children = [tok]
|
||||
return self.identifier(tree)
|
||||
reg = self._get_reg()
|
||||
assert self.cur_fun is not None
|
||||
tree.op = Set16Imm(self.cur_fun, [reg, imm])
|
||||
self._synth(tree.op)
|
||||
def delayed():
|
||||
if hasattr(tree, 'request_out'):
|
||||
reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True)
|
||||
else:
|
||||
reg = self.cur_fun.regs.take()
|
||||
assert reg is not None, f'weird: {self.cur_fun.regs}'
|
||||
op = Set16Imm(self.cur_fun, [reg, imm])
|
||||
self._synth(op)
|
||||
return op.out
|
||||
|
||||
tree.op = Delayed(delayed)
|
||||
tree.type = litt_type(tree.children[0])
|
||||
|
||||
def _assign(self, left, right):
|
||||
@ -884,6 +960,7 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
def assignment(self, tree):
|
||||
self.visit_children(tree)
|
||||
val = tree.children[1].op.out
|
||||
assert val is not None, f'no output!!! {tree.pretty()}'
|
||||
|
||||
# left hand side is an lvalue, retrieve it
|
||||
var = tree.children[0].var
|
||||
@ -916,7 +993,7 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
|
||||
for i, r in enumerate(dst):
|
||||
if r not in src:
|
||||
self.cur_fun.regs.take(r) # precious!
|
||||
#self.cur_fun.regs.take(r) # precious!
|
||||
self._synth(Move(r, src[i]))
|
||||
else:
|
||||
swap(r, src[i])
|
||||
@ -924,46 +1001,92 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
|
||||
def _prep_fun_call(self, fn_reg, params):
|
||||
"""Move all params to r0-rn."""
|
||||
target_regs = [f'r{i}' for i in range(len(params))]
|
||||
new_fn = fn_reg
|
||||
while new_fn in target_regs or new_fn in params:
|
||||
new_fn = self.cur_fun.regs.take()
|
||||
if new_fn != fn_reg:
|
||||
self._synth(Move(new_fn, fn_reg))
|
||||
fn_reg = new_fn
|
||||
self._reg_shuffle(params, target_regs)
|
||||
return fn_reg
|
||||
target_regs = [f'r{i}' for i in range(len(params) + 1)]
|
||||
self._reg_shuffle(params + [fn_reg], target_regs)
|
||||
return target_regs[-1]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _push(self, dontpush=()):
|
||||
pushed = []
|
||||
self.cur_fun.regs.evict_all()
|
||||
with self.cur_fun.tempstack():
|
||||
for reg in allregs():
|
||||
if reg not in self.cur_fun.regs.available and reg not in dontpush:
|
||||
self._log(f'reg in use: {reg}')
|
||||
stk = self.cur_fun.get_stack()
|
||||
pushed.append((reg, stk))
|
||||
self._synth(Push(reg, stk))
|
||||
self.cur_fun.regs.give(reg)
|
||||
yield pushed
|
||||
|
||||
def fun_call(self, tree):
|
||||
self.cur_fun.fun_calls += 1
|
||||
if len(tree.children) == 1:
|
||||
param_children = []
|
||||
elif tree.children[1].data == 'comma_expr':
|
||||
param_children = tree.children[1].children
|
||||
else:
|
||||
param_children = [tree.children[1]]
|
||||
|
||||
|
||||
if True:
|
||||
# pre-allocate output registers
|
||||
for i, child in enumerate(param_children):
|
||||
self._log(f'requesting r{i} for {child}')
|
||||
child.request_out = f'r{i}'
|
||||
self._log(f'requesting r{len(param_children)} for {tree.children[0]}')
|
||||
tree.children[0].request_out = f'r{len(param_children)}'
|
||||
self.visit_children(tree)
|
||||
fn_reg = tree.children[0].op.out
|
||||
if len(tree.children) == 1:
|
||||
param_regs = []
|
||||
elif tree.children[1].data == 'comma_expr':
|
||||
param_regs = [c.op.out for c in tree.children[1].children]
|
||||
else:
|
||||
param_regs = [tree.children[1].op.out]
|
||||
param_regs = [c.op.out for c in param_children]
|
||||
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
|
||||
self.cur_fun.regs.flush_all()
|
||||
for i in range(len(param_regs)):
|
||||
self.cur_fun.regs.take(f'r{i}')
|
||||
|
||||
with self._push(dontpush=param_regs + [fn_reg]) as pushed:
|
||||
|
||||
self._flush_deferred() # side effects
|
||||
fn_reg = self._prep_fun_call(fn_reg, param_regs)
|
||||
self.cur_fun.regs.reset()
|
||||
self.cur_fun.regs.take(fn_reg)
|
||||
for reg in (f'r{i}' for i in range(len(param_regs))):
|
||||
self.cur_fun.regs.take(reg)
|
||||
|
||||
tree.op = FnCall(self.cur_fun, [fn_reg, param_regs])
|
||||
self._synth(tree.op)
|
||||
|
||||
self.cur_fun.regs.reset()
|
||||
self.cur_fun.regs.take('r0')
|
||||
|
||||
pops = []
|
||||
for reg, stk in pushed:
|
||||
self.cur_fun.regs.take(reg)
|
||||
pops.append(Pop(reg, stk))
|
||||
|
||||
ret = 'r0'
|
||||
if hasattr(tree, 'request_out'):
|
||||
ret = tree.request_out
|
||||
self._synth(Move(ret, 'r0'))
|
||||
|
||||
nret = self.cur_fun.regs.take(ret, orwhatever=True)
|
||||
if nret != ret:
|
||||
self._synth(Move(nret, ret))
|
||||
tree.op = type('blah', (), {'out': nret})
|
||||
|
||||
for op in pops:
|
||||
self._synth(op)
|
||||
|
||||
tree.type = return_type(tree.children[0].type)
|
||||
|
||||
def statement(self, tree):
|
||||
self.visit_children(tree)
|
||||
def _flush_deferred(self):
|
||||
if self.cur_fun.deferred_ops:
|
||||
self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}')
|
||||
for op in self.cur_fun.deferred_ops:
|
||||
self._synth(op)
|
||||
self.cur_fun.deferred_ops = []
|
||||
|
||||
def statement(self, tree):
|
||||
self.visit_children(tree)
|
||||
self._flush_deferred()
|
||||
self.cur_fun.regs.drop_temps()
|
||||
|
||||
iter_expression = statement
|
||||
|
||||
def if_stat(self, tree):
|
||||
@ -978,6 +1101,7 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
self.visit(tree.children[0])
|
||||
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
|
||||
self._synth(op)
|
||||
self.cur_fun.regs.drop_temps()
|
||||
begin_vars = self.cur_fun.regs.snap()
|
||||
self.visit(tree.children[1])
|
||||
if has_else:
|
||||
@ -1009,6 +1133,7 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
postcond = self.cur_fun.regs.snap()
|
||||
op = WhileOp(tree.children[0].op.out, mark)
|
||||
self._synth(op)
|
||||
self.cur_fun.regs.drop_temps()
|
||||
self.visit(tree.children[1])
|
||||
self._restore_vars(begin_vars)
|
||||
self.cur_fun.ops.append(op.synth_endwhile)
|
||||
@ -1023,6 +1148,7 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
postcond = self.cur_fun.regs.snap()
|
||||
op = WhileOp(tree.children[1].op.out, mark)
|
||||
self._synth(op)
|
||||
self.cur_fun.regs.drop_temps()
|
||||
self.visit(tree.children[3])
|
||||
self.visit(tree.children[2]) # 3rd statement
|
||||
self._restore_vars(begin_vars)
|
||||
@ -1093,7 +1219,14 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
|
||||
tree.type = var.type
|
||||
|
||||
pre_increment = _unary_op(Incr)
|
||||
def pre_increment(self, tree):
|
||||
self.visit_children(tree)
|
||||
reg = tree.children[0].op.out
|
||||
tree.op = Incr(self.cur_fun, [reg, reg])
|
||||
self._synth(tree.op)
|
||||
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
|
||||
self.cur_fun.regs.cleanup[reg].append(cleanup)
|
||||
|
||||
bool_not = _unary_op(BoolNot)
|
||||
|
||||
def _binary_op(op):
|
||||
@ -1152,6 +1285,8 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
tree.op = tree.children[0].op
|
||||
|
||||
def cast(self, tree):
|
||||
if hasattr(tree, 'request_out'):
|
||||
tree.children[1].request_out = tree.request_out
|
||||
self.visit_children(tree)
|
||||
tree.op = tree.children[1].op
|
||||
tree.type = tree.children[0]
|
||||
@ -1184,7 +1319,10 @@ class CcInterp(lark.visitors.Interpreter):
|
||||
params = fun.param_dict()
|
||||
|
||||
def getcleanup(var):
|
||||
return lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
|
||||
def cl(reg):
|
||||
self._log(f'cleanup for {var} in {reg}')
|
||||
self._synth(Store(self.cur_fun, [reg, var]))
|
||||
return cl
|
||||
|
||||
for i, param in enumerate(fun.params):
|
||||
fun.locals[param.name] = param
|
||||
|
Loading…
Reference in New Issue
Block a user