#!/usr/bin/env python3

import argparse
import collections
import glob
import itertools
import os
import re


# global config
toolchain_path = f"/opt/xpack-riscv-none-elf-gcc-14.2.0-3"
toolchain_prefix = "riscv-none-elf-"
builddir = "build"
outdir = "out"
linker_script = "apps/app.ld"
hostcflags = "-g -std=c++20 -fprofile-instr-generate -fcoverage-mapping"
hostlibs = "-lgtest -lgmock -lgtest_main"
hostldflags = "-fprofile-instr-generate -fcoverage-mapping"
include_dirs = [
    "hal/uart",
    "hal/lib/common",
]

include_flags = [f"-I{os.path.relpath(i, builddir)}" for i in include_dirs]
project_flags = [
]
cpp_flags = ["-DNDEBUG"] + include_flags + project_flags
common_flags = [
    "-g",
    "-Wall",
    "-Wextra",
    "-flto",
    "-march=rv32i",
    "-ffunction-sections",
    "-Oz",
]

cc_flags = common_flags
cxx_flags = common_flags + ["-std=c++20", "-fno-rtti", "-fno-exceptions", "-Wno-missing-field-initializers"]
ldflags = [
    "-Oz",
    "-g",
    "-Wl,--gc-sections",
    "-Wl,--print-memory-usage",
    "-flto",
    "-march=rv32i",
]


def get_cxx_flags():
    return cpp_flags + cxx_flags


def get_cc_flags():
    return cpp_flags + cc_flags


def get_ldflags():
    return ldflags


def add_cpp_flag(flag):
    cpp_flags.append(flag)


def gen_rules():
    tools = {"cxx": "g++", "cc": "gcc", "as": "as", "objcopy": "objcopy"}

    tc_path = toolchain_path
    tc_prefix = toolchain_prefix

    rules = f"""
rule cxx
  command = $cxx -MMD -MT $out -MF $out.d {' '.join(get_cxx_flags())} -c $in -o $out
  description = CXX $out
  depfile = $out.d
  deps = gcc

rule cc
  command = $cc -MMD -MT $out -MF $out.d {' '.join(get_cc_flags())} -c $in -o $out
  description = CC $out
  depfile = $out.d
  deps = gcc

rule as
  command = $as $in -o $out
  description = AS $out

rule link
  command = $cxx {' '.join(ldflags)} -Wl,-T$linker_script -o $out $in $libs
  description = LINK $out

rule objcopy
  command = $objcopy -O binary $in $out
  description = OBJCOPY $out

rule hostcxx
  command = clang++ -MMD -MT $out -MF $out.d {hostcflags} -c $in -o $out
  description = HOSTCXX $out
  depfile = $out.d
  deps = gcc

rule hostlink
  command = clang++ {hostldflags} -o $out $in {hostlibs}
  description = HOSTLINK $out

rule profdata
  command = llvm-profdata merge -sparse $profraw -o $out
  description = PROFDATA

rule cov
  command = llvm-cov show --output-dir cov -format html --instr-profile $profdata $objects && touch $out
  description = COV

rule hosttest
  command = LLVM_PROFILE_FILE=$in.profraw ./$in && touch $out
"""

    for var, tool in tools.items():
        toolpath = os.path.join(tc_path, "bin", f"{tc_prefix}{tool}")
        yield f"{var} = {toolpath}"

    for line in rules.splitlines():
        yield line


def get_suffix_rule(filename, cxx_rule="cxx", cc_rule="cc"):
    suffix = filename.split(".")[-1]
    return collections.defaultdict(
        lambda: None,
        {
            "c": cc_rule,
            "cc": cxx_rule,
            "cpp": cxx_rule,
            "s": "as",
        },
    )[suffix]


def make_cxx_rule(name, cflags=()):
    cflags = " ".join(cflags + get_cxx_flags())
    rule = f"""
rule {name}
  command = $cxx -MMD -MT $out -MF $out.d {cflags} -c $in -o $out
  description = CXX $out
  depfile = $out.d
  deps = gcc
"""
    return rule.splitlines()


def make_cc_rule(name, cflags=()):
    cflags = " ".join(cflags + get_cc_flags())
    rule = f"""
rule {name}
  command = $cc -MMD -MT $out -MF $out.d {cflags} -c $in -o $out
  description = CC $out
  depfile = $out.d
  deps = gcc
"""
    return rule.splitlines()


