cc: add support for string literals and some stuff

- string literals added as rodata
- fixed register shuffling in for/while loops & fun calls
- a few other fixes
This commit is contained in:
Paul Mathieu 2021-03-14 12:16:21 -07:00
parent 54c69dd962
commit 7841987b6e

View File

@ -1,10 +1,12 @@
import argparse
import ast
import collections
import contextlib
from dataclasses import dataclass
import importlib
import io
import re
import struct
import subprocess
import sys
@ -89,7 +91,8 @@ class RegBank:
if reg is not None:
if reg not in self.available:
self.evict(self.var(reg))
self.available.remove(reg)
if reg in self.available:
self.available.remove(reg)
return reg
if not self.available:
assert self.vars, "nothing to clean, no more regs :/"
@ -189,7 +192,8 @@ class RegBank:
self.reset()
vardict, cleanup = snap
for var, reg in vardict.items():
self.loaded(var, reg, cleanup=cleanup.get(reg, None))
self.loaded(var, reg)
self.cleanup[reg] = cleanup.get(reg, [])
@contextlib.contextmanager
def borrow(self, howmany):
@ -328,6 +332,13 @@ class CcTransform(lark.visitors.Transformer):
pointed = children[-1]
return PointerType(volatile=volat, pointed=pointed)
def comma_expr(self, children):
c1, c2 = children
if c1.data == 'comma_expr':
c1.children.append(c2)
return c1
return lark.Tree('comma_expr', [c1, c2])
volatile = lambda *_: 'volatile'
const = lambda *_: 'const'
@ -359,6 +370,7 @@ class CcTransform(lark.visitors.Transformer):
IDENTIFIER = str
SIGNED_NUMBER = int
HEX_LITTERAL = lambda _, x: int(x[2:], 16)
ESCAPED_STRING = lambda _, x: ast.literal_eval(x)
class AsmOp:
@ -624,6 +636,16 @@ class Set16Imm(AsmOp):
else:
return [f'set {reg}, {lo}']
class Move(AsmOp):
scratch_need = 0
def __init__(self, dst, src):
self.dst = dst
self.src = src
def synth(self, scratches):
return [f'or {self.dst}, {self.src}, {self.src}']
class Swap(AsmOp):
scratch_need = 1
@ -669,6 +691,27 @@ class IfOp(AsmOp):
return [f'{self.endif_mark}:']
class IfEq(IfOp):
scratch_need = 0
def __init__(self, a, b, mark, has_else):
self.a = a
self.b = b
self.has_else = has_else
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
if self.has_else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.else_mark}']
else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.endif_mark}']
class WhileOp(AsmOp):
scratch_need = 1
@ -731,6 +774,8 @@ class CcInterp(lark.visitors.Interpreter):
self.funs = []
self.next_reg = 0
self.next_marker = 0
self.strings = {}
self.next_string_token = 0
def _lookup_symbol(self, s):
scope = self.cur_scope
@ -764,7 +809,7 @@ class CcInterp(lark.visitors.Interpreter):
def _load(self, ident):
s = self._lookup_symbol(ident)
assert s is not None, f'unknown identifier {ident}'
if isinstance(s, FunctionSpec) or s in self.global_scope.symbols:
if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
reg = self._get_reg()
return SetAddr(self.cur_fun, [reg, ident])
else:
@ -795,8 +840,20 @@ class CcInterp(lark.visitors.Interpreter):
assert tree.var is not None, f'undefined identifier: {tree.children[0]}'
tree.type = tree.var.type
def _make_string(self, s):
nexttok = self.next_string_token
self.next_string_token += 1
token = f'_str{nexttok}'
self.strings[token] = s
self.global_scope.symbols[token] = Variable(litt_type(s), token)
return token
def litteral(self, tree):
imm = tree.children[0]
if isinstance(imm, str):
tok = self._make_string(imm)
tree.children = [tok]
return self.identifier(tree)
reg = self._get_reg()
assert self.cur_fun is not None
tree.op = Set16Imm(self.cur_fun, [reg, imm])
@ -830,27 +887,37 @@ class CcInterp(lark.visitors.Interpreter):
fun = FunctionSpec(tree.children[0])
self.cur_scope.symbols[fun.name] = fun
def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn."""
def _reg_shuffle(self, src, dst):
def swap(old, new):
if old == new:
return
oldpos = params.index(old)
oldpos = src.index(old)
try:
newpos = params.index(new)
newpos = src.index(new)
except ValueError:
params[oldpos] = new
src[oldpos] = new
else:
params[newpos], params[oldpos] = params[oldpos], params[newpos]
src[newpos], src[oldpos] = src[oldpos], src[newpos]
self._synth(Swap(old, new))
if fn_reg in [f'r{i}' for i in range(len(params))]:
new_fn = f'r{len(params)}'
self._synth(Swap(fn_reg, new_fn))
for i, r in enumerate(dst):
if r not in src:
self.cur_fun.regs.take(r) # precious!
self._synth(Move(r, src[i]))
else:
swap(r, src[i])
def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn."""
target_regs = [f'r{i}' for i in range(len(params))]
new_fn = fn_reg
while new_fn in target_regs or new_fn in params:
new_fn = self.cur_fun.regs.take()
if new_fn != fn_reg:
self._synth(Move(new_fn, fn_reg))
fn_reg = new_fn
for i, param in enumerate(params):
new = f'r{i}'
swap(param, new)
self._reg_shuffle(params, target_regs)
return fn_reg
def fun_call(self, tree):
@ -863,6 +930,7 @@ class CcInterp(lark.visitors.Interpreter):
param_regs = [c.op.out for c in tree.children[1].children]
else:
param_regs = [tree.children[1].op.out]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
self.cur_fun.regs.flush_all()
for i in range(len(param_regs)):
self.cur_fun.regs.take(f'r{i}')
@ -885,10 +953,16 @@ class CcInterp(lark.visitors.Interpreter):
iter_expression = statement
def if_stat(self, tree):
self.visit(tree.children[0])
mark = self._get_marker()
has_else = len(tree.children) > 2
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
# optimization!!!!
if tree.children[0].data == 'eq':
self.visit_children(tree.children[0])
a, b = (c.op.out for c in tree.children[0].children)
op = IfEq(a, b, mark, has_else)
else:
self.visit(tree.children[0])
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
self._synth(op)
begin_vars = self.cur_fun.regs.snap()
self.visit(tree.children[1])
@ -898,20 +972,33 @@ class CcInterp(lark.visitors.Interpreter):
self.visit(tree.children[2])
self.cur_fun.ops.append(op.synth_endif)
def _restore_vars(self, begin_vars):
curvars = self.cur_fun.regs.vars
# 0. do the loop shuffle
shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars}
dst, src = tuple(zip(*shuffles.items())) or ([], [])
self._reg_shuffle(src, dst)
# 1. load the rest
for v, r in begin_vars.items():
if v in curvars:
continue
self._log(f'loading missing var {v} into {r}')
self._synth(Load(self.cur_fun, [r, v]))
def while_stat(self, tree):
mark = self._get_marker()
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[0])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[0].op.out, mark)
self._synth(op)
self.visit(tree.children[1])
for v, r in begin_vars.items():
rvars = self.cur_fun.regs.vars
if v not in rvars or rvars[v] != r:
self._log(f'loading missing var {v}')
self._synth(Load(self.cur_fun, [r, v]))
self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def for_stat(self, tree):
mark = self._get_marker()
@ -919,16 +1006,14 @@ class CcInterp(lark.visitors.Interpreter):
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[1])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[1].op.out, mark)
self._synth(op)
self.visit(tree.children[3])
self.visit(tree.children[2]) # 3rd statement
for v, r in begin_vars.items():
rvars = self.cur_fun.regs.vars
if v not in rvars or rvars[v] != r:
self._log(f'loading missing var {v}')
self._synth(Load(self.cur_fun, [r, v]))
self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def _unary_op(op):
def _f(self, tree):
@ -964,11 +1049,15 @@ class CcInterp(lark.visitors.Interpreter):
type = ch_strct.type.pointed
strct = self.global_scope.structs[type.struct]
offs = strct.offset(ch_field)
load_offs = lark.Tree('litteral', [offs])
addop = lark.Tree('add', [tree.children[0], load_offs])
self.visit(addop)
addop.type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [addop]
if offs > 0:
load_offs = lark.Tree('litteral', [offs])
addop = lark.Tree('add', [tree.children[0], load_offs])
self.visit(addop)
addop.type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [addop]
else:
tree.children[0].type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [tree.children[0]]
self._deref(tree)
def post_increment(self, tree):
@ -1066,7 +1155,7 @@ class CcInterp(lark.visitors.Interpreter):
if len(tree.children) > 2:
initval = tree.children[2].children[0].op.out
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.assign(var, initval, cleanup)
self.cur_fun.regs.assign(var, initval, cleanup=cleanup)
self._log(f'assigning {var} = {initval}')
def fun_def(self, tree):
@ -1087,7 +1176,7 @@ class CcInterp(lark.visitors.Interpreter):
param.stackaddr = fun.get_stack()
self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}')
cleanup = getcleanup(param)
fun.regs.loaded(param, f'r{i}', cleanup)
fun.regs.loaded(param, f'r{i}', cleanup=cleanup)
fun_scope = Scope()
fun_scope.parent = self.cur_scope
@ -1161,6 +1250,16 @@ def parse_tree(tree, debug=False):
for fun in inte.funs:
out += fun.synth()
out.append('')
for sym, dat in inte.strings.items():
out += [f'.global {sym}',
f'{sym}:']
if len(dat) % 2 != 0:
dat += '\0'
dat = dat.encode()
nwords = len(dat) // 2
out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)]
out.append('')
return '\n'.join(filter_dupes(out))