import argparse
import ast
import collections
import contextlib
from dataclasses import dataclass
import importlib
import inspect
import io
import os
import re
import struct
import subprocess
import sys

import lark

asmod = importlib.import_module("as")

_HERE = os.path.dirname(__file__)
GRAMMAR_FILE = os.path.join(_HERE, 'cc.ebnf')
CPP = ('cpp', '-P')


class Scope:
    def __init__(self):
        self.symbols = {}
        self.parent = None
        self.structs = {}


class Variable:
    def __init__(self, type, name, addr_reg=None):
        self.type = type
        self.name = name
        self.addr_reg = addr_reg

    def __repr__(self):
        return f'<Var: {self.type} {self.name}>'

    def __str__(self):
        return self.name

    @classmethod
    def from_def(cls, tree):
        type = tree.children[0]
        name = tree.children[1]
        return cls(type, name)

    @classmethod
    def from_dereference(cls, reg, tree):
        return cls(tree.type.pointed, 'deref', addr_reg=reg)


# - all registers pointing to unwritten stuff can be dropped as soon as we're
#   done with them:
#   - already fed them into the operations
#   - end of statement
# - maybe should have a separate storeage for special registers?
# need ways to:
# - assign a list of registers into r0, ... rn for function call
#   - ... and run all callbacks
# - get the register for an identifier
# - dereference an expression:
#   - essentially turn a temp into an lvalue-address
#     - if read, need a second reg for its value
#   - mark a variable to have its address in a register
# - delay reading the identifier if it's lhs
# - store the value to memory when registers are claimed
#   - if it was modified through:
#       - assignment
#       - pre-post *crement
#   - retrieve the memory type:
#     - stack
#     - absolute
#     - register
# - or just store the value when the variable is assigned
# - get a temporary register
# - return it
# - I guess dereferencing is an upgrade

    def type(self, tree):
        print(tree)


def allregs():
    return (f'r{i}' for i in range(12))


class RegBank:
    def __init__(self, logger=None):
        self.reset()
        self.log = logger or print

    def reset(self):
        self.available = list(allregs())
        self.vars = {}
        self.varregs = {}
        self.cleanup = collections.defaultdict(list)
        self.taken = {}

    def __str__(self):
        return f'regs(avail={self.available}, vars={self.vars})'

    def take(self, reg=None, orwhatever=False):
        out = self._take(reg=reg, orwhatever=orwhatever)
        self.taken[out] = inspect.stack()[1]
        return out

    def _take(self, reg=None, orwhatever=False):
        if reg is not None:
            if reg not in self.available:
                if reg in self.varregs:
                    self.evict(self.var(reg))
                elif orwhatever:
                    return self._take()
                else:
                    raise RuntimeError(
                            f'trying to overtake {reg}, '
                            f'varregs: {self.varregs}, vars: {self.vars}, '
                            f'taken in {self.taken[reg]}')
            if reg in self.available:
                self.available.remove(reg)
            return reg
        if not self.available:
            assert self.vars, "nothing to clean, no more regs :/"
            # storing one random var
            var = list(self.vars.keys())[0]
            self.evict(var)
        return self.available.pop(0)

    def give(self, reg):
        assert reg is not None
        if reg in self.varregs:
            # Need to call evict() with the var to free it.
            return
        self.available.insert(0, reg)

    def loaded(self, var, reg, cleanup=None):
        """Tells the regbank some variable was loaded into the given register."""
        assert reg is not None
        assert var is not None
        self.vars[var] = reg
        self.varregs[reg] = var
        self.take(reg)
        if cleanup is not None:
            self.log(f'recording cleanup for {reg}({var})')
            self.cleanup[reg].append(cleanup)

    def stored(self, var):
        """Tells the regbank the given var was stored to memory, register can be freed."""
        assert var in self.vars
        reg = self.vars.pop(var)
        del self.varregs[reg]
        self.give(reg)

    def load(self, var):
        """Returns the reg associated with the var, or a new reg if none was,
        and True if the var was created, False otherwise."""
        self.log(f'vars: {self.vars}, varregs: {self.varregs}')
        if var not in self.vars:
            reg = self.take()
            self.vars[var] = reg
            self.varregs[reg] = var
            return reg, True
        return self.vars[var], False

    def assign(self, var, reg, cleanup=None):
        """Assign a previously-used register to a variable."""
        assert reg is not None
        if var in self.vars:
            self.stored(var)
        if reg in self.varregs:
            for cb in self.cleanup.pop(reg, []):
                cb(reg)
            self.stored(self.varregs[reg])
        self.vars[var] = reg
        self.varregs[reg] = var
        if cleanup is not None:
            self.cleanup[reg].append(cleanup)

    def var(self, reg):
        return self.varregs.get(reg, None)

    def evict_all(self):
        for var in list(self.vars):
            self.evict(var)

    def evict(self, var):
        """Runs var callbacks & frees the register."""
        assert var in self.vars, f'trying to evict {var}'
        self.log(f'evicting {var}')
        if var not in self.vars:
            return
        reg = self.vars[var]
        for cb in self.cleanup.pop(reg, []):
            cb(reg)
        self.stored(var)

    def drop_temps(self):
        for reg in allregs():
            if reg not in self.available and reg not in self.varregs:
                self.give(reg)

    def swapped(self, reg0, reg1):
        var0 = self.varregs.get(reg0, None)
        var1 = self.varregs.get(reg1, None)

        if var0 is not None:
            self.stored(var0)
        elif reg0 not in self.available:
            self.give(reg0)

        if var1 is not None:
            self.stored(var1)
        elif reg1 not in self.available:
            self.give(reg1)

        if var0 is not None:
            self.loaded(var0, reg1)
        if var1 is not None:
            self.loaded(var1, reg0)

    def snap(self):
        cleanup = collections.defaultdict(list)
        cleanup.update({r: list(c) for r, c in self.cleanup.items()})
        return dict(self.vars), cleanup

    def restore(self, snap):
        self.reset()
        vardict, cleanup = snap
        for var, reg in vardict.items():
            assert reg is not None, f'bad snapshot: {snap}'
            self.loaded(var, reg)
            self.cleanup[reg] = cleanup.get(reg, [])

    @contextlib.contextmanager
    def borrow(self, howmany):
        regs = [self.take() for i in range(howmany)]
        yield regs
        for reg in regs:
            self.give(reg)


