/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.server.kex;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.math.BigInteger;
import java.net.URL;
import java.security.KeyPair;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.sshd.common.Factory;
import org.apache.sshd.common.FactoryManager;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.kex.DHFactory;
import org.apache.sshd.common.kex.DHG;
import org.apache.sshd.common.kex.DHGroupData;
import org.apache.sshd.common.kex.KexProposalOption;
import org.apache.sshd.common.kex.KeyExchange;
import org.apache.sshd.common.kex.KeyExchangeFactory;
import org.apache.sshd.common.random.Random;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.signature.Signature;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
import org.apache.sshd.common.util.security.SecurityUtils;
import org.apache.sshd.core.CoreModuleProperties;
import org.apache.sshd.server.kex.AbstractDHServerKeyExchange;
import org.apache.sshd.server.kex.Moduli;
import org.apache.sshd.server.session.ServerSession;

public class DHGEXServer
extends AbstractDHServerKeyExchange {
    protected final DHFactory factory;
    protected DHG dh;
    protected int min;
    protected int prf;
    protected int max;
    protected byte expected;
    protected boolean oldRequest;

    protected DHGEXServer(DHFactory factory2, Session session2) {
        super(session2);
        this.factory = Objects.requireNonNull(factory2, "No factory");
    }

    @Override
    public final String getName() {
        return this.factory.getName();
    }

    public static KeyExchangeFactory newFactory(final DHFactory factory2) {
        return new KeyExchangeFactory(){

            @Override
            public KeyExchange createKeyExchange(Session session2) throws Exception {
                return new DHGEXServer(factory2, session2);
            }

            @Override
            public String getName() {
                return factory2.getName();
            }

            public String toString() {
                return NamedFactory.class.getSimpleName() + "<" + KeyExchange.class.getSimpleName() + ">[" + this.getName() + "]";
            }
        };
    }

    @Override
    public void init(byte[] v_s, byte[] v_c, byte[] i_s, byte[] i_c) throws Exception {
        super.init(v_s, v_c, i_s, i_c);
        this.expected = (byte)34;
    }

    @Override
    public boolean next(int cmd, Buffer buffer) throws Exception {
        ServerSession session2 = this.getServerSession();
        boolean debugEnabled = this.log.isDebugEnabled();
        if (debugEnabled) {
            this.log.debug("next({})[{}] process command={} (expected={})", this, session2, KeyExchange.getGroupKexOpcodeName(cmd), KeyExchange.getGroupKexOpcodeName(this.expected));
        }
        if (cmd == 30 && this.expected == 34) {
            this.oldRequest = true;
            this.min = CoreModuleProperties.PROP_DHGEX_SERVER_MIN_KEY.get(session2).orElse(SecurityUtils.getMinDHGroupExchangeKeySize());
            this.prf = buffer.getInt();
            this.max = CoreModuleProperties.PROP_DHGEX_SERVER_MAX_KEY.get(session2).orElse(SecurityUtils.getMaxDHGroupExchangeKeySize());
            if (this.max < this.min || this.prf < this.min || this.max < this.prf) {
                throw new SshException(3, "Protocol error: bad parameters " + this.min + " !< " + this.prf + " !< " + this.max);
            }
            this.dh = this.chooseDH(this.min, this.prf, this.max);
            this.setF(this.dh.getE());
            BigInteger pValue = this.dh.getP();
            this.validateFValue(pValue);
            this.hash = this.dh.getHash();
            this.hash.init();
            if (debugEnabled) {
                this.log.debug("next({})[{}] send (old request) SSH_MSG_KEX_DH_GEX_GROUP - min={}, prf={}, max={}", this, session2, this.min, this.prf, this.max);
            }
            buffer = session2.createBuffer((byte)31);
            buffer.putMPInt(pValue);
            buffer.putMPInt(this.dh.getG());
            session2.writePacket(buffer);
            this.expected = (byte)32;
            return false;
        }
        if (cmd == 34 && this.expected == 34) {
            this.min = buffer.getInt();
            this.prf = buffer.getInt();
            this.max = buffer.getInt();
            if (this.prf < this.min || this.max < this.prf) {
                throw new SshException(3, "Protocol error: bad parameters " + this.min + " !< " + this.prf + " !< " + this.max);
            }
            this.dh = this.chooseDH(this.min, this.prf, this.max);
            this.setF(this.dh.getE());
            BigInteger pValue = this.dh.getP();
            this.validateFValue(pValue);
            this.hash = this.dh.getHash();
            this.hash.init();
            if (debugEnabled) {
                this.log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_GROUP - min={}, prf={}, max={}", this, session2, this.min, this.prf, this.max);
            }
            buffer = session2.createBuffer((byte)31);
            buffer.putMPInt(pValue);
            buffer.putMPInt(this.dh.getG());
            session2.writePacket(buffer);
            this.expected = (byte)32;
            return false;
        }
        if (cmd != this.expected) {
            throw new SshException(3, "Protocol error: expected packet " + KeyExchange.getGroupKexOpcodeName(this.expected) + ", got " + KeyExchange.getGroupKexOpcodeName(cmd));
        }
        if (cmd == 32) {
            byte[] e2 = this.updateE(buffer.getMPIntAsBytes());
            BigInteger pValue = this.dh.getP();
            this.validateEValue(pValue);
            this.dh.setF(e2);
            this.k = this.normalize(this.dh.getK());
            KeyPair kp = Objects.requireNonNull(session2.getHostKey(), "No server key pair available");
            String algo = session2.getNegotiatedKexParameter(KexProposalOption.SERVERKEYS);
            Signature sig = ValidateUtils.checkNotNull(NamedFactory.create(session2.getSignatureFactories(), algo), "Unknown negotiated server keys: %s", (Object)algo);
            sig.initSigner(session2, kp.getPrivate());
            buffer = new ByteArrayBuffer();
            buffer.putRawPublicKey(kp.getPublic());
            byte[] k_s = buffer.getCompactData();
            buffer.clear();
            buffer.putBytes(this.v_c);
            buffer.putBytes(this.v_s);
            buffer.putBytes(this.i_c);
            buffer.putBytes(this.i_s);
            buffer.putBytes(k_s);
            if (this.oldRequest) {
                buffer.putInt(this.prf);
            } else {
                buffer.putInt(this.min);
                buffer.putInt(this.prf);
                buffer.putInt(this.max);
            }
            buffer.putMPInt(pValue);
            buffer.putMPInt(this.dh.getG());
            buffer.putMPInt(e2);
            byte[] f2 = this.getF();
            buffer.putMPInt(f2);
            buffer.putBytes(this.k);
            this.hash.update(buffer.array(), 0, buffer.available());
            this.h = this.hash.digest();
            sig.update(session2, this.h);
            buffer.clear();
            buffer.putString(algo);
            byte[] sigBytes = sig.sign(session2);
            buffer.putBytes(sigBytes);
            byte[] sigH = buffer.getCompactData();
            if (this.log.isTraceEnabled()) {
                this.log.trace("next({})[{}][K_S]:  {}", this, session2, BufferUtils.toHex(k_s));
                this.log.trace("next({})[{}][f]:    {}", this, session2, BufferUtils.toHex(f2));
                this.log.trace("next({})[{}][sigH]: {}", this, session2, BufferUtils.toHex(sigH));
            }
            if (debugEnabled) {
                this.log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_REPLY - old={}, min={}, prf={}, max={}", this, session2, this.oldRequest, this.min, this.prf, this.max);
            }
            buffer = session2.prepareBuffer((byte)33, BufferUtils.clear(buffer));
            buffer.putBytes(k_s);
            buffer.putBytes(f2);
            buffer.putBytes(sigH);
            session2.writePacket(buffer);
            return true;
        }
        return false;
    }

    protected DHG chooseDH(int min2, int prf, int max2) throws Exception {
        List<Moduli.DhGroup> groups2;
        ServerSession session2 = this.getServerSession();
        List<Moduli.DhGroup> selected = this.selectModuliGroups(session2, min2, prf, max2, groups2 = this.loadModuliGroups(session2));
        if (GenericUtils.isEmpty(selected)) {
            if (!CoreModuleProperties.ALLOW_DHG1_KEX_FALLBACK.getRequired(session2).booleanValue()) {
                this.log.error("chooseDH({})[{}][prf={}, min={}, max={}] No suitable primes found - failing", this, session2, prf, min2, max2);
                throw new SshException(3, "No suitable primes found for DH group exchange");
            }
            this.log.warn("chooseDH({})[{}][prf={}, min={}, max={}] No suitable primes found - defaulting to DHG1", this, session2, prf, min2, max2);
            return this.getDH(new BigInteger(DHGroupData.getP1()), new BigInteger(DHGroupData.getG()));
        }
        FactoryManager manager = Objects.requireNonNull(session2.getFactoryManager(), "No factory manager");
        Factory<? extends Random> factory2 = Objects.requireNonNull(manager.getRandomFactory(), "No random factory");
        Random random2 = Objects.requireNonNull(factory2.create(), "No random generator");
        int which = random2.random(selected.size());
        Moduli.DhGroup group = selected.get(which);
        if (this.log.isTraceEnabled()) {
            this.log.trace("chooseDH({})[{}][prf={}, min={}, max={}] selected {}", this, session2, prf, min2, max2, group);
        }
        return this.getDH(group.getP(), group.getG());
    }

    protected List<Moduli.DhGroup> selectModuliGroups(ServerSession session2, int min2, int prf, int max2, List<Moduli.DhGroup> groups2) throws Exception {
        int maxDHGroupExchangeKeySize = SecurityUtils.getMaxDHGroupExchangeKeySize();
        int minDHGroupExchangeKeySize = SecurityUtils.getMinDHGroupExchangeKeySize();
        min2 = Math.max(min2, minDHGroupExchangeKeySize);
        prf = Math.max(prf, minDHGroupExchangeKeySize);
        prf = Math.min(prf, maxDHGroupExchangeKeySize);
        max2 = Math.min(max2, maxDHGroupExchangeKeySize);
        ArrayList<Moduli.DhGroup> selected = new ArrayList<Moduli.DhGroup>();
        int bestSize = 0;
        boolean traceEnabled = this.log.isTraceEnabled();
        for (Moduli.DhGroup group : groups2) {
            int size2 = group.getSize();
            if (size2 < min2 || size2 > max2) {
                if (!traceEnabled) continue;
                this.log.trace("selectModuliGroups({})[{}] - skip group={} - size not in range [{}-{}]", this, session2, group, min2, max2);
                continue;
            }
            if (size2 > prf && size2 < bestSize || size2 > bestSize && bestSize < prf) {
                bestSize = size2;
                if (traceEnabled) {
                    this.log.trace("selectModuliGroups({})[{}][prf={}, min={}, max={}] new best size={} from group={}", this, session2, prf, min2, max2, bestSize, group);
                }
                selected.clear();
            }
            if (size2 != bestSize) continue;
            if (traceEnabled) {
                this.log.trace("selectModuliGroups({})[{}][prf={}, min={}, max={}] selected {}", this, session2, prf, min2, max2, group);
            }
            selected.add(group);
        }
        return selected;
    }

    protected List<Moduli.DhGroup> loadModuliGroups(ServerSession session2) throws IOException {
        URL moduli;
        String moduliStr = CoreModuleProperties.MODULI_URL.getOrNull(session2);
        List<Moduli.DhGroup> groups2 = null;
        if (!GenericUtils.isEmpty(moduliStr)) {
            try {
                moduli = new URL(moduliStr);
                groups2 = Moduli.parseModuli(moduli);
            }
            catch (IOException e2) {
                this.log.warn("loadModuliGroups({})[{}] Error ({}) loading external moduli from {}: {}", this, session2, e2.getClass().getSimpleName(), moduliStr, e2.getMessage());
            }
        }
        if (groups2 == null) {
            moduliStr = "/org/apache/sshd/moduli";
            try {
                moduli = this.getClass().getResource(moduliStr);
                if (moduli == null) {
                    throw new FileNotFoundException("Missing internal moduli file");
                }
                moduliStr = moduli.toExternalForm();
                groups2 = Moduli.loadInternalModuli(moduli);
            }
            catch (IOException e3) {
                this.log.warn("loadModuliGroups({})[{}] Error ({}) loading internal moduli from {}: {}", this, session2, e3.getClass().getSimpleName(), moduliStr, e3.getMessage());
                throw e3;
            }
        }
        if (this.log.isDebugEnabled()) {
            this.log.debug("loadModuliGroups({})[{}] Loaded {} moduli groups from {}", this, session2, GenericUtils.size(groups2), moduliStr);
        }
        return groups2;
    }

    protected DHG getDH(BigInteger p, BigInteger g2) throws Exception {
        return (DHG)this.factory.create(p, g2);
    }
}

