#include <stdexcept> //wyjtki
#include <iostream> //cout
#include <fstream> //pliki
using namespace std;

const int N = 8;
const double blad_min = 1E-2;
const double blad_max = 1E1;

const double jednostkaOdleglosci = 384403000; //m
const double jednostkaMasy = 5.9736E24; //kg
const double jednostkaCzasu = 24 * 60 * 60; //dzie Ziemski w sekundach

//problem fizyczny
template<typename T>
T rzut(int i, T* y, T t) //1D,n=2 -> N=2
{
	static const double g = 9.81; // => SI

	double wynik = 0;
	switch (i)
	{
	case 0:
		wynik = y[1];
		break;
	case 1:
		wynik = g;
		break;

	default:
		throw runtime_error("Zy numer rwnania");
	}

	return wynik;
}

template<typename T>
T oscylacje(int i, T* y, T t) //1D,n=2 -> N=2
{
	static const double w = 1;
	static const double b = 0;

	double wynik = 0;
	switch (i)
	{
	case 0:
		wynik = y[1];
		break;
	case 1:
		wynik = -w*w*y[0] - 2 * b*y[1];
		break;

	default:
		throw runtime_error("Zy numer rwnania");
	}
	return wynik;
}

template<typename T>
T oscylacje_sprzezone(int i, T* y, T t) //n w 1D -> N = 2n
{
	static const double w = 1;
	static const double b = 0;

	double wynik = 0;
	if (i % 2 == 0)
	{
		//polozenia
		wynik = y[i + 1];		
	}
	else
	{
		//predkosci
		//wynik = -w*w*y[i - 1] - 2 * b*y[i];
		//*
		wynik = -2 * b*y[i];
		if (i > 1) wynik += -w*w*(y[i - 1] - y[i - 3] - 1);
		if (i < N - 1) wynik += +w*w*(y[i + 1] - y[i - 1] - 1);
		//*/
	}

	return wynik;
}

#define SQR(x) ((x)*(x))

template<typename T>
T grawitacja(int i, T* y, T t) //zagadnienie 2 cial w 2D
{
	static const double G = 6.6742867E-11; //m^3/kg/s^2
	static const double M = jednostkaMasy;
	static const double m = 0.0123*jednostkaMasy;

	double odleglosc_x = y[4] - y[0]; //(r12)_x
	double odleglosc_y = y[6] - y[2]; //(r12)_y
	double kwadrat_odleglosci = SQR(odleglosc_x) + SQR(odleglosc_y); //r12^2
	double odleglosc = sqrt(kwadrat_odleglosci); //r12
	double sila = G*M*m / kwadrat_odleglosci; //F

	double _x = odleglosc_x / odleglosc;
	double _y = odleglosc_y / odleglosc;

	switch (i)
	{
	//Ziemia
	case 0:
		return y[1];
	case 1:
		return sila*_x / M;
	case 2:
		return y[3];
	case 3:
		return sila*_y / M;

	//Ksiezyc
	case 4:
		return y[5];
	case 5:
		return -sila*_x / m;
	case 6:
		return y[7];
	case 7:
		return -sila*_y / m;
	default:
		throw runtime_error("Zy numer rwnania");
	}
}


//solver - Euler
template<typename T>
T* odeint_Euler(int N, T(*f)(int, T*, T), T* y, T t, T h, T* y_nast)
{
	for (int i = 0; i < N; ++i) y_nast[i] = y[i] + h* f(i, y, t);	
	return y_nast;
}

//solver - MidPoint
template<typename T>
T* odeint_MidPoint(int N, T(*f)(int, T*, T), T* y, T t, T h, T* y_nast)
{
	T* y_tmp = new T[N];
	for (int i = 0; i < N; ++i)
	{
		T k1 = h*f(i, y, t);
		y_tmp[i] = y[i] + 0.5*k1;
	}
	for (int i = 0; i < N; ++i)
	{
		T k2 = h * f(i, y_tmp, t + 0.5*h);
		y_nast[i] = y[i] + k2;
	}
	return y_nast;
}

template<typename T>
T* odeint_RK4(int N, T(*f)(int, T*, T), T* y, T t, T h, T* y_nast)
{
	T* y_tmp = new T[N];
	T* k1 = new T[N];
	T* k2 = new T[N];
	T* k3 = new T[N];
	T* k4 = new T[N];
	for (int i = 0; i < N; ++i)
	{
		k1[i] = h*f(i, y, t);
		y_tmp[i] = y[i] + 0.5*k1[i];
	}
	for (int i = 0; i < N; ++i)
	{
		k2[i] = h*f(i, y_tmp, t + 0.5*h);
		y_nast[i] = y[i] + 0.5*k2[i];
	}
	for (int i = 0; i < N; ++i)
	{
		k3[i] = h*f(i, y_nast, t + 0.5*h);
		y_tmp[i] = y[i] + k3[i];
	}
	for (int i = 0; i < N; ++i)
	{
		k4[i] = h* f(i, y_tmp, t + h);
		y_nast[i] = y[i] + (k1[i] + 2.0*k2[i] + 2.0*k3[i] + k4[i]) / 6.0;
	}
	delete[] k1;
	delete[] k2;
	delete[] k3;
	delete[] k4;
	delete[] y_tmp;

	return y_nast;
}

