import argparse
import string
import struct
import sys

import obj_pb2

CRT0_FILE = '/home/paulmathieu/vhdl/tools/crt0.o'


def parse_objs(objs):
    sections = []
    relocs = []

    for obj in objs:
        of = obj_pb2.ObjFile()
        of.ParseFromString(obj.read())
        sections += of.sections
        relocs += of.relocs

    return sections, relocs


def map_sections(sections, offset=0):
    secmap = []

    addr = offset

    # _start goes first
    for sec in sections:
        if sec.name == '_start':
            secmap.append((addr, sec))
            addr += len(sec.text)
            sections.remove(sec)
            break
    assert secmap, "could not find symbol _start :/"

    for sec in sections:
        secmap.append((addr, sec))
        addr += len(sec.text)

    return secmap


def do_relocs(secmap, relocs):
    namemap = {s[1].name: s for s in secmap}
    for reloc in relocs:
        assert reloc.section in namemap
        assert reloc.target in namemap
        _, sec = namemap[reloc.section]
        # the reloc hex should look like /e.ff0000/
        buff = bytearray(sec.text)
        target_addr = namemap[reloc.target][0]
        reg = buff[reloc.offset] & 0xf
        buff[reloc.offset+0:reloc.offset+4] = [
                0xe0 | reg, (target_addr >> 0) & 0xff,
                0x90 | reg, (target_addr >> 8) & 0xff,
                ]
        sec.text = bytes(buff)


def dump(secmap):
    out = bytearray()
    for _, sec in secmap:
        out += sec.text
    return out


def parse_args():
    parser = argparse.ArgumentParser(description='Assemble.')
    parser.add_argument('--debug', action='store_true',
                        help='print debug info')
    parser.add_argument('objfiles', metavar='O', nargs='+',
                        type=argparse.FileType('rb'),
                        help='input file (default: stdin)')
    parser.add_argument('--output', '-o', type=argparse.FileType('wb'),
                        default=sys.stdout.buffer, help='output file')
    parser.add_argument('--vhdl', help='vhdl output with given template')
    parser.add_argument('--offset', default=0, type=int,
                        help='memory offset to link from')
    return parser.parse_args()


def main():
    args = parse_args()

    with open(CRT0_FILE, 'rb') as crt0:
        sections, relocs = parse_objs(args.objfiles + [crt0])

    sectionmap = map_sections(sections, offset=args.offset)
    do_relocs(sectionmap, relocs)
    text = dump(sectionmap)

    if args.vhdl:
        words = struct.unpack(f'>{len(text) // 2}H', text)
        subd = dict(words=',\n'.join(f'x"{w:04x}"' for w in words), nwords=len(words))

        with open(args.vhdl) as fin:
            tpl = string.Template(fin.read())
        args.output.write(tpl.substitute(subd).encode())

    else:
        args.output.write(text)


if __name__ == "__main__":
    main()