/*
 * This is some code that converts cartesian coordinate state vectors to 
 * classical orbital elements, and back again.
 */

#include <stdio.h>
#include <math.h>

double Em11 = 0.00000000001;
double G = 6.67300 * Em11;
double E24 = 1000000000000000000000000.0;
double E24m11 = 10000000000000.0;
double km = 1000.0;

class Body
{
protected:
    double GM;
    double r;

public:
    Body(double GM, double r) : GM(GM), r(r) { }

    double mass() { return GM / G; }
    double gm() { return GM; }
    double radius() { return r; }
};

class Vector3
{
public:
    double x, y, z;

    Vector3() : x(0), y(0), z(0) { }
    Vector3(double x, double y, double z) : x(x), y(y), z(z) { }
   
    double sqlen() { return x*x + y*y + z*z; } 
    double len() { return sqrt(sqlen()); }
    double dot(Vector3 o)
    {
        return x*o.x + y*o.y + z*o.z;
    }

    Vector3 operator*(Vector3 o)
    {
        return Vector3(y*o.z - z*o.y, z*o.x - x*o.z, x*o.y - y*o.x);
    }

    Vector3 operator*(double s)
    {
        return Vector3(s*x, s*y, s*z);
    }

    Vector3 operator-(double s)
    {
        return Vector3(x-s, y-s, z-s);
    }

    Vector3 operator+(Vector3 o)
    {
        return Vector3(x+o.x, y+o.y, z+o.z);
    }

    Vector3 operator-(Vector3 o)
    {
        return Vector3(x-o.x, y-o.y, z-o.z);
    }
};

class Orbit
{
protected:
    double a, e, i, l, w, t;
    Vector3 epoch;
    Body *ref;

public:
    Orbit() { }
    Orbit(Body *ref, double a, double e, double i, double l, double w, double t) : ref(ref), a(a), e(e), i(i), l(l), w(w), t(t) { }
    Orbit(Body *ref, Vector3 r, Vector3 v) : ref(ref) { calcFromPosVel(r, v); }

    Body *reference() { return ref; }
    double semiMajorAxis() { return a; }
    double eccentricity() { return e; }
    double inclination() { return i; }
    double longitudeOfAscendingNode() { return l; }
    double argumentOfPeriapsis() { return w; }
    double trueAnomaly() { return t; }

    double semiMinorAxis()
    {
        return sqrt(a*a * (1 - e*e));
    }

    double period()
    { 
        double u = ref->gm();
        return 2 * M_PI * sqrt(a*a*a / u); 
    }

    double eccentricAnomaly()
    {
        double E = acos((e + cos(t)) / (1 + e * cos(t)));
        if (t > M_PI && E < M_PI)
            E = 2*M_PI - E;
        return E;
    }

    double meanAnomaly()
    {
        double E = eccentricAnomaly();
        double M = E - e * sin(E);
        if (E > M_PI && M < M_PI)
            M = 2*M_PI - M;
        return M;
    }

    double meanMotion()
    {
        double u = ref->gm();
        return sqrt(u / (a*a*a));
    }

    double timeSincePeriapsis()
    {
        return meanAnomaly() / meanMotion();
    }

    double semiparameter()
    {
        return a * (1 - e*e);
    }

    Vector3 position()
    {
        double p = semiparameter();
        Vector3 r;
        r.x = p * (cos(l) * cos(w + t) - sin(l) * cos(i) * sin(w + t));
        r.y = p * (sin(l) * cos(w + t) + cos(l) * cos(i) * sin(w + t));
        r.z = p * sin(i) * sin(w + t);
        return r - epoch;
    }

    Vector3 velocity()
    {
        double p = semiparameter();
        double u = ref->gm();
        Vector3 v;
        double g = -sqrt(u/p);
        v.x = g * (cos(l)          * (sin(w + t) + e * sin(w)) + 
                   sin(l) * cos(i) * (cos(w + t) + e * cos(w)));
        v.y = g * (sin(l)          * (sin(w + t) + e * sin(w)) -
                   cos(l) * cos(i) * (cos(w + t) + e * cos(w)));
        v.z = g * (sin(i) * (cos(w + t) + e * cos(w)));
        return v; 
    }

