// Copyright (C) 2010, Guy Barrand. All rights reserved.
// See the file tools.license for terms.

#ifndef tools_mat
#define tools_mat

#ifdef TOOLS_MEM
#include "../mem"
#endif

#include "MATCOM"

//#include <cstring> //memcpy

//#define TOOLS_MAT_NEW
namespace tools {

template <class T,unsigned int D>
class mat {
  static const unsigned int _D2 = D*D;
  unsigned int dim2() const {return _D2;}
#define TOOLS_MAT_CLASS mat
  TOOLS_MATCOM
#undef TOOLS_MAT_CLASS
private:
#ifdef TOOLS_MEM
  static const std::string& s_class() {
    static const std::string s_v("tools::mat");
    return s_v;
  }
#endif
public:
  mat(
#ifdef TOOLS_MEM
      bool a_inc = true
#endif      
      ) {
#ifdef TOOLS_MEM
    if(a_inc) mem::increment(s_class().c_str());
#endif
#ifdef TOOLS_MAT_NEW
    m_vec = new T[D*D];
#endif
    for(unsigned int i=0;i<_D2;i++) m_vec[i] = zero();
  }
  virtual ~mat() {
#ifdef TOOLS_MAT_NEW
    delete [] m_vec;
#endif
#ifdef TOOLS_MEM
    mem::decrement(s_class().c_str());
#endif
  }
public:
  mat(const mat& a_from) {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
#ifdef TOOLS_MAT_NEW
    m_vec = new T[D*D];
#endif
    _copy(a_from.m_vec);
  }
  mat& operator=(const mat& a_from){
    if(&a_from==this) return *this;
    _copy(a_from.m_vec);
    return *this;
  }
public:
  mat(const T a_v[]){
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
#ifdef TOOLS_MAT_NEW
    m_vec = new T[D*D];
#endif
    _copy(a_v);
  }
public:
  unsigned int dimension() const {return D;}
protected:
#ifdef TOOLS_MAT_NEW
  T* m_vec;
#else
  T m_vec[D*D];
#endif

private:static void check_instantiation() {mat<float,2> dummy;}
};

template <class T>
class nmat {
  unsigned int dim2() const {return m_D2;}
#define TOOLS_MAT_CLASS nmat
  TOOLS_MATCOM
#undef TOOLS_MAT_CLASS
private:
#ifdef TOOLS_MEM
  static const std::string& s_class() {
    static const std::string s_v("tools::nmat");
    return s_v;
  }
#endif
public:
  nmat(unsigned int a_D):m_D(a_D),m_D2(a_D*a_D),m_vec(0) {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
    m_vec = new T[m_D2];
    for(unsigned int i=0;i<m_D2;i++) m_vec[i] = zero();
  }
  virtual ~nmat() {
    delete [] m_vec;
#ifdef TOOLS_MEM
    mem::decrement(s_class().c_str());
#endif
  }
public:
  nmat(const nmat& a_from)
  :m_D(a_from.m_D),m_D2(a_from.m_D2),m_vec(0)
  {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
    m_vec = new T[m_D2];
    _copy(a_from.m_vec);
  }
  nmat& operator=(const nmat& a_from){
    if(&a_from==this) return *this;
    if(a_from.m_D!=m_D) {
      m_D = a_from.m_D;
      m_D2 = a_from.m_D2;
      delete [] m_vec;
      m_vec = new T[m_D2];
    }
    _copy(a_from.m_vec);
    return *this;
  }
public:
  nmat(unsigned int a_D,const T a_v[])
  :m_D(a_D),m_D2(a_D*a_D),m_vec(0)
  {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
    m_vec = new T[m_D2];
    _copy(a_v);
  }
public:
  unsigned int dimension() const {return m_D;}
protected:
  unsigned int m_D;
  unsigned int m_D2;
  T* m_vec;
private:static void check_instantiation() {nmat<float> dummy(2);}
};

template <class T,unsigned int D> 
inline nmat<T> copy(const mat<T,D>& a_from) {
  unsigned int D2 = D*D;
  nmat<T> v(D);
  for(unsigned int i=0;i<D2;i++) v[i] = a_from[i];
  return v;
}

template <class VECTOR>
inline void multiply(VECTOR& a_vec,const typename VECTOR::value_type& a_mat) {
  //typedef typename VECTOR::size_type sz_t;
  //sz_t number = a_vec.size();
  //for(sz_t index=0;index<number;index++) a_vec[index] *= a_mat;
  typedef typename VECTOR::iterator it_t;
  for(it_t it=a_vec.begin();it!=a_vec.end();++it) *it *= a_mat;
}

template <class VECTOR>
inline void multiply(VECTOR& a_vec,const typename VECTOR::value_type::elem_t& a_value) {
  typedef typename VECTOR::iterator it_t;
  for(it_t it=a_vec.begin();it!=a_vec.end();++it) (*it).multiply(a_value);
}

///////////////////////////////////////////////////////////////////////////////////////
/// related to complex numbers : //////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////

template <class MAT> 
inline void conjugate(MAT& a_m,typename MAT::elem_t (*a_conj)(const typename MAT::elem_t&)) {
  typedef typename MAT::elem_t T;
  T* pos = const_cast<T*>(a_m.data());
  unsigned int D2 = a_m.dimension()*a_m.dimension();
  for(unsigned int i=0;i<D2;i++,pos++) {
    *pos = a_conj(*pos); //T = std::complex<>
  }
}

template <class MAT> 
inline bool is_real(MAT& a_m,typename MAT::elem_t::value_type (*a_imag)(const typename MAT::elem_t&)) {
  typedef typename MAT::elem_t T;
  T* pos = const_cast<T*>(a_m.data());
  unsigned int D2 = a_m.dimension()*a_m.dimension();
  for(unsigned int i=0;i<D2;i++,pos++) {if(a_imag(*pos)) return false;}
  return true;
}

template <class MAT,class PREC> 
inline bool is_real_prec(MAT& a_m,typename MAT::elem_t::value_type (*a_imag)(const typename MAT::elem_t&),
                         const PREC& a_prec,PREC(*a_fabs)(const typename MAT::elem_t::value_type&)) {\
  typedef typename MAT::elem_t T;
  T* pos = const_cast<T*>(a_m.data());
  unsigned int D2 = a_m.dimension()*a_m.dimension();
  for(unsigned int i=0;i<D2;i++,pos++) {if(a_fabs(a_imag(*pos))>=a_prec) return false;}
  return true;
}

template <class MAT>
inline bool is_imag(MAT& a_m,typename MAT::elem_t::value_type (*a_real)(const typename MAT::elem_t&)) {
  typedef typename MAT::elem_t T;
  T* pos = const_cast<T*>(a_m.data());
  unsigned int D2 = a_m.dimension()*a_m.dimension();
  for(unsigned int i=0;i<D2;i++,pos++) {if(a_real(*pos)) return false;}
  return true;
}

template <class CMAT,class RMAT> 
inline bool to_real(const CMAT& a_c,RMAT& a_r,typename CMAT::elem_t::value_type (*a_real)(const typename CMAT::elem_t&)) {
  if(a_r.dimension()!=a_c.dimension()) return false;
  typedef typename CMAT::elem_t CT;
  const CT* cpos = a_c.data();
  typedef typename RMAT::elem_t RT;
  RT* rpos = const_cast<RT*>(a_r.data());
  unsigned int D2 = a_c.dimension()*a_c.dimension();
  for(unsigned int i=0;i<D2;i++,cpos++,rpos++) *rpos = a_real(*cpos);
  return true;
}

template <class MAT> 
inline void dagger(MAT& a_m,typename MAT::elem_t (*a_conj)(const typename MAT::elem_t&)) {
  conjugate<MAT>(a_m,a_conj);
  a_m.transpose();
}

template <class CMAT,class RMAT>
inline bool decomplex(const CMAT& a_c,RMAT& a_r,
                      typename CMAT::elem_t::value_type (*a_real)(const typename CMAT::elem_t&),
                      typename CMAT::elem_t::value_type (*a_imag)(const typename CMAT::elem_t&)) {
  //  CMAT = X+iY
  //  RMAT = |  X  Y |
  //         | -Y  X |
  typedef typename CMAT::elem_t CT; //std::complex<double>
  typedef typename RMAT::elem_t RT; //double
  
  unsigned int cdim = a_c.dimension();
  if(a_r.dimension()!=2*cdim) {a_r.set_zero();return false;}  
  RT value;unsigned int r,c;  
  for(r=0;r<cdim;r++) {
    for(c=0;c<cdim;c++) {
      const CT& cvalue = a_c.value(r,c);
      value = a_real(cvalue);      
      a_r.set_value(r,c,value);
      a_r.set_value(r+cdim,c+cdim,value);
      value = a_imag(cvalue);      
      a_r.set_value(r,c+cdim,value);
      a_r.set_value(r+cdim,c,-value);
    }
  }
  return true;
}

template <class VEC_CMAT,class VEC_RMAT>
inline bool decomplex(
 const VEC_CMAT& a_vc,VEC_RMAT& a_vr
,typename VEC_CMAT::value_type::elem_t::value_type (*a_real)(const typename VEC_CMAT::value_type::elem_t&)
,typename VEC_CMAT::value_type::elem_t::value_type (*a_imag)(const typename VEC_CMAT::value_type::elem_t&)
) {
  //  CMAT = X+iY
  //  RMAT = |  X  Y |
  //         | -Y  X |
  typedef typename VEC_CMAT::size_type sz_t;
  sz_t number = a_vc.size();
  a_vr.resize(number);
  for(sz_t index=0;index<number;index++) {
    if(!decomplex(a_vc[index],a_vr[index],a_real,a_imag)) {
      a_vr.clear();
      return false;
    }
  }
  return true;
}

///////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////

//for sf, mf :
//template <class T,unsigned int D>
//inline const T* get_data(const mat<T,D>& a_v) {return a_v.data();}

template <class MAT>
inline MAT commutator(const MAT& a1,const MAT& a2) {
  return a1*a2-a2*a1;
}

template <class MAT>
inline MAT anticommutator(const MAT& a1,const MAT& a2) {
  return a1*a2+a2*a1;
}

template <class MAT>
inline void commutator(const MAT& a1,const MAT& a2,MAT& a_tmp,MAT& a_res) {
  a_res = a1;
  a_res *= a2;
  a_tmp = a2;
  a_tmp *= a1;
  a_res -= a_tmp;
}

template <class MAT,class T>
inline void commutator(const MAT& a1,const MAT& a2,MAT& a_tmp,T a_vtmp[],MAT& a_res) {
  a_res = a1;
  a_res.mul_mtx(a2,a_vtmp);
  a_tmp = a2;
  a_tmp.mul_mtx(a1,a_vtmp);
  a_res -= a_tmp;
}

template <class MAT>
inline void anticommutator(const MAT& a1,const MAT& a2,MAT& a_tmp,MAT& a_res) {
  a_res = a1;
  a_res *= a2;
  a_tmp = a2;
  a_tmp *= a1;
  a_res += a_tmp;
}

template <class MAT,class T>
inline void anticommutator(const MAT& a1,const MAT& a2,MAT& a_tmp,T a_vtmp[],MAT& a_res) {
  a_res = a1;
  a_res.mul_mtx(a2,a_vtmp);
  a_tmp = a2;
  a_tmp.mul_mtx(a1,a_vtmp);
  a_res += a_tmp;
}

template <class T,unsigned int D>
inline bool commutator_equal(const mat<T,D>& a_1,const mat<T,D>& a_2,const mat<T,D>& a_c,const T& a_prec) {
  unsigned int order = D;
  const T* p1 = a_1.data();
  const T* p2 = a_2.data();
  const T* pc = a_c.data();
  for(unsigned int r=0;r<order;r++) {
    for(unsigned int c=0;c<order;c++) {
      T _12 = T();
     {for(unsigned int i=0;i<order;i++) {
        _12 += (*(p1+r+i*order)) * (*(p2+i+c*order));
      }}
      T _21 = T();
     {for(unsigned int i=0;i<order;i++) {
        _21 += (*(p2+r+i*order)) * (*(p1+i+c*order));
      }}
      T diff = (_12-_21) - *(pc+r+c*order);
      if(diff<T()) diff *= T(-1);
      if(diff>=a_prec) return false;
    }
  }
  return true;
}

template <class T,unsigned int D>
inline bool anticommutator_equal(const mat<T,D>& a_1,const mat<T,D>& a_2,const mat<T,D>& a_c,const T& a_prec) {
  unsigned int order = D;
  const T* p1 = a_1.data();
  const T* p2 = a_2.data();
  const T* pc = a_c.data();
  for(unsigned int r=0;r<order;r++) {
    for(unsigned int c=0;c<order;c++) {
      T _12 = T();
     {for(unsigned int i=0;i<order;i++) {
        _12 += (*(p1+r+i*order)) * (*(p2+i+c*order));
      }}
      T _21 = T();
     {for(unsigned int i=0;i<order;i++) {
        _21 += (*(p2+r+i*order)) * (*(p1+i+c*order));
      }}
      T diff = (_12+_21) - *(pc+r+c*order);
      if(diff<T()) diff *= T(-1);
      if(diff>=a_prec) return false;
    }
  }
  return true;
}

template <class T,unsigned int D>
inline mat<T,D> operator-(const mat<T,D>& a1,const mat<T,D>& a2) {
  mat<T,D> res(a1);
  res -= a2;
  return res;
}
template <class T,unsigned int D>
inline mat<T,D> operator+(const mat<T,D>& a1,const mat<T,D>& a2) {
  mat<T,D> res(a1);
  res += a2;
  return res;
}
template <class T,unsigned int D>
inline mat<T,D> operator*(const mat<T,D>& a1,const mat<T,D>& a2) {
  mat<T,D> res(a1);
  res *= a2;
  return res;
}
template <class T,unsigned int D>
inline mat<T,D> operator*(const T& a_fac,const mat<T,D>& a_m) {
  mat<T,D> res(a_m);
  res *= a_fac;
  return res;
}
/*
template <class T,unsigned int D>
inline mat<T,D> operator*(const mat<T,D>& a_m,const T& a_fac) {
  mat<T,D> res(a_m);
  res *= a_fac;
  return res;
}
*/

template <class T>
inline nmat<T> operator-(const nmat<T>& a1,const nmat<T>& a2) {
  nmat<T> res(a1);
  res -= a2;
  return res;
}
template <class T>
inline nmat<T> operator+(const nmat<T>& a1,const nmat<T>& a2) {
  nmat<T> res(a1);
  res += a2;
  return res;
}
template <class T>
inline nmat<T> operator*(const nmat<T>& a1,const nmat<T>& a2) {
  nmat<T> res(a1);
  res *= a2;
  return res;
}
template <class T>
inline nmat<T> operator*(const T& a_fac,const nmat<T>& a_m) {
  nmat<T> res(a_m);
  res *= a_fac;
  return res;
}
/*
template <class T>
inline nmat<T> operator*(const nmat<T>& a_m,const T& a_fac) {
  nmat<T> res(a_m);
  res *= a_fac;
  return res;
}
*/

}