class FunctionSpec:
    def __init__(self, fun_prot):
        self.return_type = fun_prot.children[0]
        self.name = fun_prot.children[1]
        self.param_types = [x.children[0] for x in fun_prot.children[2:]]

    def __repr__(self):
        params = ', '.join(self.param_types)
        return f'<Function: {self.return_type} {self.name}({params})>'

    @property
    def type(self):
        return FunType(ret=self.return_type, params=self.param_types)


class Function:
    def __init__(self, fun_prot):
        self.locals = {}
        self.spec = FunctionSpec(fun_prot)
        self.params = [Variable(*x.children) for x in fun_prot.children[2:]]
        self.ret = None
        self.nextstack = 0
        self.topstack = 0
        self.statements = []
        self.regs = RegBank(logger=self.log)
        self.deferred_ops = []
        self.fun_calls = 0
        self.ops = []

    def log(self, line):
        self.ops.append(lambda: [f'// {line}'])

    @property
    def stack_usage(self):
        top_usage = max(self.topstack, self.nextstack)
        if self.fun_calls > 0:
            return top_usage + 2
        return top_usage

    def get_stack(self, size=2):
        stk = self.nextstack
        self.nextstack += size
        return stk

    @contextlib.contextmanager
    def tempstack(self):
        stk = self.nextstack
        yield
        if self.nextstack > self.topstack:
            self.topstack = self.nextstack
        self.nextstack = stk

    def param_dict(self):
        return {p.name: p for p in self.params}

    def __repr__(self):
        return repr(self.spec)

    def synth(self):
        ops = []
        if self.fun_calls > 0:
            ops += [f'store lr, [sp, -2]']
        if self.stack_usage > 0:
            ops += [f'set r4, {self.stack_usage}',
                    f'sub sp, sp, r4']
        for op in self.ops:
            ops += op()
        indented = [f' {x}' if x[-1] == ':' else f'  {x}' for x in ops]
        return [f'.global {self.spec.name}',
                f'{self.spec.name}:'] + indented


class Struct:
    def __init__(self, struct_def):
        self.name = struct_def.children[0]
        self.fields = [Variable(*x.children) for x in struct_def.children[1].children]

    def offset(self, field):
        return 2 * [x.name for x in self.fields].index(field)

    def field_type(self, field):
        for f in self.fields:
            if f.name == field:
                return f.type
        return None


@dataclass
class CType:
    volatile: bool = False
    const: bool = False
    size: int = 2

@dataclass
class StructType(CType):
    struct: str = ''

@dataclass
class PointerType(CType):
    pointed: CType = None

@dataclass
class FunType(CType):
    ret: CType = None
    params: [CType] = None

@dataclass
class PodType(CType):
    pod: str = 'int'
    signed = True

@dataclass
class VoidType(CType):
    pass


class CcTransform(lark.visitors.Transformer):
    def _binary_op(litt):
        @lark.v_args(tree=True)
        def _f(self, tree):
            left, right = tree.children
            if left.data == 'literal' and right.data == 'literal':
                tree.data = 'literal'
                tree.children = [litt(left.children[0], right.children[0])]
            return tree
        return _f

    def field(self, children):
        (field,) = children
        return field

    def struct_type(self, children):
        (name,) = children
        return StructType(struct=name)

    def pointer(self, children):
        volat = 'volatile' in children
        pointed = children[-1]
        return PointerType(volatile=volat, pointed=pointed)

    def funptr_type(self, children):
        ret, *params = children
        return FunType(ret=ret, params=params)

    def comma_expr(self, children):
        c1, c2 = children
        if c1.data == 'comma_expr':
            c1.children.append(c2)
            return c1
        return lark.Tree('comma_expr', [c1, c2])

    volatile = lambda *_: 'volatile'
    const = lambda *_: 'const'

    def type(self, children):
        volat = 'volatile' in children
        const = 'const' in children
        typ = children[-1]
        if isinstance(typ, str):
            if typ == 'void':
                return VoidType()
            size = 1 if typ == 'char' else 2
            return PodType(volatile=volat, const=const, pod=typ, size=size)
        else:
            return typ

    # operations on literals can be done by the compiler
    add = _binary_op(lambda a, b: a+b)
    sub = _binary_op(lambda a, b: a-b)
    mul = _binary_op(lambda a, b: a*b)
    shl = _binary_op(lambda a, b: a<<b)

    CHARACTER = lambda _, x: ord(ast.literal_eval(x)[0])
    IDENTIFIER = str
    SIGNED_NUMBER = int
    HEX_LITTERAL = lambda _, x: int(x[2:], 16)
    ESCAPED_STRING = lambda _, x: ast.literal_eval(x)
    INT = int


class AsmOp:
    scratch_need = 0

    def synth(self, scratches):
        return [f'nop']

    @property
    def out(self):
        return None


class Reg:
    scratch_need = 0

    def __init__(self, reg):
        self.out = reg

    def synth(self, scratches):
        return []


