
LIBRARY IEEE;
USE IEEE.std_logic_1164.ALL;
USE IEEE.std_logic_arith.ALL;
--USE work.fix_std.ALL;
--USE work.fixed_pkg.ALL;
--USE IEEE.fixed_pkg.ALL;

entity matrix_math is
  port(
--To 3d_projector
    RESULT : out std_logic_vector(95 downto 0);
    
--Controller interface
    CLK : in std_logic;
    RST_N : in std_logic;
    DATA : in std_logic_vector(511 downto 0);
    OPCODE : in std_logic_vector(2 downto 0);
    ENABLE : in std_logic;
    BUSY : out std_logic);    
   
end matrix_math;


architecture matrix_math_arch of matrix_math is
  type state_type is (NOP, SETMATRIX, VECTORMULT, IDENTITY, MATRIXMULT);

  signal state, nextstate : state_type;

  -- stored local matrix
  signal localmat : std_logic_vector(511 downto 0);
  
  -- latching signals for local matrix and result vector
  signal localmatset : std_logic_vector(511 downto 0);
  signal RESULTSET : std_logic_vector(95 downto 0);

  -- signals VEE - vector multiplication results (not rescaled)
  -- signals V - vector multiplication results 
  -- signals RR - matrix multiplication results (not rescaled)
  -- signals R - matrix multiplication resized results
  -- signals A - input data converted
  -- signals M - stored matrix converted
  signal V1, R11, A11, M11 : signed(31 downto 0);
  signal V2, R12, A12, M12 : signed(31 downto 0);
  signal V3, R13, A13, M13 : signed(31 downto 0);
  signal R14, A14, M14 : signed(31 downto 0);
  signal R21, A21, M21 : signed(31 downto 0);
  signal R22, A22, M22 : signed(31 downto 0);
  signal R23, A23, M23 : signed(31 downto 0);
  signal R24, A24, M24 : signed(31 downto 0);
  signal R31, A31, M31 : signed(31 downto 0);
  signal R32, A32, M32 : signed(31 downto 0);
  signal R33, A33, M33 : signed(31 downto 0);
  signal R34, A34, M34 : signed(31 downto 0);
  signal R41, A41, M41 : signed(31 downto 0);
  signal R42, A42, M42 : signed(31 downto 0);
  signal R43, A43, M43 : signed(31 downto 0);
  signal R44, A44, M44 : signed(31 downto 0);

  signal VEE1, VEE2, VEE3 : signed(63 downto 0);
  signal RR11, RR12, RR13, RR14, RR21, RR22, RR23, RR24, RR31, RR32, RR33, RR34, RR41, RR42, RR43, RR44 : signed(63 downto 0);

  -- Zero vectors for assignments
  constant DATA_512_ZEROS: STD_LOGIC_VECTOR(511 downto 0) := x"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
  constant DATA_96_ZEROS: STD_LOGIC_VECTOR(95 downto 0) := x"000000000000000000000000";

