#ifndef MATRIX_H__
#define MATRIX_H__

#include <cmath>

enum fill
{
	none = 0,
	zeros = 1,
	ones = 2,
	eye = 3
};

enum dim
{
    x = 0,
    y = 1,
    z = 2
};

template<typename T, size_t DIM>
class Matrix
{
#define _MATRIX_SIZE (DIM*DIM)
private:
	Matrix();
protected:
	static const T zero;
	static const T one;
	typedef T * const constMatPtr;
	typedef const T * const constMat;
public:
	static void Fill(constMatPtr& matrix, const fill& type = fill::none);
	static void Multiply(constMat& lhsMatrix, constMat& rhsMatrix, constMatPtr& outMatrix);
	static void Transpose(constMat& inMatrix, constMatPtr& outMatrix);
};

template<typename T, size_t DIM>
const T Matrix<T, DIM>::zero = static_cast<T>(0);

template<typename T, size_t DIM>
const T Matrix<T, DIM>::one = static_cast<T>(1);

template<typename T>
class Matrix4 : public Matrix<T, 4>
{
public:
	static void FillTransform(constMatPtr& matrix, const T& x, const T& y, const T& z);
	static void FillScale(constMatPtr& matrix, const T& scaleX, const T& scaleY, const T& scaleZ);
	static void FillRotate(constMatPtr& matrix, const T& angle, const dim& dimension);
	static void FillFrustum(constMatPtr& matrix, const T& left, const T& right, const T& bottom, const T& top, const T& nearVal, const T& farVal);
    static void FillOrtho(constMatPtr& matrix, const T& left, const T& right, const T& bottom, const T& top, const T& nearVal, const T& farVal);
};

typedef Matrix4<float> mat4;
typedef Matrix4<double> mat4d;

template<typename T, size_t DIM>
void Matrix<T, DIM>::Fill(constMatPtr& matrix, const fill& type = fill::none)
{
	T *end = matrix + _MATRIX_SIZE;
	switch (type)
	{
	case zeros:
		for (T *i = matrix; i < end; ++i)
			*i = zero;
		break;
	case ones:
		for (T *i = matrix; i < end; ++i)
			*i = one;
		break;
	case eye:
		for (size_t i = 0; i < _MATRIX_SIZE; ++i)
			matrix[i] = ((i / DIM) == (i % DIM) ? one : zero);
		break;
	case none:
	default:
		break;
	}
}

template<typename T, size_t DIM>
void Matrix<T, DIM>::Multiply(constMat& lhsMatrix, constMat& rhsMatrix, constMatPtr& outMatrix)
{
    Fill(outMatrix, fill::zeros);
	for (size_t c = 0; c < DIM; ++c)
		for (size_t r = 0; r < DIM; ++r)
			for (size_t i = 0; i < DIM; ++i)
				outMatrix[r + c*DIM] += lhsMatrix[r + i*DIM] * rhsMatrix[i + c*DIM];
}

template<typename T, size_t DIM>
void Matrix<T, DIM>::Transpose(constMat& inMatrix, constMatPtr& outMatrix)
{
	for (size_t w = 0; w < DIM; ++w)
		for (size_t k = 0; k < DIM; ++k)
			outMatrix[w + k*DIM] = inMatrix[k + w*DIM];
}

template<typename T>
void Matrix4<T>::FillTransform(constMatPtr& matrix, const T& x, const T& y, const T& z)
{
	Fill(matrix, fill::eye);
	matrix[3]  = x;
	matrix[7]  = y;
	matrix[11] = z;
}

template<typename T>
void Matrix4<T>::FillScale(constMatPtr& matrix, const T& scaleX, const T& scaleY, const T& scaleZ)
{
	Fill(matrix, fill::zeros);
	matrix[0]  = scaleX;
	matrix[5]  = scaleY;
	matrix[10] = scaleZ;
	matrix[15] = one;
}

template<typename T>
void Matrix4<T>::FillRotate(constMatPtr& matrix, const T& angle, const dim& dimension)
{
	const T sin = static_cast<T>(std::sin(angle));
	const T cos = static_cast<T>(std::cos(angle));
	Fill(matrix, fill::eye);
    switch (dimension)
    {
    case x:
        matrix[5]  =  cos;
	    matrix[6]  = -sin;
	    matrix[9]  =  sin;
	    matrix[10] =  cos;
        break;
    case y:
        matrix[0]  =  cos;
	    matrix[2]  =  sin;
	    matrix[8]  = -sin;
	    matrix[10] =  cos;
        break;
    case z:
        matrix[0] =  cos;
	    matrix[1] = -sin;
	    matrix[4] =  sin;
	    matrix[5] =  cos;
        break;
    default:
        break;
    }
}

template<typename T>
void Matrix4<T>::FillFrustum(constMatPtr& matrix, const T& left, const T& right, const T& bottom, const T& top, const T& nearVal, const T& farVal)
{
	Fill(matrix, fill::zeros);
	const T x = right - left,
			y = top - bottom,
            z = nearVal - farVal,
			dn = 2 * nearVal;

	matrix[0]  = dn / x;
	matrix[2]  = (right+left) / x;
	matrix[5]  = dn / y;
	matrix[6]  = (top+bottom) / y;
	matrix[10] = (farVal+nearVal) / z;
	matrix[11] = dn*farVal / z;
	matrix[14] = -1;
}

template<typename T>
void Matrix4<T>::FillOrtho(constMatPtr& matrix, const T& left, const T& right, const T& bottom, const T& top, const T& nearVal, const T& farVal)
{
	Fill(matrix, fill::zeros);
	const T x = right - left,
			y = top - bottom,
			z = farVal - nearVal;

	matrix[0]  = 2 / x;
    matrix[3]  = -(right+left) / x;
	matrix[5]  = 2 / y;
	matrix[7]  = -(top+bottom) / y;
	matrix[10] = -2 / z;
	matrix[11] = -(farVal+nearVal) / z;
	matrix[15] = 1;
}

#endif /* MATRIX_H__ */