def source_set(name, sources, cflags=()):
    builds = [
        (os.path.relpath(s, builddir), re.sub(r"\.\w+", ".o", s))
        for s in sources
        if get_suffix_rule(s) is not None
    ]

    lines = []

    cxx_rule = "cxx"
    cc_rule = "cc"
    if cflags:
        cxx_rule = f"cxx_{name}"
        lines += make_cxx_rule(cxx_rule, cflags=cflags)
        cc_rule = f"cc_{name}"
        lines += make_cc_rule(cc_rule, cflags=cflags)

    for i, o in builds:
        rule = get_suffix_rule(i, cxx_rule=cxx_rule, cc_rule=cc_rule)
        if rule is None:
            continue
        lines.append(f"build {o}: {rule} {i}")

    return [b[1] for b in builds], lines


def build_source_set(source_set):
    for line in source_set[1]:
        yield line


def build_image(source_set, elf_out, dependencies=(), bin_out=None, linker_script=linker_script):
    # to make it builddir-relative
    linker_script = os.path.relpath(linker_script, builddir)

    elf_out = os.path.relpath(elf_out, builddir)

    objects, lines = source_set
    for line in lines:
        yield line
    for objs, _ in dependencies:
        objects += objs
    objects = " ".join(objects)

    yield f"build {elf_out}: link {objects} | {linker_script}"
    yield f"  linker_script = {linker_script}"
    if bin_out is not None:
        bin_out = os.path.relpath(bin_out, builddir)
        yield f"build {bin_out}: objcopy {elf_out}"


def build_test(name, sources):
    builds = [
        (os.path.relpath(s, builddir), f"{name}_" + re.sub(r"\.\w+", ".o", s))
        for s in sources
    ]

    out = name

    for i, o in builds:
        rule = "hostcxx"
        yield f"build {o}: {rule} {i}"

    objects = " ".join(b[1] for b in builds)

    yield f"build {out}: hostlink {objects}"
    yield f"build {out}.run: hosttest {out}"


def make_coverage(binaries):
    bins = " ".join(binaries)
    profraw = " ".join(f"{x}.profraw" for x in binaries)
    objects = " ".join(f"--object {x}" for x in binaries)
    testruns = " ".join(f"{x}.run" for x in binaries)
    yield f"build profdata: profdata | {testruns}"
    yield f"  profraw = {profraw}"
    yield f"build cov/index.html: cov {bins} | profdata"
    yield f"  profdata = profdata"
    yield f"  objects = {objects}"


hal = source_set("hal", [
        "hal/start.cc",
        "hal/lib/common/xil_assert.c",
        "hal/uart/xuartlite.c",
        "hal/uart/xuartlite_stats.c",
        "hal/uart/xuartlite_intr.c",
        ])

bootloader = source_set("bootloader", glob.glob("./bootloader/**/*.cc", recursive=True))
helloworld = source_set("helloworld", glob.glob("./apps/helloworld/**/*.cc", recursive=True))

bootloader_image = build_image(
    bootloader,
    dependencies=[hal],
    elf_out="out/bootloader.elf",
    linker_script="bootloader/bootloader.ld",
)

helloworld_image = build_image(
    helloworld,
    dependencies=[hal],
    elf_out="out/helloworld.elf",
    bin_out="out/helloworld.bin",
)

all = [build_source_set(hal), bootloader_image, helloworld_image]


def parse_args():
    parser = argparse.ArgumentParser(description='Generate ninja build files.')
    parser.add_argument('--version', required=True,
                        help='version tag (typically from `git describe`)')
    return parser.parse_args()


def main():
    args = parse_args()
    add_cpp_flag(f'-DGIT_VERSION_TAG=\\"{args.version}\\"')
    header = gen_rules()
    lines = itertools.chain(header, *all)

    if not os.path.exists(builddir):
        os.mkdir(builddir)

    with open(os.path.join(builddir, "build.ninja"), "w") as f:
        f.write("\n".join(lines))
        f.write("\n")

    print(
        f'Configure done. Build with "ninja -C {builddir}". Output will be in {outdir}/'
    )


if __name__ == "__main__":
    main()