template<typename T>
T* odeint_RK4F5(int N, T(*f)(int, T*, T), T* y, T t, T& h, T* y_nast)
{
	static const T b31 = 3.0 / 32.0;
	static const T b32 = 9.0 / 32.0;
	static const T a3 = 3.0 / 8.0;

	static const T b41 = 1932.0 / 2197.0;
	static const T b42 = -7200.0 / 2197.0;
	static const T b43 = 7296.0 / 2197.0;
	static const T a4 = 12.0 / 13.0;

	static const T b51 = 439.0 / 216.0;
	static const T b53 = 3680.0 / 513.0;
	static const T b54 = -845.0 / 4104.0;

	static const T w41 = 25.0 / 216.0;
	static const T w43 = 1408.0 / 2565.0;
	static const T w44 = 2197.0 / 4104.0;

	static const T b61 = -8.0 / 27.0;
	static const T b63 = -3544.0 / 2565.0;
	static const T b64 = 1859.0 / 4104.0;
	static const T b65 = -11.0 / 40.0;

	static const T w51 = 16.0 / 135.0;
	static const T w53 = 6656.0 / 12825.0;
	static const T w54 = 28561.0 / 56430.0;
	static const T w55 = -9.0 / 50.0;
	static const T w56 = 2.0 / 55.0;

	T* y_tmp = new T[N];
	T* y4_nast = new T[N];
	T* k1 = new T[N];
	T* k2 = new T[N];
	T* k3 = new T[N];
	T* k4 = new T[N];
	T* k5 = new T[N];
	T* k6 = new T[N];	

	T blad = 0;
	do
	{
		for (int i = 0; i < N; ++i)
		{
			k1[i] = h*f(i, y, t);
			y_tmp[i] = y[i] + 0.25*k1[i];
		}
		for (int i = 0; i < N; ++i)
		{
			k2[i] = h*f(i, y_tmp, t + 0.25*h);
			y_nast[i] = y[i] + b31*k1[i] + b32*k2[i];
		}
		for (int i = 0; i < N; ++i)
		{
			k3[i] = h*f(i, y_nast, t + a3*h);
			y_tmp[i] = y[i] + b41*k1[i] + b42*k2[i] + b43*k3[i];
		}
		for (int i = 0; i < N; ++i)
		{
			k4[i] = h*f(i, y_tmp, t + a4*h);
			y_nast[i] = y[i] + b51*k1[i] - 8.0*k2[i] + b53*k3[i] + b54*k4[i];
		}
		for (int i = 0; i < N; ++i)
		{
			k5[i] = h*f(i, y_nast, t + h);
			y4_nast[i] = y[i] + w41*k1[i] + w43*k3[i] + w44*k4[i] - 0.2*k5[i]; //RK4 na 5 wyrazach
			y_tmp[i] = y[i] + b61 * k1[i] + 2 * k2[i] + b63*k3[i] + b64*k4[i] + b65*k5[i];
		}
		for (int i = 0; i < N; ++i)
		{
			k6[i] = h*f(i, y_tmp, t + 0.5*h);
			y_nast[i] = y[i] + w51*k1[i] + w53*k3[i] + w54*k4[i] + w55*k5[i] + w56*k6[i]; //RK5 na 6 wyrazach
			blad += fabs(y_nast[i] - y4_nast[i]);
			blad /= N;
		}

		//cout << "t=" << t << ", blad=" << blad << ", h=" << h;
		if (blad < blad_min) h *= 2;
		if (blad > blad_max) h /= 2;
		//cout << ", nowe h=" << h << "\n";
	} 
	while(blad<blad_min || blad>blad_max);

	delete[] y_tmp;
	delete[] y4_nast;
	delete[] k1;
	delete[] k2;
	delete[] k3;
	delete[] k4;
	delete[] k5;
	delete[] k6;

	return y_nast;
}

int main()
{
	//cout << "Hello, Physics!\n";

	//int N = 2;

	double* y = new double[N];
	double* y_nast = new double[N];
	for (int i = 0; i < N; ++i)
	{
		y[i] = 0; //m
		y_nast[i] = 0;
	}
	double tmax = 365 * jednostkaCzasu; //s
	double h = 1; //s
	
	y[6] = jednostkaOdleglosci;
	y[5] = 1.022E3; //m/s

	//plik
	ofstream plik_wy("wyniki.dat");
	plik_wy.precision(10);
	plik_wy.setf(ios::scientific);

	//ewolucja ukadu
	for (double t = 0; t < tmax; t += h)
	{
		//odeint_Euler<double>(N, rzut<double>, y, t, h, y_nast);
		//odeint_MidPoint<double>(N, rzut<double>, y, t, h, y_nast);
		
		//odeint_Euler<double>(N, oscylacje<double>, y, t, h, y_nast);
		//odeint_MidPoint<double>(N, oscylacje<double>, y, t, h, y_nast);
		//odeint_RK4<double>(N, oscylacje<double>, y, t, h, y_nast);

		//odeint_Euler<double>(N, oscylacje_sprzezone<double>, y, t, h, y_nast);
		//odeint_MidPoint<double>(N, oscylacje_sprzezone<double>, y, t, h, y_nast);
		//odeint_RK4<double>(N, oscylacje_sprzezone<double>, y, t, h, y_nast);
		//odeint_RK4F5<double>(N, oscylacje_sprzezone<double>, y, t, h, y_nast);

		odeint_RK4F5<double>(N, grawitacja<double>, y, t, h, y_nast);

		//cout << "t=" << t << "\ty[0]=" << y[0] << "\ty[1]=" << y[1] << "\n";		
		plik_wy << t << "\t" << h;
		for (int i = 0; i < N; ++i) plik_wy << "\t" << y[i];
		plik_wy << "\n";

		//brute-force
		/*
		for (int i = 0; i < N; ++i)
		{
			y[i] = y_nast[i];
			y_nast[i] = 0; //niepotrzebne			
		}
		*/

		double* y_tmp = y;
		y = y_nast;
		y_nast = y_tmp;
	}

	plik_wy.close();

	cout << "\n\nOK.\n";

	delete[] y;
}