cc: add support for string literals and some stuff

- string literals added as rodata
- fixed register shuffling in for/while loops & fun calls
- a few other fixes
This commit is contained in:
Paul Mathieu 2021-03-14 12:16:21 -07:00
parent 54c69dd962
commit 7841987b6e

View File

@ -1,10 +1,12 @@
import argparse import argparse
import ast
import collections import collections
import contextlib import contextlib
from dataclasses import dataclass from dataclasses import dataclass
import importlib import importlib
import io import io
import re import re
import struct
import subprocess import subprocess
import sys import sys
@ -89,7 +91,8 @@ class RegBank:
if reg is not None: if reg is not None:
if reg not in self.available: if reg not in self.available:
self.evict(self.var(reg)) self.evict(self.var(reg))
self.available.remove(reg) if reg in self.available:
self.available.remove(reg)
return reg return reg
if not self.available: if not self.available:
assert self.vars, "nothing to clean, no more regs :/" assert self.vars, "nothing to clean, no more regs :/"
@ -189,7 +192,8 @@ class RegBank:
self.reset() self.reset()
vardict, cleanup = snap vardict, cleanup = snap
for var, reg in vardict.items(): for var, reg in vardict.items():
self.loaded(var, reg, cleanup=cleanup.get(reg, None)) self.loaded(var, reg)
self.cleanup[reg] = cleanup.get(reg, [])
@contextlib.contextmanager @contextlib.contextmanager
def borrow(self, howmany): def borrow(self, howmany):
@ -328,6 +332,13 @@ class CcTransform(lark.visitors.Transformer):
pointed = children[-1] pointed = children[-1]
return PointerType(volatile=volat, pointed=pointed) return PointerType(volatile=volat, pointed=pointed)
def comma_expr(self, children):
c1, c2 = children
if c1.data == 'comma_expr':
c1.children.append(c2)
return c1
return lark.Tree('comma_expr', [c1, c2])
volatile = lambda *_: 'volatile' volatile = lambda *_: 'volatile'
const = lambda *_: 'const' const = lambda *_: 'const'
@ -359,6 +370,7 @@ class CcTransform(lark.visitors.Transformer):
IDENTIFIER = str IDENTIFIER = str
SIGNED_NUMBER = int SIGNED_NUMBER = int
HEX_LITTERAL = lambda _, x: int(x[2:], 16) HEX_LITTERAL = lambda _, x: int(x[2:], 16)
ESCAPED_STRING = lambda _, x: ast.literal_eval(x)
class AsmOp: class AsmOp:
@ -624,6 +636,16 @@ class Set16Imm(AsmOp):
else: else:
return [f'set {reg}, {lo}'] return [f'set {reg}, {lo}']
class Move(AsmOp):
scratch_need = 0
def __init__(self, dst, src):
self.dst = dst
self.src = src
def synth(self, scratches):
return [f'or {self.dst}, {self.src}, {self.src}']
class Swap(AsmOp): class Swap(AsmOp):
scratch_need = 1 scratch_need = 1
@ -669,6 +691,27 @@ class IfOp(AsmOp):
return [f'{self.endif_mark}:'] return [f'{self.endif_mark}:']
class IfEq(IfOp):
scratch_need = 0
def __init__(self, a, b, mark, has_else):
self.a = a
self.b = b
self.has_else = has_else
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
if self.has_else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.else_mark}']
else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.endif_mark}']
class WhileOp(AsmOp): class WhileOp(AsmOp):
scratch_need = 1 scratch_need = 1
@ -731,6 +774,8 @@ class CcInterp(lark.visitors.Interpreter):
self.funs = [] self.funs = []
self.next_reg = 0 self.next_reg = 0
self.next_marker = 0 self.next_marker = 0
self.strings = {}
self.next_string_token = 0
def _lookup_symbol(self, s): def _lookup_symbol(self, s):
scope = self.cur_scope scope = self.cur_scope
@ -764,7 +809,7 @@ class CcInterp(lark.visitors.Interpreter):
def _load(self, ident): def _load(self, ident):
s = self._lookup_symbol(ident) s = self._lookup_symbol(ident)
assert s is not None, f'unknown identifier {ident}' assert s is not None, f'unknown identifier {ident}'
if isinstance(s, FunctionSpec) or s in self.global_scope.symbols: if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
reg = self._get_reg() reg = self._get_reg()
return SetAddr(self.cur_fun, [reg, ident]) return SetAddr(self.cur_fun, [reg, ident])
else: else:
@ -795,8 +840,20 @@ class CcInterp(lark.visitors.Interpreter):
assert tree.var is not None, f'undefined identifier: {tree.children[0]}' assert tree.var is not None, f'undefined identifier: {tree.children[0]}'
tree.type = tree.var.type tree.type = tree.var.type
def _make_string(self, s):
nexttok = self.next_string_token
self.next_string_token += 1
token = f'_str{nexttok}'
self.strings[token] = s
self.global_scope.symbols[token] = Variable(litt_type(s), token)
return token
def litteral(self, tree): def litteral(self, tree):
imm = tree.children[0] imm = tree.children[0]
if isinstance(imm, str):
tok = self._make_string(imm)
tree.children = [tok]
return self.identifier(tree)
reg = self._get_reg() reg = self._get_reg()
assert self.cur_fun is not None assert self.cur_fun is not None
tree.op = Set16Imm(self.cur_fun, [reg, imm]) tree.op = Set16Imm(self.cur_fun, [reg, imm])
@ -830,27 +887,37 @@ class CcInterp(lark.visitors.Interpreter):
fun = FunctionSpec(tree.children[0]) fun = FunctionSpec(tree.children[0])
self.cur_scope.symbols[fun.name] = fun self.cur_scope.symbols[fun.name] = fun
def _prep_fun_call(self, fn_reg, params): def _reg_shuffle(self, src, dst):
"""Move all params to r0-rn."""
def swap(old, new): def swap(old, new):
if old == new: if old == new:
return return
oldpos = params.index(old) oldpos = src.index(old)
try: try:
newpos = params.index(new) newpos = src.index(new)
except ValueError: except ValueError:
params[oldpos] = new src[oldpos] = new
else: else:
params[newpos], params[oldpos] = params[oldpos], params[newpos] src[newpos], src[oldpos] = src[oldpos], src[newpos]
self._synth(Swap(old, new)) self._synth(Swap(old, new))
if fn_reg in [f'r{i}' for i in range(len(params))]: for i, r in enumerate(dst):
new_fn = f'r{len(params)}' if r not in src:
self._synth(Swap(fn_reg, new_fn)) self.cur_fun.regs.take(r) # precious!
self._synth(Move(r, src[i]))
else:
swap(r, src[i])
def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn."""
target_regs = [f'r{i}' for i in range(len(params))]
new_fn = fn_reg
while new_fn in target_regs or new_fn in params:
new_fn = self.cur_fun.regs.take()
if new_fn != fn_reg:
self._synth(Move(new_fn, fn_reg))
fn_reg = new_fn fn_reg = new_fn
for i, param in enumerate(params): self._reg_shuffle(params, target_regs)
new = f'r{i}'
swap(param, new)
return fn_reg return fn_reg
def fun_call(self, tree): def fun_call(self, tree):
@ -863,6 +930,7 @@ class CcInterp(lark.visitors.Interpreter):
param_regs = [c.op.out for c in tree.children[1].children] param_regs = [c.op.out for c in tree.children[1].children]
else: else:
param_regs = [tree.children[1].op.out] param_regs = [tree.children[1].op.out]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
self.cur_fun.regs.flush_all() self.cur_fun.regs.flush_all()
for i in range(len(param_regs)): for i in range(len(param_regs)):
self.cur_fun.regs.take(f'r{i}') self.cur_fun.regs.take(f'r{i}')
@ -885,10 +953,16 @@ class CcInterp(lark.visitors.Interpreter):
iter_expression = statement iter_expression = statement
def if_stat(self, tree): def if_stat(self, tree):
self.visit(tree.children[0])
mark = self._get_marker() mark = self._get_marker()
has_else = len(tree.children) > 2 has_else = len(tree.children) > 2
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else]) # optimization!!!!
if tree.children[0].data == 'eq':
self.visit_children(tree.children[0])
a, b = (c.op.out for c in tree.children[0].children)
op = IfEq(a, b, mark, has_else)
else:
self.visit(tree.children[0])
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
self._synth(op) self._synth(op)
begin_vars = self.cur_fun.regs.snap() begin_vars = self.cur_fun.regs.snap()
self.visit(tree.children[1]) self.visit(tree.children[1])
@ -898,20 +972,33 @@ class CcInterp(lark.visitors.Interpreter):
self.visit(tree.children[2]) self.visit(tree.children[2])
self.cur_fun.ops.append(op.synth_endif) self.cur_fun.ops.append(op.synth_endif)
def _restore_vars(self, begin_vars):
curvars = self.cur_fun.regs.vars
# 0. do the loop shuffle
shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars}
dst, src = tuple(zip(*shuffles.items())) or ([], [])
self._reg_shuffle(src, dst)
# 1. load the rest
for v, r in begin_vars.items():
if v in curvars:
continue
self._log(f'loading missing var {v} into {r}')
self._synth(Load(self.cur_fun, [r, v]))
def while_stat(self, tree): def while_stat(self, tree):
mark = self._get_marker() mark = self._get_marker()
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars) begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[0]) self.visit(tree.children[0])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[0].op.out, mark) op = WhileOp(tree.children[0].op.out, mark)
self._synth(op) self._synth(op)
self.visit(tree.children[1]) self.visit(tree.children[1])
for v, r in begin_vars.items(): self._restore_vars(begin_vars)
rvars = self.cur_fun.regs.vars
if v not in rvars or rvars[v] != r:
self._log(f'loading missing var {v}')
self._synth(Load(self.cur_fun, [r, v]))
self.cur_fun.ops.append(op.synth_endwhile) self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def for_stat(self, tree): def for_stat(self, tree):
mark = self._get_marker() mark = self._get_marker()
@ -919,16 +1006,14 @@ class CcInterp(lark.visitors.Interpreter):
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark)) self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars) begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[1]) self.visit(tree.children[1])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[1].op.out, mark) op = WhileOp(tree.children[1].op.out, mark)
self._synth(op) self._synth(op)
self.visit(tree.children[3]) self.visit(tree.children[3])
self.visit(tree.children[2]) # 3rd statement self.visit(tree.children[2]) # 3rd statement
for v, r in begin_vars.items(): self._restore_vars(begin_vars)
rvars = self.cur_fun.regs.vars
if v not in rvars or rvars[v] != r:
self._log(f'loading missing var {v}')
self._synth(Load(self.cur_fun, [r, v]))
self.cur_fun.ops.append(op.synth_endwhile) self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def _unary_op(op): def _unary_op(op):
def _f(self, tree): def _f(self, tree):
@ -964,11 +1049,15 @@ class CcInterp(lark.visitors.Interpreter):
type = ch_strct.type.pointed type = ch_strct.type.pointed
strct = self.global_scope.structs[type.struct] strct = self.global_scope.structs[type.struct]
offs = strct.offset(ch_field) offs = strct.offset(ch_field)
load_offs = lark.Tree('litteral', [offs]) if offs > 0:
addop = lark.Tree('add', [tree.children[0], load_offs]) load_offs = lark.Tree('litteral', [offs])
self.visit(addop) addop = lark.Tree('add', [tree.children[0], load_offs])
addop.type = PointerType(pointed=strct.field_type(ch_field)) self.visit(addop)
tree.children = [addop] addop.type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [addop]
else:
tree.children[0].type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [tree.children[0]]
self._deref(tree) self._deref(tree)
def post_increment(self, tree): def post_increment(self, tree):
@ -1066,7 +1155,7 @@ class CcInterp(lark.visitors.Interpreter):
if len(tree.children) > 2: if len(tree.children) > 2:
initval = tree.children[2].children[0].op.out initval = tree.children[2].children[0].op.out
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var])) cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.assign(var, initval, cleanup) self.cur_fun.regs.assign(var, initval, cleanup=cleanup)
self._log(f'assigning {var} = {initval}') self._log(f'assigning {var} = {initval}')
def fun_def(self, tree): def fun_def(self, tree):
@ -1087,7 +1176,7 @@ class CcInterp(lark.visitors.Interpreter):
param.stackaddr = fun.get_stack() param.stackaddr = fun.get_stack()
self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}') self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}')
cleanup = getcleanup(param) cleanup = getcleanup(param)
fun.regs.loaded(param, f'r{i}', cleanup) fun.regs.loaded(param, f'r{i}', cleanup=cleanup)
fun_scope = Scope() fun_scope = Scope()
fun_scope.parent = self.cur_scope fun_scope.parent = self.cur_scope
@ -1161,6 +1250,16 @@ def parse_tree(tree, debug=False):
for fun in inte.funs: for fun in inte.funs:
out += fun.synth() out += fun.synth()
out.append('') out.append('')
for sym, dat in inte.strings.items():
out += [f'.global {sym}',
f'{sym}:']
if len(dat) % 2 != 0:
dat += '\0'
dat = dat.encode()
nwords = len(dat) // 2
out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)]
out.append('')
return '\n'.join(filter_dupes(out)) return '\n'.join(filter_dupes(out))