class BinOp(AsmOp):
    scratch_need = 0

    def __init__(self, fun, ops):
        self.fun = fun
        self.dest, self.left, self.right = ops

    @property
    def out(self):
        return self.dest


def make_cpu_bin_op(cpu_op):
    class _C(BinOp):
        def synth(self, _):
            return [f'{cpu_op} {self.dest}, {self.left}, {self.right}']
    return _C


AddOp = make_cpu_bin_op('add')
SubOp = make_cpu_bin_op('sub')
MulOp = make_cpu_bin_op('mul')
# no div
# no mod either
AndOp = make_cpu_bin_op('and')
OrOp = make_cpu_bin_op('or')
XorOp = make_cpu_bin_op('xor')
ShrOp = make_cpu_bin_op('shr')  # 2nd operand needs to be a literal

class ShlOp(BinOp):
    scratch_need = 2
    def synth(self, scratches):
        sc0, sc1 = scratches
        return [f'set {sc1}, 1',
                f'or {self.dest}, {self.left}, {self.left}',
                f'sub {sc0}, {self.right}, {sc1}',
                f'beq [pc, 4]',
                f'add {self.dest}, {self.dest}, {self.dest}',
                f'sub {sc0}, {sc0}, {sc1}',
                f'bneq [pc, -8]']


class LtOp(BinOp):
    scratch_need = 1
    def synth(self, scratches):
        sc0 = scratches[0]
        return [f'set {self.dest}, 0',
                f'sub {sc0}, {self.left}, {self.right}',
                f'bneq [pc, 0]',
                f'set {self.dest}, 1']

class GtOp(LtOp):
    def __init__(self, fun, ops):
        dest, left, right = ops
        super(GtOp, self).__init__(fun, [dest, right, left])


class UnOp(AsmOp):
    scratch_need = 0

    def __init__(self, fun, ops):
        self.fun = fun
        self.dest, self.operand = ops

    @property
    def out(self):
        return self.dest


class Incr(UnOp):
    scratch_need = 1
    def synth(self, scratches):
        sc0 = scratches[0]
        return [f'set {sc0}, 1',
                f'add {self.dest}, {self.operand}, {sc0}',
                f'or {self.operand}, {self.dest}, {self.dest}']

class Decr(UnOp):
    scratch_need = 1
    def synth(self, scratches):
        sc0 = scratches[0]
        return [f'set {sc0}, 1',
                f'sub {self.dest}, {self.operand}, {sc0}',
                f'or {self.operand}, {self.dest}, {self.dest}']


class NotOp(UnOp):
    def synth(self, scratches):
        return [f'not {self.dest}, {self.operand}']

class BoolNot(UnOp):
    def synth(self, scratches):
        return [f'set {self.dest}, 0',
                f'cmp {self.dest}, {self.operand}',
                f'bneq [pc, 0]',
                f'set {self.dest}, 1']

class NeqOp(BinOp):
    def synth(self, scratches):
        return [f'sub {self.dest}, {self.left}, {self.right}']


class FnCall(AsmOp):
    scratch_need = 1

    def __init__(self, fun, ops):
        self.fun = fun
        self.dest_fn, self.params = ops

    @property
    def out(self):
        return 'r0'

    def synth(self, scratches):
        out = []
        sc0 = scratches[0]
        fn = self.dest_fn

        return out + [f'set {sc0}, 0',
                f'add lr, pc, {sc0}',
                f'or pc, {fn}, {fn}']


class ReturnReg(AsmOp):
    scratch_need = 1

    def __init__(self, fun, ops):
        self.fun = fun
        (self.ret_reg,) = ops

    def synth(self, scratches):
        ops = []
        stack_usage = self.fun.stack_usage
        (sc0,) = scratches
        if stack_usage > 0:
            ops += [f'set {sc0}, {stack_usage}',
                    f'add sp, sp, {sc0}']
        if self.fun.ret is not None:
            ret = self.ret_reg
            ops += [f'or r0, {ret}, {ret}']

        if self.fun.fun_calls == 0:
            ops += [f'or pc, lr, lr']
        else:
            ops += [f'load pc, [sp, -2]  // return']

        return ops


class Load(AsmOp):
    scratch_need = 0

    def __init__(self, fun, ops):
        self.fun = fun
        self.dest, self.var = ops
        fun.log(f'8bit load? {self.var.type} {self.is8bit}')

    @property
    def is8bit(self):
        if not isinstance(self.var.type, PodType):
            return False
        if self.var.type.pod in ['uint8_t', 'int8_t', 'char']:
            return True
        return False

    def _maybecast8bit(self, reg):
        if self.is8bit:
            return [f'shr {reg}, {reg}, 8']
        return []

    @property
    def out(self):
        return self.dest

    def synth(self, scratches):
        reg = self.dest
        if self.var.name in self.fun.locals:
            src = self.var.stackaddr
            return [f'load {reg}, [sp, {src}]']  # no 8bit cast to/from stack
        elif self.var.addr_reg is not None:
            return [f'load {reg}, [{self.var.addr_reg}]'] + self._maybecast8bit(reg)
        else:
            return [f'set {reg}, {self.var.name}',
                    f'nop  // in case we load a far global',
                    f'load {reg}, [{reg}]']  # unsure if we need a cast here


class Push(AsmOp):
    def __init__(self, reg, stackaddr):
        self.reg = reg
        self.stackaddr = stackaddr

    def synth(self, scratches):
        return [f'store {self.reg}, [sp, {self.stackaddr}]']


class Pop(AsmOp):
    def __init__(self, reg, stackaddr):
        self.reg = reg
        self.stackaddr = stackaddr

    @property
    def out(self):
        return self.reg

    def synth(self, scratches):
        return [f'load {self.reg}, [sp, {self.stackaddr}]']


