library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

entity cpu is
  port(
        clk: in std_logic;
        rst: in std_logic;

        code_data: in std_logic_vector(15 downto 0);
        code_addr: out std_logic_vector(15 downto 0);

        mem_in: in std_logic_vector(15 downto 0);
        mem_out: out std_logic_vector(15 downto 0);
        mem_addr: out std_logic_vector(15 downto 0);
        mem_write: out std_logic;
        mem_read: out std_logic
      );
end entity cpu;

architecture behavior of cpu is
  component alu is
    port(
          a: in std_logic_vector(15 downto 0);
          b: in std_logic_vector(15 downto 0);
          sel: in std_logic_vector(3 downto 0);

          flag: out std_logic;
          q: out std_logic_vector(15 downto 0)
        );
  end component;

  component reg is
    port(
          clk : in std_logic;
          rst : in std_logic;

          d : in std_logic_vector(15 downto 0);
          q : out std_logic_vector(15 downto 0)
        );
  end component;

  signal alu_a: std_logic_vector(15 downto 0);
  signal alu_b: std_logic_vector(15 downto 0);
  signal alu_q: std_logic_vector(15 downto 0);
  signal alu_sel: std_logic_vector(3 downto 0);
  signal alu_flag: std_logic;

  signal load_reg_next, load_reg: std_logic_vector(15 downto 0);
  signal load_addr_next, load_addr: std_logic_vector(15 downto 0);

  type regbank is array(0 to 15) of std_logic_vector(15 downto 0);
  signal reg_d: regbank;
  signal reg_q: regbank;

  type cpu_state_t is (RUN, LOAD, BRANCH);
  signal cpu_state, cpu_state_next: cpu_state_t;
