package org.springframework.ai.chat.client.advisor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import reactor.core.scheduler.Scheduler;

/* loaded from: input_file:org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.class */
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
    public static final String DOCUMENT_CONTEXT = "rag_document_context";
    private final List<QueryTransformer> queryTransformers;

    @Nullable
    private final QueryExpander queryExpander;
    private final DocumentRetriever documentRetriever;
    private final DocumentJoiner documentJoiner;
    private final QueryAugmenter queryAugmenter;
    private final TaskExecutor taskExecutor;
    private final Scheduler scheduler;
    private final int order;

    /* loaded from: input_file:org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor$Builder.class */
    public static final class Builder {
        private List<QueryTransformer> queryTransformers;
        private QueryExpander queryExpander;
        private DocumentRetriever documentRetriever;
        private DocumentJoiner documentJoiner;
        private QueryAugmenter queryAugmenter;
        private TaskExecutor taskExecutor;
        private Scheduler scheduler;
        private Integer order;

        private Builder() {
        }

        public Builder queryTransformers(List<QueryTransformer> list) {
            this.queryTransformers = list;
            return this;
        }

        public Builder queryTransformers(QueryTransformer... queryTransformerArr) {
            this.queryTransformers = Arrays.asList(queryTransformerArr);
            return this;
        }

        public Builder queryExpander(QueryExpander queryExpander) {
            this.queryExpander = queryExpander;
            return this;
        }

        public Builder documentRetriever(DocumentRetriever documentRetriever) {
            this.documentRetriever = documentRetriever;
            return this;
        }

        public Builder documentJoiner(DocumentJoiner documentJoiner) {
            this.documentJoiner = documentJoiner;
            return this;
        }

        public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
            this.queryAugmenter = queryAugmenter;
            return this;
        }

        public Builder taskExecutor(TaskExecutor taskExecutor) {
            this.taskExecutor = taskExecutor;
            return this;
        }

        public Builder scheduler(Scheduler scheduler) {
            this.scheduler = scheduler;
            return this;
        }

        public Builder order(Integer num) {
            this.order = num;
            return this;
        }

        public RetrievalAugmentationAdvisor build() {
            return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever, this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
        }
    }

    public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> list, @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever, @Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer num) {
        Assert.notNull(documentRetriever, "documentRetriever cannot be null");
        Assert.noNullElements(list, "queryTransformers cannot contain null elements");
        this.queryTransformers = list != null ? list : List.of();
        this.queryExpander = queryExpander;
        this.documentRetriever = documentRetriever;
        this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
        this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
        this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
        this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
        this.order = num != null ? num.intValue() : 0;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // org.springframework.ai.chat.client.advisor.api.BaseAdvisor
    public AdvisedRequest before(AdvisedRequest advisedRequest) {
        HashMap hashMap = new HashMap(advisedRequest.adviseContext());
        Query build = Query.builder().text(new PromptTemplate(advisedRequest.userText(), advisedRequest.userParams()).render()).history(advisedRequest.messages()).build();
        Query query = build;
        Iterator<QueryTransformer> it = this.queryTransformers.iterator();
        while (it.hasNext()) {
            query = it.next().apply(query);
        }
        List<Document> join = this.documentJoiner.join((Map) (this.queryExpander != null ? this.queryExpander.expand(query) : List.of(query)).stream().map(query2 -> {
            return CompletableFuture.supplyAsync(() -> {
                return getDocumentsForQuery(query2);
            }, this.taskExecutor);
        }).toList().stream().map((v0) -> {
            return v0.join();
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return List.of((List) entry.getValue());
        })));
        hashMap.put(DOCUMENT_CONTEXT, join);
        return AdvisedRequest.from(advisedRequest).userText(this.queryAugmenter.augment(build, join).text()).adviseContext(hashMap).build();
    }

    private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
        return Map.entry(query, this.documentRetriever.retrieve(query));
    }

    @Override // org.springframework.ai.chat.client.advisor.api.BaseAdvisor
    public AdvisedResponse after(AdvisedResponse advisedResponse) {
        ChatResponse.Builder builder = advisedResponse.response() == null ? ChatResponse.builder() : ChatResponse.builder().from(advisedResponse.response());
        builder.metadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
        return new AdvisedResponse(builder.build(), advisedResponse.adviseContext());
    }

    @Override // org.springframework.ai.chat.client.advisor.api.BaseAdvisor
    public Scheduler getScheduler() {
        return this.scheduler;
    }

    public int getOrder() {
        return this.order;
    }

    private static TaskExecutor buildDefaultTaskExecutor() {
        ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
        threadPoolTaskExecutor.setThreadNamePrefix("ai-advisor-");
        threadPoolTaskExecutor.setCorePoolSize(4);
        threadPoolTaskExecutor.setMaxPoolSize(16);
        threadPoolTaskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
        threadPoolTaskExecutor.initialize();
        return threadPoolTaskExecutor;
    }
}