class Store(AsmOp):
    scratch_need = 1

    def __init__(self, fun, ops):
        self.fun = fun
        self.src, self.var = ops

    def synth(self, scratches):
        (sc,) = scratches
        reg = self.src
        if self.var.name in self.fun.locals:
            dst = self.var.stackaddr
            self.fun.log(f'storing {self.var}({reg}) to [sp, {dst}]')
            return [f'store {reg}, [sp, {dst}]']
        elif self.var.addr_reg is not None:
            return [f'store {reg}, [{self.var.addr_reg}]']
        return [f'set {sc}, {self.var.name}',
                f'nop  // you know, in case blah',
                f'store {reg}, [{sc}]']


class Assign(AsmOp):
    scratch_need = 0

    def __init__(self, fun, ops):
        self.fun = fun
        self.src, self.var = ops

    @property
    def out(self):
        return self.var

    def synth(self, scratches):
        return [f'or {self.var}, {self.src}, {self.src}']


class SetAddr(AsmOp):
    scratch_need = 0

    def __init__(self, fun, ops):
        self.dest, self.ident = ops

    @property
    def out(self):
        return self.dest

    def synth(self, scratches):
        reg = self.dest
        return [f'set {reg}, {self.ident}',
                f'nop  // placeholder for a long address']


class Set16Imm(AsmOp):
    scratch_need = 0

    def __init__(self, fun, ops):
        self.dest, self.imm16 = ops

    @property
    def out(self):
        return self.dest

    def synth(self, scratches):
        reg = self.dest
        hi = (self.imm16 >> 8) & 0xff
        lo = (self.imm16 >> 0) & 0xff
        if hi != 0:
            return [f'set {reg}, {lo}',
                    f'seth {reg}, {hi}']
        else:
            return [f'set {reg}, {lo}']

class Move(AsmOp):
    scratch_need = 0

    def __init__(self, dst, src):
        self.dst = dst
        self.src = src

    def synth(self, scratches):
        return [f'or {self.dst}, {self.src}, {self.src}']

class Swap(AsmOp):
    scratch_need = 1

    def __init__(self, a0, a1):
        self.a0 = a0
        self.a1 = a1

    def synth(self, scratches):
        (sc0,) = scratches
        return [f'or {sc0}, {self.a0}, {self.a0}',
                f'or {self.a0}, {self.a1}, {self.a1}',
                f'or {self.a1}, {sc0}, {sc0}']


class IfOp(AsmOp):
    scratch_need = 1

    def __init__(self, fun, op):
        self.fun = fun
        self.cond, mark, self.has_else = op

        self.then_mark = f'_then_{mark}'
        self.else_mark = f'_else_{mark}'
        self.endif_mark = f'_endif_{mark}'

    def synth(self, scratches):
        sc0 = scratches[0]
        if self.has_else:
            return [f'set {sc0}, 0',
                    f'cmp {sc0}, {self.cond}',  # flag if cond == 0
                    f'beq {self.else_mark}']
        else:
            return [f'set {sc0}, 0',
                    f'cmp {sc0}, {self.cond}',
                    f'beq {self.endif_mark}']

    def synth_else(self):
        return [f'cmp r0, r0',  # trick because beq is better than "or pc, ."
                f'beq {self.endif_mark}',
                f'{self.else_mark}:']

    def synth_endif(self):
        return [f'{self.endif_mark}:']


class IfEq(IfOp):
    scratch_need = 0

    def __init__(self, a, b, mark, has_else):
        self.a = a
        self.b = b
        self.has_else = has_else

        self.then_mark = f'_then_{mark}'
        self.else_mark = f'_else_{mark}'
        self.endif_mark = f'_endif_{mark}'

    def synth(self, scratches):
        if self.has_else:
            return [f'cmp {self.a}, {self.b}',
                    f'bneq {self.else_mark}']
        else:
            return [f'cmp {self.a}, {self.b}',
                    f'bneq {self.endif_mark}']


class WhileOp(AsmOp):
    scratch_need = 1

    @staticmethod
    def synth_loop(mark):
        loop_mark = f'_loop_{mark}'
        return [f'{loop_mark}:']

    def __init__(self, cond, mark):
        self.cond = cond
        self.loop_mark = f'_loop_{mark}'
        self.endwhile_mark  = f'_endwhile_{mark}'

    def synth(self, scratches):
        sc0 = scratches[0]
        return [f'set {sc0}, 0',
                f'cmp {sc0}, {self.cond}',
                f'beq {self.endwhile_mark}']

    def synth_endwhile(self):
        return [f'cmp r0, r0',
                f'beq {self.loop_mark}',
                f'{self.endwhile_mark}:']


class Delayed(AsmOp):
    def __init__(self, out_cb):
        self.out_cb = out_cb

    @property
    def out(self):
        return self.out_cb()

    def synth(self):
        return []


def litt_type(val):
    if isinstance(val, str):
        return PointerType(pointed=PodType(pod='char', size=1, const=True))
    elif isinstance(val, float):
        return PodType(pod='float', const=True)
    else:
        return PodType(pod='int', const=True)


def return_type(type):
    return type.ret


def deref_type(type):
    return type.pointed


