package defpackage;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:CP_RL.class */
public class CP_RL {
    CartPole cp;
    double u;
    double sig;
    double mu;
    double TD;
    double Vold;
    int nV_crit;
    int ntheta_crit;
    int nomega_crit;
    int nx_crit;
    int nv_crit;
    NGnet V_crit;
    double[] eV_crit;
    int d_act;
    int nmu_act;
    int ntheta_act;
    int nomega_act;
    int nx_act;
    int nv_act;
    NGnet[] mu_act;
    double[][] emu_act;
    double[] w_sig;
    double[] esig_act;
    double dt_model;
    double dt;
    int Niteration;
    int d = 4;
    int reward_flag = 0;
    double t = 0.0d;
    double umax = 10.0d;
    double thetamax = 3.141592653589793d;
    double thetamin = -this.thetamax;
    double omegamax = 7.853981633974483d;
    double omegamin = -this.omegamax;
    double xmax = 2.4d;
    double xmin = -this.xmax;
    double vmax = 2.0d;
    double vmin = -this.vmax;
    double tau = 1.0d;
    double kappa_crit = 0.1d;
    double kappa_mu = 0.5d;
    double kappa_sig = 0.5d;
    double alpha = 30.0d;
    double beta_mu = 30.0d;
    double beta_sig = 30.0d;