namespace tools {

////////////////////////////////////////////////
/// specific D=2 ///////////////////////////////
////////////////////////////////////////////////

template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11
){
  //a_<R><C>
  //vec[R + C * 2];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[2] = a_01;
  vec[1] = a_10;vec[3] = a_11;
}

////////////////////////////////////////////////
/// specific D=3 ///////////////////////////////
////////////////////////////////////////////////
template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01,const typename MAT::elem_t& a_02 //1 row
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11,const typename MAT::elem_t& a_12 //2 row
,const typename MAT::elem_t& a_20,const typename MAT::elem_t& a_21,const typename MAT::elem_t& a_22 //3 row
){
  //a_<R><C>
  //vec[R + C * 3];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[3] = a_01;vec[6] = a_02;
  vec[1] = a_10;vec[4] = a_11;vec[7] = a_12;
  vec[2] = a_20;vec[5] = a_21;vec[8] = a_22;
}

////////////////////////////////////////////////
/// specific D=4 ///////////////////////////////
////////////////////////////////////////////////
template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01,const typename MAT::elem_t& a_02,const typename MAT::elem_t& a_03 //1 row
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11,const typename MAT::elem_t& a_12,const typename MAT::elem_t& a_13 //2 row
,const typename MAT::elem_t& a_20,const typename MAT::elem_t& a_21,const typename MAT::elem_t& a_22,const typename MAT::elem_t& a_23 //3 row
,const typename MAT::elem_t& a_30,const typename MAT::elem_t& a_31,const typename MAT::elem_t& a_32,const typename MAT::elem_t& a_33 //4 row
){
  //a_<R><C>
  //vec[R + C * 4];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[4] = a_01;vec[ 8] = a_02;vec[12] = a_03;
  vec[1] = a_10;vec[5] = a_11;vec[ 9] = a_12;vec[13] = a_13;
  vec[2] = a_20;vec[6] = a_21;vec[10] = a_22;vec[14] = a_23;
  vec[3] = a_30;vec[7] = a_31;vec[11] = a_32;vec[15] = a_33;
}