class CcInterp(lark.visitors.Interpreter):
    def __init__(self):
        self.global_scope = Scope()
        self.cur_scope = self.global_scope
        self.cur_fun = None
        self.funs = []
        self.next_reg = 0
        self.next_marker = 0
        self.strings = {}
        self.next_string_token = 0
        self.rodata = {}

    def _lookup_symbol(self, s):
        scope = self.cur_scope
        while scope is not None:
            if s in scope.symbols:
                return scope.symbols[s]
            scope = scope.parent
        return None

    def _get_reg(self):
        return self.cur_fun.regs.take()

    def _get_marker(self):
        mark = self.next_marker
        self.next_marker += 1
        return mark


    @contextlib.contextmanager
    def _nosynth(self):
        old = self._synth
        self._synth = lambda *_: None
        yield
        self._synth = old

    def _synth(self, op):
        with self.cur_fun.regs.borrow(op.scratch_need) as scratches:
            self._log(f'{op.__class__.__name__}')
            self.cur_fun.ops.append(lambda: op.synth(scratches))

    def _load(self, ident, reg=None):
        s = self._lookup_symbol(ident)
        assert s is not None, f'unknown identifier {ident}'
        if isinstance(s, FunctionSpec) or ident in self.global_scope.symbols:
            reg = self.cur_fun.regs.take(reg=reg, orwhatever=True)
            return SetAddr(self.cur_fun, [reg, s.name])
        else:
            if s.type.volatile:
                self._log(f'loading volatile {s}')
                return Load(self.cur_fun, [reg, s])
            reg, created = self.cur_fun.regs.load(s)
            if created:
                return Load(self.cur_fun, [reg, s])
            else:
                self._log(f'{s} was already in {reg}')
                return Reg(reg)

    def identifier(self, tree):
        # TODO: not actually load the value until it's used in an expression
        # could have the op.out() function have a side effect that does that
        # if it's an assignment, we need to make a Variable with the proper
        # address and assign the register to it.
        # If it's volatile, we need to flush it.
        def delayed_load():
            self._log(f'delay-loading {tree.children[0]}')
            reg = getattr(tree, 'request_out', None)
            op = self._load(tree.children[0], reg=reg)
            self._synth(op)
            return op.out

        tree.op = Delayed(delayed_load)
        tree.var = self._lookup_symbol(tree.children[0])
        assert tree.var is not None, f'undefined identifier: {tree.children[0]}'
        tree.type = tree.var.type

    def _make_string(self, s):
        nexttok = self.next_string_token
        self.next_string_token += 1
        token = f'_str{nexttok}'
        self.strings[token] = s
        self.global_scope.symbols[token] = Variable(litt_type(s), token)
        return token

    def literal(self, tree):
        imm = tree.children[0]
        if isinstance(imm, str):
            tok = self._make_string(imm)
            tree.children = [tok]
            return self.identifier(tree)
        def delayed():
            if hasattr(tree, 'request_out'):
                reg = self.cur_fun.regs.take(reg=tree.request_out, orwhatever=True)
            else:
                reg = self.cur_fun.regs.take()
                assert reg is not None, f'weird: {self.cur_fun.regs}'
            op = Set16Imm(self.cur_fun, [reg, imm])
            self._synth(op)
            return op.out

        tree.op = Delayed(delayed)
        tree.type = litt_type(tree.children[0])

    def array_item(self, tree):
        # transform blarg[foo] into *(blarg + foo) because reasons
        # TODO: bring this over to the main parser, we need to know the type
        # of blarg in order to compute the actual offset
        s = self._lookup_symbol(tree.children[0].children[0])
        if s.type.pointed.size == 1:
            # easy deref + add
            tree.children = [lark.Tree('add', tree.children)]
            return self.dereference(tree)

        # everything else is 16 bit, do something like:
        # array[index] === *(array + (index << 1))
        add1 = lark.Tree('shl', [tree.children[1], lark.Tree('literal', [1])])
        add2 = lark.Tree('add', [tree.children[0], add1])
        tree.children = [add2]
        return self.dereference(tree)

    def _assign(self, left, right):
        self._log(f'assigning {left} = {right}')
        self.cur_fun.regs.assign(left, right)
        op = Store(self.cur_fun, [right, left])
        self._synth(op)
        return op.out

    def assignment(self, tree):
        self.visit_children(tree)
        val = tree.children[1].op.out
        assert val is not None, f'no output!!! {tree.pretty()}'

        # left hand side is an lvalue, retrieve it
        var = tree.children[0].var
        tree.op = self._assign(var, val)

    def global_var(self, tree):
        self.visit_children(tree)
        var = Variable.from_def(tree)
        self.global_scope.symbols[var.name] = var
        val = 0
        if len(tree.children) > 2:
            val = tree.children[2].children[0].value

    def fun_decl(self, tree):
        fun = FunctionSpec(tree.children[0])
        self.cur_scope.symbols[fun.name] = fun

    def _reg_shuffle(self, src, dst):
        def swap(old, new):
            if old == new:
                return
            oldpos = src.index(old)
            try:
                newpos = src.index(new)
            except ValueError:
                src[oldpos] = new
            else:
                src[newpos], src[oldpos] = src[oldpos], src[newpos]
            self._synth(Swap(old, new))

        for i, r in enumerate(dst):
            if r not in src:
                #self.cur_fun.regs.take(r)  # precious!
                self._synth(Move(r, src[i]))
            else:
                swap(r, src[i])


    def _prep_fun_call(self, fn_reg, params):
        """Move all params to r0-rn."""
        target_regs = [f'r{i}' for i in range(len(params) + 1)]
        self._reg_shuffle(params + [fn_reg], target_regs)
        return target_regs[-1]

    @contextlib.contextmanager
    def _push(self, dontpush=()):
        pushed = []
        self.cur_fun.regs.evict_all()
        with self.cur_fun.tempstack():
            for reg in allregs():
                if reg not in self.cur_fun.regs.available and reg not in dontpush:
                    self._log(f'reg in use: {reg}')
                    stk = self.cur_fun.get_stack()
                    pushed.append((reg, stk))
                    self._synth(Push(reg, stk))
                    self.cur_fun.regs.give(reg)
            yield pushed

    def fun_call(self, tree):
        self.cur_fun.fun_calls += 1
        if len(tree.children) == 1:
            param_children = []
        elif tree.children[1].data == 'comma_expr':
            param_children = tree.children[1].children
        else:
            param_children = [tree.children[1]]


        if True:
            # pre-allocate output registers
            for i, child in enumerate(param_children):
                self._log(f'requesting r{i} for {child}')
                child.request_out = f'r{i}'
            self._log(f'requesting r{len(param_children)} for {tree.children[0]}')
            tree.children[0].request_out = f'r{len(param_children)}'
            self.visit_children(tree)
            fn_reg = tree.children[0].op.out
            param_regs = [c.op.out for c in param_children]
            self._log(f'calling {tree.children[0].children[0]}({param_regs})')

        with self._push(dontpush=param_regs + [fn_reg]) as pushed:

            self._flush_deferred()  # side effects
            fn_reg = self._prep_fun_call(fn_reg, param_regs)
            self.cur_fun.regs.reset()
            self.cur_fun.regs.take(fn_reg)
            for reg in (f'r{i}' for i in range(len(param_regs))):
                self.cur_fun.regs.take(reg)

            tree.op = FnCall(self.cur_fun, [fn_reg, param_regs])
            self._synth(tree.op)

            self.cur_fun.regs.reset()

            pops = []
            for reg, stk in pushed:
                self.cur_fun.regs.take(reg)
                pops.append(Pop(reg, stk))

        ret = 'r0'
        if hasattr(tree, 'request_out'):
            ret = tree.request_out
            self._synth(Move(ret, 'r0'))

        nret = self.cur_fun.regs.take(ret, orwhatever=True)
        if nret != ret:
            self._synth(Move(nret, ret))
        tree.op = type('blah', (), {'out': nret})

        for op in pops:
            self._synth(op)

        tree.type = return_type(tree.children[0].type)

    def _flush_deferred(self):
        if self.cur_fun.deferred_ops:
            self._log(f'deferred logic: {len(self.cur_fun.deferred_ops)}')
        for op in self.cur_fun.deferred_ops:
            self._synth(op)
        self.cur_fun.deferred_ops = []

    def statement(self, tree):
        self.visit_children(tree)
        self._flush_deferred()
        self.cur_fun.regs.drop_temps()

    iter_expression = statement

    def if_stat(self, tree):
        mark = self._get_marker()
        has_else = len(tree.children) > 2
        # optimization!!!!
        if tree.children[0].data == 'eq':
            self.visit_children(tree.children[0])
            a, b = (c.op.out for c in tree.children[0].children)
            op = IfEq(a, b, mark, has_else)
        else:
            self.visit(tree.children[0])
            op = IfOp(self.cur_fun, [tree.children[0].op.out, mark, has_else])
        self._synth(op)
        self.cur_fun.regs.drop_temps()
        begin_vars = self.cur_fun.regs.snap()
        valid_vars_set = set(begin_vars[0].items())
        self.visit(tree.children[1])
        valid_vars_set &= set(self.cur_fun.regs.snap()[0].items())
        if has_else:
            self.cur_fun.regs.restore(begin_vars)
            self.cur_fun.ops.append(op.synth_else)
            self.visit(tree.children[2])
            valid_vars_set &= set(self.cur_fun.regs.snap()[0].items())
        valid_vars = dict(valid_vars_set)
        valid_snap = (valid_vars, {r: begin_vars[1][r] for r in valid_vars.values()})
        self.cur_fun.regs.restore(valid_snap)
        self.cur_fun.ops.append(op.synth_endif)

    def _restore_vars(self, begin_vars):
        curvars = self.cur_fun.regs.vars

        # 0. do the loop shuffle
        shuffles = {begin_vars[v]: r for v, r in curvars.items() if v in begin_vars}
        dst, src = tuple(zip(*shuffles.items())) or ([], [])
        self._reg_shuffle(src, dst)

        # 1. load the rest
        for v, r in begin_vars.items():
            if v in curvars:
                continue
            self._log(f'loading missing var {v} into {r}')
            self._synth(Load(self.cur_fun, [r, v]))

    def while_stat(self, tree):
        mark = self._get_marker()
        self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
        begin_vars = dict(self.cur_fun.regs.vars)
        self.visit(tree.children[0])
        postcond = self.cur_fun.regs.snap()
        op = WhileOp(tree.children[0].op.out, mark)
        self._synth(op)
        self.cur_fun.regs.drop_temps()
        self.visit(tree.children[1])
        self._restore_vars(begin_vars)
        self.cur_fun.ops.append(op.synth_endwhile)
        self.cur_fun.regs.restore(postcond)

    def for_stat(self, tree):
        mark = self._get_marker()
        self.visit(tree.children[0])  # initialization
        self.cur_fun.ops.append(lambda: WhileOp.synth_loop(mark))
        begin_vars = dict(self.cur_fun.regs.vars)
        self.visit(tree.children[1])
        postcond = self.cur_fun.regs.snap()
        op = WhileOp(tree.children[1].op.out, mark)
        self._synth(op)
        self.cur_fun.regs.drop_temps()
        self.visit(tree.children[3])
        self.visit(tree.children[2])  # 3rd statement
        self._restore_vars(begin_vars)
        self.cur_fun.ops.append(op.synth_endwhile)
        self.cur_fun.regs.restore(postcond)

    def _unary_op(op):
        def _f(self, tree):
            self.visit_children(tree)
            operand = tree.children[0].op.out
            reg = self.cur_fun.regs.take()
            tree.op = op(self.cur_fun, [reg, operand])
            self._synth(tree.op)
            self.cur_fun.regs.give(operand)
        return _f

    def _deref(self, tree):
        reg = tree.children[0].op.out
        var = Variable.from_dereference(reg, tree.children[0])
        self._log(f'making var {var} from derefing reg {reg}')
        tree.var = var
        def delayed_load():
            self._log(f'delay-loading {tree.children[0]}')
            op = Load(self.cur_fun, [reg, var])
            self._synth(op)
            return op.out
        tree.op = Delayed(delayed_load)
        tree.type = var.type

    def dereference(self, tree):
        self.visit_children(tree)
        self._deref(tree)

    def pointer_access(self, tree):
        with self._nosynth():
            self.visit_children(tree)
        ch_strct, ch_field = tree.children
        type = ch_strct.type.pointed
        strct = self.global_scope.structs[type.struct]
        offs = strct.offset(ch_field)
        if offs > 0:
            load_offs = lark.Tree('literal', [offs])
            addop = lark.Tree('add', [tree.children[0], load_offs])
            self.visit(addop)
            addop.type = PointerType(pointed=strct.field_type(ch_field))
            tree.children = [addop]
        else:
            tree.children[0].type = PointerType(pointed=strct.field_type(ch_field))
            tree.children = [tree.children[0]]
            self.visit_children(tree)
        self._deref(tree)

    def post_increment(self, tree):
        self.visit_children(tree)
        tree.op = Reg(tree.children[0].op.out)
        var = tree.children[0].var
        reg = tree.op.out
        self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg]))
        self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
        tree.type = var.type

    def post_decrement(self, tree):
        self.visit_children(tree)
        tree.op = Reg(tree.children[0].op.out)
        var = tree.children[0].var
        reg = tree.op.out
        self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg]))
        self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
        tree.type = var.type

    def pre_increment(self, tree):
        self.visit_children(tree)
        reg = tree.children[0].op.out
        tree.op = Incr(self.cur_fun, [reg, reg])
        self._synth(tree.op)
        cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
        self.cur_fun.regs.cleanup[reg].append(cleanup)

    bool_not = _unary_op(BoolNot)

    def _binary_op(op):
        def _f(self, tree):
            self.visit_children(tree)
            left, right = (x.op.out for x in tree.children)
            dest = self.cur_fun.regs.take()
            tree.op = op(self.cur_fun, [dest, left, right])
            self._synth(tree.op)
            self.cur_fun.regs.give(left)
            self.cur_fun.regs.give(right)
            tree.type = tree.children[0].type  # because uhm reasons
        return _f

    def _combo(uop, bop):
        def _f(self, tree):
            ftree = lark.Tree(bop, tree.children)
            tree.children = [ftree]
            uop.__get__(self)(tree)
        return _f

    shl = _binary_op(ShlOp)
    add = _binary_op(AddOp)
    sub = _binary_op(SubOp)
    mul = _binary_op(MulOp)
    _and = _binary_op(AndOp)
    _or = _binary_op(OrOp)
     # ...
    gt = _binary_op(GtOp)
    lt = _binary_op(LtOp)
    neq = _binary_op(NeqOp)
    eq = _combo(bool_not, 'neq')

    def shr(self, tree):
        self.visit(tree.children[0])
        left = tree.children[0].op.out
        assert tree.children[1].data == 'literal'
        right = tree.children[1].children[0]
        dest = self.cur_fun.regs.take()
        tree.op = ShrOp(self.cur_fun, [dest, left, right])
        self._synth(tree.op)
        self.cur_fun.regs.give(left)
        tree.type = tree.children[0].type  # because uhm reasons

    def _assign(self, left, right):
        self._log(f'assigning {left} = {right}')
        self.cur_fun.regs.assign(left, right)
        op = Store(self.cur_fun, [right, left])
        self._synth(op)
        return op.out


    def _assign_op(op):
        def f(self, tree):
            opnode = lark.Tree(op, tree.children)
            self.visit(opnode)
            val = opnode.op.out
            # left hand side is an lvalue, retrieve it
            var = tree.children[0].var
            tree.op = self._assign(var, val)
        return f

    or_ass = _assign_op('_or')
    add_ass = _assign_op('add')

    def _forward_op(self, tree):
        self.visit_children(tree)
        tree.op = tree.children[0].op

    def cast(self, tree):
        if hasattr(tree, 'request_out'):
            tree.children[1].request_out = tree.request_out
        self.visit_children(tree)
        tree.op = tree.children[1].op
        tree.type = tree.children[0]

    def _log(self, line):
        self.cur_fun.log(line)

    def local_var(self, tree):
        self.visit_children(tree)
        assert self.cur_fun is not None
        assert self.cur_scope is not None
        var = Variable.from_def(tree)
        var.stackaddr = self.cur_fun.get_stack()  # will have to invert
        self.cur_scope.symbols[var.name] = var
        self.cur_fun.locals[var.name] = var
        if len(tree.children) > 2:
            initval = tree.children[2].op.out
            cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, var]))
            self.cur_fun.regs.assign(var, initval, cleanup=cleanup)
            self._log(f'assigning {var} = {initval}')

    def local_array(self, tree):
        self.visit_children(tree)
        assert self.cur_fun is not None
        assert self.cur_scope is not None
        var = Variable.from_def(tree)
        if var.type.const:
            # storing in rodata
            assert len(tree.children) > 3
            tok = self.next_string_token
            self.next_string_token += 1
            tok = f'_dat{tok}'
            assert len(tree.children[3].children) > 0
            fields =[c.children[0] for c in tree.children[3].children]
            if tree.children[2].data == 'empty_array':
                fieldcount = len(fields)
            else:
                fieldcount = tree.children[2].children[0]
            if len(fields) < fieldcount:
                fields += [0] * (fieldcount - len(fields))
            name = tree.children[1]
            self.rodata[tok] = (tree.children[0], fields)
            var = Variable(PointerType(pointed=litt_type(fields[0])), tok)
            self.global_scope.symbols[name] = var

    def fun_def(self, tree):
        prot, body = tree.children
        fun = Function(prot)
        assert self.cur_fun is None
        self.cur_fun = fun

        self.cur_scope.symbols[fun.spec.name] = fun.spec

        params = fun.param_dict()

        def getcleanup(var):
            def cl(reg):
                self._log(f'cleanup for {var} in {reg}')
                self._synth(Store(self.cur_fun, [reg, var]))
            return cl

        for i, param in enumerate(fun.params):
            fun.locals[param.name] = param
            param.stackaddr = fun.get_stack()
            self._log(f'param [sp, {param.stackaddr}]: {param.name} in r{i}')
            cleanup = getcleanup(param)
            fun.regs.loaded(param, f'r{i}', cleanup=cleanup)

        fun_scope = Scope()
        fun_scope.parent = self.cur_scope
        fun_scope.symbols.update(params)

        body.scope = fun_scope
        self.visit_children(tree)

        if fun.ret is None:
            if fun.spec.name == 'main':
                self._synth(Set16Imm(fun, ['r0', 0]))
                self._synth(ReturnReg(fun, ['r0']))
            elif fun.spec.return_type == VoidType():
                self._synth(ReturnReg(fun, ['r0']))
            else:
                assert fun.ret is not None
        self.cur_fun = None
        self.funs.append(fun)

    def body(self, tree):
        bscope = getattr(tree, 'scope', Scope())
        bscope.parent = self.cur_scope
        self.cur_scope = bscope
        self.visit_children(tree)
        self.cur_scope = bscope.parent

    def return_stat(self, tree):
        assert self.cur_fun is not None
        self.cur_fun.ret = True
        self.visit_children(tree)
        expr_reg = tree.children[0].op.out
        tree.op = ReturnReg(self.cur_fun, [expr_reg])
        self._synth(tree.op)

    def struct_def(self, tree):
        self.visit_children(tree)
        strct = Struct(tree)
        self.cur_scope.structs[strct.name] = strct

