package io.gravitee.policy.jws;

import io.gravitee.gateway.api.ExecutionContext;
import io.gravitee.gateway.api.Request;
import io.gravitee.gateway.api.buffer.Buffer;
import io.gravitee.gateway.api.http.stream.TransformableRequestStreamBuilder;
import io.gravitee.gateway.api.stream.ReadWriteStream;
import io.gravitee.gateway.api.stream.exception.TransformationException;
import io.gravitee.policy.api.PolicyChain;
import io.gravitee.policy.api.PolicyResult;
import io.gravitee.policy.api.annotations.OnRequestContent;
import io.gravitee.policy.jws.configuration.JWSPolicyConfiguration;
import io.gravitee.policy.jws.utils.JsonUtils;
import io.gravitee.policy.jws.utils.JwsHeader;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.SignatureException;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.SigningKeyResolverAdapter;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.impl.DefaultClaims;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.Key;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509CRL;
import java.security.cert.X509CRLEntry;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.RSAPublicKeySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.xml.bind.DatatypeConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
import sun.security.x509.CRLDistributionPointsExtension;
import sun.security.x509.DistributionPoint;
import sun.security.x509.GeneralNames;
import sun.security.x509.X509CertImpl;

/* loaded from: input_file:io/gravitee/policy/jws/JWSPolicy.class */
public class JWSPolicy {
    private static final String DEFAULT_KID = "default";
    private static final String PUBLIC_KEY_PROPERTY = "policy.jws.kid.%s";
    private static final String PEM_EXTENSION = ".pem";
    private static final String BEGIN_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----";
    private static final String END_PRIVATE_KEY = "-----END PRIVATE KEY-----";
    private static final String BEGIN_RSA_PRIVATE_KEY = "-----BEGIN RSA PRIVATE KEY-----";
    private static final String END_RSA_PRIVATE_KEY = "-----END RSA PRIVATE KEY-----";
    private static final String JSON_CTY = "json";
    private static final String APPLICATION_PREFIX = "application/";
    private JWSPolicyConfiguration jwsPolicyConfiguration;
    private static final Logger LOGGER = LoggerFactory.getLogger(JWSPolicy.class);
    private static final Pattern SSH_PUB_KEY = Pattern.compile("ssh-(rsa|dsa) ([A-Za-z0-9/+]+=*) (.*)");
    private static final String JSON_TYP = "JSON";
    private static final String JOSE_JSON_TYP = "JOSE+JSON";
    private static final String[] AUTHORIZED_TYPES = {JSON_TYP, JOSE_JSON_TYP};

    public JWSPolicy(JWSPolicyConfiguration jWSPolicyConfiguration) {
        this.jwsPolicyConfiguration = jWSPolicyConfiguration;
    }

    @OnRequestContent
    public ReadWriteStream onRequestContent(Request request, ExecutionContext executionContext, PolicyChain policyChain) {
        return TransformableRequestStreamBuilder.on(request).chain(policyChain).contentType("application/json").transform(map(executionContext, policyChain)).build();
    }

    Function<Buffer, Buffer> map(ExecutionContext executionContext, PolicyChain policyChain) {
        return buffer -> {
            try {
                return Buffer.buffer(JsonUtils.writeValueAsString(validateJsonWebToken(buffer.toString(), executionContext)));
            } catch (UnsupportedJwtException | ExpiredJwtException | MalformedJwtException | SignatureException | IllegalArgumentException | CertificateException e) {
                LOGGER.error("Failed to decoding JWS token", e);
                policyChain.streamFailWith(PolicyResult.failure(401, "Unauthorized"));
                return null;
            } catch (Exception e2) {
                LOGGER.error("Error occurs while decoding JWS token", e2);
                throw new TransformationException("Unable to apply JWS decode: " + e2.getMessage(), e2);
            }
        };
    }