////////////////////////////////////////////////
/// specific D=5 ///////////////////////////////
////////////////////////////////////////////////
template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01,const typename MAT::elem_t& a_02,const typename MAT::elem_t& a_03,const typename MAT::elem_t& a_04 //1 row
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11,const typename MAT::elem_t& a_12,const typename MAT::elem_t& a_13,const typename MAT::elem_t& a_14 //2 row
,const typename MAT::elem_t& a_20,const typename MAT::elem_t& a_21,const typename MAT::elem_t& a_22,const typename MAT::elem_t& a_23,const typename MAT::elem_t& a_24 //3 row
,const typename MAT::elem_t& a_30,const typename MAT::elem_t& a_31,const typename MAT::elem_t& a_32,const typename MAT::elem_t& a_33,const typename MAT::elem_t& a_34 //4 row
,const typename MAT::elem_t& a_40,const typename MAT::elem_t& a_41,const typename MAT::elem_t& a_42,const typename MAT::elem_t& a_43,const typename MAT::elem_t& a_44 //5 row
){
  //a_<R><C>
  //vec[R + C * 5];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[5] = a_01;vec[10] = a_02;vec[15] = a_03;vec[20] = a_04;
  vec[1] = a_10;vec[6] = a_11;vec[11] = a_12;vec[16] = a_13;vec[21] = a_14;
  vec[2] = a_20;vec[7] = a_21;vec[12] = a_22;vec[17] = a_23;vec[22] = a_24;
  vec[3] = a_30;vec[8] = a_31;vec[13] = a_32;vec[18] = a_33;vec[23] = a_34;
  vec[4] = a_40;vec[9] = a_41;vec[14] = a_42;vec[19] = a_43;vec[24] = a_44;
}