preamble = [f'_start:',
        f'xor r0, r0, r0',
        f'xor r1, r1, r1',
        f'set sp, 0',
        f'seth sp, {0x11}',  # 256 bytes of stack ought to be enough
        f'set r2, main',
        f'set r3, 0',
        f'add lr, pc, r3',
        f'or pc, r2, r2',
        f'cmp r0, r0',
        f'beq [pc, -4]  // loop forever',
        ]

def filter_dupes(ops):
    dupe_re = re.compile(r'or (r\d+), \1, \1')
    for op in ops:
        if dupe_re.search(op):
            continue
        yield op

def parse_tree(tree, debug=False):
    tr = CcTransform()
    tree = tr.transform(tree)

    if debug:
        print(tree.pretty())

    inte = CcInterp()
    inte.visit(tree)

    out = []

    for fun in inte.funs:
        out += fun.synth()
        out.append('')
    for sym, dat in inte.strings.items():
        out += [f'.global {sym}',
                f'{sym}:']
        dat += '\0'
        if len(dat) % 2 != 0:
            dat += '\0'
        dat = dat.encode()
        nwords = len(dat) // 2
        out += [f'.word 0x{d:04x}' for d in struct.unpack(f'>{nwords}H', dat)]
        out.append('')
    for sym, (type, fields) in inte.rodata.items():
        out += [f'.global {sym}',
                f'{sym}:']
        out += [f'.word 0x{d:04x}' for d in fields]
        out.append('')

    return '\n'.join(filter_dupes(out))