begin
  cpu_alu: alu port map(a => alu_a, b => alu_b, sel => alu_sel, flag => alu_flag, q => alu_q);

  load_reg_r: reg port map(clk => clk, rst => rst, d => load_reg_next, q => load_reg);
  load_addr_r: reg port map(clk => clk, rst => rst, d => load_addr_next, q => load_addr);

  allregs:
  for i in 0 to 15 generate
    regx: reg port map(clk => clk, rst => rst, d => reg_d(i), q => reg_q(i));
  end generate allregs;

  process(clk, rst)
  begin
    if rst = '1' then
      cpu_state <= BRANCH;  -- wait a cycle at first
    elsif rising_edge(clk) then
      cpu_state <= cpu_state_next;
    end if;
  end process;

  code_addr <= reg_q(14);

  process(code_data, reg_q, mem_in, alu_q, alu_flag, cpu_state, load_addr, load_reg) is
    variable inst: std_logic_vector(15 downto 0);
    variable regn_0: natural;
    variable regn_1: natural;
    variable regn_2: natural;
    variable do_alu: std_logic;
  begin
    mem_write <= '0';
    mem_read <= '0';
    mem_addr <= x"0000";
    mem_out <= x"0000";

    alu_sel <= "0000";
    alu_a <= x"0000";
    alu_b <= x"0000";

    do_alu := '0';

    for i in 0 to 15 loop
      reg_d(i) <= reg_q(i);
    end loop;

    cpu_state_next <= RUN;

    load_reg_next <= load_reg;
    load_addr_next <= load_addr;

    case cpu_state is
      when RUN =>
        reg_d(14) <= std_logic_vector(unsigned(reg_q(14)) + 2);
        inst := code_data;
      when LOAD =>
        inst := x"0000";  -- NOP
        mem_addr <= load_addr;  -- maintain this until we're done reading
        if load_reg(3 downto 0) = x"e" then
          cpu_state_next <= BRANCH;
        else
          reg_d(14) <= std_logic_vector(unsigned(reg_q(14)) + 2);
        end if;

        regn_0 := to_integer(unsigned(load_reg(3 downto 0)));
        reg_d(regn_0) <= mem_in;
      when BRANCH =>
        inst := x"0000";  -- NOP
        reg_d(14) <= std_logic_vector(unsigned(reg_q(14)) + 2);
    end case;

    regn_0 := to_integer(unsigned(inst(11 downto 8)));
    regn_1 := to_integer(unsigned(inst(7 downto 4)));
    regn_2 := to_integer(unsigned(inst(3 downto 0)));

    case inst(15 downto 12) is
      when "0000" => -- NOP
      when "0001" => -- LOAD rn, [rm, imm] (imm is signed 4 bits)
        mem_read <= '1';
        cpu_state_next <= LOAD;
        reg_d(14) <= reg_q(14);  -- halt the prefetcher

        load_addr_next <= std_logic_vector(signed(reg_q(regn_1)) + signed(inst(3 downto 0) & '0'));
        mem_addr       <= std_logic_vector(signed(reg_q(regn_1)) + signed(inst(3 downto 0) & '0'));
        load_reg_next(3 downto 0) <= inst(11 downto 8);
      when "0010" => -- STORE rn, [rm, imm]
        mem_write <= '1';
        mem_addr <= std_logic_vector(signed(reg_q(regn_1)) + signed(inst(3 downto 0) & '0'));
        mem_out <= reg_q(regn_0);

      --- ALU stuff
      when "0011" => do_alu := '1'; -- ADD rd, rn, rm (rd := rn + rm)
      when "0100" => do_alu := '1'; -- SUB rd, rn, rm (rd := rn - rm)
      when "0101" => do_alu := '1'; -- OR rd, rn, rm (rd := rn or rm)
      when "0110" => do_alu := '1'; -- AND rd, rn, rm (rd := rn and rm)
      when "0111" => do_alu := '1'; -- NOT rd, rn (rd := not rn)
      when "1000" => do_alu := '1'; -- XOR rd, rn, rm (rd := rn xor rm)
      when "1001" => -- SETH rd, imm
        reg_d(regn_0)(15 downto 8) <= inst(7 downto 0);
      when "1010" =>  -- SHR rd, rn, imm (rd := rn >> imm)
        alu_sel <= inst(15 downto 12);
        alu_a <= reg_q(regn_1);
        alu_b <= x"000" & inst(3 downto 0);
        reg_d(regn_0) <= alu_q;
      when "1011" => do_alu := '1'; -- MUL rd, rn, rm (rd := rn * rm)

      when "1100" => -- CMP rn, rm (flag := 1 if equal)
        alu_sel <= "1100";
        alu_a <= reg_q(regn_0);
        alu_b <= reg_q(regn_1);
        reg_d(15)(0) <= alu_flag;

      when "1101" => -- BEQ [rn, imm] (jump to [rn, imm] if flag is set, imm is signed 8 bits)
        if reg_q(15)(0) = '1' then
          reg_d(14) <= std_logic_vector(signed(reg_q(regn_0)) + signed(inst(7 downto 0) & '0'));
          cpu_state_next <= BRANCH;
        end if;
      when "1110" => -- SET rd, imm (rd := imm, imm is 8 bit)
        reg_d(regn_0) <= x"00" & inst(7 downto 0);
      when "1111" => -- BNEQ [rn, imm]
        if reg_q(15)(0) = '0' then
          reg_d(14) <= std_logic_vector(signed(reg_q(regn_0)) + signed(inst(7 downto 0) & '0'));
          cpu_state_next <= BRANCH;
        end if;

      when others => -- do nothing
    end case;

    if do_alu = '1' then
      -- 1:1 mapping
      alu_sel <= inst(15 downto 12);
      alu_a <= reg_q(regn_1);
      alu_b <= reg_q(regn_2);
      reg_d(regn_0) <= alu_q;
      reg_d(15)(0) <= alu_flag;
      if inst(11 downto 8) = x"e" then
        cpu_state_next <= BRANCH;
      end if;
    end if;
  end process;

end behavior;