#include <stdexcept>
#include <cmath>
#include "vector_sse.hpp"

namespace rk_sse {
	template<uint rank, uint dim>
	class StateVector {
		protected:
			static constexpr uint vsize = (rank * dim) + ((rank * dim) % 2);
		
			double h;
			Vector_SSE<vsize> stateVector;
			
			
			virtual Vector_SSE<vsize> step(Vector_SSE<vsize>  vector) = 0;
			
		public:
			static double maxError;
		
			StateVector(double h) {
				this->h = h;
			}
			
			virtual ~StateVector() {}

			void operator=(const StateVector<rank, dim> & sv) {
				stateVector = sv.stateVector;
			}

			double & operator[](int index) {
				return stateVector[index];
			}

			double getStep() {
				return h;
			}
			
			void RK1SolverUpdate() {
				stateVector = stateVector + h * step(stateVector);
			}

			void RK2SolverUpdate() {
				Vector_SSE<vsize> k1;
				Vector_SSE<vsize> k2;
				
				k1 = h * step(stateVector);
				k2 = h * step(stateVector + 0.5 * k1);
				
				stateVector += k2;
			}
			
			void RK4SolverUpdate() {
				Vector_SSE<vsize> k1;
				Vector_SSE<vsize> k2;
				Vector_SSE<vsize> k3;
				Vector_SSE<vsize> k4;

				k1 = h * step(stateVector);
				k2 = h * step(stateVector + 0.5 * k1);
				k3 = h * step(stateVector + 0.5 * k2);
				k4 = h * step(stateVector + k3);

				stateVector += (k1 + 2. * k2 + 2. * k3 + k4) / 6.;
			}

			void RK4F5SolverUpdate() {
				static const double b31 = 3.0 / 32.0;
				static const double b32 = 9.0 / 32.0;
				static const double a3 = 3.0 / 8.0;
				static const double b41 = 1932.0 / 2197.0;
				static const double b42 = -7200.0 / 2197.0;
				static const double b43 = 7296.0 / 2197.0;
				static const double a4 = 12.0 / 13.0;
				static const double b51 = 439.0 / 216.0;
				static const double b52 = -8.0;
				static const double b53 = 3680.0 / 513.0;
				static const double b54 = -845.0 / 4104.0;
				static const double w41 = 25.0 / 216.0;
				static const double w43 = 1408.0 / 2565.0;
				static const double w44 = 2197.0 / 4104.0;
				static const double w45 = -0.2;
				static const double b61 = -8.0 / 27.0;
				static const double b62 = 2;
				static const double b63 = -3544.0 / 2565.0;
				static const double b64 = 1859.0 / 4104.0;
				static const double b65 = -11.0 / 40.0;
				static const double w51 = 16.0 / 135.0;
				static const double w53 = 6656.0 / 12825.0;
				static const double w54 = 28561.0 / 56430.0;
				static const double w55 = -9.0 / 50.0;
				static const double w56 = 2.0 / 55.0;
				

				Vector_SSE<vsize> k1;
				Vector_SSE<vsize> k2;
				Vector_SSE<vsize> k3;
				Vector_SSE<vsize> k4;
				Vector_SSE<vsize> k5;
				Vector_SSE<vsize> k6;
				Vector_SSE<vsize> next1, next2;
				double error = 0;

				do {					
					k1 = h * step(stateVector);
					k2 = h * step(stateVector + 0.25 * k1);
					k3 = h * step(stateVector + b31 * k1 + b32 * k2);
					k4 = h * step(stateVector + b41 * k1 + b42 * k2 + b43 * k3);
					k5 = h * step(stateVector + b51 * k1 + b52 * k2 + b53 * k3 + b54 * k4);
					k6 = h * step(stateVector + b61 * k1 + b62 * k2 + b63 * k3 + b64 * k4 + b65 * k5);
					
					next1 = stateVector + w41 * k1 + w43 * k3 + w44 * k4 + w45 * k5;
					next2 = stateVector + w51 * k1 + w53 * k3 + w54 * k4 + w55 * k5 + w56 * k6;

					for(uint i = 0; i < vsize; i++) {
						error += fabs(next1[i] - next2[i]);
						error /= dim * 2;
					}
					
					h *= 0.1 * pow(1. / error, 0.2);
				} while(error > maxError);

				stateVector = next2;
			}
	};

	template<uint rank, uint dim>
	double StateVector<rank, dim>::maxError = 1e-5;
	
}
