package org.springframework.ai.vectorstore.pgvector;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.pgvector.PGvector;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.util.JacksonUtils;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/vectorstore/pgvector/PgVectorStore.class */
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final String DEFAULT_TABLE_NAME = "vector_store";
    public static final String DEFAULT_VECTOR_INDEX_NAME = "spring_ai_vector_index";
    public static final String DEFAULT_SCHEMA_NAME = "public";
    public static final boolean DEFAULT_SCHEMA_VALIDATION = false;
    public static final int MAX_DOCUMENT_BATCH_SIZE = 10000;
    public final FilterExpressionConverter filterExpressionConverter;
    private final String vectorTableName;
    private final String vectorIndexName;
    private final JdbcTemplate jdbcTemplate;
    private final String schemaName;
    private final PgIdType idType;
    private final boolean schemaValidation;
    private final boolean initializeSchema;
    private final int dimensions;
    private final PgDistanceType distanceType;
    private final ObjectMapper objectMapper;
    private final boolean removeExistingVectorStoreTable;
    private final PgIndexType createIndexMethod;
    private final PgVectorSchemaValidator schemaValidator;
    private final int maxDocumentBatchSize;
    public static final PgIdType DEFAULT_ID_TYPE = PgIdType.UUID;
    private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class);
    private static Map<PgDistanceType, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(PgDistanceType.COSINE_DISTANCE, VectorStoreSimilarityMetric.COSINE, PgDistanceType.EUCLIDEAN_DISTANCE, VectorStoreSimilarityMetric.EUCLIDEAN, PgDistanceType.NEGATIVE_INNER_PRODUCT, VectorStoreSimilarityMetric.DOT);

    /* loaded from: input_file:org/springframework/ai/vectorstore/pgvector/PgVectorStore$DocumentRowMapper.class */
    private static class DocumentRowMapper implements RowMapper<Document> {
        private static final String COLUMN_EMBEDDING = "embedding";
        private static final String COLUMN_METADATA = "metadata";
        private static final String COLUMN_ID = "id";
        private static final String COLUMN_CONTENT = "content";
        private static final String COLUMN_DISTANCE = "distance";
        private final ObjectMapper objectMapper;

        DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        /* renamed from: mapRow, reason: merged with bridge method [inline-methods] */
        public Document m5mapRow(ResultSet resultSet, int i) throws SQLException {
            String string = resultSet.getString(COLUMN_ID);
            String string2 = resultSet.getString(COLUMN_CONTENT);
            PGobject pGobject = (PGobject) resultSet.getObject(COLUMN_METADATA, PGobject.class);
            Float valueOf = Float.valueOf(resultSet.getFloat(COLUMN_DISTANCE));
            Map<String, Object> map = toMap(pGobject);
            map.put(DocumentMetadata.DISTANCE.value(), valueOf);
            return Document.builder().id(string).text(string2).metadata(map).score(Double.valueOf(1.0d - valueOf.floatValue())).build();
        }

        private Map<String, Object> toMap(PGobject pGobject) {
            try {
                return (Map) this.objectMapper.readValue(pGobject.getValue(), Map.class);
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/pgvector/PgVectorStore$PgDistanceType.class */
    public enum PgDistanceType {
        EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
        NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
        COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");

        public final String operator;
        public final String index;
        public final String similaritySearchSqlTemplate;

        PgDistanceType(String str, String str2, String str3) {
            this.operator = str;
            this.index = str2;
            this.similaritySearchSqlTemplate = str3;
        }
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/pgvector/PgVectorStore$PgIdType.class */
    public enum PgIdType {
        UUID,
        TEXT,
        INTEGER,
        SERIAL,
        BIGSERIAL
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/pgvector/PgVectorStore$PgIndexType.class */
    public enum PgIndexType {
        NONE,
        IVFFLAT,
        HNSW
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/pgvector/PgVectorStore$PgVectorStoreBuilder.class */
    public static final class PgVectorStoreBuilder extends AbstractVectorStoreBuilder<PgVectorStoreBuilder> {
        private final JdbcTemplate jdbcTemplate;
        private String schemaName;
        private String vectorTableName;
        private PgIdType idType;
        private boolean vectorTableValidationsEnabled;
        private int dimensions;
        private PgDistanceType distanceType;
        private boolean removeExistingVectorStoreTable;
        private PgIndexType indexType;
        private boolean initializeSchema;
        private int maxDocumentBatchSize;

        private PgVectorStoreBuilder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            this.schemaName = PgVectorStore.DEFAULT_SCHEMA_NAME;
            this.vectorTableName = PgVectorStore.DEFAULT_TABLE_NAME;
            this.idType = PgVectorStore.DEFAULT_ID_TYPE;
            this.vectorTableValidationsEnabled = false;
            this.dimensions = -1;
            this.distanceType = PgDistanceType.COSINE_DISTANCE;
            this.removeExistingVectorStoreTable = false;
            this.indexType = PgIndexType.HNSW;
            this.maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE;
            Assert.notNull(jdbcTemplate, "JdbcTemplate must not be null");
            this.jdbcTemplate = jdbcTemplate;
        }

        public PgVectorStoreBuilder schemaName(String str) {
            this.schemaName = str;
            return this;
        }

        public PgVectorStoreBuilder vectorTableName(String str) {
            this.vectorTableName = str;
            return this;
        }

        public PgVectorStoreBuilder idType(PgIdType pgIdType) {
            this.idType = pgIdType;
            return this;
        }

        public PgVectorStoreBuilder vectorTableValidationsEnabled(boolean z) {
            this.vectorTableValidationsEnabled = z;
            return this;
        }

        public PgVectorStoreBuilder dimensions(int i) {
            this.dimensions = i;
            return this;
        }

        public PgVectorStoreBuilder distanceType(PgDistanceType pgDistanceType) {
            this.distanceType = pgDistanceType;
            return this;
        }

        public PgVectorStoreBuilder removeExistingVectorStoreTable(boolean z) {
            this.removeExistingVectorStoreTable = z;
            return this;
        }

        public PgVectorStoreBuilder indexType(PgIndexType pgIndexType) {
            this.indexType = pgIndexType;
            return this;
        }

        public PgVectorStoreBuilder initializeSchema(boolean z) {
            this.initializeSchema = z;
            return this;
        }

        public PgVectorStoreBuilder maxDocumentBatchSize(int i) {
            this.maxDocumentBatchSize = i;
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public PgVectorStore m9build() {
            return new PgVectorStore(this);
        }
    }

    protected PgVectorStore(PgVectorStoreBuilder pgVectorStoreBuilder) {
        super(pgVectorStoreBuilder);
        this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
        Assert.notNull(pgVectorStoreBuilder.jdbcTemplate, "JdbcTemplate must not be null");
        this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
        String str = pgVectorStoreBuilder.vectorTableName;
        this.vectorTableName = str.isEmpty() ? DEFAULT_TABLE_NAME : str.trim();
        logger.info("Using the vector table name: {}. Is empty: {}", this.vectorTableName, Boolean.valueOf(this.vectorTableName.isEmpty()));
        this.vectorIndexName = this.vectorTableName.equals(DEFAULT_TABLE_NAME) ? DEFAULT_VECTOR_INDEX_NAME : this.vectorTableName + "_index";
        this.schemaName = pgVectorStoreBuilder.schemaName;
        this.idType = pgVectorStoreBuilder.idType;
        this.schemaValidation = pgVectorStoreBuilder.vectorTableValidationsEnabled;
        this.jdbcTemplate = pgVectorStoreBuilder.jdbcTemplate;
        this.dimensions = pgVectorStoreBuilder.dimensions;
        this.distanceType = pgVectorStoreBuilder.distanceType;
        this.removeExistingVectorStoreTable = pgVectorStoreBuilder.removeExistingVectorStoreTable;
        this.createIndexMethod = pgVectorStoreBuilder.indexType;
        this.initializeSchema = pgVectorStoreBuilder.initializeSchema;
        this.schemaValidator = new PgVectorSchemaValidator(this.jdbcTemplate);
        this.maxDocumentBatchSize = pgVectorStoreBuilder.maxDocumentBatchSize;
    }

    public PgDistanceType getDistanceType() {
        return this.distanceType;
    }

    public static PgVectorStoreBuilder builder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        return new PgVectorStoreBuilder(jdbcTemplate, embeddingModel);
    }

    public void doAdd(List<Document> list) {
        List embed = this.embeddingModel.embed(list, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        batchDocuments(list).forEach(list2 -> {
            insertOrUpdateBatch(list2, list, embed);
        });
    }

    private List<List<Document>> batchDocuments(List<Document> list) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= list.size()) {
                return arrayList;
            }
            arrayList.add(list.subList(i2, Math.min(i2 + this.maxDocumentBatchSize, list.size())));
            i = i2 + this.maxDocumentBatchSize;
        }
    }

    private void insertOrUpdateBatch(final List<Document> list, final List<Document> list2, final List<float[]> list3) {
        this.jdbcTemplate.batchUpdate("INSERT INTO " + getFullyQualifiedTableName() + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() { // from class: org.springframework.ai.vectorstore.pgvector.PgVectorStore.1
            public void setValues(PreparedStatement preparedStatement, int i) throws SQLException {
                Document document = (Document) list.get(i);
                Object convertIdToPgType = PgVectorStore.this.convertIdToPgType(document.getId());
                String text = document.getText();
                String json = PgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector((float[]) list3.get(list2.indexOf(document)));
                StatementCreatorUtils.setParameterValue(preparedStatement, 1, Integer.MIN_VALUE, convertIdToPgType);
                StatementCreatorUtils.setParameterValue(preparedStatement, 2, Integer.MIN_VALUE, text);
                StatementCreatorUtils.setParameterValue(preparedStatement, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(preparedStatement, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(preparedStatement, 5, Integer.MIN_VALUE, text);
                StatementCreatorUtils.setParameterValue(preparedStatement, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(preparedStatement, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return list.size();
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        } catch (JsonProcessingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    private Object convertIdToPgType(String str) {
        switch (getIdType()) {
            case UUID:
                return UUID.fromString(str);
            case TEXT:
                return str;
            case INTEGER:
            case SERIAL:
                return Integer.valueOf(str);
            case BIGSERIAL:
                return Long.valueOf(str);
            default:
                throw new IncompatibleClassChangeError();
        }
    }

    public void doDelete(List<String> list) {
        int i = 0;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            i += this.jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?", new Object[]{UUID.fromString(it.next())});
        }
    }

    protected void doDelete(Filter.Expression expression) {
        try {
            this.jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE metadata::jsonb @@ '" + this.filterExpressionConverter.convertExpression(expression) + "'::jsonpath");
        } catch (Exception e) {
            throw new IllegalStateException("Failed to delete documents by filter", e);
        }
    }

    public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
        String convertExpression = searchRequest.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(searchRequest.getFilterExpression()) : "";
        String str = StringUtils.hasText(convertExpression) ? " AND metadata::jsonb @@ '" + convertExpression + "'::jsonpath " : "";
        double similarityThreshold = 1.0d - searchRequest.getSimilarityThreshold();
        PGvector queryEmbedding = getQueryEmbedding(searchRequest.getQuery());
        return this.jdbcTemplate.query(String.format(getDistanceType().similaritySearchSqlTemplate, getFullyQualifiedTableName(), str), new DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, Double.valueOf(similarityThreshold), Integer.valueOf(searchRequest.getTopK())});
    }

    public List<Double> embeddingDistance(String str) {
        return this.jdbcTemplate.query("SELECT embedding " + comparisonOperator() + " ? AS distance FROM " + getFullyQualifiedTableName(), new RowMapper<Double>() { // from class: org.springframework.ai.vectorstore.pgvector.PgVectorStore.2
            /* renamed from: mapRow, reason: merged with bridge method [inline-methods] */
            public Double m3mapRow(ResultSet resultSet, int i) throws SQLException {
                return Double.valueOf(resultSet.getDouble("distance"));
            }
        }, new Object[]{getQueryEmbedding(str)});
    }

    private PGvector getQueryEmbedding(String str) {
        return new PGvector(this.embeddingModel.embed(str));
    }

    private String comparisonOperator() {
        return getDistanceType().operator;
    }

    public void afterPropertiesSet() {
        logger.info("Initializing PGVectorStore schema for table: {} in schema: {}", getVectorTableName(), getSchemaName());
        logger.info("vectorTableValidationsEnabled {}", Boolean.valueOf(this.schemaValidation));
        if (this.schemaValidation) {
            this.schemaValidator.validateTableSchema(getSchemaName(), getVectorTableName());
        }
        if (!this.initializeSchema) {
            logger.debug("Skipping the schema initialization for the table: {}", getFullyQualifiedTableName());
            return;
        }
        this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
        this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
        this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
        this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", getSchemaName()));
        if (this.removeExistingVectorStoreTable) {
            this.jdbcTemplate.execute(String.format("DROP TABLE IF EXISTS %s", getFullyQualifiedTableName()));
        }
        this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", getFullyQualifiedTableName(), Integer.valueOf(embeddingDimensions())));
        if (this.createIndexMethod != PgIndexType.NONE) {
            this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", getVectorIndexName(), getFullyQualifiedTableName(), this.createIndexMethod, getDistanceType().index));
        }
    }

    private String getFullyQualifiedTableName() {
        return this.schemaName + "." + this.vectorTableName;
    }

    private PgIdType getIdType() {
        return this.idType;
    }

    private String getVectorTableName() {
        return this.vectorTableName;
    }

    private String getSchemaName() {
        return this.schemaName;
    }

    private String getVectorIndexName() {
        return this.vectorIndexName;
    }

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        }
        try {
            int dimensions = this.embeddingModel.dimensions();
            return dimensions > 0 ? dimensions : OPENAI_EMBEDDING_DIMENSION_SIZE;
        } catch (Exception e) {
            logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", e);
            return OPENAI_EMBEDDING_DIMENSION_SIZE;
        }
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String str) {
        return VectorStoreObservationContext.builder(VectorStoreProvider.PG_VECTOR.value(), str).collectionName(this.vectorTableName).dimensions(Integer.valueOf(embeddingDimensions())).namespace(this.schemaName).similarityMetric(getSimilarityMetric());
    }

    private String getSimilarityMetric() {
        return !SIMILARITY_TYPE_MAPPING.containsKey(getDistanceType()) ? getDistanceType().name() : SIMILARITY_TYPE_MAPPING.get(this.distanceType).value();
    }

    public <T> Optional<T> getNativeClient() {
        return Optional.of(this.jdbcTemplate);
    }
}