Begin
  Clock : process (CLK, RST_N)
  Begin
    if (RST_N = '0') then
      -- assigning asynch. reset values
      localmat <= DATA_512_ZEROS;
      RESULT <= DATA_96_ZEROS;
      state <= NOP;
    elsif (CLK'event and CLK = '1') then
      state <= nextstate;
      if (state = VECTORMULT) then
        -- returning a vector
        RESULT <= RESULTSET;
      end if;
      if (state = MATRIXMULT or state = SETMATRIX or state = IDENTITY) then
        -- returning a matrix
        localmat <= localmatset;
      end if;
    end if;
end process Clock;

  Outlogic : process(CLK, RST_N, DATA, OPCODE, ENABLE, state, nextstate, localmatset)
  begin
    case state is 
      when NOP =>
        BUSY <= '0';
      when SETMATRIX =>
        BUSY <= '1';
      when VECTORMULT =>
        BUSY <= '1';
      when IDENTITY =>
        BUSY <= '1';
      when MATRIXMULT =>
        BUSY <= '1';
      when others =>
        BUSY <= '0';
    end case;
  end process Outlogic;

  -- casting from local matrix
  M11 <= signed(localmat( 31 downto 0));
  M12 <= signed(localmat( 63 downto 32));
  M13 <= signed(localmat( 95 downto 64));
  M14 <= signed(localmat(127 downto 96));
  M21 <= signed(localmat(159 downto 128));
  M22 <= signed(localmat(191 downto 160));
  M23 <= signed(localmat(223 downto 192));
  M24 <= signed(localmat(255 downto 224));
  M31 <= signed(localmat(287 downto 256));
  M32 <= signed(localmat(319 downto 288));
  M33 <= signed(localmat(351 downto 320));
  M34 <= signed(localmat(383 downto 352));
  M41 <= signed(localmat(415 downto 384));
  M42 <= signed(localmat(447 downto 416));
  M43 <= signed(localmat(479 downto 448));
  M44 <= signed(localmat(511 downto 480));

  -- casting from input signal
  A11 <= signed(DATA( 31 downto 0));
  A12 <= signed(DATA( 63 downto 32));
  A13 <= signed(DATA( 95 downto 64));
  A14 <= signed(DATA(127 downto 96));
  A21 <= signed(DATA(159 downto 128));
  A22 <= signed(DATA(191 downto 160));
  A23 <= signed(DATA(223 downto 192));
  A24 <= signed(DATA(255 downto 224));
  A31 <= signed(DATA(287 downto 256));
  A32 <= signed(DATA(319 downto 288));
  A33 <= signed(DATA(351 downto 320));
  A34 <= signed(DATA(383 downto 352));
  A41 <= signed(DATA(415 downto 384));
  A42 <= signed(DATA(447 downto 416));
  A43 <= signed(DATA(479 downto 448));
  A44 <= signed(DATA(511 downto 480));

  -- calculating multiplied matrix
  RR11 <= M11*A11 + M12*A21 + M13*A31 + M14*A41;
  RR12 <= M11*A12 + M12*A22 + M13*A32 + M14*A42;
  RR13 <= M11*A13 + M12*A23 + M13*A33 + M14*A43;
  RR14 <= M11*A14 + M12*A24 + M13*A34 + M14*A44;
  RR21 <= M21*A11 + M22*A21 + M23*A31 + M24*A41;
  RR22 <= M21*A12 + M22*A22 + M23*A32 + M24*A42;
  RR23 <= M21*A13 + M22*A23 + M23*A33 + M24*A43;
  RR24 <= M21*A14 + M22*A24 + M23*A34 + M24*A44;
  RR31 <= M31*A11 + M32*A21 + M33*A31 + M34*A41;
  RR32 <= M31*A12 + M32*A22 + M33*A32 + M34*A42;
  RR33 <= M31*A13 + M32*A23 + M33*A33 + M34*A43;
  RR34 <= M31*A14 + M32*A24 + M33*A34 + M34*A44;
  RR41 <= M41*A11 + M42*A21 + M43*A31 + M44*A41;
  RR42 <= M41*A12 + M42*A22 + M43*A32 + M44*A42;
  RR43 <= M41*A13 + M42*A23 + M43*A33 + M44*A43;
  RR44 <= M41*A14 + M42*A24 + M43*A34 + M44*A44;

  --function resize (
  --  arg            : ufixed;            -- input
  --  integer_width  : integer;           -- integer portion
  --  fraction_width : integer;           -- size of fraction
  --  round_style : fixed_round_type := round_nearest)  -- rounding
  --return ufixed;
    
  -- resizing matrix multiplication results
  R11 <= RR11(47 downto 16);
  R12 <= RR12(47 downto 16);
  R13 <= RR13(47 downto 16);
  R14 <= RR14(47 downto 16);
  R21 <= RR21(47 downto 16);
  R22 <= RR22(47 downto 16);
  R23 <= RR23(47 downto 16);
  R24 <= RR24(47 downto 16);
  R31 <= RR31(47 downto 16);
  R32 <= RR32(47 downto 16);
  R33 <= RR33(47 downto 16);
  R34 <= RR34(47 downto 16);
  R41 <= RR41(47 downto 16);
  R42 <= RR42(47 downto 16);
  R43 <= RR43(47 downto 16);
  R44 <= RR44(47 downto 16);

  -- vector multiplication
  VEE1 <= M11*A11 + M21*A12 + M31*A13 + M41;
  VEE2 <= M12*A11 + M22*A12 + M32*A13 + M42;
  VEE3 <= M13*A11 + M23*A12 + M33*A13 + M43;

  -- resizing vector result
  V1 <= VEE1(47 downto 16);
  V2 <= VEE2(47 downto 16);
  V3 <= VEE3(47 downto 16);


  -- Logic for setting up signals to be flipflopped/output on nextstate
  Nextreg : process (CLK, RST_N, OPCODE, ENABLE, state, DATA, V1, V2, V3, R11, 
                      R12, R13, R14, R21, R22, R23, R24, R31, R32, R33, R34, 
                      R41, R42, R43, R44, RESULTSET, localmat)
  Begin
    case state is 
      when NOP =>
        if RST_N = '0' then
          -- reset any flipflopped values
          RESULTSET <= DATA_96_ZEROS;
          localmatset <= DATA_512_ZEROS;
          nextstate <= NOP;
        elsif ENABLE = '0' then
          -- sit around and wait for the enable to kick on
          RESULTSET <= DATA_96_ZEROS;
          nextstate <= NOP;
          localmatset <= localmat;
        elsif (ENABLE = '1' and OPCODE = "000") then
          -- sit in NOP
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= NOP;
          localmatset <= localmat;
        elsif (ENABLE = '1' and OPCODE = "001") then
          -- Prepare for vector multiplication
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= VECTORMULT;
          localmatset <= localmat;
        elsif (ENABLE = '1' and OPCODE = "011") then
          -- Prepare to set identity matrix
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= IDENTITY;
          localmatset <= localmat;
        elsif (ENABLE = '1' and OPCODE = "010") then
          -- Prepare to store input matrix
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= SETMATRIX;
          localmatset <= localmat;
        elsif (ENABLE = '1' and OPCODE = "100") then
          -- Prepare to multiply two matrices
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          localmatset <= localmat;
          nextstate <= MATRIXMULT;
        else
          -- handle strange cases
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          localmatset <= localmat;
          nextstate <= NOP;
        end if;
      when SETMATRIX =>
        if ENABLE = '1' then
          -- Store matrix input
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= NOP;
          localmatset <= DATA;
        else
          -- return to NOP
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= NOP;
          localmatset <= localmat;
        end if;
      when VECTORMULT =>
        if ENABLE = '1' then
          -- return result of vector multiplication
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          localmatset <= localmat;
          nextstate <= NOP;
        else
          --go back to NOP
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          localmatset <= localmat;
          nextstate <= NOP;
        end if;
      when IDENTITY =>
        if ENABLE = '1' then
          -- store the identity matrix locally
          localmatset <= x"00010000000000000000000000000000000000000001000000000000000000000000000000000000000100000000000000000000000000000000000000010000";
                      --00 01 00 00  00 00 00 00  00 00 00 00  00 00 00 00
                      --00 00 00 00  00 01 00 00  00 00 00 00  00 00 00 00
                      --00 00 00 00  00 00 00 00  00 01 00 00  00 00 00 00
                      --00 00 00 00  00 00 00 00  00 00 00 00  00 01 00 00
          nextstate <= NOP;
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
        else
          -- return to nop
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= NOP;
          localmatset <= localmat;
        end if;
      when MATRIXMULT =>
        if ENABLE = '1' then
          -- return the following matrix multiplication result
          --localmatset <= M44 & M43 & M42 & M41 & M34 & M33 & M32 & M31 & M24 & M23 & M22 & M21 & M14 & M13 & M12 & M11;
            localmatset(511 downto 480) <= std_logic_vector(R44);
            localmatset(479 downto 448) <= std_logic_vector(R43);
            localmatset(447 downto 416) <= std_logic_vector(R42);
            localmatset(415 downto 384) <= std_logic_vector(R41);
            localmatset(383 downto 352) <= std_logic_vector(R34);
            localmatset(351 downto 320) <= std_logic_vector(R33);
            localmatset(319 downto 288) <= std_logic_vector(R32);
            localmatset(287 downto 256) <= std_logic_vector(R31);
            localmatset(255 downto 224) <= std_logic_vector(R24);
            localmatset(223 downto 192) <= std_logic_vector(R23);
            localmatset(191 downto 160) <= std_logic_vector(R22);
            localmatset(159 downto 128) <= std_logic_vector(R21);
            localmatset(127 downto  96) <= std_logic_vector(R14);
            localmatset( 95 downto  64) <= std_logic_vector(R13);
            localmatset( 63 downto  32) <= std_logic_vector(R12);
            localmatset( 31 downto   0) <= std_logic_vector(R11);
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
          nextstate <= NOP;
        else
          -- go back to NOP
          nextstate <= NOP;
          localmatset <= localmat;
          RESULTSET(95 downto 64) <= std_logic_vector(V3);
          RESULTSET(63 downto 32) <= std_logic_vector(V2);
          RESULTSET(31 downto 0) <= std_logic_vector(V1);
      end if;
      when others =>
        -- extra case handling - go sit in nop
        RESULTSET(95 downto 64) <= std_logic_vector(V3);
        RESULTSET(63 downto 32) <= std_logic_vector(V2);
        RESULTSET(31 downto 0) <= std_logic_vector(V1);
        nextstate <= NOP;
        localmatset <= localmat;
    end case;
  end process Nextreg;
        
end matrix_math_arch;
