#include <stdexcept>
#include <iostream>
using namespace std;

//fizyka
double rzutWPoluGrawitacyjnym(int i,double* y,double t)
{
	static const double g = 9.81;

	double wynik = 0;
	switch(i)
	{
	case 0:
		wynik = y[1];
		break;
	case 1:
		wynik = g;
		break;
	default:
		throw runtime_error("Zly numer rownania");
	}
	return wynik;
}

double oscylator(int i,double* y,double t)
{
	static const double k = 1;

	double wynik = 0;
	switch(i)
	{
	case 0:
		wynik = y[1];
		break;
	case 1:
		wynik = -k*y[0];
		break;
	default:
		throw runtime_error("Zly numer rownania");
	}
	return wynik;
}

double zbiorOscylatorowSprzezonych(int i,double* y,double t)
{
	static const double k = 1;

	double wynik = 0;
	switch(i)
	{
	case 0:
		wynik = y[1];
		break;
	case 1:
		wynik = -k*(y[0]-y[2]+1);
		break;
	case 2:
		wynik = y[3];
		break;
	case 3:
		wynik = -k*(y[2]-y[0]-1)-k*(y[2]-y[4]+1);
		break;
	case 4:
		wynik = y[5];
		break;
	case 5:
		wynik = -k*(y[4]-y[2]-1);
		break;
	default:
		throw runtime_error("Zly numer rownania");
	}
	return wynik;
}

//numeryka
double* odeint_Euler( //zawsze rozbiezne
	int N,double (*f)(int i,double* y,double t), 
	double* y, double t, double h,double* y_nast)
{
	for(int i = 0; i < N; ++i) y_nast[i] = y[i] + h * f(i, y, t);
	return y_nast;
}

double* odeint_MidPoint(
	int N,double (*f)(int i,double* y,double t), 
	double* y, double t, double h,double* y_nast)
{
	double* y_tmp = new double[N];
	for(int i = 0; i<N;++i)
	{
		double k1 = h*f(i,y,t);
		y_tmp[i]=y[i]+0.5*k1;
	}
	for(int i = 0; i < N;++i)
	{
		double k2 = h*f(i,y_tmp, t+0.5*h);
		y_nast[i] = y[i]+k2;
	}
	delete[] y_tmp;
	return y_nast;
}


int main()
{
	int N = 6;

	double* y = new double[N];
	double* y_nast = new double[N];
	for(int i = 0; i<N; ++i) y[i] = 0;

	y[1] = 0.1;
	y[5] = 0;

	double tmax = 10;
	double h = 0.001;

	for(double t = 0; t<tmax; t+=h)
	{
		//odeint_Euler(N,rzutWPoluGrawitacyjnym,y,t,h,y_nast);
		//odeint_MidPoint(N,rzutWPoluGrawitacyjnym,y,t,h,y_nast);
		//odeint_MidPoint(N,oscylator,y,t,h,y_nast);
		odeint_MidPoint(N,zbiorOscylatorowSprzezonych,y,t,h,y_nast);

		cout << t;
		for(int i = 0; i<N; ++i) cout << "\t" << y[i];
		cout << "\n";

		double* y_tmp = y;
		y = y_nast;
		y_nast = y_tmp;
	}

	delete[] y;
	delete[] y_nast;

	cout << "\n\nOK.\n\n";

	return EXIT_SUCCESS;
}