synth/tools/cc.py

1544 lines
45 KiB
Python
Raw Normal View History

import argparse
import ast
2021-02-17 21:20:30 +00:00
import collections
import contextlib
from dataclasses import dataclass
import importlib
2021-03-21 04:27:26 +00:00
import inspect
import io
import os
2021-02-17 21:20:30 +00:00
import re
import struct
import subprocess
2021-02-17 21:20:30 +00:00
import sys
import lark
asmod = importlib.import_module("as")
_HERE = os.path.dirname(__file__)
GRAMMAR_FILE = os.path.join(_HERE, 'cc.ebnf')
CPP = ('cpp', '-P')
2021-02-17 21:20:30 +00:00
class Scope:
def __init__(self):
self.symbols = {}
self.parent = None
self.structs = {}
2021-02-17 21:20:30 +00:00
class Variable:
def __init__(self, type, name, addr_reg=None):
2021-02-17 21:20:30 +00:00
self.type = type
self.name = name
self.addr_reg = addr_reg
def __repr__(self):
return f'<Var: {self.type} {self.name}>'
2021-03-21 04:27:26 +00:00
def __str__(self):
return self.name
2021-02-17 21:20:30 +00:00
@classmethod
def from_def(cls, tree):
type = tree.children[0]
name = tree.children[1]
return cls(type, name)
2021-02-17 21:20:30 +00:00
@classmethod
def from_dereference(cls, reg, tree):
return cls(tree.type.pointed, 'deref', addr_reg=reg)
2021-02-17 21:20:30 +00:00
# - all registers pointing to unwritten stuff can be dropped as soon as we're
# done with them:
# - already fed them into the operations
# - end of statement
# - maybe should have a separate storeage for special registers?
# need ways to:
# - assign a list of registers into r0, ... rn for function call
# - ... and run all callbacks
# - get the register for an identifier
# - dereference an expression:
# - essentially turn a temp into an lvalue-address
# - if read, need a second reg for its value
# - mark a variable to have its address in a register
# - delay reading the identifier if it's lhs
# - store the value to memory when registers are claimed
# - if it was modified through:
# - assignment
# - pre-post *crement
# - retrieve the memory type:
# - stack
# - absolute
# - register
# - or just store the value when the variable is assigned
# - get a temporary register
# - return it
# - I guess dereferencing is an upgrade
def type(self, tree):
print(tree)
2021-03-21 04:27:26 +00:00
def allregs():
return (f'r{i}' for i in range(12))
2021-02-17 21:20:30 +00:00
class RegBank:
def __init__(self, logger=None):
self.reset()
self.log = logger or print
def reset(self):
2021-03-21 04:27:26 +00:00
self.available = list(allregs())
2021-02-17 21:20:30 +00:00
self.vars = {}
self.varregs = {}
self.cleanup = collections.defaultdict(list)
2021-03-21 04:27:26 +00:00
self.taken = {}
def __str__(self):
return f'regs(avail={self.available}, vars={self.vars})'
def take(self, reg=None, orwhatever=False):
out = self._take(reg=reg, orwhatever=orwhatever)
self.taken[out] = inspect.stack()[1]
return out
2021-02-17 21:20:30 +00:00
2021-03-21 04:27:26 +00:00
def _take(self, reg=None, orwhatever=False):
2021-02-17 21:20:30 +00:00
if reg is not None:
if reg not in self.available:
2021-03-21 04:27:26 +00:00
if reg in self.varregs:
self.evict(self.var(reg))
elif orwhatever:
return self._take()
else:
raise RuntimeError(
f'trying to overtake {reg}, '
f'varregs: {self.varregs}, vars: {self.vars}, '
f'taken in {self.taken[reg]}')
if reg in self.available:
self.available.remove(reg)
2021-02-17 21:20:30 +00:00
return reg
if not self.available:
assert self.vars, "nothing to clean, no more regs :/"
# storing one random var
var = list(self.vars.keys())[0]
2021-02-17 21:20:30 +00:00
self.evict(var)
return self.available.pop(0)
def give(self, reg):
2021-03-21 04:27:26 +00:00
assert reg is not None
2021-02-17 21:20:30 +00:00
if reg in self.varregs:
# Need to call evict() with the var to free it.
return
self.available.insert(0, reg)
def loaded(self, var, reg, cleanup=None):
"""Tells the regbank some variable was loaded into the given register."""
2021-03-21 04:27:26 +00:00
assert reg is not None
assert var is not None
2021-02-17 21:20:30 +00:00
self.vars[var] = reg
self.varregs[reg] = var
self.take(reg)
if cleanup is not None:
2021-03-21 04:27:26 +00:00
self.log(f'recording cleanup for {reg}({var})')
2021-02-17 21:20:30 +00:00
self.cleanup[reg].append(cleanup)
def stored(self, var):
"""Tells the regbank the given var was stored to memory, register can be freed."""
assert var in self.vars
reg = self.vars.pop(var)
del self.varregs[reg]
self.give(reg)
def load(self, var):
"""Returns the reg associated with the var, or a new reg if none was,
and True if the var was created, False otherwise."""
self.log(f'vars: {self.vars}, varregs: {self.varregs}')
if var not in self.vars:
reg = self.take()
self.vars[var] = reg
self.varregs[reg] = var
return reg, True
return self.vars[var], False
def assign(self, var, reg, cleanup=None):
"""Assign a previously-used register to a variable."""
2021-03-21 04:27:26 +00:00
assert reg is not None
2021-02-17 21:20:30 +00:00
if var in self.vars:
self.stored(var)
if reg in self.varregs:
for cb in self.cleanup.pop(reg, []):
cb(reg)
self.stored(self.varregs[reg])
self.vars[var] = reg
self.varregs[reg] = var
if cleanup is not None:
self.cleanup[reg].append(cleanup)
def var(self, reg):
return self.varregs.get(reg, None)
2021-03-21 04:27:26 +00:00
def evict_all(self):
for var in list(self.vars):
self.evict(var)
2021-02-17 21:20:30 +00:00
def evict(self, var):
"""Runs var callbacks & frees the register."""
2021-03-21 04:27:26 +00:00
assert var in self.vars, f'trying to evict {var}'
self.log(f'evicting {var}')
2021-02-17 21:20:30 +00:00
if var not in self.vars:
return
reg = self.vars[var]
2021-03-21 04:27:26 +00:00
for cb in self.cleanup.pop(reg, []):
2021-02-17 21:20:30 +00:00
cb(reg)
self.stored(var)
2021-03-21 04:27:26 +00:00
def drop_temps(self):
for reg in allregs():
if reg not in self.available and reg not in self.varregs:
self.give(reg)
2021-02-17 21:20:30 +00:00
def swapped(self, reg0, reg1):
var0 = self.varregs.get(reg0, None)
var1 = self.varregs.get(reg1, None)
if var0 is not None:
self.stored(var0)
elif reg0 not in self.available:
self.give(reg0)
if var1 is not None:
self.stored(var1)
elif reg1 not in self.available:
self.give(reg1)
if var0 is not None:
self.loaded(var0, reg1)
if var1 is not None:
self.loaded(var1, reg0)
2021-02-23 04:33:35 +00:00
def snap(self):
2021-07-26 06:58:17 +00:00
cleanup = collections.defaultdict(list)
cleanup.update({r: list(c) for r, c in self.cleanup.items()})
return dict(self.vars), cleanup
2021-02-23 04:33:35 +00:00
def restore(self, snap):
self.reset()
vardict, cleanup = snap
for var, reg in vardict.items():
2021-03-21 04:27:26 +00:00
assert reg is not None, f'bad snapshot: {snap}'
self.loaded(var, reg)
self.cleanup[reg] = cleanup.get(reg, [])
2021-02-23 04:33:35 +00:00
2021-02-17 21:20:30 +00:00
@contextlib.contextmanager
def borrow(self, howmany):
regs = [self.take() for i in range(howmany)]
yield regs
for reg in regs:
self.give(reg)
class FunctionSpec:
def __init__(self, fun_prot):
self.return_type = fun_prot.children[0]
self.name = fun_prot.children[1]
self.param_types = [x.children[0] for x in fun_prot.children[2:]]
def __repr__(self):
params = ', '.join(self.param_types)
return f'<Function: {self.return_type} {self.name}({params})>'
@property
def type(self):
return FunType(ret=self.return_type, params=self.param_types)
2021-02-17 21:20:30 +00:00
class Function:
def __init__(self, fun_prot):
self.locals = {}
self.spec = FunctionSpec(fun_prot)
self.params = [Variable(*x.children) for x in fun_prot.children[2:]]
self.ret = None
self.nextstack = 0
2021-03-21 04:27:26 +00:00
self.topstack = 0
2021-02-17 21:20:30 +00:00
self.statements = []
self.regs = RegBank(logger=self.log)
self.deferred_ops = []
self.fun_calls = 0
self.ops = []
def log(self, line):
self.ops.append(lambda: [f'// {line}'])
@property
def stack_usage(self):
2021-03-21 04:27:26 +00:00
top_usage = max(self.topstack, self.nextstack)
if self.fun_calls > 0:
return top_usage + 2
return top_usage
2021-02-17 21:20:30 +00:00
def get_stack(self, size=2):
stk = self.nextstack
self.nextstack += size
return stk
2021-03-21 04:27:26 +00:00
@contextlib.contextmanager
def tempstack(self):
stk = self.nextstack
yield
if self.nextstack > self.topstack:
self.topstack = self.nextstack
self.nextstack = stk
2021-02-17 21:20:30 +00:00
def param_dict(self):
return {p.name: p for p in self.params}
def __repr__(self):
return repr(self.spec)
def synth(self):
2021-03-21 04:27:26 +00:00
ops = []
2021-02-17 21:20:30 +00:00
if self.fun_calls > 0:
2021-03-21 04:27:26 +00:00
ops += [f'store lr, [sp, -2]']
if self.stack_usage > 0:
ops += [f'set r4, {self.stack_usage}',
f'sub sp, sp, r4']
2021-02-17 21:20:30 +00:00
for op in self.ops:
ops += op()
indented = [f' {x}' if x[-1] == ':' else f' {x}' for x in ops]
return [f'.global {self.spec.name}',
f'{self.spec.name}:'] + indented
class Struct:
def __init__(self, struct_def):
self.name = struct_def.children[0]
self.fields = [Variable(*x.children) for x in struct_def.children[1].children]
def offset(self, field):
return 2 * [x.name for x in self.fields].index(field)
def field_type(self, field):
for f in self.fields:
if f.name == field:
return f.type
return None
@dataclass
class CType:
volatile: bool = False
const: bool = False
size: int = 2
@dataclass
class StructType(CType):
struct: str = ''
@dataclass
class PointerType(CType):
pointed: CType = None
@dataclass
class FunType(CType):
ret: CType = None
params: [CType] = None
@dataclass
class PodType(CType):
pod: str = 'int'
signed = True
@dataclass
class VoidType(CType):
pass
2021-02-17 21:20:30 +00:00
class CcTransform(lark.visitors.Transformer):
def _binary_op(litt):
@lark.v_args(tree=True)
def _f(self, tree):
left, right = tree.children
2021-03-21 04:25:01 +00:00
if left.data == 'literal' and right.data == 'literal':
tree.data = 'literal'
2021-02-17 21:20:30 +00:00
tree.children = [litt(left.children[0], right.children[0])]
return tree
return _f
def field(self, children):
(field,) = children
return field
def struct_type(self, children):
(name,) = children
return StructType(struct=name)
def pointer(self, children):
volat = 'volatile' in children
pointed = children[-1]
return PointerType(volatile=volat, pointed=pointed)
2021-04-18 06:04:30 +00:00
def funptr_type(self, children):
ret, *params = children
return FunType(ret=ret, params=params)
def comma_expr(self, children):
c1, c2 = children
if c1.data == 'comma_expr':
c1.children.append(c2)
return c1
return lark.Tree('comma_expr', [c1, c2])
volatile = lambda *_: 'volatile'
const = lambda *_: 'const'
def type(self, children):
volat = 'volatile' in children
const = 'const' in children
typ = children[-1]
if isinstance(typ, str):
if typ == 'void':
return VoidType()
size = 1 if typ == 'char' else 2
return PodType(volatile=volat, const=const, pod=typ, size=size)
else:
return typ
2021-03-21 04:25:01 +00:00
# operations on literals can be done by the compiler
2021-02-17 21:20:30 +00:00
add = _binary_op(lambda a, b: a+b)
sub = _binary_op(lambda a, b: a-b)
mul = _binary_op(lambda a, b: a*b)
shl = _binary_op(lambda a, b: a<<b)
CHARACTER = lambda _, x: ord(ast.literal_eval(x)[0])
2021-02-17 21:20:30 +00:00
IDENTIFIER = str
SIGNED_NUMBER = int
HEX_LITTERAL = lambda _, x: int(x[2:], 16)
ESCAPED_STRING = lambda _, x: ast.literal_eval(x)
INT = int
2021-02-17 21:20:30 +00:00
class AsmOp:
scratch_need = 0
def synth(self, scratches):
return [f'nop']
@property
def out(self):
return None
class Reg:
scratch_need = 0
def __init__(self, reg):
self.out = reg
def synth(self, scratches):
return []
class BinOp(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.left, self.right = ops
@property
def out(self):
return self.dest
def make_cpu_bin_op(cpu_op):
class _C(BinOp):
def synth(self, _):
return [f'{cpu_op} {self.dest}, {self.left}, {self.right}']
return _C
AddOp = make_cpu_bin_op('add')
SubOp = make_cpu_bin_op('sub')
MulOp = make_cpu_bin_op('mul')
# no div
# no mod either
AndOp = make_cpu_bin_op('and')
OrOp = make_cpu_bin_op('or')
2021-02-17 21:20:30 +00:00
XorOp = make_cpu_bin_op('xor')
2021-07-26 06:55:36 +00:00
ShrOp = make_cpu_bin_op('shr') # 2nd operand needs to be a literal
2021-02-17 21:20:30 +00:00
class ShlOp(BinOp):
2021-02-23 04:33:58 +00:00
scratch_need = 2
def synth(self, scratches):
2021-02-23 04:33:58 +00:00
sc0, sc1 = scratches
return [f'set {sc1}, 1',
f'or {self.dest}, {self.left}, {self.left}',
2021-02-23 04:33:58 +00:00
f'sub {sc0}, {self.right}, {sc1}',
f'beq [pc, 4]',
f'add {self.dest}, {self.dest}, {self.dest}',
f'sub {sc0}, {sc0}, {sc1}',
f'bneq [pc, -8]']
2021-02-17 21:20:30 +00:00
class LtOp(BinOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {self.dest}, 0',
f'sub {sc0}, {self.left}, {self.right}',
f'bneq [pc, 0]',
2021-02-17 21:20:30 +00:00
f'set {self.dest}, 1']
class GtOp(LtOp):
def __init__(self, fun, ops):
dest, left, right = ops
super(GtOp, self).__init__(fun, [dest, right, left])
class UnOp(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.operand = ops
@property
def out(self):
return self.dest
class Incr(UnOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 1',
f'add {self.dest}, {self.operand}, {sc0}',
f'or {self.operand}, {self.dest}, {self.dest}']
2021-02-23 04:34:17 +00:00
class Decr(UnOp):
scratch_need = 1
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 1',
f'sub {self.dest}, {self.operand}, {sc0}',
f'or {self.operand}, {self.dest}, {self.dest}']
2021-02-17 21:20:30 +00:00
class NotOp(UnOp):
def synth(self, scratches):
return [f'not {self.dest}, {self.operand}']
class BoolNot(UnOp):
def synth(self, scratches):
return [f'set {self.dest}, 0',
f'cmp {self.dest}, {self.operand}',
f'bneq [pc, 0]',
2021-02-17 21:20:30 +00:00
f'set {self.dest}, 1']
class NeqOp(BinOp):
def synth(self, scratches):
return [f'sub {self.dest}, {self.left}, {self.right}']
class FnCall(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
self.dest_fn, self.params = ops
@property
def out(self):
return 'r0'
def synth(self, scratches):
out = []
sc0 = scratches[0]
fn = self.dest_fn
return out + [f'set {sc0}, 0',
2021-02-17 21:20:30 +00:00
f'add lr, pc, {sc0}',
f'or pc, {fn}, {fn}']
class ReturnReg(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
(self.ret_reg,) = ops
def synth(self, scratches):
2021-03-21 04:27:26 +00:00
ops = []
2021-02-17 21:20:30 +00:00
stack_usage = self.fun.stack_usage
2021-03-21 04:27:26 +00:00
(sc0,) = scratches
if stack_usage > 0:
ops += [f'set {sc0}, {stack_usage}',
f'add sp, sp, {sc0}']
if self.fun.ret is not None:
ret = self.ret_reg
ops += [f'or r0, {ret}, {ret}']
if self.fun.fun_calls == 0:
ops += [f'or pc, lr, lr']
else:
ops += [f'load pc, [sp, -2] // return']
return ops
2021-02-17 21:20:30 +00:00
class Load(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.dest, self.var = ops
fun.log(f'8bit load? {self.var.type} {self.is8bit}')
@property
def is8bit(self):
if not isinstance(self.var.type, PodType):
return False
if self.var.type.pod in ['uint8_t', 'int8_t', 'char']:
return True
return False
def _maybecast8bit(self, reg):
if self.is8bit:
return [f'shr {reg}, {reg}, 8']
return []
2021-02-17 21:20:30 +00:00
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
if self.var.name in self.fun.locals:
src = self.var.stackaddr
return [f'load {reg}, [sp, {src}]'] # no 8bit cast to/from stack
2021-02-17 21:20:30 +00:00
elif self.var.addr_reg is not None:
return [f'load {reg}, [{self.var.addr_reg}]'] + self._maybecast8bit(reg)
2021-02-17 21:20:30 +00:00
else:
return [f'set {reg}, {self.var.name}',
f'nop // in case we load a far global',
f'load {reg}, [{reg}]'] # unsure if we need a cast here
2021-02-17 21:20:30 +00:00
2021-03-21 04:27:26 +00:00
class Push(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
def synth(self, scratches):
return [f'store {self.reg}, [sp, {self.stackaddr}]']
class Pop(AsmOp):
def __init__(self, reg, stackaddr):
self.reg = reg
self.stackaddr = stackaddr
@property
def out(self):
return self.reg
def synth(self, scratches):
return [f'load {self.reg}, [sp, {self.stackaddr}]']
2021-02-17 21:20:30 +00:00
class Store(AsmOp):
scratch_need = 1
def __init__(self, fun, ops):
self.fun = fun
self.src, self.var = ops
def synth(self, scratches):
(sc,) = scratches
reg = self.src
if self.var.name in self.fun.locals:
dst = self.var.stackaddr
self.fun.log(f'storing {self.var}({reg}) to [sp, {dst}]')
return [f'store {reg}, [sp, {dst}]']
elif self.var.addr_reg is not None:
return [f'store {reg}, [{self.var.addr_reg}]']
return [f'set {sc}, {self.var.name}',
f'nop // you know, in case blah',
f'store {reg}, [{sc}]']
class Assign(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.fun = fun
self.src, self.var = ops
@property
def out(self):
return self.var
def synth(self, scratches):
return [f'or {self.var}, {self.src}, {self.src}']
class SetAddr(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.dest, self.ident = ops
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
return [f'set {reg}, {self.ident}',
f'nop // placeholder for a long address']
class Set16Imm(AsmOp):
scratch_need = 0
def __init__(self, fun, ops):
self.dest, self.imm16 = ops
@property
def out(self):
return self.dest
def synth(self, scratches):
reg = self.dest
hi = (self.imm16 >> 8) & 0xff
lo = (self.imm16 >> 0) & 0xff
if hi != 0:
return [f'set {reg}, {lo}',
f'seth {reg}, {hi}']
else:
return [f'set {reg}, {lo}']
class Move(AsmOp):
scratch_need = 0
def __init__(self, dst, src):
self.dst = dst
self.src = src
def synth(self, scratches):
return [f'or {self.dst}, {self.src}, {self.src}']
2021-02-17 21:20:30 +00:00
class Swap(AsmOp):
scratch_need = 1
def __init__(self, a0, a1):
self.a0 = a0
self.a1 = a1
def synth(self, scratches):
(sc0,) = scratches
return [f'or {sc0}, {self.a0}, {self.a0}',
f'or {self.a0}, {self.a1}, {self.a1}',
f'or {self.a1}, {sc0}, {sc0}']
class IfOp(AsmOp):
scratch_need = 1
def __init__(self, fun, op):
self.fun = fun
self.cond, mark, self.has_else = op
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
sc0 = scratches[0]
if self.has_else:
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}', # flag if cond == 0
f'beq {self.else_mark}']
else:
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}',
f'beq {self.endif_mark}']
def synth_else(self):
return [f'cmp r0, r0', # trick because beq is better than "or pc, ."
f'beq {self.endif_mark}',
f'{self.else_mark}:']
def synth_endif(self):
return [f'{self.endif_mark}:']
class IfEq(IfOp):
scratch_need = 0
def __init__(self, a, b, mark, has_else):
self.a = a
self.b = b
self.has_else = has_else
self.then_mark = f'_then_{mark}'
self.else_mark = f'_else_{mark}'
self.endif_mark = f'_endif_{mark}'
def synth(self, scratches):
if self.has_else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.else_mark}']
else:
return [f'cmp {self.a}, {self.b}',
f'bneq {self.endif_mark}']
2021-02-17 21:20:30 +00:00
class WhileOp(AsmOp):
scratch_need = 1
@staticmethod
def synth_loop(mark):
loop_mark = f'_loop_{mark}'
return [f'{loop_mark}:']
def __init__(self, cond, mark):
self.cond = cond
self.loop_mark = f'_loop_{mark}'
self.endwhile_mark = f'_endwhile_{mark}'
def synth(self, scratches):
sc0 = scratches[0]
return [f'set {sc0}, 0',
f'cmp {sc0}, {self.cond}',
f'beq {self.endwhile_mark}']
def synth_endwhile(self):
return [f'cmp r0, r0',
f'beq {self.loop_mark}',
f'{self.endwhile_mark}:']
class Delayed(AsmOp):
def __init__(self, out_cb):
self.out_cb = out_cb
@property
def out(self):
return self.out_cb()
def synth(self):
return []
def litt_type(val):
if isinstance(val, str):
return PointerType(pointed=PodType(pod='char', size=1, const=True))
elif isinstance(val, float):
return PodType(pod='float', const=True)
else:
return PodType(pod='int', const=True)
def return_type(type):
return type.ret
def deref_type(type):
return type.pointed
2021-02-17 21:20:30 +00:00
class CcInterp(lark.visitors.Interpreter):
def __init__(self):
self.global_scope = Scope()
self.cur_scope = self.global_scope
self.cur_fun = None
self.funs = []
self.next_reg = 0
self.next_marker = 0
self.strings = {}
self.next_string_token = 0
self.rodata = {}
2021-02-17 21:20:30 +00:00
def _lookup_symbol(self, s):
scope = self.cur_scope
while scope is not None:
if s in scope.symbols:
return scope.symbols[s]
scope = scope.parent
return None
def _get_reg(self):
return self.cur_fun.regs.take()
def _get_marker(self):
mark = self.next_marker
self.next_marker += 1
return mark
@contextlib.contextmanager
def _nosynth(self):
old = self._synth
self._synth = lambda *_: None
yield
self._synth = old
2021-02-17 21:20:30 +00:00
def _synth(self, op):
with self.cur_fun.regs.borrow(op.scratch_need) as scratches:
self._log(f'{op.__class__.__name__}')
self.cur_fun.ops.append(lambda: op.synth(scratches))
2021-03-21 04:27:26 +00:00
def _load(self, ident, reg=None):
2021-02-17 21:20:30 +00:00
s = self._lookup_symbol(ident)
assert s is not None, f'unknown identifier {ident}'
if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
2021-03-21 04:49:40 +00:00
reg = self.cur_fun.regs.take(reg=reg, orwhatever=True)
return SetAddr(self.cur_fun, [reg, s.name])
2021-02-17 21:20:30 +00:00
else:
if s.type.volatile:
2021-02-17 21:20:30 +00:00
self._log(f'loading volatile {s}')
return Load(self.cur_fun, [reg, s])
reg, created = self.cur_fun.regs.load(s)
if created:
return Load(self.cur_fun, [reg, s])
else:
self._log(f'{s} was already in {reg}')
return Reg(reg)
def identifier(self, tree):
# TODO: not actually load the value until it's used in an expression
# could have the op.out() function have a side effect that does that
# if it's an assignment, we need to make a Variable with the proper
# address and assign the register to it.
# If it's volatile, we need to flush it.
def delayed_load():
self._log(f'delay-loading {tree.children[0]}')
2021-03-21 04:27:26 +00:00
reg = getattr(tree, 'request_out', None)
op = self._load(tree.children[0], reg=reg)
2021-02-17 21:20:30 +00:00
self._synth(op)
return op.out
tree.op = Delayed(delayed_load)
tree.var = self._lookup_symbol(tree.children[0])
assert tree.var is not None, f'undefined identifier: {tree.children[0]}'
tree.type = tree.var.type
2021-02-17 21:20:30 +00:00
def _make_string(self, s):
nexttok = self.next_string_token
self.next_string_token += 1
token = f'_str{nexttok}'
self.strings[token] = s
self.global_scope.symbols[token] = Variable(litt_type(s), token)
return token
2021-03-21 04:25:01 +00:00
def literal(self, tree):
2021-02-17 21:20:30 +00:00
imm = tree.children[0]
if isinstance(imm, str):
tok = self._make_string(imm)
tree.children = [tok]
return self.identifier(tree)
2021-03-21 04:27:26 +00:00
def delayed():
if hasattr(tree, 'request_out'):
reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True)
else:
reg = self.cur_fun.regs.take()
assert reg is not None, f'weird: {self.cur_fun.regs}'
op = Set16Imm(self.cur_fun, [reg, imm])
self._synth(op)
return op.out
tree.op = Delayed(delayed)
tree.type = litt_type(tree.children[0])
2021-02-17 21:20:30 +00:00
def array_item(self, tree):
# transform blarg[foo] into *(blarg + foo) because reasons
# TODO: bring this over to the main parser, we need to know the type
# of blarg in order to compute the actual offset
s = self._lookup_symbol(tree.children[0].children[0])
if s.type.pointed.size == 1:
# easy deref + add
tree.children = [lark.Tree('add', tree.children)]
return self.dereference(tree)
# everything else is 16 bit, do something like:
# array[index] === *(array + (index << 1))
add1 = lark.Tree('shl', [tree.children[1], lark.Tree('literal', [1])])
add2 = lark.Tree('add', [tree.children[0], add1])
tree.children = [add2]
return self.dereference(tree)
2021-02-17 21:20:30 +00:00
def _assign(self, left, right):
self._log(f'assigning {left} = {right}')
self.cur_fun.regs.assign(left, right)
op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
2021-02-17 21:20:30 +00:00
def assignment(self, tree):
self.visit_children(tree)
val = tree.children[1].op.out
2021-03-21 04:27:26 +00:00
assert val is not None, f'no output!!! {tree.pretty()}'
2021-02-17 21:20:30 +00:00
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
tree.op = self._assign(var, val)
2021-02-17 21:20:30 +00:00
def global_var(self, tree):
self.visit_children(tree)
var = Variable.from_def(tree)
self.global_scope.symbols[var.name] = var
val = 0
if len(tree.children) > 2:
val = tree.children[2].children[0].value
def fun_decl(self, tree):
fun = FunctionSpec(tree.children[0])
self.cur_scope.symbols[fun.name] = fun
def _reg_shuffle(self, src, dst):
2021-02-17 21:20:30 +00:00
def swap(old, new):
if old == new:
return
oldpos = src.index(old)
2021-02-17 21:20:30 +00:00
try:
newpos = src.index(new)
2021-02-17 21:20:30 +00:00
except ValueError:
src[oldpos] = new
2021-02-17 21:20:30 +00:00
else:
src[newpos], src[oldpos] = src[oldpos], src[newpos]
2021-02-17 21:20:30 +00:00
self._synth(Swap(old, new))
for i, r in enumerate(dst):
if r not in src:
2021-03-21 04:27:26 +00:00
#self.cur_fun.regs.take(r) # precious!
self._synth(Move(r, src[i]))
else:
swap(r, src[i])
def _prep_fun_call(self, fn_reg, params):
"""Move all params to r0-rn."""
2021-03-21 04:27:26 +00:00
target_regs = [f'r{i}' for i in range(len(params) + 1)]
self._reg_shuffle(params + [fn_reg], target_regs)
return target_regs[-1]
@contextlib.contextmanager
def _push(self, dontpush=()):
pushed = []
self.cur_fun.regs.evict_all()
with self.cur_fun.tempstack():
for reg in allregs():
if reg not in self.cur_fun.regs.available and reg not in dontpush:
self._log(f'reg in use: {reg}')
stk = self.cur_fun.get_stack()
pushed.append((reg, stk))
self._synth(Push(reg, stk))
self.cur_fun.regs.give(reg)
yield pushed
2021-02-17 21:20:30 +00:00
def fun_call(self, tree):
self.cur_fun.fun_calls += 1
if len(tree.children) == 1:
2021-03-21 04:27:26 +00:00
param_children = []
elif tree.children[1].data == 'comma_expr':
2021-03-21 04:27:26 +00:00
param_children = tree.children[1].children
else:
2021-03-21 04:27:26 +00:00
param_children = [tree.children[1]]
if True:
# pre-allocate output registers
for i, child in enumerate(param_children):
self._log(f'requesting r{i} for {child}')
child.request_out = f'r{i}'
self._log(f'requesting r{len(param_children)} for {tree.children[0]}')
tree.children[0].request_out = f'r{len(param_children)}'
self.visit_children(tree)
fn_reg = tree.children[0].op.out
param_regs = [c.op.out for c in param_children]
self._log(f'calling {tree.children[0].children[0]}({param_regs})')
with self._push(dontpush=param_regs + [fn_reg]) as pushed:
self._flush_deferred() # side effects
fn_reg = self._prep_fun_call(fn_reg, param_regs)
self.cur_fun.regs.reset()
self.cur_fun.regs.take(fn_reg)
for reg in (f'r{i}' for i in range(len(param_regs))):
self.cur_fun.regs.take(reg)
tree.op = FnCall(self.cur_fun, [fn_reg, param_regs])
self._synth(tree.op)
self.cur_fun.regs.reset()
pops = []
for reg, stk in pushed:
self.cur_fun.regs.take(reg)
pops.append(Pop(reg, stk))
ret = 'r0'
if hasattr(tree, 'request_out'):
ret = tree.request_out
self._synth(Move(ret, 'r0'))
nret = self.cur_fun.regs.take(ret, orwhatever=True)
if nret != ret:
self._synth(Move(nret, ret))
tree.op = type('blah', (), {'out': nret})
for op in pops:
self._synth(op)
tree.type = return_type(tree.children[0].type)
2021-02-17 21:20:30 +00:00
2021-03-21 04:27:26 +00:00
def _flush_deferred(self):
2021-02-17 21:20:30 +00:00
if self.cur_fun.deferred_ops:
self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}')
for op in self.cur_fun.deferred_ops:
self._synth(op)
self.cur_fun.deferred_ops = []
2021-03-21 04:27:26 +00:00
def statement(self, tree):
self.visit_children(tree)
self._flush_deferred()
self.cur_fun.regs.drop_temps()
2021-02-17 21:20:30 +00:00
iter_expression = statement
def if_stat(self, tree):
mark = self._get_marker()
has_else = len(tree.children) > 2
# optimization!!!!
if tree.children[0].data == 'eq':
self.visit_children(tree.children[0])
a, b = (c.op.out for c in tree.children[0].children)
op = IfEq(a, b, mark, has_else)
else:
self.visit(tree.children[0])
op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
2021-02-17 21:20:30 +00:00
self._synth(op)
2021-03-21 04:27:26 +00:00
self.cur_fun.regs.drop_temps()
2021-02-23 04:33:35 +00:00
begin_vars = self.cur_fun.regs.snap()
valid_vars_set = set(begin_vars[0].items())
2021-02-17 21:20:30 +00:00
self.visit(tree.children[1])
valid_vars_set &= set(self.cur_fun.regs.snap()[0].items())
2021-02-17 21:20:30 +00:00
if has_else:
2021-02-23 04:33:35 +00:00
self.cur_fun.regs.restore(begin_vars)
2021-02-17 21:20:30 +00:00
self.cur_fun.ops.append(op.synth_else)
self.visit(tree.children[2])
valid_vars_set &= set(self.cur_fun.regs.snap()[0].items())
valid_vars = dict(valid_vars_set)
valid_snap = (valid_vars, {r: begin_vars[1][r] for r in valid_vars.values()})
self.cur_fun.regs.restore(valid_snap)
2021-02-17 21:20:30 +00:00
self.cur_fun.ops.append(op.synth_endif)
def _restore_vars(self, begin_vars):
curvars = self.cur_fun.regs.vars
# 0. do the loop shuffle
shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars}
dst, src = tuple(zip(*shuffles.items())) or ([], [])
self._reg_shuffle(src, dst)
# 1. load the rest
for v, r in begin_vars.items():
if v in curvars:
continue
self._log(f'loading missing var {v} into {r}')
self._synth(Load(self.cur_fun, [r, v]))
2021-02-17 21:20:30 +00:00
def while_stat(self, tree):
mark = self._get_marker()
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[0])
postcond = self.cur_fun.regs.snap()
2021-02-17 21:20:30 +00:00
op = WhileOp(tree.children[0].op.out, mark)
self._synth(op)
2021-03-21 04:27:26 +00:00
self.cur_fun.regs.drop_temps()
2021-02-17 21:20:30 +00:00
self.visit(tree.children[1])
self._restore_vars(begin_vars)
2021-02-17 21:20:30 +00:00
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
2021-02-17 21:20:30 +00:00
def for_stat(self, tree):
mark = self._get_marker()
self.visit(tree.children[0]) # initialization
self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
begin_vars = dict(self.cur_fun.regs.vars)
self.visit(tree.children[1])
postcond = self.cur_fun.regs.snap()
2021-02-17 21:20:30 +00:00
op = WhileOp(tree.children[1].op.out, mark)
self._synth(op)
2021-03-21 04:27:26 +00:00
self.cur_fun.regs.drop_temps()
2021-02-17 21:20:30 +00:00
self.visit(tree.children[3])
self.visit(tree.children[2]) # 3rd statement
self._restore_vars(begin_vars)
2021-02-17 21:20:30 +00:00
self.cur_fun.ops.append(op.synth_endwhile)
self.cur_fun.regs.restore(postcond)
2021-02-17 21:20:30 +00:00
def _unary_op(op):
def _f(self, tree):
self.visit_children(tree)
operand = tree.children[0].op.out
reg = self.cur_fun.regs.take()
tree.op = op(self.cur_fun, [reg, operand])
self._synth(tree.op)
self.cur_fun.regs.give(operand)
return _f
def _deref(self, tree):
2021-02-17 21:20:30 +00:00
reg = tree.children[0].op.out
var = Variable.from_dereference(reg, tree.children[0])
2021-02-17 21:20:30 +00:00
self._log(f'making var {var} from derefing reg {reg}')
tree.var = var
def delayed_load():
self._log(f'delay-loading {tree.children[0]}')
op = Load(self.cur_fun, [reg, var])
self._synth(op)
return op.out
tree.op = Delayed(delayed_load)
tree.type = var.type
def dereference(self, tree):
self.visit_children(tree)
self._deref(tree)
def pointer_access(self, tree):
with self._nosynth():
self.visit_children(tree)
ch_strct, ch_field = tree.children
type = ch_strct.type.pointed
strct = self.global_scope.structs[type.struct]
offs = strct.offset(ch_field)
if offs > 0:
2021-03-21 04:25:01 +00:00
load_offs = lark.Tree('literal', [offs])
addop = lark.Tree('add', [tree.children[0], load_offs])
self.visit(addop)
addop.type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [addop]
else:
tree.children[0].type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [tree.children[0]]
2021-03-15 01:07:52 +00:00
self.visit_children(tree)
self._deref(tree)
2021-02-17 21:20:30 +00:00
def post_increment(self, tree):
self.visit_children(tree)
tree.op = Reg(tree.children[0].op.out)
var = tree.children[0].var
reg = tree.op.out
self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
2021-02-17 21:20:30 +00:00
2021-02-23 04:34:17 +00:00
def post_decrement(self, tree):
self.visit_children(tree)
tree.op = Reg(tree.children[0].op.out)
var = tree.children[0].var
reg = tree.op.out
self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
2021-02-23 04:34:17 +00:00
2021-03-21 04:27:26 +00:00
def pre_increment(self, tree):
self.visit_children(tree)
reg = tree.children[0].op.out
tree.op = Incr(self.cur_fun, [reg, reg])
self._synth(tree.op)
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.cleanup[reg].append(cleanup)
2021-02-17 21:20:30 +00:00
bool_not = _unary_op(BoolNot)
def _binary_op(op):
def _f(self, tree):
self.visit_children(tree)
left, right = (x.op.out for x in tree.children)
dest = self.cur_fun.regs.take()
tree.op = op(self.cur_fun, [dest, left, right])
self._synth(tree.op)
self.cur_fun.regs.give(left)
self.cur_fun.regs.give(right)
tree.type = tree.children[0].type # because uhm reasons
2021-02-17 21:20:30 +00:00
return _f
def _combo(uop, bop):
def _f(self, tree):
2021-02-23 04:34:35 +00:00
ftree = lark.Tree(bop, tree.children)
tree.children = [ftree]
uop.__get__(self)(tree)
return _f
shl = _binary_op(ShlOp)
2021-02-17 21:20:30 +00:00
add = _binary_op(AddOp)
sub = _binary_op(SubOp)
mul = _binary_op(MulOp)
_and = _binary_op(AndOp)
_or = _binary_op(OrOp)
2021-02-17 21:20:30 +00:00
# ...
gt = _binary_op(GtOp)
lt = _binary_op(LtOp)
neq = _binary_op(NeqOp)
2021-02-23 04:34:35 +00:00
eq = _combo(bool_not, 'neq')
2021-02-17 21:20:30 +00:00
2021-07-26 06:55:36 +00:00
def shr(self, tree):
self.visit(tree.children[0])
left = tree.children[0].op.out
assert tree.children[1].data == 'literal'
right = tree.children[1].children[0]
dest = self.cur_fun.regs.take()
tree.op = ShrOp(self.cur_fun, [dest, left, right])
self._synth(tree.op)
self.cur_fun.regs.give(left)
tree.type = tree.children[0].type # because uhm reasons
def _assign(self, left, right):
self._log(f'assigning {left} = {right}')
self.cur_fun.regs.assign(left, right)
op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
def _assign_op(op):
def f(self, tree):
opnode = lark.Tree(op, tree.children)
self.visit(opnode)
val = opnode.op.out
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
tree.op = self._assign(var, val)
return f
or_ass = _assign_op('_or')
2021-04-18 06:04:30 +00:00
add_ass = _assign_op('add')
2021-02-17 21:20:30 +00:00
def _forward_op(self, tree):
self.visit_children(tree)
tree.op = tree.children[0].op
def cast(self, tree):
2021-03-21 04:27:26 +00:00
if hasattr(tree, 'request_out'):
tree.children[1].request_out = tree.request_out
2021-02-17 21:20:30 +00:00
self.visit_children(tree)
tree.op = tree.children[1].op
tree.type = tree.children[0]
2021-02-17 21:20:30 +00:00
def _log(self, line):
self.cur_fun.log(line)
def local_var(self, tree):
self.visit_children(tree)
assert self.cur_fun is not None
assert self.cur_scope is not None
var = Variable.from_def(tree)
var.stackaddr = self.cur_fun.get_stack() # will have to invert
self.cur_scope.symbols[var.name] = var
self.cur_fun.locals[var.name] = var
if len(tree.children) > 2:
initval = tree.children[2].op.out
2021-02-17 21:20:30 +00:00
cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
self.cur_fun.regs.assign(var, initval, cleanup=cleanup)
2021-02-17 21:20:30 +00:00
self._log(f'assigning {var} = {initval}')
def local_array(self, tree):
self.visit_children(tree)
assert self.cur_fun is not None
assert self.cur_scope is not None
var = Variable.from_def(tree)
if var.type.const:
# storing in rodata
assert len(tree.children) > 3
tok = self.next_string_token
self.next_string_token += 1
tok = f'_dat{tok}'
assert len(tree.children[3].children) > 0
fields =[c.children[0] for c in tree.children[3].children]
if tree.children[2].data == 'empty_array':
fieldcount = len(fields)
else:
fieldcount = tree.children[2].children[0]
if len(fields) < fieldcount:
fields += [0] * (fieldcount - len(fields))
name = tree.children[1]
self.rodata[tok] = (tree.children[0], fields)
var = Variable(PointerType(pointed=litt_type(fields[0])), tok)
self.global_scope.symbols[name] = var
2021-02-17 21:20:30 +00:00
def fun_def(self, tree):
prot, body = tree.children
fun = Function(prot)
assert self.cur_fun is None
self.cur_fun = fun
self.cur_scope.symbols[fun.spec.name] = fun.spec
params = fun.param_dict()
def getcleanup(var):
2021-03-21 04:27:26 +00:00
def cl(reg):
self._log(f'cleanup for {var} in {reg}')
self._synth(Store(self.cur_fun, [reg, var]))
return cl
2021-02-17 21:20:30 +00:00
for i, param in enumerate(fun.params):
fun.locals[param.name] = param
param.stackaddr = fun.get_stack()
self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}')
cleanup = getcleanup(param)
fun.regs.loaded(param, f'r{i}', cleanup=cleanup)
2021-02-17 21:20:30 +00:00
fun_scope = Scope()
fun_scope.parent = self.cur_scope
fun_scope.symbols.update(params)
body.scope = fun_scope
self.visit_children(tree)
if fun.ret is None:
if fun.spec.name == 'main':
self._synth(Set16Imm(fun, ['r0', 0]))
self._synth(ReturnReg(fun, ['r0']))
elif fun.spec.return_type == VoidType():
2021-02-17 21:20:30 +00:00
self._synth(ReturnReg(fun, ['r0']))
else:
assert fun.ret is not None
self.cur_fun = None
self.funs.append(fun)
def body(self, tree):
bscope = getattr(tree, 'scope', Scope())
bscope.parent = self.cur_scope
self.cur_scope = bscope
self.visit_children(tree)
self.cur_scope = bscope.parent
def return_stat(self, tree):
assert self.cur_fun is not None
self.cur_fun.ret = True
self.visit_children(tree)
expr_reg = tree.children[0].op.out
tree.op = ReturnReg(self.cur_fun, [expr_reg])
self._synth(tree.op)
def struct_def(self, tree):
self.visit_children(tree)
strct = Struct(tree)
self.cur_scope.structs[strct.name] = strct
2021-02-17 21:20:30 +00:00
preamble = [f'_start:',
f'xor r0, r0, r0',
f'xor r1, r1, r1',
f'set sp, 0',
f'seth sp, {0x11}', # 256 bytes of stack ought to be enough
f'set r2, main',
f'set r3, 0',
2021-02-17 21:20:30 +00:00
f'add lr, pc, r3',
f'or pc, r2, r2',
f'cmp r0, r0',
f'beq [pc, -4] // loop forever',
2021-02-17 21:20:30 +00:00
]
def filter_dupes(ops):
dupe_re = re.compile(r'or (r\d+), \1, \1')
for op in ops:
if dupe_re.search(op):
continue
yield op
def parse_tree(tree, debug=False):
2021-02-17 21:20:30 +00:00
tr = CcTransform()
tree = tr.transform(tree)
if debug:
print(tree.pretty())
2021-02-17 21:20:30 +00:00
inte = CcInterp()
inte.visit(tree)
out = []
for fun in inte.funs:
out += fun.synth()
out.append('')
for sym, dat in inte.strings.items():
out += [f'.global {sym}',
f'{sym}:']
dat += '\0'
if len(dat) % 2 != 0:
dat += '\0'
dat = dat.encode()
nwords = len(dat) // 2
out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)]
out.append('')
for sym, (type, fields) in inte.rodata.items():
out += [f'.global {sym}',
f'{sym}:']
out += [f'.word 0x{d:04x}' for d in fields]
out.append('')
2021-02-17 21:20:30 +00:00
return '\n'.join(filter_dupes(out))
def larkparse(f, debug=False):
2021-02-17 21:20:30 +00:00
with open(GRAMMAR_FILE) as g:
asparser = lark.Lark(g.read())
data = f.read()
if isinstance(data, bytes):
data = data.decode()
tree = asparser.parse(data)
return parse_tree(tree, debug=debug)
def parse_args():
parser = argparse.ArgumentParser(description='Compile.')
parser.add_argument('--debug', action='store_true',
help='print the AST')
parser.add_argument('--assembly', '-S', action='store_true',
help='output assembly')
parser.add_argument('--compile', '-c', action='store_true',
help='compile a single file')
parser.add_argument('input', type=argparse.FileType('r'),
default=sys.stdin, help='input file (default: stdin)')
parser.add_argument('--output', '-o', type=argparse.FileType('wb'),
default=sys.stdout.buffer, help='output file')
2021-03-13 23:44:12 +00:00
parser.add_argument('-I', dest='include_dirs', action='append',
default=[], help='include dirs')
return parser.parse_args()
2021-03-13 23:44:12 +00:00
def preprocess(fin, include_dirs):
cmd = list(CPP) + [f'-I{x}' for x in include_dirs]
p = subprocess.Popen(cmd, stdin=fin, stdout=subprocess.PIPE)
out, _ = p.communicate()
if p.returncode != 0:
raise RuntimeError(f'preprocessor error')
return io.StringIO(out.decode())
def assemble(text, fout):
fin = io.StringIO(text)
asmod.write_obj(fout, *asmod.larkparse(fin))
def main():
args = parse_args()
2021-03-13 23:44:12 +00:00
assy = larkparse(preprocess(args.input, args.include_dirs), debug=args.debug)
if args.assembly:
args.output.write(assy.encode() + b'\n')
else:
assemble(assy, args.output)
2021-02-17 21:20:30 +00:00
if __name__ == "__main__":
main()