    private DefaultClaims validateJsonWebToken(String str, ExecutionContext executionContext) throws CertificateException {
        JwtParser parser = Jwts.parser();
        SigningKeyResolver signingKeyResolverByGatewaySettings = getSigningKeyResolverByGatewaySettings(executionContext);
        parser.setSigningKeyResolver(signingKeyResolverByGatewaySettings);
        Jws parseClaimsJws = parser.parseClaimsJws(str);
        String str2 = (String) parseClaimsJws.getHeader().get(JwsHeader.TYPE);
        if (str2 != null && !str2.isEmpty() && !Arrays.asList(AUTHORIZED_TYPES).contains(str2.toUpperCase())) {
            throw new MalformedJwtException("Only " + AUTHORIZED_TYPES + " JWS typ header are authorized but was " + str2);
        }
        String str3 = (String) parseClaimsJws.getHeader().get(JwsHeader.CONTENT_TYPE);
        if (str3 != null && !str3.isEmpty()) {
            String replaceAll = str3.toLowerCase().replaceAll(APPLICATION_PREFIX, "");
            if (!JSON_CTY.equals(replaceAll)) {
                throw new MalformedJwtException("Only json JWS cty header is authorized but was " + replaceAll);
            }
        }
        List list = (List) parseClaimsJws.getHeader().get(JwsHeader.X509_CERT_CHAIN);
        String[] strArr = (String[]) list.toArray(new String[list.size()]);
        if (strArr == null || strArr.length == 0) {
            throw new MalformedJwtException("X5C JWS Header is missing");
        }
        X509Certificate extractCertificateFromX5CHeader = extractCertificateFromX5CHeader(strArr);
        RSAPublicKey rSAPublicKey = (RSAPublicKey) signingKeyResolverByGatewaySettings.resolveSigningKey(parseClaimsJws.getHeader(), (Claims) parseClaimsJws.getBody());
        RSAPublicKey rSAPublicKey2 = (RSAPublicKey) extractCertificateFromX5CHeader.getPublicKey();
        if (rSAPublicKey2.getPublicExponent().compareTo(rSAPublicKey.getPublicExponent()) != 0) {
            throw new SignatureException("Certificate public key exponent is different compare to the given public key exponent");
        }
        if (rSAPublicKey2.getModulus().compareTo(rSAPublicKey.getModulus()) != 0) {
            throw new SignatureException("Certificate public key modulus is different compare to the given public key modulus");
        }
        if (this.jwsPolicyConfiguration.isCheckCertificateValidity()) {
            extractCertificateFromX5CHeader.checkValidity();
        }
        if (this.jwsPolicyConfiguration.isCheckCertificateRevocation()) {
            validateCRLSFromCertificate(extractCertificateFromX5CHeader);
        }
        return (DefaultClaims) parseClaimsJws.getBody();
    }

    private SigningKeyResolver getSigningKeyResolverByGatewaySettings(final ExecutionContext executionContext) {
        return new SigningKeyResolverAdapter() { // from class: io.gravitee.policy.jws.JWSPolicy.1
            public Key resolveSigningKey(io.jsonwebtoken.JwsHeader jwsHeader, Claims claims) {
                String keyId = jwsHeader.getKeyId();
                if (keyId == null || keyId.isEmpty()) {
                    keyId = JWSPolicy.DEFAULT_KID;
                }
                String property = ((Environment) executionContext.getComponent(Environment.class)).getProperty(String.format(JWSPolicy.PUBLIC_KEY_PROPERTY, keyId));
                if (property == null || property.trim().isEmpty()) {
                    return null;
                }
                if (JWSPolicy.SSH_PUB_KEY.matcher(property).matches()) {
                    return JWSPolicy.parsePublicKey(property);
                }
                if (!property.endsWith(JWSPolicy.PEM_EXTENSION)) {
                    return null;
                }
                try {
                    return JWSPolicy.extractPublicKeyFromPEMFile(property);
                } catch (Exception e) {
                    JWSPolicy.LOGGER.error("Failed to load PEM file", e);
                    return null;
                }
            }
        };
    }

