library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;
use work.spi_pkg.all;

entity spi_master_ctrl is

  generic (
    SIZES            : natural_vector := (8, 16);
    SIZES_2LOG       : natural := 1;
    DIVISORS         : natural_vector := (2, 4, 6, 8, 16, 32, 64, 128, 256);
    DIVISORS_LOG2    : natural := 3;
    CSN_PULSE_CYCLES : natural := 1
  );

  port (
    clk_i                : in  std_logic;
    rst_in               : in  std_logic;
    en_i                 : in  std_logic;
    size_sel_i           : in  std_logic_vector(SIZES_2LOG - 1 downto 0);
    div_sel_i            : in  std_logic_vector(DIVISORS_LOG2 - 1 downto 0);
    pulse_csn_i          : in  std_logic;
    clock_phase_i        : in  std_logic;
    counter_overflow_i   : in  std_logic;
    rx_block_on_full_i   : in  std_logic;
    rx_en_i              : in  std_logic;
    rx_valid_o           : out std_logic;
    rx_ready_i           : in  std_logic;
    tx_en_i              : in  std_logic;
    tx_valid_i           : in  std_logic;
    tx_ready_o           : out std_logic;
    busy_o               : out std_logic;
    err_lost_rx_data_o   : out std_logic;
    clear_lost_rx_data_i : in  std_logic;
    rst_on               : out std_logic;
    csn_o                : out std_logic;
    csn_en_o             : out std_logic;
    mosi_en_o            : out std_logic;
    miso_en_o            : out std_logic;
    sck_mask_o           : out std_logic;
    sck_en_o             : out std_logic;
    gen_clk_en_o         : out std_logic;
    latch_tx_data_o      : out std_logic
  );

end entity spi_master_ctrl;

architecture a1 of spi_master_ctrl is
  constant MAX_SIZE : natural := get_max_natural(SIZES);
  constant MAX_DIVISOR : natural := get_max_natural(DIVISORS);

  type states_t is (RESET, IDLE, SHIFTING, NEXT_DATA, CSN_RISING);

  type tx_states_t is (IDLE, TX_LATCHING_DATA, TX_LATCHED, TX_WAITING);
  type rx_states_t is (IDLE, RX_GOT_DATA, RX_INVALID_DATA);

  signal rx_block : std_logic;

  signal curr_rx_state : rx_states_t;
  signal next_rx_state : rx_states_t;

  signal curr_tx_state : tx_states_t;
  signal next_tx_state : tx_states_t;

  signal curr_state : states_t;
  signal next_state : states_t;

  signal curr_counter : natural;
  signal next_counter : natural;

  signal set_lost_rx_data : std_logic;

  signal tx_got_data : std_logic;
  signal ack_tx_got_data : std_logic;

  signal transmission_done : std_logic;

  signal shifting_length : integer range 0 to MAX_SIZE * 2;
  signal selected_divisor : integer range 0 to MAX_DIVISOR;
  signal clear_lost_rx_data : std_logic;