    void print()
    {
        printf("semi-major axis: %f\n", a);
        printf("eccentricity: %f\n", e);
        printf("inclination: %f\n", i);
        printf("longitude of ascending node: %f\n", l);
        printf("argument of periapsis: %f\n", w);
        printf("true anomaly: %f\n", t);
        printf("orbit period: %f\n", period());
        printf("eccentric anomaly: %f\n", eccentricAnomaly());
        printf("mean anomaly: %f\n", meanAnomaly());
        printf("mean motion: %f\n", meanMotion());
        printf("time since periapsis: %f\n", timeSincePeriapsis());
        printf("epoch: %f %f %f\n", epoch.x, epoch.y, epoch.z);
        Vector3 pos = position();
        printf("position: %f %f %f, alt=%fkm\n", pos.x, pos.y, pos.z, (pos.len() - ref->radius()) / km);
        Vector3 vel = velocity();
        printf("velocity: %f %f %f, len=%f\n", vel.x, vel.y, vel.z, vel.len());
        printf("--\n");
    }    

    // vectors in geocentric equatorial inertial coordinates
    void calcFromPosVel(Vector3 r, Vector3 v)
    {
        // calculate specific relative angular momement
        Vector3 h = r * v;

        // calculate vector to the ascending node
        Vector3 n(-h.y, h.x, 0);

        // standard gravity
        double u = ref->gm();

        // calculate eccentricity vector and scalar
        Vector3 e = ((v * h) * (1.0 / u)) - (r * (1.0 / r.len()));
        this->e = e.len();

        // calculate specific orbital energy and semi-major axis
        double E = v.sqlen() * 0.5 - u / r.len();
        this->a = -u / (E * 2);

        // calculate inclination
        this->i = acos(h.z / h.len());

        // calculate longitude of ascending node
        if (this->i == 0.0)
            this->l = 0;
        else if (n.y >= 0.0)
            this->l = acos(n.x / n.len());
        else
            this->l = 2 * M_PI - acos(n.x / n.len());
        
        // calculate argument of periapsis
        if (this->i == 0.0)
            this->w = acos(e.x / e.len());
        else if (e.z >= 0.0)
            this->w = acos(n.dot(e) / (n.len() * e.len()));
        else
            this->w = 2 * M_PI - acos(n.dot(e) / (n.len() * e.len())); 

        // calculate true anomaly
        if (r.dot(v) >= 0.0)
            this->t = acos(e.dot(r) / (e.len() * r.len()));
        else
            this->t = 2 * M_PI - acos(e.dot(r) / (e.len() * r.len()));

        // calculate epoch
        this->epoch = Vector3(0,0,0);
        this->epoch = position() - r;
    }

    // For small eccentricities a good approximation of true anomaly can be 
    // obtained by the following formula (the error is of the order e^3)
    double estimateTrueAnomaly(double meanAnomaly)
    {
        double M = meanAnomaly;
        return M + 2 * e * sin(M) + 1.25 * e * e * sin(2 * M);
    }

    double calcEccentricAnomaly(double meanAnomaly)
    {
        double t = estimateTrueAnomaly(meanAnomaly);
        double E = acos((e + cos(t)) / (1 + e * cos(t)));
        double M = E - e * sin(E);

        // iterate to get M closer to meanAnomaly
        double rate = 0.01;
        bool lastDec = false;
        while(1) 
        {
            //printf("   using approx %f to %f\n", M, meanAnomaly);
            if (fabs(M - meanAnomaly) < 0.0000000000001)
                break;
            if (M > meanAnomaly)
            {
                E -= rate;
                lastDec = true;
            }
            else
            {
                E += rate;
                if (lastDec)
                    rate *= 0.1;
            }
            M = E - e * sin(E);
        }

        if (meanAnomaly > M_PI && E < M_PI)
            E = 2*M_PI - E;

        return E;
    }

    void calcTrueAnomaly(double eccentricAnomaly)
    {
        double E = eccentricAnomaly;
        this->t = acos((cos(E) - e) / (1 - e * cos(E)));
        //this->t = 2 * atan2(sqrt(1+e)*sin(E/2), sqrt(1-e)*cos(E/2));
        if (eccentricAnomaly > M_PI && this->t < M_PI)
            this->t = 2*M_PI - this->t;
    }

    void step(double time)
    {
        double M = meanAnomaly();
        M += meanMotion() * time;
        while (M < -2*M_PI)
            M = M + 2*M_PI;
        if (M < 0)
            M = 2*M_PI + M;
        while (M > 2*M_PI)
            M = M - 2*M_PI;
        double E = calcEccentricAnomaly(M);
        calcTrueAnomaly(E);
        //printf("since M: %f, E=%f, t=%f\n", M, E, t);
    }
};

Body earth(6.67300 * 5.9742 * E24m11, 6378.1 * km);
Orbit moon(&earth, 384399 * km, 0.0549, 18.29, 0, 0, 0);

int main()
{
    Vector3 pos(0, earth.radius() + 300 * km, 0); 
    Vector3 vel(-7000, 0, 0);
    Orbit o(&earth, pos, vel);
    o.print();
    for (int i = 0; i < o.period(); i++)
    {
        o.step(10);
        o.print();
    }
}