////////////////////////////////////////////////
/// specific D=6 ///////////////////////////////
////////////////////////////////////////////////
template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01,const typename MAT::elem_t& a_02,const typename MAT::elem_t& a_03,const typename MAT::elem_t& a_04,const typename MAT::elem_t& a_05 //1 row
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11,const typename MAT::elem_t& a_12,const typename MAT::elem_t& a_13,const typename MAT::elem_t& a_14,const typename MAT::elem_t& a_15 //2 row
,const typename MAT::elem_t& a_20,const typename MAT::elem_t& a_21,const typename MAT::elem_t& a_22,const typename MAT::elem_t& a_23,const typename MAT::elem_t& a_24,const typename MAT::elem_t& a_25 //3 row
,const typename MAT::elem_t& a_30,const typename MAT::elem_t& a_31,const typename MAT::elem_t& a_32,const typename MAT::elem_t& a_33,const typename MAT::elem_t& a_34,const typename MAT::elem_t& a_35 //4 row
,const typename MAT::elem_t& a_40,const typename MAT::elem_t& a_41,const typename MAT::elem_t& a_42,const typename MAT::elem_t& a_43,const typename MAT::elem_t& a_44,const typename MAT::elem_t& a_45 //5 row
,const typename MAT::elem_t& a_50,const typename MAT::elem_t& a_51,const typename MAT::elem_t& a_52,const typename MAT::elem_t& a_53,const typename MAT::elem_t& a_54,const typename MAT::elem_t& a_55 //6 row
){
  //a_<R><C>
  //vec[R + C * 6];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[ 6] = a_01;vec[12] = a_02;vec[18] = a_03;vec[24] = a_04;vec[30] = a_05;
  vec[1] = a_10;vec[ 7] = a_11;vec[13] = a_12;vec[19] = a_13;vec[25] = a_14;vec[31] = a_15;
  vec[2] = a_20;vec[ 8] = a_21;vec[14] = a_22;vec[20] = a_23;vec[26] = a_24;vec[32] = a_25;
  vec[3] = a_30;vec[ 9] = a_31;vec[15] = a_32;vec[21] = a_33;vec[27] = a_34;vec[33] = a_35;
  vec[4] = a_40;vec[10] = a_41;vec[16] = a_42;vec[22] = a_43;vec[28] = a_44;vec[34] = a_45;
  vec[5] = a_50;vec[11] = a_51;vec[17] = a_52;vec[23] = a_53;vec[29] = a_54;vec[35] = a_55;
}

}

////////////////////////////////////////////////
////////////////////////////////////////////////
////////////////////////////////////////////////
#include <ostream>

namespace tools {

//NOTE : print is a Python keyword.
template <class MAT>
inline void dump(std::ostream& a_out,const std::string& aCMT,const MAT& a_matrix) {
  if(aCMT.size()) a_out << aCMT << std::endl;
  unsigned int D = a_matrix.dimension();
  for(unsigned int r=0;r<D;r++) {
    for(unsigned int c=0;c<D;c++) {
      a_out << " " << a_matrix.value(r,c);
    }
    a_out << std::endl;
  }
}

template <class MAT>
inline bool check_invert(const MAT& a_matrix,std::ostream& a_out) {
  MAT I;I.set_identity();
  MAT tmp;
  if(!a_matrix.invert(tmp)) return false;
  tmp.mul_mtx(a_matrix);
  if(!tmp.equal(I)) {
    dump(a_out,"problem with inv of :",a_matrix);
    return false;
  }
  return true;
}

}

#endif
