synth/tools/cc.py
2021-03-20 21:25:01 -07:00

1335 lines
38 KiB
Python

import argparse
import ast
import collections
import contextlib
from dataclasses import dataclass
import importlib
import io
import re
import struct
import subprocess
import sys
import lark
asmod = importlib.import_module("as")
GRAMMAR_FILE = '/home/paulmathieu/vhdl/tools/cc.ebnf'
CPP = ('cpp', '-P')
class Scope:
def __init__(self):
self.symbols = {}
self.parent = None
self.structs = {}
class Variable:
def __init__(self, type, name, addr_reg=None):
self.type = type
self.name = name
self.addr_reg = addr_reg
def __repr__(self):
return f'<Var: {self.type} {self.name}>'
@classmethod
def from_def(cls, tree):
type = tree.children[0]
name = tree.children[1]
return cls(type, name)
@classmethod
def from_dereference(cls, reg, tree):
return cls(tree.type.pointed, 'deref', addr_reg=reg)
# - all registers pointing to unwritten stuff can be dropped as soon as we're
# done with them:
# - already fed them into the operations
# - end of statement
# - maybe should have a separate storeage for special registers?
# need ways to:
# - assign a list of registers into r0, ... rn for function call
# - ... and run all callbacks
# - get the register for an identifier
# - dereference an expression:
# - essentially turn a temp into an lvalue-address
# - if read, need a second reg for its value
# - mark a variable to have its address in a register
# - delay reading the identifier if it's lhs
# - store the value to memory when registers are claimed
# - if it was modified through:
# - assignment
# - pre-post *crement
# - retrieve the memory type:
# - stack
# - absolute
# - register
# - or just store the value when the variable is assigned
# - get a temporary register
# - return it
# - I guess dereferencing is an upgrade
def type(self, tree):
print(tree)
class RegBank:
def __init__(self, logger=None):
self.reset()
self.log = logger or print
def reset(self):
self.available = [f'r{i}' for i in range(12)]
self.vars = {}
self.varregs = {}
self.cleanup = collections.defaultdict(list)
def take(self, reg=None):
if reg is not None:
if reg not in self.available:
self.evict(self.var(reg))
if reg in self.available:
self.available.remove(reg)
return reg
if not self.available:
assert self.vars, "nothing to clean, no more regs :/"
# storing one random var
var = list(self.vars.keys())[0]
self.evict(var)
return self.available.pop(0)
def give(self, reg):
if reg in self.varregs:
# Need to call evict() with the var to free it.
return
self.available.insert(0, reg)
def loaded(self, var, reg, cleanup=None):
"""Tells the regbank some variable was loaded into the given register."""
self.vars[var] = reg
self.varregs[reg] = var
self.take(reg)
if cleanup is not None:
self.log(f'recording cleanup for {reg}({var.name})')
self.cleanup[reg].append(cleanup)
def stored(self, var):
"""Tells the regbank the given var was stored to memory, register can be freed."""
assert var in self.vars
reg = self.vars.pop(var)
del self.varregs[reg]
self.give(reg)
def load(self, var):
"""Returns the reg associated with the var, or a new reg if none was,
and True if the var was created, False otherwise."""
self.log(f'vars: {self.vars}, varregs: {self.varregs}')
if var not in self.vars:
reg = self.take()
self.vars[var] = reg
self.varregs[reg] = var
return reg, True
return self.vars[var], False
def assign(self, var, reg, cleanup=None):
"""Assign a previously-used register to a variable."""
if var in self.vars:
self.stored(var)
if reg in self.varregs:
for cb in self.cleanup.pop(reg, []):
cb(reg)
self.stored(self.varregs[reg])
self.vars[var] = reg
self.varregs[reg] = var
if cleanup is not None:
self.cleanup[reg].append(cleanup)
def var(self, reg):
return self.varregs.get(reg, None)
def evict(self, var):
"""Runs var callbacks & frees the register."""
if var not in self.vars:
return
reg = self.vars[var]
for cb in self.cleanup.pop(var, []):
cb(reg)
self.stored(var)
def flush_all(self):
for reg in list(self.cleanup):
self.log(f'flushing {reg}({self.varregs[reg].name})')
for cb in self.cleanup.pop(reg):
cb(reg)
self.reset()
def swapped(self, reg0, reg1):
var0 = self.varregs.get(reg0, None)
var1 = self.varregs.get(reg1, None)
if var0 is not None:
self.stored(var0)
elif reg0 not in self.available:
self.give(reg0)
if var1 is not None:
self.stored(var1)
elif reg1 not in self.available:
self.give(reg1)
if var0 is not None:
self.loaded(var0, reg1)
if var1 is not None:
self.loaded(var1, reg0)
def snap(self):
return dict(self.vars), dict(self.cleanup)
def restore(self, snap):
self.reset()
vardict, cleanup = snap
for var, reg in vardict.items():
self.loaded(var, reg)
self.cleanup[reg] = cleanup.get(reg, [])
@contextlib.contextmanager
def borrow(self, howmany):
regs = [self.take() for i in range(howmany)]
yield regs
for reg in regs:
self.give(reg)
class FunctionSpec:
def __init__(self, fun_prot):
self.return_type = fun_prot.children[0]
self.name = fun_prot.children[1]
self.param_types = [x.children[0] for x in fun_prot.children[2:]]
def __repr__(self):
params = ', '.join(self.param_types)
return f'<Function: {self.return_type} {self.name}({params})>'
@property
def type(self):
return FunType(ret=self.return_type, params=self.param_types)
class Function:
def __init__(self, fun_prot):
self.locals = {}
self.spec = FunctionSpec(fun_prot)
self.params = [Variable(*x.children) for x in fun_prot.children[2:]]
self.ret = None
self.nextstack = 0
self.statements = []
self.regs = RegBank(logger=self.log)
self.deferred_ops = []
self.fun_calls = 0
self.ops = []
def log(self, line):
self.ops.append(lambda: [f'// {line}'])
@property
def stack_usage(self):
return self.nextstack + 2
def get_stack(self, size=2):
stk = self.nextstack
self.nextstack += size
return stk
def param_dict(self):
return {p.name: p for p in self.params}
def __repr__(self):
return repr(self.spec)
def synth(self):
if self.fun_calls > 0:
preamble = [f'store lr, [sp, -2]',
f'set r4, {self.stack_usage}',
f'sub sp, sp, r4']
else:
preamble = []
ops = preamble
for op in self.ops:
ops += op()
indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops]
return [f'.global {self.spec.name}',
f'{self.spec.name}:'] + indented
class Struct:
def __init__(self, struct_def):
self.name = struct_def.children[0]
self.fields = [Variable(*x.children) for x in struct_def.children[1].children]
def offset(self, field):
return 2 * [x.name for x in self.fields].index(field)
def field_type(self, field):
for f in self.fields:
if f.name == field:
return f.type
return None
@dataclass
class CType:
volatile: bool = False
const: bool = False
@dataclass
class StructType(CType):
struct: str = ''
@dataclass
class PointerType(CType):
pointed: CType = None
@dataclass
class FunType(CType):
ret: CType = None
params: [CType] = None
@dataclass
class PodType(CType):
pod: str = 'int'
signed = True
@dataclass
class VoidType(CType):
pass
class CcTransform(lark.visitors.Transformer):
def _binary_op(litt):
@lark.v_args(tree=True)
def _f(self, tree):
left, right = tree.children
if left.data == 'literal' and right.data == 'literal':
tree.data = 'literal'
tree.children = [litt(left.children[0], right.children[0])]
return tree
return _f
def field(self, children):
(field,) = children
return field
def struct_type(self, children):
(name,) = children
return StructType(struct=name)
def pointer(self, children):
volat = 'volatile' in children
pointed = children[-1]
return PointerType(volatile=volat, pointed=pointed)
def comma_expr(self, children):
c1, c2 = children
if c1.data == 'comma_expr':
c1.children.append(c2)
return c1
return lark.Tree('comma_expr', [c1, c2])
volatile = lambda *_: 'volatile'
const = lambda *_: 'const'
def type(self, children):
volat = 'volatile' in children
const = 'const' in children
typ = children[-1]
if isinstance(typ, str):
if typ == 'void':
return VoidType()
return PodType(volatile=volat, const=const, pod=typ)
else:
return typ
def array_item(self, children):
# transform blarg[foo] into *(blarg + foo) because reasons
# TODO: bring this over to the main parser, we need to know the type
# of blarg in order to compute the actual offset
addop = lark.Tree('add', children)
return lark.Tree('dereference', [addop])
# operations on literals can be done by the compiler
add = _binary_op(lambda a, b: a+b)
sub = _binary_op(lambda a, b: a-b)
mul = _binary_op(lambda a, b: a*b)
shl = _binary_op(lambda a, b: a<<b)
CHARACTER = lambda _, x: ord(ast.literal_eval(x)[0])
IDENTIFIER = str
SIGNED_NUMBER = int
HEX_LITTERAL = lambda _, x: int(x[2:], 16)
ESCAPED_STRING = lambda _, x: ast.literal_eval(x)
class AsmOp:
scratch_need = 0
def synth(self, scratches):
return [f'nop']
@property
def out(self):
return None
class Reg:
scratch_need = 0
def __init__(self, reg):
self.out = reg
def synth(self, scratches):
return []
class BinOp(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.left, self.right = ops
@property
def out(self):
return self.dest
def make_cpu_bin_op(cpu_op):
class _C(BinOp):
def synth(self, _):
return [f'{cpu_op} {self.dest}, {self.left}, {self.right}']
return _C
AddOp = make_cpu_bin_op('add')
SubOp = make_cpu_bin_op('sub')
MulOp = make_cpu_bin_op('mul')
# no div
# no mod either
AndOp = make_cpu_bin_op('and')
OrOp = make_cpu_bin_op('or')
XorOp = make_cpu_bin_op('xor')
class ShlOp(BinOp):
scratch_need = 2
def synth(self, scratches):
sc0, sc1 = scratches
return [f'set {sc1}, 1',
f'or {self.dest}, {self.left}, {self.left}',
f'sub {sc0}, {self.right}, {sc1}',
f'beq [pc, 6]',
f'add {self.dest}, {self.dest}, {self.dest}',
f'sub {sc0}, {sc0}, {sc1}',
f'bneq [pc, -6]']
class LtOp(BinOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {self.dest}, 0',
f'sub {sc0}, {self.left}, {self.right}',
f'bneq [pc, 2]',
f'set {self.dest}, 1']
class GtOp(LtOp):
def __init__(self, fun, ops):
dest, left, right = ops
super(GtOp, self).__init__(fun, [dest, right, left])
class UnOp(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.operand = ops
@property
def out(self):
return self.dest
class Incr(UnOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 1',
f'add {self.dest}, {self.operand}, {sc0}',
f'or {self.operand}, {self.dest}, {self.dest}']
class Decr(UnOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 1',
f'sub {self.dest}, {self.operand}, {sc0}',
f'or {self.operand}, {self.dest}, {self.dest}']
class NotOp(UnOp):
def synth(self, scratches):
return [f'not {self.dest}, {self.operand}']
class BoolNot(UnOp):
def synth(self, scratches):
return [f'set {self.dest}, 0',
f'cmp {self.dest}, {self.operand}',
f'bneq [pc, 2]',
f'set {self.dest}, 1']
class NeqOp(BinOp):
def synth(self, scratches):
return [f'sub {self.dest}, {self.left}, {self.right}']
class FnCall(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
self.dest_fn, self.params = ops
@property
def out(self):
return 'r0'
def synth(self, scratches):
out = []
sc0 = scratches[0]
fn = self.dest_fn
return out + [f'set {sc0}, 2',
f'add lr, pc, {sc0}',
f'or pc, {fn}, {fn}']
class ReturnReg(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
(self.ret_reg,) = ops
def synth(self, scratches):
if self.fun.fun_calls == 0:
return [f'or r0, {self.ret_reg}, {self.ret_reg}',
f'or pc, lr, lr']
sc0 = scratches[0]
stack_usage = self.fun.stack_usage
ret = self.ret_reg
assert stack_usage < 255
return [f'set {sc0}, {stack_usage}',
f'add sp, sp, {sc0}',
f'or r0, {ret}, {ret}',
f'load pc, [sp, -2] // return']
class Load(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.var = ops
fun.log(f'8bit load? {self.var.type} {self.is8bit}')
@property
def is8bit(self):
if not isinstance(self.var.type, PodType):
return False
if self.var.type.pod in ['uint8_t', 'int8_t', 'char']:
return True
return False
def _maybecast8bit(self, reg):
if self.is8bit:
return [f'shr {reg}, {reg}, 8']
return []
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
if self.var.name in self.fun.locals:
src = self.var.stackaddr
return [f'load {reg}, [sp, {src}]'] # no 8bit cast to/from stack
elif self.var.addr_reg is not None:
return [f'load {reg}, [{self.var.addr_reg}]'] + self._maybecast8bit(reg)
else:
return [f'set {reg}, {self.var.name}',
f'nop // in case we load a far global',
f'load {reg}, [{reg}]'] # unsure if we need a cast here
class Store(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
self.src, self.var = ops
@property
def out(self):
return None
def synth(self, scratches):
(sc,) = scratches
reg = self.src
if self.var.name in self.fun.locals:
dst = self.var.stackaddr
self.fun.log(f'storing {self.var}({reg}) to [sp, {dst}]')
return [f'store {reg}, [sp, {dst}]']
elif self.var.addr_reg is not None:
return [f'store {reg}, [{self.var.addr_reg}]']
return [f'set {sc}, {self.var.name}',
f'nop // you know, in case blah',
f'store {reg}, [{sc}]']
class Assign(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.src, self.var = ops
@property
def out(self):
return self.var
def synth(self, scratches):
return [f'or {self.var}, {self.src}, {self.src}']
class SetAddr(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.dest, self.ident = ops
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
return [f'set {reg}, {self.ident}',
f'nop // placeholder for a long address']
class Set16Imm(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.dest, self.imm16 = ops
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
hi = (self.imm16 >> 8) & 0xff
lo = (self.imm16 >> 0) & 0xff
if hi != 0:
return [f'set {reg}, {lo}',
f'seth {reg}, {hi}']
else:
return [f'set {reg}, {lo}']
class Move(AsmOp):
scratch_need = 0
def __init__(self, dst, src):
self.dst = dst
self.src = src
def synth(self, scratches):
return [f'or {self.dst}, {self.src}, {self.src}']
class Swap(AsmOp):
scratch_need = 1
def __init__(self, a0, a1):
self.a0 = a0
self.a1 = a1
def synth(self, scratches):
(sc0,) = scratches
return [f'or {sc0}, {self.a0}, {self.a0}',
f'or {self.a0}, {self.a1}, {self.a1}',
f'or {self.a1}, {sc0}, {sc0}']
class IfOp(AsmOp):
scratch_need = 1
def __init__(self, fun, op):
self.fun = fun
self.cond, mark, self.has_else = op
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
sc0 = scratches[0]
if self.has_else:
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}', # flag if cond == 0
f'beq {self.else_mark}']
else:
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}',
f'beq {self.endif_mark}']
def synth_else(self):
return [f'cmp r0, r0', # trick because beq is better than "or pc, ."
f'beq {self.endif_mark}',
f'{self.else_mark}:']
def synth_endif(self):
return [f'{self.endif_mark}:']
class IfEq(IfOp):
scratch_need = 0
def __init__(self, a, b, mark, has_else):
self.a = a
self.b = b
self.has_else = has_else
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
if self.has_else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.else_mark}']
else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.endif_mark}']
class WhileOp(AsmOp):
scratch_need = 1
@staticmethod
def synth_loop(mark):
loop_mark = f'_loop_{mark}'
return [f'{loop_mark}:']
def __init__(self, cond, mark):
self.cond = cond
self.loop_mark = f'_loop_{mark}'
self.endwhile_mark = f'_endwhile_{mark}'
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}',
f'beq {self.endwhile_mark}']
def synth_endwhile(self):
return [f'cmp r0, r0',
f'beq {self.loop_mark}',
f'{self.endwhile_mark}:']
class Delayed(AsmOp):
def __init__(self, out_cb):
self.out_cb = out_cb
@property
def out(self):
return self.out_cb()
def synth(self):
return []
def litt_type(val):
if isinstance(val, str):
return PointerType(pointed=PodType(pod='char', const=True))
elif isinstance(val, float):
return PodType(pod='float', const=True)
else:
return PodType(pod='int', const=True)
def return_type(type):
return type.ret
def deref_type(type):
return type.pointed
class CcInterp(lark.visitors.Interpreter):
def __init__(self):
self.global_scope = Scope()
self.cur_scope = self.global_scope
self.cur_fun = None
self.funs = []
self.next_reg = 0
self.next_marker = 0
self.strings = {}
self.next_string_token = 0
def _lookup_symbol(self, s):
scope = self.cur_scope
while scope is not None:
if s in scope.symbols:
return scope.symbols[s]
scope = scope.parent
return None
def _get_reg(self):
return self.cur_fun.regs.take()
def _get_marker(self):
mark = self.next_marker
self.next_marker += 1
return mark
@contextlib.contextmanager
def _nosynth(self):
old = self._synth
self._synth = lambda *_: None
yield
self._synth = old
def _synth(self, op):
with self.cur_fun.regs.borrow(op.scratch_need) as scratches:
self._log(f'{op.__class__.__name__}')
self.cur_fun.ops.append(lambda: op.synth(scratches))
def _load(self, ident):
s = self._lookup_symbol(ident)
assert s is not None, f'unknown identifier {ident}'
if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
reg = self._get_reg()
return SetAddr(self.cur_fun, [reg, ident])
else:
if s.type.volatile:
self._log(f'loading volatile {s}')
return Load(self.cur_fun, [reg, s])
reg, created = self.cur_fun.regs.load(s)
if created:
return Load(self.cur_fun, [reg, s])
else:
self._log(f'{s} was already in {reg}')
return Reg(reg)
def identifier(self, tree):
# TODO: not actually load the value until it's used in an expression
# could have the op.out() function have a side effect that does that
# if it's an assignment, we need to make a Variable with the proper
# address and assign the register to it.
# If it's volatile, we need to flush it.
def delayed_load():
self._log(f'delay-loading {tree.children[0]}')
op = self._load(tree.children[0])
self._synth(op)
return op.out
tree.op = Delayed(delayed_load)
tree.var = self._lookup_symbol(tree.children[0])
assert tree.var is not None, f'undefined identifier: {tree.children[0]}'
tree.type = tree.var.type
def _make_string(self, s):
nexttok = self.next_string_token
self.next_string_token += 1
token = f'_str{nexttok}'
self.strings[token] = s
self.global_scope.symbols[token] = Variable(litt_type(s), token)
return token
def literal(self, tree):
imm = tree.children[0]
if isinstance(imm, str):
tok = self._make_string(imm)
tree.children = [tok]
return self.identifier(tree)
reg = self._get_reg()
assert self.cur_fun is not None
tree.op = Set16Imm(self.cur_fun, [reg, imm])
self._synth(tree.op)
tree.type = litt_type(tree.children[0])
def _assign(self, left, right):
self._log(f'assigning {left} = {right}')
self.cur_fun.regs.assign(left, right)
op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
def assignment(self, tree):
self.visit_children(tree)
val = tree.children[1].op.out
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
tree.op = self._assign(var, val)
def global_var(self, tree):
self.visit_children(tree)
var = Variable.from_def(tree)
self.global_scope.symbols[var.name] = var
val = 0
if len(tree.children) > 2:
val = tree.children[2].children[0].value
def fun_decl(self, tree):
fun = FunctionSpec(tree.children[0])
self.cur_scope.symbols[fun.name] = fun
def _reg_shuffle(self, src, dst):
def swap(old, new):
if old == new:
return
oldpos = src.index(old)
try:
newpos = src.index(new)
except ValueError:
src[oldpos] = new
else:
src[newpos], src[oldpos] = src[oldpos], src[newpos]
self._synth(Swap(old, new))
for i, r in enumerate(dst):
if r not in src:
self.cur_fun.regs.take(r) # precious!
self._synth(Move(r, src[i]))
else:
swap(r, src[i])
def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn."""
target_regs = [f'r{i}' for i in range(len(params))]
new_fn = fn_reg
while new_fn in target_regs or new_fn in params:
new_fn = self.cur_fun.regs.take()
if new_fn != fn_reg:
self._synth(Move(new_fn, fn_reg))
fn_reg = new_fn
self._reg_shuffle(params, target_regs)
return fn_reg
def fun_call(self, tree):
self.cur_fun.fun_calls += 1
self.visit_children(tree)
fn_reg = tree.children[0].op.out
if len(tree.children) == 1:
param_regs = []
elif tree.children[1].data == 'comma_expr':
param_regs = [c.op.out for c in tree.children[1].children]
else:
param_regs = [tree.children[1].op.out]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
self.cur_fun.regs.flush_all()
for i in range(len(param_regs)):
self.cur_fun.regs.take(f'r{i}')
fn_reg = self._prep_fun_call(fn_reg, param_regs)
self.cur_fun.regs.take(fn_reg)
tree.op = FnCall(self.cur_fun, [fn_reg, param_regs])
self._synth(tree.op)
self.cur_fun.regs.reset()
self.cur_fun.regs.take('r0')
tree.type = return_type(tree.children[0].type)
def statement(self, tree):
self.visit_children(tree)
if self.cur_fun.deferred_ops:
self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}')
for op in self.cur_fun.deferred_ops:
self._synth(op)
self.cur_fun.deferred_ops = []
iter_expression = statement
def if_stat(self, tree):
mark = self._get_marker()
has_else = len(tree.children) > 2
# optimization!!!!
if tree.children[0].data == 'eq':
self.visit_children(tree.children[0])
a, b = (c.op.out for c in tree.children[0].children)
op = IfEq(a, b, mark, has_else)
else:
self.visit(tree.children[0])
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
self._synth(op)
begin_vars = self.cur_fun.regs.snap()
self.visit(tree.children[1])
if has_else:
self.cur_fun.regs.restore(begin_vars)
self.cur_fun.ops.append(op.synth_else)
self.visit(tree.children[2])
self.cur_fun.ops.append(op.synth_endif)
def _restore_vars(self, begin_vars):
curvars = self.cur_fun.regs.vars
# 0. do the loop shuffle
shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars}
dst, src = tuple(zip(*shuffles.items())) or ([], [])
self._reg_shuffle(src, dst)
# 1. load the rest
for v, r in begin_vars.items():
if v in curvars:
continue
self._log(f'loading missing var {v} into {r}')
self._synth(Load(self.cur_fun, [r, v]))
def while_stat(self, tree):
mark = self._get_marker()
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[0])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[0].op.out, mark)
self._synth(op)
self.visit(tree.children[1])
self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def for_stat(self, tree):
mark = self._get_marker()
self.visit(tree.children[0]) # initialization
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[1])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[1].op.out, mark)
self._synth(op)
self.visit(tree.children[3])
self.visit(tree.children[2]) # 3rd statement
self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def _unary_op(op):
def _f(self, tree):
self.visit_children(tree)
operand = tree.children[0].op.out
reg = self.cur_fun.regs.take()
tree.op = op(self.cur_fun, [reg, operand])
self._synth(tree.op)
self.cur_fun.regs.give(operand)
return _f
def _deref(self, tree):
reg = tree.children[0].op.out
var = Variable.from_dereference(reg, tree.children[0])
self._log(f'making var {var} from derefing reg {reg}')
tree.var = var
def delayed_load():
self._log(f'delay-loading {tree.children[0]}')
op = Load(self.cur_fun, [reg, var])
self._synth(op)
return op.out
tree.op = Delayed(delayed_load)
tree.type = var.type
def dereference(self, tree):
self.visit_children(tree)
self._deref(tree)
def pointer_access(self, tree):
with self._nosynth():
self.visit_children(tree)
ch_strct, ch_field = tree.children
type = ch_strct.type.pointed
strct = self.global_scope.structs[type.struct]
offs = strct.offset(ch_field)
if offs > 0:
load_offs = lark.Tree('literal', [offs])
addop = lark.Tree('add', [tree.children[0], load_offs])
self.visit(addop)
addop.type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [addop]
else:
tree.children[0].type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [tree.children[0]]
self.visit_children(tree)
self._deref(tree)
def post_increment(self, tree):
self.visit_children(tree)
tree.op = Reg(tree.children[0].op.out)
var = tree.children[0].var
reg = tree.op.out
self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
def post_decrement(self, tree):
self.visit_children(tree)
tree.op = Reg(tree.children[0].op.out)
var = tree.children[0].var
reg = tree.op.out
self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
pre_increment = _unary_op(Incr)
bool_not = _unary_op(BoolNot)
def _binary_op(op):
def _f(self, tree):
self.visit_children(tree)
left, right = (x.op.out for x in tree.children)
dest = self.cur_fun.regs.take()
tree.op = op(self.cur_fun, [dest, left, right])
self._synth(tree.op)
self.cur_fun.regs.give(left)
self.cur_fun.regs.give(right)
tree.type = tree.children[0].type # because uhm reasons
return _f
def _combo(uop, bop):
def _f(self, tree):
ftree = lark.Tree(bop, tree.children)
tree.children = [ftree]
uop.__get__(self)(tree)
return _f
shl = _binary_op(ShlOp)
add = _binary_op(AddOp)
sub = _binary_op(SubOp)
mul = _binary_op(MulOp)
_and = _binary_op(AndOp)
_or = _binary_op(OrOp)
# ...
gt = _binary_op(GtOp)
lt = _binary_op(LtOp)
neq = _binary_op(NeqOp)
eq = _combo(bool_not, 'neq')
def _assign(self, left, right):
self._log(f'assigning {left} = {right}')
self.cur_fun.regs.assign(left, right)
op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
def _assign_op(op):
def f(self, tree):
opnode = lark.Tree(op, tree.children)
self.visit(opnode)
val = opnode.op.out
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
tree.op = self._assign(var, val)
return f
or_ass = _assign_op('_or')
def _forward_op(self, tree):
self.visit_children(tree)
tree.op = tree.children[0].op
def cast(self, tree):
self.visit_children(tree)
tree.op = tree.children[1].op
tree.type = tree.children[0]
def _log(self, line):
self.cur_fun.log(line)
def local_var(self, tree):
self.visit_children(tree)
assert self.cur_fun is not None
assert self.cur_scope is not None
var = Variable.from_def(tree)
var.stackaddr = self.cur_fun.get_stack() # will have to invert
self.cur_scope.symbols[var.name] = var
self.cur_fun.locals[var.name] = var
if len(tree.children) > 2:
initval = tree.children[2].children[0].op.out
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.assign(var, initval, cleanup=cleanup)
self._log(f'assigning {var} = {initval}')
def fun_def(self, tree):
prot, body = tree.children
fun = Function(prot)
assert self.cur_fun is None
self.cur_fun = fun
self.cur_scope.symbols[fun.spec.name] = fun.spec
params = fun.param_dict()
def getcleanup(var):
return lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
for i, param in enumerate(fun.params):
fun.locals[param.name] = param
param.stackaddr = fun.get_stack()
self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}')
cleanup = getcleanup(param)
fun.regs.loaded(param, f'r{i}', cleanup=cleanup)
fun_scope = Scope()
fun_scope.parent = self.cur_scope
fun_scope.symbols.update(params)
body.scope = fun_scope
self.visit_children(tree)
if fun.ret is None:
if fun.spec.name == 'main':
self._synth(Set16Imm(fun, ['r0', 0]))
self._synth(ReturnReg(fun, ['r0']))
elif fun.spec.return_type == VoidType():
self._synth(ReturnReg(fun, ['r0']))
else:
assert fun.ret is not None
self.cur_fun = None
self.funs.append(fun)
def body(self, tree):
bscope = getattr(tree, 'scope', Scope())
bscope.parent = self.cur_scope
self.cur_scope = bscope
self.visit_children(tree)
self.cur_scope = bscope.parent
def return_stat(self, tree):
assert self.cur_fun is not None
self.cur_fun.ret = True
self.visit_children(tree)
expr_reg = tree.children[0].op.out
tree.op = ReturnReg(self.cur_fun, [expr_reg])
self._synth(tree.op)
def struct_def(self, tree):
self.visit_children(tree)
strct = Struct(tree)
self.cur_scope.structs[strct.name] = strct
preamble = [f'_start:',
f'xor r0, r0, r0',
f'xor r1, r1, r1',
f'set sp, 0',
f'seth sp, {0x11}', # 256 bytes of stack ought to be enough
f'set r2, main',
f'set r3, 2',
f'add lr, pc, r3',
f'or pc, r2, r2',
f'or pc, pc, pc // loop forever',
]
def filter_dupes(ops):
dupe_re = re.compile(r'or (r\d+), \1, \1')
for op in ops:
if dupe_re.search(op):
continue
yield op
def parse_tree(tree, debug=False):
tr = CcTransform()
tree = tr.transform(tree)
if debug:
print(tree.pretty())
inte = CcInterp()
inte.visit(tree)
out = []
for fun in inte.funs:
out += fun.synth()
out.append('')
for sym, dat in inte.strings.items():
out += [f'.global {sym}',
f'{sym}:']
dat += '\0'
if len(dat) % 2 != 0:
dat += '\0'
dat = dat.encode()
nwords = len(dat) // 2
out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)]
out.append('')
return '\n'.join(filter_dupes(out))
def larkparse(f, debug=False):
with open(GRAMMAR_FILE) as g:
asparser = lark.Lark(g.read())
data = f.read()
if isinstance(data, bytes):
data = data.decode()
tree = asparser.parse(data)
return parse_tree(tree, debug=debug)
def parse_args():
parser = argparse.ArgumentParser(description='Compile.')
parser.add_argument('--debug', action='store_true',
help='print the AST')
parser.add_argument('--assembly', '-S', action='store_true',
help='output assembly')
parser.add_argument('--compile', '-c', action='store_true',
help='compile a single file')
parser.add_argument('input', type=argparse.FileType('r'),
default=sys.stdin, help='input file (default: stdin)')
parser.add_argument('--output', '-o', type=argparse.FileType('wb'),
default=sys.stdout.buffer, help='output file')
parser.add_argument('-I', dest='include_dirs', action='append',
default=[], help='include dirs')
return parser.parse_args()
def preprocess(fin, include_dirs):
cmd = list(CPP) + [f'-I{x}' for x in include_dirs]
p = subprocess.Popen(cmd, stdin=fin, stdout=subprocess.PIPE)
out, _ = p.communicate()
if p.returncode != 0:
raise RuntimeError(f'preprocessor error')
return io.StringIO(out.decode())
def assemble(text, fout):
fin = io.StringIO(text)
asmod.write_obj(fout, *asmod.larkparse(fin))
def main():
args = parse_args()
assy = larkparse(preprocess(args.input, args.include_dirs), debug=args.debug)
if args.assembly:
args.output.write(assy.encode() + b'\n')
else:
assemble(assy, args.output)
if __name__ == "__main__":
main()