synth/tools/cc.py

1544 lines
45 KiB
Python

import argparse
import ast
import collections
import contextlib
from dataclasses import dataclass
import importlib
import inspect
import io
import os
import re
import struct
import subprocess
import sys
import lark
asmod = importlib.import_module("as")
_HERE = os.path.dirname(__file__)
GRAMMAR_FILE = os.path.join(_HERE, 'cc.ebnf')
CPP = ('cpp', '-P')
class Scope:
def __init__(self):
self.symbols = {}
self.parent = None
self.structs = {}
class Variable:
def __init__(self, type, name, addr_reg=None):
self.type = type
self.name = name
self.addr_reg = addr_reg
def __repr__(self):
return f'<Var: {self.type} {self.name}>'
def __str__(self):
return self.name
@classmethod
def from_def(cls, tree):
type = tree.children[0]
name = tree.children[1]
return cls(type, name)
@classmethod
def from_dereference(cls, reg, tree):
return cls(tree.type.pointed, 'deref', addr_reg=reg)
# - all registers pointing to unwritten stuff can be dropped as soon as we're
# done with them:
# - already fed them into the operations
# - end of statement
# - maybe should have a separate storeage for special registers?
# need ways to:
# - assign a list of registers into r0, ... rn for function call
# - ... and run all callbacks
# - get the register for an identifier
# - dereference an expression:
# - essentially turn a temp into an lvalue-address
# - if read, need a second reg for its value
# - mark a variable to have its address in a register
# - delay reading the identifier if it's lhs
# - store the value to memory when registers are claimed
# - if it was modified through:
# - assignment
# - pre-post *crement
# - retrieve the memory type:
# - stack
# - absolute
# - register
# - or just store the value when the variable is assigned
# - get a temporary register
# - return it
# - I guess dereferencing is an upgrade
def type(self, tree):
print(tree)
def allregs():
return (f'r{i}' for i in range(12))
class RegBank:
def __init__(self, logger=None):
self.reset()
self.log = logger or print
def reset(self):
self.available = list(allregs())
self.vars = {}
self.varregs = {}
self.cleanup = collections.defaultdict(list)
self.taken = {}
def __str__(self):
return f'regs(avail={self.available}, vars={self.vars})'
def take(self, reg=None, orwhatever=False):
out = self._take(reg=reg, orwhatever=orwhatever)
self.taken[out] = inspect.stack()[1]
return out
def _take(self, reg=None, orwhatever=False):
if reg is not None:
if reg not in self.available:
if reg in self.varregs:
self.evict(self.var(reg))
elif orwhatever:
return self._take()
else:
raise RuntimeError(
f'trying to overtake {reg}, '
f'varregs: {self.varregs}, vars: {self.vars}, '
f'taken in {self.taken[reg]}')
if reg in self.available:
self.available.remove(reg)
return reg
if not self.available:
assert self.vars, "nothing to clean, no more regs :/"
# storing one random var
var = list(self.vars.keys())[0]
self.evict(var)
return self.available.pop(0)
def give(self, reg):
assert reg is not None
if reg in self.varregs:
# Need to call evict() with the var to free it.
return
self.available.insert(0, reg)
def loaded(self, var, reg, cleanup=None):
"""Tells the regbank some variable was loaded into the given register."""
assert reg is not None
assert var is not None
self.vars[var] = reg
self.varregs[reg] = var
self.take(reg)
if cleanup is not None:
self.log(f'recording cleanup for {reg}({var})')
self.cleanup[reg].append(cleanup)
def stored(self, var):
"""Tells the regbank the given var was stored to memory, register can be freed."""
assert var in self.vars
reg = self.vars.pop(var)
del self.varregs[reg]
self.give(reg)
def load(self, var):
"""Returns the reg associated with the var, or a new reg if none was,
and True if the var was created, False otherwise."""
self.log(f'vars: {self.vars}, varregs: {self.varregs}')
if var not in self.vars:
reg = self.take()
self.vars[var] = reg
self.varregs[reg] = var
return reg, True
return self.vars[var], False
def assign(self, var, reg, cleanup=None):
"""Assign a previously-used register to a variable."""
assert reg is not None
if var in self.vars:
self.stored(var)
if reg in self.varregs:
for cb in self.cleanup.pop(reg, []):
cb(reg)
self.stored(self.varregs[reg])
self.vars[var] = reg
self.varregs[reg] = var
if cleanup is not None:
self.cleanup[reg].append(cleanup)
def var(self, reg):
return self.varregs.get(reg, None)
def evict_all(self):
for var in list(self.vars):
self.evict(var)
def evict(self, var):
"""Runs var callbacks & frees the register."""
assert var in self.vars, f'trying to evict {var}'
self.log(f'evicting {var}')
if var not in self.vars:
return
reg = self.vars[var]
for cb in self.cleanup.pop(reg, []):
cb(reg)
self.stored(var)
def drop_temps(self):
for reg in allregs():
if reg not in self.available and reg not in self.varregs:
self.give(reg)
def swapped(self, reg0, reg1):
var0 = self.varregs.get(reg0, None)
var1 = self.varregs.get(reg1, None)
if var0 is not None:
self.stored(var0)
elif reg0 not in self.available:
self.give(reg0)
if var1 is not None:
self.stored(var1)
elif reg1 not in self.available:
self.give(reg1)
if var0 is not None:
self.loaded(var0, reg1)
if var1 is not None:
self.loaded(var1, reg0)
def snap(self):
cleanup = collections.defaultdict(list)
cleanup.update({r: list(c) for r, c in self.cleanup.items()})
return dict(self.vars), cleanup
def restore(self, snap):
self.reset()
vardict, cleanup = snap
for var, reg in vardict.items():
assert reg is not None, f'bad snapshot: {snap}'
self.loaded(var, reg)
self.cleanup[reg] = cleanup.get(reg, [])
@contextlib.contextmanager
def borrow(self, howmany):
regs = [self.take() for i in range(howmany)]
yield regs
for reg in regs:
self.give(reg)
class FunctionSpec:
def __init__(self, fun_prot):
self.return_type = fun_prot.children[0]
self.name = fun_prot.children[1]
self.param_types = [x.children[0] for x in fun_prot.children[2:]]
def __repr__(self):
params = ', '.join(self.param_types)
return f'<Function: {self.return_type} {self.name}({params})>'
@property
def type(self):
return FunType(ret=self.return_type, params=self.param_types)
class Function:
def __init__(self, fun_prot):
self.locals = {}
self.spec = FunctionSpec(fun_prot)
self.params = [Variable(*x.children) for x in fun_prot.children[2:]]
self.ret = None
self.nextstack = 0
self.topstack = 0
self.statements = []
self.regs = RegBank(logger=self.log)
self.deferred_ops = []
self.fun_calls = 0
self.ops = []
def log(self, line):
self.ops.append(lambda: [f'// {line}'])
@property
def stack_usage(self):
top_usage = max(self.topstack, self.nextstack)
if self.fun_calls > 0:
return top_usage + 2
return top_usage
def get_stack(self, size=2):
stk = self.nextstack
self.nextstack += size
return stk
@contextlib.contextmanager
def tempstack(self):
stk = self.nextstack
yield
if self.nextstack > self.topstack:
self.topstack = self.nextstack
self.nextstack = stk
def param_dict(self):
return {p.name: p for p in self.params}
def __repr__(self):
return repr(self.spec)
def synth(self):
ops = []
if self.fun_calls > 0:
ops += [f'store lr, [sp, -2]']
if self.stack_usage > 0:
ops += [f'set r4, {self.stack_usage}',
f'sub sp, sp, r4']
for op in self.ops:
ops += op()
indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops]
return [f'.global {self.spec.name}',
f'{self.spec.name}:'] + indented
class Struct:
def __init__(self, struct_def):
self.name = struct_def.children[0]
self.fields = [Variable(*x.children) for x in struct_def.children[1].children]
def offset(self, field):
return 2 * [x.name for x in self.fields].index(field)
def field_type(self, field):
for f in self.fields:
if f.name == field:
return f.type
return None
@dataclass
class CType:
volatile: bool = False
const: bool = False
size: int = 2
@dataclass
class StructType(CType):
struct: str = ''
@dataclass
class PointerType(CType):
pointed: CType = None
@dataclass
class FunType(CType):
ret: CType = None
params: [CType] = None
@dataclass
class PodType(CType):
pod: str = 'int'
signed = True
@dataclass
class VoidType(CType):
pass
class CcTransform(lark.visitors.Transformer):
def _binary_op(litt):
@lark.v_args(tree=True)
def _f(self, tree):
left, right = tree.children
if left.data == 'literal' and right.data == 'literal':
tree.data = 'literal'
tree.children = [litt(left.children[0], right.children[0])]
return tree
return _f
def field(self, children):
(field,) = children
return field
def struct_type(self, children):
(name,) = children
return StructType(struct=name)
def pointer(self, children):
volat = 'volatile' in children
pointed = children[-1]
return PointerType(volatile=volat, pointed=pointed)
def funptr_type(self, children):
ret, *params = children
return FunType(ret=ret, params=params)
def comma_expr(self, children):
c1, c2 = children
if c1.data == 'comma_expr':
c1.children.append(c2)
return c1
return lark.Tree('comma_expr', [c1, c2])
volatile = lambda *_: 'volatile'
const = lambda *_: 'const'
def type(self, children):
volat = 'volatile' in children
const = 'const' in children
typ = children[-1]
if isinstance(typ, str):
if typ == 'void':
return VoidType()
size = 1 if typ == 'char' else 2
return PodType(volatile=volat, const=const, pod=typ, size=size)
else:
return typ
# operations on literals can be done by the compiler
add = _binary_op(lambda a, b: a+b)
sub = _binary_op(lambda a, b: a-b)
mul = _binary_op(lambda a, b: a*b)
shl = _binary_op(lambda a, b: a<<b)
CHARACTER = lambda _, x: ord(ast.literal_eval(x)[0])
IDENTIFIER = str
SIGNED_NUMBER = int
HEX_LITTERAL = lambda _, x: int(x[2:], 16)
ESCAPED_STRING = lambda _, x: ast.literal_eval(x)
INT = int
class AsmOp:
scratch_need = 0
def synth(self, scratches):
return [f'nop']
@property
def out(self):
return None
class Reg:
scratch_need = 0
def __init__(self, reg):
self.out = reg
def synth(self, scratches):
return []
class BinOp(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.left, self.right = ops
@property
def out(self):
return self.dest
def make_cpu_bin_op(cpu_op):
class _C(BinOp):
def synth(self, _):
return [f'{cpu_op} {self.dest}, {self.left}, {self.right}']
return _C
AddOp = make_cpu_bin_op('add')
SubOp = make_cpu_bin_op('sub')
MulOp = make_cpu_bin_op('mul')
# no div
# no mod either
AndOp = make_cpu_bin_op('and')
OrOp = make_cpu_bin_op('or')
XorOp = make_cpu_bin_op('xor')
ShrOp = make_cpu_bin_op('shr') # 2nd operand needs to be a literal
class ShlOp(BinOp):
scratch_need = 2
def synth(self, scratches):
sc0, sc1 = scratches
return [f'set {sc1}, 1',
f'or {self.dest}, {self.left}, {self.left}',
f'sub {sc0}, {self.right}, {sc1}',
f'beq [pc, 4]',
f'add {self.dest}, {self.dest}, {self.dest}',
f'sub {sc0}, {sc0}, {sc1}',
f'bneq [pc, -8]']
class LtOp(BinOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {self.dest}, 0',
f'sub {sc0}, {self.left}, {self.right}',
f'bneq [pc, 0]',
f'set {self.dest}, 1']
class GtOp(LtOp):
def __init__(self, fun, ops):
dest, left, right = ops
super(GtOp, self).__init__(fun, [dest, right, left])
class UnOp(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.operand = ops
@property
def out(self):
return self.dest
class Incr(UnOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 1',
f'add {self.dest}, {self.operand}, {sc0}',
f'or {self.operand}, {self.dest}, {self.dest}']
class Decr(UnOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 1',
f'sub {self.dest}, {self.operand}, {sc0}',
f'or {self.operand}, {self.dest}, {self.dest}']
class NotOp(UnOp):
def synth(self, scratches):
return [f'not {self.dest}, {self.operand}']
class BoolNot(UnOp):
def synth(self, scratches):
return [f'set {self.dest}, 0',
f'cmp {self.dest}, {self.operand}',
f'bneq [pc, 0]',
f'set {self.dest}, 1']
class NeqOp(BinOp):
def synth(self, scratches):
return [f'sub {self.dest}, {self.left}, {self.right}']
class FnCall(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
self.dest_fn, self.params = ops
@property
def out(self):
return 'r0'
def synth(self, scratches):
out = []
sc0 = scratches[0]
fn = self.dest_fn
return out + [f'set {sc0}, 0',
f'add lr, pc, {sc0}',
f'or pc, {fn}, {fn}']
class ReturnReg(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
(self.ret_reg,) = ops
def synth(self, scratches):
ops = []
stack_usage = self.fun.stack_usage
(sc0,) = scratches
if stack_usage > 0:
ops += [f'set {sc0}, {stack_usage}',
f'add sp, sp, {sc0}']
if self.fun.ret is not None:
ret = self.ret_reg
ops += [f'or r0, {ret}, {ret}']
if self.fun.fun_calls == 0:
ops += [f'or pc, lr, lr']
else:
ops += [f'load pc, [sp, -2] // return']
return ops
class Load(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.var = ops
fun.log(f'8bit load? {self.var.type} {self.is8bit}')
@property
def is8bit(self):
if not isinstance(self.var.type, PodType):
return False
if self.var.type.pod in ['uint8_t', 'int8_t', 'char']:
return True
return False
def _maybecast8bit(self, reg):
if self.is8bit:
return [f'shr {reg}, {reg}, 8']
return []
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
if self.var.name in self.fun.locals:
src = self.var.stackaddr
return [f'load {reg}, [sp, {src}]'] # no 8bit cast to/from stack
elif self.var.addr_reg is not None:
return [f'load {reg}, [{self.var.addr_reg}]'] + self._maybecast8bit(reg)
else:
return [f'set {reg}, {self.var.name}',
f'nop // in case we load a far global',
f'load {reg}, [{reg}]'] # unsure if we need a cast here
class Push(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
def synth(self, scratches):
return [f'store {self.reg}, [sp, {self.stackaddr}]']
class Pop(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
@property
def out(self):
return self.reg
def synth(self, scratches):
return [f'load {self.reg}, [sp, {self.stackaddr}]']
class Store(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
self.src, self.var = ops
def synth(self, scratches):
(sc,) = scratches
reg = self.src
if self.var.name in self.fun.locals:
dst = self.var.stackaddr
self.fun.log(f'storing {self.var}({reg}) to [sp, {dst}]')
return [f'store {reg}, [sp, {dst}]']
elif self.var.addr_reg is not None:
return [f'store {reg}, [{self.var.addr_reg}]']
return [f'set {sc}, {self.var.name}',
f'nop // you know, in case blah',
f'store {reg}, [{sc}]']
class Assign(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.src, self.var = ops
@property
def out(self):
return self.var
def synth(self, scratches):
return [f'or {self.var}, {self.src}, {self.src}']
class SetAddr(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.dest, self.ident = ops
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
return [f'set {reg}, {self.ident}',
f'nop // placeholder for a long address']
class Set16Imm(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.dest, self.imm16 = ops
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
hi = (self.imm16 >> 8) & 0xff
lo = (self.imm16 >> 0) & 0xff
if hi != 0:
return [f'set {reg}, {lo}',
f'seth {reg}, {hi}']
else:
return [f'set {reg}, {lo}']
class Move(AsmOp):
scratch_need = 0
def __init__(self, dst, src):
self.dst = dst
self.src = src
def synth(self, scratches):
return [f'or {self.dst}, {self.src}, {self.src}']
class Swap(AsmOp):
scratch_need = 1
def __init__(self, a0, a1):
self.a0 = a0
self.a1 = a1
def synth(self, scratches):
(sc0,) = scratches
return [f'or {sc0}, {self.a0}, {self.a0}',
f'or {self.a0}, {self.a1}, {self.a1}',
f'or {self.a1}, {sc0}, {sc0}']
class IfOp(AsmOp):
scratch_need = 1
def __init__(self, fun, op):
self.fun = fun
self.cond, mark, self.has_else = op
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
sc0 = scratches[0]
if self.has_else:
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}', # flag if cond == 0
f'beq {self.else_mark}']
else:
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}',
f'beq {self.endif_mark}']
def synth_else(self):
return [f'cmp r0, r0', # trick because beq is better than "or pc, ."
f'beq {self.endif_mark}',
f'{self.else_mark}:']
def synth_endif(self):
return [f'{self.endif_mark}:']
class IfEq(IfOp):
scratch_need = 0
def __init__(self, a, b, mark, has_else):
self.a = a
self.b = b
self.has_else = has_else
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
if self.has_else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.else_mark}']
else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.endif_mark}']
class WhileOp(AsmOp):
scratch_need = 1
@staticmethod
def synth_loop(mark):
loop_mark = f'_loop_{mark}'
return [f'{loop_mark}:']
def __init__(self, cond, mark):
self.cond = cond
self.loop_mark = f'_loop_{mark}'
self.endwhile_mark = f'_endwhile_{mark}'
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}',
f'beq {self.endwhile_mark}']
def synth_endwhile(self):
return [f'cmp r0, r0',
f'beq {self.loop_mark}',
f'{self.endwhile_mark}:']
class Delayed(AsmOp):
def __init__(self, out_cb):
self.out_cb = out_cb
@property
def out(self):
return self.out_cb()
def synth(self):
return []
def litt_type(val):
if isinstance(val, str):
return PointerType(pointed=PodType(pod='char', size=1, const=True))
elif isinstance(val, float):
return PodType(pod='float', const=True)
else:
return PodType(pod='int', const=True)
def return_type(type):
return type.ret
def deref_type(type):
return type.pointed
class CcInterp(lark.visitors.Interpreter):
def __init__(self):
self.global_scope = Scope()
self.cur_scope = self.global_scope
self.cur_fun = None
self.funs = []
self.next_reg = 0
self.next_marker = 0
self.strings = {}
self.next_string_token = 0
self.rodata = {}
def _lookup_symbol(self, s):
scope = self.cur_scope
while scope is not None:
if s in scope.symbols:
return scope.symbols[s]
scope = scope.parent
return None
def _get_reg(self):
return self.cur_fun.regs.take()
def _get_marker(self):
mark = self.next_marker
self.next_marker += 1
return mark
@contextlib.contextmanager
def _nosynth(self):
old = self._synth
self._synth = lambda *_: None
yield
self._synth = old
def _synth(self, op):
with self.cur_fun.regs.borrow(op.scratch_need) as scratches:
self._log(f'{op.__class__.__name__}')
self.cur_fun.ops.append(lambda: op.synth(scratches))
def _load(self, ident, reg=None):
s = self._lookup_symbol(ident)
assert s is not None, f'unknown identifier {ident}'
if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
reg = self.cur_fun.regs.take(reg=reg, orwhatever=True)
return SetAddr(self.cur_fun, [reg, s.name])
else:
if s.type.volatile:
self._log(f'loading volatile {s}')
return Load(self.cur_fun, [reg, s])
reg, created = self.cur_fun.regs.load(s)
if created:
return Load(self.cur_fun, [reg, s])
else:
self._log(f'{s} was already in {reg}')
return Reg(reg)
def identifier(self, tree):
# TODO: not actually load the value until it's used in an expression
# could have the op.out() function have a side effect that does that
# if it's an assignment, we need to make a Variable with the proper
# address and assign the register to it.
# If it's volatile, we need to flush it.
def delayed_load():
self._log(f'delay-loading {tree.children[0]}')
reg = getattr(tree, 'request_out', None)
op = self._load(tree.children[0], reg=reg)
self._synth(op)
return op.out
tree.op = Delayed(delayed_load)
tree.var = self._lookup_symbol(tree.children[0])
assert tree.var is not None, f'undefined identifier: {tree.children[0]}'
tree.type = tree.var.type
def _make_string(self, s):
nexttok = self.next_string_token
self.next_string_token += 1
token = f'_str{nexttok}'
self.strings[token] = s
self.global_scope.symbols[token] = Variable(litt_type(s), token)
return token
def literal(self, tree):
imm = tree.children[0]
if isinstance(imm, str):
tok = self._make_string(imm)
tree.children = [tok]
return self.identifier(tree)
def delayed():
if hasattr(tree, 'request_out'):
reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True)
else:
reg = self.cur_fun.regs.take()
assert reg is not None, f'weird: {self.cur_fun.regs}'
op = Set16Imm(self.cur_fun, [reg, imm])
self._synth(op)
return op.out
tree.op = Delayed(delayed)
tree.type = litt_type(tree.children[0])
def array_item(self, tree):
# transform blarg[foo] into *(blarg + foo) because reasons
# TODO: bring this over to the main parser, we need to know the type
# of blarg in order to compute the actual offset
s = self._lookup_symbol(tree.children[0].children[0])
if s.type.pointed.size == 1:
# easy deref + add
tree.children = [lark.Tree('add', tree.children)]
return self.dereference(tree)
# everything else is 16 bit, do something like:
# array[index] === *(array + (index << 1))
add1 = lark.Tree('shl', [tree.children[1], lark.Tree('literal', [1])])
add2 = lark.Tree('add', [tree.children[0], add1])
tree.children = [add2]
return self.dereference(tree)
def _assign(self, left, right):
self._log(f'assigning {left} = {right}')
self.cur_fun.regs.assign(left, right)
op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
def assignment(self, tree):
self.visit_children(tree)
val = tree.children[1].op.out
assert val is not None, f'no output!!! {tree.pretty()}'
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
tree.op = self._assign(var, val)
def global_var(self, tree):
self.visit_children(tree)
var = Variable.from_def(tree)
self.global_scope.symbols[var.name] = var
val = 0
if len(tree.children) > 2:
val = tree.children[2].children[0].value
def fun_decl(self, tree):
fun = FunctionSpec(tree.children[0])
self.cur_scope.symbols[fun.name] = fun
def _reg_shuffle(self, src, dst):
def swap(old, new):
if old == new:
return
oldpos = src.index(old)
try:
newpos = src.index(new)
except ValueError:
src[oldpos] = new
else:
src[newpos], src[oldpos] = src[oldpos], src[newpos]
self._synth(Swap(old, new))
for i, r in enumerate(dst):
if r not in src:
#self.cur_fun.regs.take(r) # precious!
self._synth(Move(r, src[i]))
else:
swap(r, src[i])
def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn."""
target_regs = [f'r{i}' for i in range(len(params) + 1)]
self._reg_shuffle(params + [fn_reg], target_regs)
return target_regs[-1]
@contextlib.contextmanager
def _push(self, dontpush=()):
pushed = []
self.cur_fun.regs.evict_all()
with self.cur_fun.tempstack():
for reg in allregs():
if reg not in self.cur_fun.regs.available and reg not in dontpush:
self._log(f'reg in use: {reg}')
stk = self.cur_fun.get_stack()
pushed.append((reg, stk))
self._synth(Push(reg, stk))
self.cur_fun.regs.give(reg)
yield pushed
def fun_call(self, tree):
self.cur_fun.fun_calls += 1
if len(tree.children) == 1:
param_children = []
elif tree.children[1].data == 'comma_expr':
param_children = tree.children[1].children
else:
param_children = [tree.children[1]]
if True:
# pre-allocate output registers
for i, child in enumerate(param_children):
self._log(f'requesting r{i} for {child}')
child.request_out = f'r{i}'
self._log(f'requesting r{len(param_children)} for {tree.children[0]}')
tree.children[0].request_out = f'r{len(param_children)}'
self.visit_children(tree)
fn_reg = tree.children[0].op.out
param_regs = [c.op.out for c in param_children]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
with self._push(dontpush=param_regs + [fn_reg]) as pushed:
self._flush_deferred() # side effects
fn_reg = self._prep_fun_call(fn_reg, param_regs)
self.cur_fun.regs.reset()
self.cur_fun.regs.take(fn_reg)
for reg in (f'r{i}' for i in range(len(param_regs))):
self.cur_fun.regs.take(reg)
tree.op = FnCall(self.cur_fun, [fn_reg, param_regs])
self._synth(tree.op)
self.cur_fun.regs.reset()
pops = []
for reg, stk in pushed:
self.cur_fun.regs.take(reg)
pops.append(Pop(reg, stk))
ret = 'r0'
if hasattr(tree, 'request_out'):
ret = tree.request_out
self._synth(Move(ret, 'r0'))
nret = self.cur_fun.regs.take(ret, orwhatever=True)
if nret != ret:
self._synth(Move(nret, ret))
tree.op = type('blah', (), {'out': nret})
for op in pops:
self._synth(op)
tree.type = return_type(tree.children[0].type)
def _flush_deferred(self):
if self.cur_fun.deferred_ops:
self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}')
for op in self.cur_fun.deferred_ops:
self._synth(op)
self.cur_fun.deferred_ops = []
def statement(self, tree):
self.visit_children(tree)
self._flush_deferred()
self.cur_fun.regs.drop_temps()
iter_expression = statement
def if_stat(self, tree):
mark = self._get_marker()
has_else = len(tree.children) > 2
# optimization!!!!
if tree.children[0].data == 'eq':
self.visit_children(tree.children[0])
a, b = (c.op.out for c in tree.children[0].children)
op = IfEq(a, b, mark, has_else)
else:
self.visit(tree.children[0])
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
self._synth(op)
self.cur_fun.regs.drop_temps()
begin_vars = self.cur_fun.regs.snap()
valid_vars_set = set(begin_vars[0].items())
self.visit(tree.children[1])
valid_vars_set &= set(self.cur_fun.regs.snap()[0].items())
if has_else:
self.cur_fun.regs.restore(begin_vars)
self.cur_fun.ops.append(op.synth_else)
self.visit(tree.children[2])
valid_vars_set &= set(self.cur_fun.regs.snap()[0].items())
valid_vars = dict(valid_vars_set)
valid_snap = (valid_vars, {r: begin_vars[1][r] for r in valid_vars.values()})
self.cur_fun.regs.restore(valid_snap)
self.cur_fun.ops.append(op.synth_endif)
def _restore_vars(self, begin_vars):
curvars = self.cur_fun.regs.vars
# 0. do the loop shuffle
shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars}
dst, src = tuple(zip(*shuffles.items())) or ([], [])
self._reg_shuffle(src, dst)
# 1. load the rest
for v, r in begin_vars.items():
if v in curvars:
continue
self._log(f'loading missing var {v} into {r}')
self._synth(Load(self.cur_fun, [r, v]))
def while_stat(self, tree):
mark = self._get_marker()
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[0])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[0].op.out, mark)
self._synth(op)
self.cur_fun.regs.drop_temps()
self.visit(tree.children[1])
self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def for_stat(self, tree):
mark = self._get_marker()
self.visit(tree.children[0]) # initialization
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[1])
postcond = self.cur_fun.regs.snap()
op = WhileOp(tree.children[1].op.out, mark)
self._synth(op)
self.cur_fun.regs.drop_temps()
self.visit(tree.children[3])
self.visit(tree.children[2]) # 3rd statement
self._restore_vars(begin_vars)
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
def _unary_op(op):
def _f(self, tree):
self.visit_children(tree)
operand = tree.children[0].op.out
reg = self.cur_fun.regs.take()
tree.op = op(self.cur_fun, [reg, operand])
self._synth(tree.op)
self.cur_fun.regs.give(operand)
return _f
def _deref(self, tree):
reg = tree.children[0].op.out
var = Variable.from_dereference(reg, tree.children[0])
self._log(f'making var {var} from derefing reg {reg}')
tree.var = var
def delayed_load():
self._log(f'delay-loading {tree.children[0]}')
op = Load(self.cur_fun, [reg, var])
self._synth(op)
return op.out
tree.op = Delayed(delayed_load)
tree.type = var.type
def dereference(self, tree):
self.visit_children(tree)
self._deref(tree)
def pointer_access(self, tree):
with self._nosynth():
self.visit_children(tree)
ch_strct, ch_field = tree.children
type = ch_strct.type.pointed
strct = self.global_scope.structs[type.struct]
offs = strct.offset(ch_field)
if offs > 0:
load_offs = lark.Tree('literal', [offs])
addop = lark.Tree('add', [tree.children[0], load_offs])
self.visit(addop)
addop.type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [addop]
else:
tree.children[0].type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [tree.children[0]]
self.visit_children(tree)
self._deref(tree)
def post_increment(self, tree):
self.visit_children(tree)
tree.op = Reg(tree.children[0].op.out)
var = tree.children[0].var
reg = tree.op.out
self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
def post_decrement(self, tree):
self.visit_children(tree)
tree.op = Reg(tree.children[0].op.out)
var = tree.children[0].var
reg = tree.op.out
self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
def pre_increment(self, tree):
self.visit_children(tree)
reg = tree.children[0].op.out
tree.op = Incr(self.cur_fun, [reg, reg])
self._synth(tree.op)
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.cleanup[reg].append(cleanup)
bool_not = _unary_op(BoolNot)
def _binary_op(op):
def _f(self, tree):
self.visit_children(tree)
left, right = (x.op.out for x in tree.children)
dest = self.cur_fun.regs.take()
tree.op = op(self.cur_fun, [dest, left, right])
self._synth(tree.op)
self.cur_fun.regs.give(left)
self.cur_fun.regs.give(right)
tree.type = tree.children[0].type # because uhm reasons
return _f
def _combo(uop, bop):
def _f(self, tree):
ftree = lark.Tree(bop, tree.children)
tree.children = [ftree]
uop.__get__(self)(tree)
return _f
shl = _binary_op(ShlOp)
add = _binary_op(AddOp)
sub = _binary_op(SubOp)
mul = _binary_op(MulOp)
_and = _binary_op(AndOp)
_or = _binary_op(OrOp)
# ...
gt = _binary_op(GtOp)
lt = _binary_op(LtOp)
neq = _binary_op(NeqOp)
eq = _combo(bool_not, 'neq')
def shr(self, tree):
self.visit(tree.children[0])
left = tree.children[0].op.out
assert tree.children[1].data == 'literal'
right = tree.children[1].children[0]
dest = self.cur_fun.regs.take()
tree.op = ShrOp(self.cur_fun, [dest, left, right])
self._synth(tree.op)
self.cur_fun.regs.give(left)
tree.type = tree.children[0].type # because uhm reasons
def _assign(self, left, right):
self._log(f'assigning {left} = {right}')
self.cur_fun.regs.assign(left, right)
op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
def _assign_op(op):
def f(self, tree):
opnode = lark.Tree(op, tree.children)
self.visit(opnode)
val = opnode.op.out
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
tree.op = self._assign(var, val)
return f
or_ass = _assign_op('_or')
add_ass = _assign_op('add')
def _forward_op(self, tree):
self.visit_children(tree)
tree.op = tree.children[0].op
def cast(self, tree):
if hasattr(tree, 'request_out'):
tree.children[1].request_out = tree.request_out
self.visit_children(tree)
tree.op = tree.children[1].op
tree.type = tree.children[0]
def _log(self, line):
self.cur_fun.log(line)
def local_var(self, tree):
self.visit_children(tree)
assert self.cur_fun is not None
assert self.cur_scope is not None
var = Variable.from_def(tree)
var.stackaddr = self.cur_fun.get_stack() # will have to invert
self.cur_scope.symbols[var.name] = var
self.cur_fun.locals[var.name] = var
if len(tree.children) > 2:
initval = tree.children[2].op.out
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.assign(var, initval, cleanup=cleanup)
self._log(f'assigning {var} = {initval}')
def local_array(self, tree):
self.visit_children(tree)
assert self.cur_fun is not None
assert self.cur_scope is not None
var = Variable.from_def(tree)
if var.type.const:
# storing in rodata
assert len(tree.children) > 3
tok = self.next_string_token
self.next_string_token += 1
tok = f'_dat{tok}'
assert len(tree.children[3].children) > 0
fields =[c.children[0] for c in tree.children[3].children]
if tree.children[2].data == 'empty_array':
fieldcount = len(fields)
else:
fieldcount = tree.children[2].children[0]
if len(fields) < fieldcount:
fields += [0] * (fieldcount - len(fields))
name = tree.children[1]
self.rodata[tok] = (tree.children[0], fields)
var = Variable(PointerType(pointed=litt_type(fields[0])), tok)
self.global_scope.symbols[name] = var
def fun_def(self, tree):
prot, body = tree.children
fun = Function(prot)
assert self.cur_fun is None
self.cur_fun = fun
self.cur_scope.symbols[fun.spec.name] = fun.spec
params = fun.param_dict()
def getcleanup(var):
def cl(reg):
self._log(f'cleanup for {var} in {reg}')
self._synth(Store(self.cur_fun, [reg, var]))
return cl
for i, param in enumerate(fun.params):
fun.locals[param.name] = param
param.stackaddr = fun.get_stack()
self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}')
cleanup = getcleanup(param)
fun.regs.loaded(param, f'r{i}', cleanup=cleanup)
fun_scope = Scope()
fun_scope.parent = self.cur_scope
fun_scope.symbols.update(params)
body.scope = fun_scope
self.visit_children(tree)
if fun.ret is None:
if fun.spec.name == 'main':
self._synth(Set16Imm(fun, ['r0', 0]))
self._synth(ReturnReg(fun, ['r0']))
elif fun.spec.return_type == VoidType():
self._synth(ReturnReg(fun, ['r0']))
else:
assert fun.ret is not None
self.cur_fun = None
self.funs.append(fun)
def body(self, tree):
bscope = getattr(tree, 'scope', Scope())
bscope.parent = self.cur_scope
self.cur_scope = bscope
self.visit_children(tree)
self.cur_scope = bscope.parent
def return_stat(self, tree):
assert self.cur_fun is not None
self.cur_fun.ret = True
self.visit_children(tree)
expr_reg = tree.children[0].op.out
tree.op = ReturnReg(self.cur_fun, [expr_reg])
self._synth(tree.op)
def struct_def(self, tree):
self.visit_children(tree)
strct = Struct(tree)
self.cur_scope.structs[strct.name] = strct
preamble = [f'_start:',
f'xor r0, r0, r0',
f'xor r1, r1, r1',
f'set sp, 0',
f'seth sp, {0x11}', # 256 bytes of stack ought to be enough
f'set r2, main',
f'set r3, 0',
f'add lr, pc, r3',
f'or pc, r2, r2',
f'cmp r0, r0',
f'beq [pc, -4] // loop forever',
]
def filter_dupes(ops):
dupe_re = re.compile(r'or (r\d+), \1, \1')
for op in ops:
if dupe_re.search(op):
continue
yield op
def parse_tree(tree, debug=False):
tr = CcTransform()
tree = tr.transform(tree)
if debug:
print(tree.pretty())
inte = CcInterp()
inte.visit(tree)
out = []
for fun in inte.funs:
out += fun.synth()
out.append('')
for sym, dat in inte.strings.items():
out += [f'.global {sym}',
f'{sym}:']
dat += '\0'
if len(dat) % 2 != 0:
dat += '\0'
dat = dat.encode()
nwords = len(dat) // 2
out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)]
out.append('')
for sym, (type, fields) in inte.rodata.items():
out += [f'.global {sym}',
f'{sym}:']
out += [f'.word 0x{d:04x}' for d in fields]
out.append('')
return '\n'.join(filter_dupes(out))
def larkparse(f, debug=False):
with open(GRAMMAR_FILE) as g:
asparser = lark.Lark(g.read())
data = f.read()
if isinstance(data, bytes):
data = data.decode()
tree = asparser.parse(data)
return parse_tree(tree, debug=debug)
def parse_args():
parser = argparse.ArgumentParser(description='Compile.')
parser.add_argument('--debug', action='store_true',
help='print the AST')
parser.add_argument('--assembly', '-S', action='store_true',
help='output assembly')
parser.add_argument('--compile', '-c', action='store_true',
help='compile a single file')
parser.add_argument('input', type=argparse.FileType('r'),
default=sys.stdin, help='input file (default: stdin)')
parser.add_argument('--output', '-o', type=argparse.FileType('wb'),
default=sys.stdout.buffer, help='output file')
parser.add_argument('-I', dest='include_dirs', action='append',
default=[], help='include dirs')
return parser.parse_args()
def preprocess(fin, include_dirs):
cmd = list(CPP) + [f'-I{x}' for x in include_dirs]
p = subprocess.Popen(cmd, stdin=fin, stdout=subprocess.PIPE)
out, _ = p.communicate()
if p.returncode != 0:
raise RuntimeError(f'preprocessor error')
return io.StringIO(out.decode())
def assemble(text, fout):
fin = io.StringIO(text)
asmod.write_obj(fout, *asmod.larkparse(fin))
def main():
args = parse_args()
assy = larkparse(preprocess(args.input, args.include_dirs), debug=args.debug)
if args.assembly:
args.output.write(assy.encode() + b'\n')
else:
assemble(assy, args.output)
if __name__ == "__main__":
main()