diff --git a/tools/cc.py b/tools/cc.py index 07a8576..b522fd9 100644 --- a/tools/cc.py +++ b/tools/cc.py @@ -1,10 +1,12 @@ import argparse +import ast import collections import contextlib from dataclasses import dataclass import importlib import io import re +import struct import subprocess import sys @@ -89,7 +91,8 @@ class RegBank: if reg is not None: if reg not in self.available: self.evict(self.var(reg)) - self.available.remove(reg) + if reg in self.available: + self.available.remove(reg) return reg if not self.available: assert self.vars, "nothing to clean, no more regs :/" @@ -189,7 +192,8 @@ class RegBank: self.reset() vardict, cleanup = snap for var, reg in vardict.items(): - self.loaded(var, reg, cleanup=cleanup.get(reg, None)) + self.loaded(var, reg) + self.cleanup[reg] = cleanup.get(reg, []) @contextlib.contextmanager def borrow(self, howmany): @@ -328,6 +332,13 @@ class CcTransform(lark.visitors.Transformer): pointed = children[-1] return PointerType(volatile=volat, pointed=pointed) + def comma_expr(self, children): + c1, c2 = children + if c1.data == 'comma_expr': + c1.children.append(c2) + return c1 + return lark.Tree('comma_expr', [c1, c2]) + volatile = lambda *_: 'volatile' const = lambda *_: 'const' @@ -359,6 +370,7 @@ class CcTransform(lark.visitors.Transformer): IDENTIFIER = str SIGNED_NUMBER = int HEX_LITTERAL = lambda _, x: int(x[2:], 16) + ESCAPED_STRING = lambda _, x: ast.literal_eval(x) class AsmOp: @@ -624,6 +636,16 @@ class Set16Imm(AsmOp): else: return [f'set {reg}, {lo}'] +class Move(AsmOp): + scratch_need = 0 + + def __init__(self, dst, src): + self.dst = dst + self.src = src + + def synth(self, scratches): + return [f'or {self.dst}, {self.src}, {self.src}'] + class Swap(AsmOp): scratch_need = 1 @@ -669,6 +691,27 @@ class IfOp(AsmOp): return [f'{self.endif_mark}:'] +class IfEq(IfOp): + scratch_need = 0 + + def __init__(self, a, b, mark, has_else): + self.a = a + self.b = b + self.has_else = has_else + + self.then_mark = f'_then_{mark}' + self.else_mark = f'_else_{mark}' + self.endif_mark = f'_endif_{mark}' + + def synth(self, scratches): + if self.has_else: + return [f'cmp {self.a}, {self.b}', + f'bneq {self.else_mark}'] + else: + return [f'cmp {self.a}, {self.b}', + f'bneq {self.endif_mark}'] + + class WhileOp(AsmOp): scratch_need = 1 @@ -731,6 +774,8 @@ class CcInterp(lark.visitors.Interpreter): self.funs = [] self.next_reg = 0 self.next_marker = 0 + self.strings = {} + self.next_string_token = 0 def _lookup_symbol(self, s): scope = self.cur_scope @@ -764,7 +809,7 @@ class CcInterp(lark.visitors.Interpreter): def _load(self, ident): s = self._lookup_symbol(ident) assert s is not None, f'unknown identifier {ident}' - if isinstance(s, FunctionSpec) or s in self.global_scope.symbols: + if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols: reg = self._get_reg() return SetAddr(self.cur_fun, [reg, ident]) else: @@ -795,8 +840,20 @@ class CcInterp(lark.visitors.Interpreter): assert tree.var is not None, f'undefined identifier: {tree.children[0]}' tree.type = tree.var.type + def _make_string(self, s): + nexttok = self.next_string_token + self.next_string_token += 1 + token = f'_str{nexttok}' + self.strings[token] = s + self.global_scope.symbols[token] = Variable(litt_type(s), token) + return token + def litteral(self, tree): imm = tree.children[0] + if isinstance(imm, str): + tok = self._make_string(imm) + tree.children = [tok] + return self.identifier(tree) reg = self._get_reg() assert self.cur_fun is not None tree.op = Set16Imm(self.cur_fun, [reg, imm]) @@ -830,27 +887,37 @@ class CcInterp(lark.visitors.Interpreter): fun = FunctionSpec(tree.children[0]) self.cur_scope.symbols[fun.name] = fun - def _prep_fun_call(self, fn_reg, params): - """Move all params to r0-rn.""" + def _reg_shuffle(self, src, dst): def swap(old, new): if old == new: return - oldpos = params.index(old) + oldpos = src.index(old) try: - newpos = params.index(new) + newpos = src.index(new) except ValueError: - params[oldpos] = new + src[oldpos] = new else: - params[newpos], params[oldpos] = params[oldpos], params[newpos] + src[newpos], src[oldpos] = src[oldpos], src[newpos] self._synth(Swap(old, new)) - if fn_reg in [f'r{i}' for i in range(len(params))]: - new_fn = f'r{len(params)}' - self._synth(Swap(fn_reg, new_fn)) + for i, r in enumerate(dst): + if r not in src: + self.cur_fun.regs.take(r) # precious! + self._synth(Move(r, src[i])) + else: + swap(r, src[i]) + + + def _prep_fun_call(self, fn_reg, params): + """Move all params to r0-rn.""" + target_regs = [f'r{i}' for i in range(len(params))] + new_fn = fn_reg + while new_fn in target_regs or new_fn in params: + new_fn = self.cur_fun.regs.take() + if new_fn != fn_reg: + self._synth(Move(new_fn, fn_reg)) fn_reg = new_fn - for i, param in enumerate(params): - new = f'r{i}' - swap(param, new) + self._reg_shuffle(params, target_regs) return fn_reg def fun_call(self, tree): @@ -863,6 +930,7 @@ class CcInterp(lark.visitors.Interpreter): param_regs = [c.op.out for c in tree.children[1].children] else: param_regs = [tree.children[1].op.out] + self._log(f'calling {tree.children[0].children[0]}({param_regs})') self.cur_fun.regs.flush_all() for i in range(len(param_regs)): self.cur_fun.regs.take(f'r{i}') @@ -885,10 +953,16 @@ class CcInterp(lark.visitors.Interpreter): iter_expression = statement def if_stat(self, tree): - self.visit(tree.children[0]) mark = self._get_marker() has_else = len(tree.children) > 2 - op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) + # optimization!!!! + if tree.children[0].data == 'eq': + self.visit_children(tree.children[0]) + a, b = (c.op.out for c in tree.children[0].children) + op = IfEq(a, b, mark, has_else) + else: + self.visit(tree.children[0]) + op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) self._synth(op) begin_vars = self.cur_fun.regs.snap() self.visit(tree.children[1]) @@ -898,20 +972,33 @@ class CcInterp(lark.visitors.Interpreter): self.visit(tree.children[2]) self.cur_fun.ops.append(op.synth_endif) + def _restore_vars(self, begin_vars): + curvars = self.cur_fun.regs.vars + + # 0. do the loop shuffle + shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars} + dst, src = tuple(zip(*shuffles.items())) or ([], []) + self._reg_shuffle(src, dst) + + # 1. load the rest + for v, r in begin_vars.items(): + if v in curvars: + continue + self._log(f'loading missing var {v} into {r}') + self._synth(Load(self.cur_fun, [r, v])) + def while_stat(self, tree): mark = self._get_marker() self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[0]) + postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[0].op.out, mark) self._synth(op) self.visit(tree.children[1]) - for v, r in begin_vars.items(): - rvars = self.cur_fun.regs.vars - if v not in rvars or rvars[v] != r: - self._log(f'loading missing var {v}') - self._synth(Load(self.cur_fun, [r, v])) + self._restore_vars(begin_vars) self.cur_fun.ops.append(op.synth_endwhile) + self.cur_fun.regs.restore(postcond) def for_stat(self, tree): mark = self._get_marker() @@ -919,16 +1006,14 @@ class CcInterp(lark.visitors.Interpreter): self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) begin_vars = dict(self.cur_fun.regs.vars) self.visit(tree.children[1]) + postcond = self.cur_fun.regs.snap() op = WhileOp(tree.children[1].op.out, mark) self._synth(op) self.visit(tree.children[3]) self.visit(tree.children[2]) # 3rd statement - for v, r in begin_vars.items(): - rvars = self.cur_fun.regs.vars - if v not in rvars or rvars[v] != r: - self._log(f'loading missing var {v}') - self._synth(Load(self.cur_fun, [r, v])) + self._restore_vars(begin_vars) self.cur_fun.ops.append(op.synth_endwhile) + self.cur_fun.regs.restore(postcond) def _unary_op(op): def _f(self, tree): @@ -964,11 +1049,15 @@ class CcInterp(lark.visitors.Interpreter): type = ch_strct.type.pointed strct = self.global_scope.structs[type.struct] offs = strct.offset(ch_field) - load_offs = lark.Tree('litteral', [offs]) - addop = lark.Tree('add', [tree.children[0], load_offs]) - self.visit(addop) - addop.type = PointerType(pointed=strct.field_type(ch_field)) - tree.children = [addop] + if offs > 0: + load_offs = lark.Tree('litteral', [offs]) + addop = lark.Tree('add', [tree.children[0], load_offs]) + self.visit(addop) + addop.type = PointerType(pointed=strct.field_type(ch_field)) + tree.children = [addop] + else: + tree.children[0].type = PointerType(pointed=strct.field_type(ch_field)) + tree.children = [tree.children[0]] self._deref(tree) def post_increment(self, tree): @@ -1066,7 +1155,7 @@ class CcInterp(lark.visitors.Interpreter): if len(tree.children) > 2: initval = tree.children[2].children[0].op.out cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) - self.cur_fun.regs.assign(var, initval, cleanup) + self.cur_fun.regs.assign(var, initval, cleanup=cleanup) self._log(f'assigning {var} = {initval}') def fun_def(self, tree): @@ -1087,7 +1176,7 @@ class CcInterp(lark.visitors.Interpreter): param.stackaddr = fun.get_stack() self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}') cleanup = getcleanup(param) - fun.regs.loaded(param, f'r{i}', cleanup) + fun.regs.loaded(param, f'r{i}', cleanup=cleanup) fun_scope = Scope() fun_scope.parent = self.cur_scope @@ -1161,6 +1250,16 @@ def parse_tree(tree, debug=False): for fun in inte.funs: out += fun.synth() out.append('') + for sym, dat in inte.strings.items(): + out += [f'.global {sym}', + f'{sym}:'] + if len(dat) % 2 != 0: + dat += '\0' + dat = dat.encode() + nwords = len(dat) // 2 + out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)] + out.append('') + return '\n'.join(filter_dupes(out))