/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.client.kex;

import java.math.BigInteger;
import java.security.PublicKey;
import java.util.Objects;
import org.apache.sshd.client.kex.AbstractDHClientKeyExchange;
import org.apache.sshd.client.session.AbstractClientSession;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.kex.AbstractDH;
import org.apache.sshd.common.kex.DHFactory;
import org.apache.sshd.common.kex.KexProposalOption;
import org.apache.sshd.common.kex.KeyExchange;
import org.apache.sshd.common.kex.KeyExchangeFactory;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.signature.Signature;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
import org.apache.sshd.common.util.security.SecurityUtils;
import org.apache.sshd.core.CoreModuleProperties;

public class DHGEXClient
extends AbstractDHClientKeyExchange {
    protected final DHFactory factory;
    protected byte expected;
    protected int min;
    protected int prf;
    protected int max;
    protected AbstractDH dh;
    protected byte[] g;
    private byte[] p;
    private BigInteger pValue;

    protected DHGEXClient(DHFactory factory2, Session session2) {
        super(session2);
        this.factory = Objects.requireNonNull(factory2, "No factory");
        this.min = CoreModuleProperties.PROP_DHGEX_CLIENT_MIN_KEY.get(session2).orElse(SecurityUtils.getMinDHGroupExchangeKeySize());
        this.max = CoreModuleProperties.PROP_DHGEX_CLIENT_MAX_KEY.get(session2).orElse(SecurityUtils.getMaxDHGroupExchangeKeySize());
        this.prf = CoreModuleProperties.PROP_DHGEX_CLIENT_PRF_KEY.get(session2).orElse(Math.min(4096, this.max));
    }

    @Override
    public final String getName() {
        return this.factory.getName();
    }

    protected byte[] getP() {
        return this.p;
    }

    protected BigInteger getPValue() {
        if (this.pValue == null) {
            this.pValue = BufferUtils.fromMPIntBytes(this.getP());
        }
        return this.pValue;
    }

    protected void setP(byte[] p) {
        this.p = p;
        if (this.pValue != null) {
            this.pValue = null;
        }
    }

    protected void validateEValue() throws Exception {
        this.validateEValue(this.getPValue());
    }

    protected void validateFValue() throws Exception {
        this.validateFValue(this.getPValue());
    }

    public static KeyExchangeFactory newFactory(final DHFactory delegate) {
        return new KeyExchangeFactory(){

            @Override
            public String getName() {
                return delegate.getName();
            }

            @Override
            public KeyExchange createKeyExchange(Session session2) throws Exception {
                return new DHGEXClient(delegate, session2);
            }

            public String toString() {
                return NamedFactory.class.getSimpleName() + "<" + KeyExchange.class.getSimpleName() + ">[" + this.getName() + "]";
            }
        };
    }

    @Override
    public void init(byte[] v_s, byte[] v_c, byte[] i_s, byte[] i_c) throws Exception {
        super.init(v_s, v_c, i_s, i_c);
        Session s2 = this.getSession();
        if (this.log.isDebugEnabled()) {
            this.log.debug("init({})[{}] Send SSH_MSG_KEX_DH_GEX_REQUEST - min={}, prf={}, max={}", this, s2, this.min, this.prf, this.max);
        }
        if (this.max < this.min || this.prf < this.min || this.max < this.prf) {
            throw new SshException(3, "Protocol error: bad parameters " + this.min + " !< " + this.prf + " !< " + this.max);
        }
        Buffer buffer = s2.createBuffer((byte)34, 32);
        buffer.putInt(this.min);
        buffer.putInt(this.prf);
        buffer.putInt(this.max);
        s2.writePacket(buffer);
        this.expected = (byte)31;
    }

    @Override
    public boolean next(int cmd, Buffer buffer) throws Exception {
        AbstractClientSession session2 = this.getClientSession();
        boolean debugEnabled = this.log.isDebugEnabled();
        if (debugEnabled) {
            this.log.debug("next({})[{}] process command={} (expected={})", this, session2, KeyExchange.getGroupKexOpcodeName(cmd), KeyExchange.getGroupKexOpcodeName(this.expected));
        }
        if (cmd != this.expected) {
            throw new SshException(3, "Protocol error: expected packet " + KeyExchange.getGroupKexOpcodeName(this.expected) + ", got " + KeyExchange.getGroupKexOpcodeName(cmd));
        }
        if (cmd == 31) {
            this.setP(buffer.getMPIntAsBytes());
            this.g = buffer.getMPIntAsBytes();
            this.dh = this.getDH(this.getPValue(), new BigInteger(this.g));
            this.hash = this.dh.getHash();
            this.hash.init();
            byte[] e2 = this.updateE(this.dh.getE());
            this.validateEValue();
            if (debugEnabled) {
                this.log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_INIT", (Object)this, (Object)session2);
            }
            buffer = session2.createBuffer((byte)32, e2.length + 8);
            buffer.putMPInt(e2);
            session2.writePacket(buffer);
            this.expected = (byte)33;
            return false;
        }
        if (cmd == 33) {
            if (debugEnabled) {
                this.log.debug("next({})[{}] validate SSH_MSG_KEX_DH_GEX_REPLY - min={}, prf={}, max={}", this, session2, this.min, this.prf, this.max);
            }
            byte[] k_s = buffer.getBytes();
            byte[] f2 = this.updateF(buffer);
            byte[] sig = buffer.getBytes();
            this.validateFValue();
            this.dh.setF(f2);
            this.k = this.normalize(this.dh.getK());
            buffer = new ByteArrayBuffer(k_s);
            PublicKey serverKey = buffer.getRawPublicKey();
            String keyAlg = session2.getNegotiatedKexParameter(KexProposalOption.SERVERKEYS);
            if (GenericUtils.isEmpty(keyAlg)) {
                throw new SshException("Unsupported server key type: " + serverKey.getAlgorithm() + " [" + serverKey.getFormat() + "]");
            }
            buffer = new ByteArrayBuffer();
            buffer.putBytes(this.v_c);
            buffer.putBytes(this.v_s);
            buffer.putBytes(this.i_c);
            buffer.putBytes(this.i_s);
            buffer.putBytes(k_s);
            buffer.putInt(this.min);
            buffer.putInt(this.prf);
            buffer.putInt(this.max);
            buffer.putMPInt(this.getP());
            buffer.putMPInt(this.g);
            buffer.putMPInt(this.getE());
            buffer.putMPInt(f2);
            buffer.putBytes(this.k);
            this.hash.update(buffer.array(), 0, buffer.available());
            this.h = this.hash.digest();
            Signature verif = ValidateUtils.checkNotNull(NamedFactory.create(session2.getSignatureFactories(), keyAlg), "No verifier located for algorithm=%s", (Object)keyAlg);
            verif.initVerifier(session2, serverKey);
            verif.update(session2, this.h);
            if (!verif.verify(session2, sig)) {
                throw new SshException(3, "KeyExchange signature verification failed for key type=" + keyAlg);
            }
            session2.setServerKey(serverKey);
            return true;
        }
        throw new IllegalStateException("Unknown command value: " + KeyExchange.getGroupKexOpcodeName(cmd));
    }

    protected AbstractDH getDH(BigInteger p, BigInteger g2) throws Exception {
        return this.factory.create(p, g2);
    }
}