    public void validateCRLSFromCertificate(X509Certificate x509Certificate, BigInteger bigInteger) throws CertificateException {
        X509CRLEntry x509CRLEntry = null;
        CRLDistributionPointsExtension cRLDistributionPointsExtension = ((X509CertImpl) x509Certificate).getCRLDistributionPointsExtension();
        if (cRLDistributionPointsExtension == null) {
            throw new CertificateException("Failed to find CRL distribution points for the given certificate");
        }
        try {
            Iterator it = ((ArrayList) cRLDistributionPointsExtension.get("points")).iterator();
            boolean z = false;
            while (it.hasNext() && x509CRLEntry == null) {
                GeneralNames fullName = ((DistributionPoint) it.next()).getFullName();
                for (int i = 0; i < fullName.size(); i++) {
                    z = false;
                    if (x509CRLEntry != null) {
                        break;
                    }
                    DataInputStream dataInputStream = null;
                    try {
                        try {
                            dataInputStream = new DataInputStream(new URL(fullName.get(i).getName().getURI().toString()).openConnection().getInputStream());
                            x509CRLEntry = ((X509CRL) certificateFactory().generateCRL(dataInputStream)).getRevokedCertificate(bigInteger != null ? bigInteger : x509Certificate.getSerialNumber());
                            if (dataInputStream != null) {
                                dataInputStream.close();
                            }
                        } catch (Throwable th) {
                            if (dataInputStream != null) {
                                dataInputStream.close();
                            }
                            throw th;
                        }
                    } catch (Exception e) {
                        z = true;
                        LOGGER.warn("Failed to get the certificate revocation list, try the next one if any", e);
                        if (dataInputStream != null) {
                            dataInputStream.close();
                        }
                    }
                }
                if (z && !it.hasNext()) {
                    throw new CertificateException("An error has occurred while checking if certificate was revoked");
                }
            }
            if (x509CRLEntry != null) {
                throw new CertificateException("Certificate has been revoked");
            }
        } catch (IOException e2) {
            throw new CertificateException("Failed to get CRL distribution points");
        }
    }

    private void validateCRLSFromCertificate(X509Certificate x509Certificate) throws CertificateException {
        validateCRLSFromCertificate(x509Certificate, null);
    }

    static RSAPublicKey parsePublicKey(String str) {
        Matcher matcher = SSH_PUB_KEY.matcher(str);
        if (!matcher.matches()) {
            return null;
        }
        String group = matcher.group(1);
        String group2 = matcher.group(2);
        if ("rsa".equalsIgnoreCase(group)) {
            return parseSSHPublicKey(group2);
        }
        throw new IllegalArgumentException("Only RSA is currently supported, but algorithm was " + group);
    }

    private static RSAPublicKey parseSSHPublicKey(String str) {
        byte[] bArr = {0, 0, 0, 7, 115, 115, 104, 45, 114, 115, 97};
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(Base64.getDecoder().decode(StandardCharsets.UTF_8.encode(str)).array());
        byte[] bArr2 = new byte[11];
        try {
            if (byteArrayInputStream.read(bArr2) == 11 && Arrays.equals(bArr, bArr2)) {
                return createPublicKey(new BigInteger(readBigInteger(byteArrayInputStream)), new BigInteger(readBigInteger(byteArrayInputStream)));
            }
            throw new IllegalArgumentException("SSH key prefix not found");
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static RSAPublicKey createPublicKey(BigInteger bigInteger, BigInteger bigInteger2) {
        try {
            return (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(bigInteger, bigInteger2));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    static X509Certificate extractCertificateFromX5CHeader(String[] strArr) throws CertificateException {
        return extractCertificate(new ByteArrayInputStream(DatatypeConverter.parseBase64Binary(strArr[0])));
    }

    static PublicKey extractPublicKeyFromPEMFile(String str) throws IOException, CertificateException {
        return extractCertificate(new ByteArrayInputStream(new String(Files.readAllBytes(Paths.get(str, new String[0])), Charset.forName(StandardCharsets.UTF_8.name())).replaceAll(BEGIN_PRIVATE_KEY, "").replaceAll(END_PRIVATE_KEY, "").replaceAll(BEGIN_RSA_PRIVATE_KEY, "").replaceAll(END_RSA_PRIVATE_KEY, "").getBytes(StandardCharsets.UTF_8))).getPublicKey();
    }

    static X509Certificate extractCertificate(InputStream inputStream) throws CertificateException {
        return (X509Certificate) certificateFactory().generateCertificate(inputStream);
    }

    static CertificateFactory certificateFactory() throws CertificateException {
        return CertificateFactory.getInstance("X.509");
    }

    private static byte[] readBigInteger(ByteArrayInputStream byteArrayInputStream) throws IOException {
        byte[] bArr = new byte[4];
        if (byteArrayInputStream.read(bArr) != 4) {
            throw new IOException("Expected length data as 4 bytes");
        }
        int i = (bArr[0] << 24) | (bArr[1] << 16) | (bArr[2] << 8) | bArr[3];
        byte[] bArr2 = new byte[i];
        if (byteArrayInputStream.read(bArr2) != i) {
            throw new IOException("Expected " + i + " key bytes");
        }
        return bArr2;
    }
}