begin  -- architecture a1
  registers: process (clk_i) is
  begin  -- process registers
    if rising_edge(clk_i) then          -- rising clock edge
      if rst_in = '0' then              -- synchronous reset (active low)
        curr_counter <= 0;
        curr_state <= RESET;
        curr_tx_state <= IDLE;
        curr_rx_state <= IDLE;
      else
        curr_counter <= next_counter;
        curr_state <= next_state;
        curr_tx_state <= next_tx_state;
        curr_rx_state <= next_rx_state;
      end if;
    end if;
  end process registers;

  state: process (all) is
    procedure switch_to (
      constant state   : in states_t;
      constant counter : in natural) is
    begin  -- procedure switch_to
      next_state <= state;
      next_counter <= counter;
    end procedure switch_to;

    procedure switch_to_shifting(constant is_next_data: boolean) is
      variable count : natural;
    begin  -- procedure switch_to_shifting
      if is_next_data then
        if selected_divisor = 2 then
          count := shifting_length * 2 - 2;
        else
          count := shifting_length * 2;
        end if;
      else
        count := shifting_length * 2 - 1;
        if clock_phase_i = '1' then
          count := count + 1;
        end if;
      end if;

      switch_to(SHIFTING, count);
    end procedure switch_to_shifting;

    variable zero : std_logic;
  begin  -- process state_sel
    next_counter <= curr_counter;
    if curr_counter /= 0 and counter_overflow_i = '1' then
      next_counter <= curr_counter - 1;
    end if;

    if curr_counter = 0 then
      zero := '1';
    else
      zero := '0';
    end if;

    transmission_done <= '0';
    next_state <= curr_state;

    gen_clk_en_o <= '1';
    ack_tx_got_data <= '0';

    rst_on <= '1';

    sck_mask_o <= '1';
    busy_o <= '1';
    csn_o <= '1';

    case curr_state is
      when RESET =>
        switch_to(IDLE, 0);
        next_state <= IDLE;
        rst_on <= '0';
        gen_clk_en_o <= '0';
        csn_o <= '1';
      when IDLE =>
        busy_o <= '0';
        gen_clk_en_o <= '0';

        if zero = '1' and tx_got_data = '1' then
          switch_to_shifting(false);
          gen_clk_en_o <= '1';
          ack_tx_got_data <= '1';
        end if;
      when SHIFTING =>
        csn_o <= '0';
        sck_mask_o <= '1';

        if zero = '1' then
          transmission_done <= '1';
          switch_to(NEXT_DATA, 0);
        end if;
      when NEXT_DATA =>
        csn_o <= '0';
        sck_mask_o <= '0';

        if pulse_csn_i = '1' then
          switch_to(IDLE, CSN_PULSE_CYCLES - 1);
        elsif tx_got_data = '1' then
          sck_mask_o <= '1';
          switch_to_shifting(true);
          ack_tx_got_data <= '1';
        else
          switch_to(IDLE, 0);
        end if;
      when others =>
        next_state <= RESET;
    end case;

    if en_i = '0' then
      next_state <= RESET;
    end if;
  end process state;

  tx_state: process(all) is
  begin  -- process tx_state
    next_tx_state <= curr_tx_state;

    latch_tx_data_o <= '0';
    tx_got_data <= '0';
    tx_ready_o <= '0';

    case curr_tx_state is
      when IDLE =>
        next_tx_state <= TX_LATCHING_DATA;
      when TX_LATCHING_DATA =>
        tx_ready_o <= '1';

        if tx_valid_i = '1' then
          latch_tx_data_o <= '1';
          next_tx_state <= TX_LATCHED;

          if ack_tx_got_data = '1' then
            next_tx_state <= TX_WAITING;
          end if;
        end if;
      when TX_LATCHED =>
        tx_got_data <= '1';

        if ack_tx_got_data = '1' then
          next_tx_state <= TX_WAITING;
        end if;
      when TX_WAITING =>
        if (transmission_done = '1' or curr_state /= SHIFTING) and rx_block = '0' then

          -- prevent pulse...
          if rx_ready_i = '1' or rx_block_on_full_i = '0' then
            tx_ready_o <= '1';
          end if;

          next_tx_state <= TX_LATCHING_DATA;

          if tx_valid_i = '1' then
            next_tx_state <= TX_LATCHED;
            latch_tx_data_o <= '1';
          end if;
        end if;
      when others =>
        next_tx_state <= IDLE;
    end case;

    if curr_state = RESET then
      next_tx_state <= IDLE;
    end if;

    if tx_en_i = '0' then
      next_tx_state <= IDLE;
      tx_got_data <= not rx_block and tx_valid_i;               -- simulate always receiving new data
    end if;
  end process tx_state;

  rx_state: process(all) is
  begin  -- process rx_state
    next_rx_state <= curr_rx_state;

    rx_block <= rx_block_on_full_i;
    rx_valid_o <= '0';
    set_lost_rx_data <= '0';

    case curr_rx_state is
      when IDLE =>
        next_rx_state <= RX_INVALID_DATA;
        rx_block <= '0';
      when RX_GOT_DATA =>
        rx_valid_o <= '1';
        if rx_ready_i = '1' or tx_got_data = '1' then
          next_rx_state <= RX_INVALID_DATA;
          rx_block <= '0';
          rx_valid_o <= '0';

          if rx_ready_i = '0' then
            set_lost_rx_data <= '1';
          end if;
        end if;
      when RX_INVALID_DATA =>
        rx_block <= '0';
        if transmission_done = '1' then
          if rx_ready_i = '0' then
            rx_block <= rx_block_on_full_i;
          end if;
          if rx_ready_i = '0' then
            next_rx_state <= RX_GOT_DATA;
          end if;
          rx_valid_o <= '1';            -- TODO check
        end if;
      when others =>
        next_rx_state <= IDLE;
    end case;

    if curr_state = RESET then
      next_rx_state <= IDLE;
    end if;

    if rx_en_i = '0' then
      next_rx_state <= IDLE;
      rx_block <= '0';                  -- do not block if disabled
    end if;
  end process rx_state;

  error_rx_lost : entity work.rs_latch
    port map (
      set_i   => set_lost_rx_data,
      reset_i => clear_lost_rx_data,
      q_o     => err_lost_rx_data_o);

  -- Internal
  clear_lost_rx_data <= '1' when clear_lost_rx_data_i = '1' or curr_state = RESET else '0';
  shifting_length <= SIZES(to_integer(unsigned(size_sel_i)));
  selected_divisor <= DIVISORS(to_integer(unsigned(div_sel_i)));

  -- Enable Outputs
  miso_en_o <= '0';
  sck_en_o <= en_i;
  mosi_en_o <= en_i and tx_en_i;
  csn_en_o <= en_i;
  sck_en_o <= en_i;                      -- TODO make it configurable so sck can be Z when not commnicating

end architecture a1;