def larkparse(f, debug=False):
    with open(GRAMMAR_FILE) as g:
        asparser = lark.Lark(g.read())
    data = f.read()
    if isinstance(data, bytes):
        data = data.decode()
    tree = asparser.parse(data)
    return parse_tree(tree, debug=debug)


def parse_args():
    parser = argparse.ArgumentParser(description='Compile.')
    parser.add_argument('--debug', action='store_true',
                        help='print the AST')
    parser.add_argument('--assembly', '-S', action='store_true',
                        help='output assembly')
    parser.add_argument('--compile', '-c', action='store_true',
                        help='compile a single file')
    parser.add_argument('input', type=argparse.FileType('r'),
                        default=sys.stdin, help='input file (default: stdin)')
    parser.add_argument('--output', '-o', type=argparse.FileType('wb'),
                        default=sys.stdout.buffer, help='output file')
    parser.add_argument('-I', dest='include_dirs', action='append',
                        default=[], help='include dirs')
    return parser.parse_args()


def preprocess(fin, include_dirs):
    cmd = list(CPP) + [f'-I{x}' for x in include_dirs]
    p = subprocess.Popen(cmd, stdin=fin, stdout=subprocess.PIPE)
    out, _ = p.communicate()
    if p.returncode != 0:
        raise RuntimeError(f'preprocessor error')
    return io.StringIO(out.decode())


def assemble(text, fout):
    fin = io.StringIO(text)
    asmod.write_obj(fout, *asmod.larkparse(fin))


def main():
    args = parse_args()

    assy = larkparse(preprocess(args.input, args.include_dirs), debug=args.debug)
    if args.assembly:
        args.output.write(assy.encode() + b'\n')
    else:
        assemble(assy, args.output)


if __name__ == "__main__":
    main()