    /* JADX INFO: Access modifiers changed from: package-private */
    public CP_RL(double d, double d2, double d3, double d4, double d5, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, double d6, double d7) {
        this.cp = new CartPole(d, d2, d3, d4, d5, d6);
        this.ntheta_crit = i;
        this.nomega_crit = i2;
        this.nx_crit = i3;
        this.nv_crit = i4;
        this.ntheta_act = i5;
        this.nomega_act = i6;
        this.nx_act = i7;
        this.nv_act = i8;
        this.dt_model = d6;
        this.dt = d7;
        this.Niteration = (int) (d7 / d6);
        this.nV_crit = this.ntheta_crit * this.nomega_crit * this.nx_crit * this.nv_crit;
        this.eV_crit = new double[this.nV_crit];
        for (int i9 = 0; i9 < this.nV_crit; i9++) {
            this.eV_crit[i9] = 0.0d;
        }
        this.V_crit = new NGnet(this.d);
        this.V_crit.setParam(0, this.thetamin, this.thetamax, this.ntheta_crit);
        this.V_crit.setParam(1, this.omegamin, this.omegamax, this.nomega_crit);
        this.V_crit.setParam(2, this.xmin, this.xmax, this.nx_crit);
        this.V_crit.setParam(3, this.vmin, this.vmax, this.nv_crit);
        this.V_crit.initw(0.0d, 0.0d);
        this.d_act = 1;
        this.nmu_act = this.ntheta_act * this.nomega_act * this.nx_act * this.nv_act;
        this.emu_act = new double[this.d_act][this.nmu_act];
        this.w_sig = new double[this.d_act];
        this.esig_act = new double[this.d_act];
        for (int i10 = 0; i10 < this.d_act; i10++) {
            for (int i11 = 0; i11 < this.nmu_act; i11++) {
                this.emu_act[i10][i11] = 0.0d;
            }
            this.esig_act[i10] = 0.0d;
            this.w_sig[i10] = 0.0d;
        }
        this.mu_act = new NGnet[this.d_act];
        for (int i12 = 0; i12 < this.d_act; i12++) {
            this.mu_act[i12] = new NGnet(this.d);
            this.mu_act[i12].setParam(0, this.thetamin, this.thetamax, this.ntheta_act);
            this.mu_act[i12].setParam(1, this.omegamin, this.omegamax, this.nomega_act);
            this.mu_act[i12].setParam(2, this.xmin, this.xmax, this.nx_act);
            this.mu_act[i12].setParam(3, this.vmin, this.vmax, this.nv_act);
            this.mu_act[i12].initw(0.0d, 0.0d);
        }
        this.TD = 0.0d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void iteration() {
        this.mu_act[0].setX(0, this.cp.getx(0));
        this.mu_act[0].setX(1, this.cp.getx(1));
        this.mu_act[0].setX(2, this.cp.getx(2));
        this.mu_act[0].setX(3, this.cp.getx(3));
        this.mu = this.mu_act[0].getVal();
        this.sig = 1.0d / (1.0d + Math.exp(-this.w_sig[0]));
        this.u = this.umax * sigmoid(this.mu + (this.sig * gausRand()));
        this.cp.setu(this.u);
        this.V_crit.setX(0, this.cp.getx(0));
        this.V_crit.setX(1, this.cp.getx(1));
        this.V_crit.setX(2, this.cp.getx(2));
        this.V_crit.setX(3, this.cp.getx(3));
        double val = this.V_crit.getVal();
        this.TD = (reward(this.reward_flag) - (val / this.tau)) + ((val - this.Vold) / this.dt);
        this.Vold = val;
        for (int i = 0; i < this.nV_crit; i++) {
            this.V_crit.setw(i, this.V_crit.getw(i) + (this.alpha * this.TD * this.eV_crit[i] * this.dt));
            double[] dArr = this.eV_crit;
            int i2 = i;
            dArr[i2] = dArr[i2] + ((((-this.eV_crit[i]) / this.kappa_crit) + this.V_crit.getBaseVal(i)) * this.dt);
        }
        for (int i3 = 0; i3 < this.nmu_act; i3++) {
            double wVar = this.mu_act[0].getw(i3) + (this.beta_mu * this.TD * this.emu_act[0][i3] * this.dt);
            if (wVar > 10.0d) {
                wVar = 10.0d;
            } else if (wVar < -10.0d) {
                wVar = -10.0d;
            }
            this.mu_act[0].setw(i3, wVar);
            double baseVal = (((5.0d * this.u) / this.umax) - this.mu) * this.mu_act[0].getBaseVal(i3);
            double[] dArr2 = this.emu_act[0];
            int i4 = i3;
            dArr2[i4] = dArr2[i4] + ((((-this.emu_act[0][i3]) / this.kappa_mu) + baseVal) * this.dt);
        }
        double[] dArr3 = this.w_sig;
        dArr3[0] = dArr3[0] + (this.beta_sig * this.TD * this.esig_act[0] * this.dt);
        double d = (((((5.0d * this.u) / this.umax) - this.mu) * (((5.0d * this.u) / this.umax) - this.mu)) - (this.sig * this.sig)) * (1.0d - this.sig);
        double[] dArr4 = this.esig_act;
        dArr4[0] = dArr4[0] + ((((-this.esig_act[0]) / this.kappa_sig) + d) * this.dt);
        for (int i5 = 0; i5 < this.Niteration; i5++) {
            this.cp.nextstep();
            this.cp.nextcopy();
        }
    }

    double one_trial(double d) {
        double d2 = 0.0d;
        int i = 0;
        initialize();
        double xVar = this.cp.getx(0);
        while (this.cp.gett() < d) {
            if (xVar > 1.5707963267948966d && this.cp.getx(0) < -2.356194490192345d) {
                i++;
            }
            if (xVar < -2.356194490192345d && this.cp.getx(0) > 1.5707963267948966d) {
                i--;
            }
            xVar = this.cp.getx(0);
            if (i > 2 || i < -2 || Math.abs(this.cp.getx(1)) > 2.0d * this.omegamax || Math.abs(this.cp.getx(2)) > this.xmax) {
                this.reward_flag = 1;
                for (int i2 = 0; i2 < ((int) (1.0d / this.dt)); i2++) {
                    iteration();
                    this.t += this.dt;
                }
                return d2;
            }
            if (this.cp.getx(0) > -0.7853981633974483d && this.cp.getx(0) < 0.7853981633974483d) {
                d2 += this.dt;
            }
            iteration();
            this.t += this.dt;
        }
        return d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void reset() {
        this.V_crit.initw(0.0d, 0.0d);
        this.mu_act[0].initw(0.0d, 0.0d);
        this.w_sig[0] = 0.0d;
        initialize();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initialize() {
        this.cp.sett(0.0d);
        this.reward_flag = 0;
        this.cp.setx(0, ((this.thetamax - this.thetamin) * Math.random()) + this.thetamin);
        this.cp.setx(1, 0.0d);
        this.cp.setx(2, 0.0d);
        this.cp.setx(3, 0.0d);
        this.V_crit.setX(0, this.cp.getx(0));
        this.V_crit.setX(1, this.cp.getx(1));
        this.V_crit.setX(2, this.cp.getx(2));
        this.V_crit.setX(3, this.cp.getx(3));
        this.Vold = this.V_crit.getVal();
        for (int i = 0; i < this.nV_crit; i++) {
            this.eV_crit[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.d_act; i2++) {
            for (int i3 = 0; i3 < this.nmu_act; i3++) {
                this.emu_act[i2][i3] = 0.0d;
            }
            this.esig_act[i2] = 0.0d;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initialize(double d, double d2) {
        this.cp.sett(0.0d);
        this.reward_flag = 0;
        this.cp.setx(0, d);
        this.cp.setx(1, d2);
        this.cp.setx(2, 0.0d);
        this.cp.setx(3, 0.0d);
        this.V_crit.setX(0, this.cp.getx(0));
        this.V_crit.setX(1, this.cp.getx(1));
        this.V_crit.setX(2, this.cp.getx(2));
        this.V_crit.setX(3, this.cp.getx(3));
        this.Vold = this.V_crit.getVal();
        for (int i = 0; i < this.nV_crit; i++) {
            this.eV_crit[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.d_act; i2++) {
            for (int i3 = 0; i3 < this.nmu_act; i3++) {
                this.emu_act[i2][i3] = 0.0d;
            }
            this.esig_act[i2] = 0.0d;
        }
    }

    double reward(int i) {
        return i == 0 ? Math.cos(this.cp.getx(0)) : i == 1 ? -1.0d : 0.0d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double sigmoid(double d) {
        return (2.0d * Math.atan((3.141592653589793d * d) / 2.0d)) / 3.141592653589793d;
    }

    public double gausRand() {
        if (0 != 0) {
            return 0.0d;
        }
        while (true) {
            double random = (2.0d * Math.random()) - 1.0d;
            double random2 = (2.0d * Math.random()) - 1.0d;
            double d = (random * random) + (random2 * random2);
            if (d < 1.0d && d != 0.0d) {
                double sqrt = Math.sqrt(((-2.0d) * Math.log(d)) / d);
                double d2 = random * sqrt;
                return random2 * sqrt;
            }
        }
    }

    public static void main(String[] strArr) {
        if (strArr.length != 12) {
            System.err.println("java TP_RL Nc1 Nc2 Nc3 Nc4 Na1 Na2 Na3 Na4 dt1 dt2 n TMAX");
            System.exit(1);
        }
        int parseInt = Integer.parseInt(strArr[0]);
        int parseInt2 = Integer.parseInt(strArr[1]);
        int parseInt3 = Integer.parseInt(strArr[2]);
        int parseInt4 = Integer.parseInt(strArr[3]);
        int parseInt5 = Integer.parseInt(strArr[4]);
        int parseInt6 = Integer.parseInt(strArr[5]);
        int parseInt7 = Integer.parseInt(strArr[6]);
        int parseInt8 = Integer.parseInt(strArr[7]);
        double doubleValue = Double.valueOf(strArr[8]).doubleValue();
        double doubleValue2 = Double.valueOf(strArr[9]).doubleValue();
        int parseInt9 = Integer.parseInt(strArr[10]);
        double doubleValue3 = Double.valueOf(strArr[11]).doubleValue();
        CP_RL cp_rl = new CP_RL(0.1d, 1.0d, 0.5d, 2.0E-6d, 5.0E-4d, parseInt, parseInt2, parseInt3, parseInt4, parseInt5, parseInt6, parseInt7, parseInt8, doubleValue, doubleValue2);
        for (int i = 0; i < parseInt9; i++) {
            double one_trial = cp_rl.one_trial(doubleValue3);
            System.out.println(new StringBuffer().append(i).append(" ").append(one_trial).toString());
            System.err.println(new StringBuffer().append(i).append(" ").append(one_trial).toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getx(int i) {
        return this.cp.getx(i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double gett() {
        return this.cp.gett();
    }

    void sett(double d) {
        this.cp.sett(d);
    }

    double getdtmodel() {
        return this.dt_model;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getdt() {
        return this.dt;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setreward_flag(int i) {
        this.reward_flag = i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getomegamax() {
        return this.omegamax;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getxmax() {
        return this.xmax;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Code restructure failed: missing block: B:32:0x00b4, code lost:
    
        return;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void readCoef() {
        /*
            Method dump skipped, instructions count: 294
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: defpackage.CP_RL.readCoef():void");
    }
}
