/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Matches;
import org.apache.lucene.search.MatchesUtils;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.neuralsearch.executors.HybridQueryExecutor;
import org.opensearch.neuralsearch.executors.HybridQueryExecutorCollector;
import org.opensearch.neuralsearch.executors.HybridQueryScoreSupplierCollectorManager;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.query.HybridScorerSupplier;

public final class HybridQueryWeight
extends Weight {
    private final List<Weight> weights;
    private final ScoreMode scoreMode;

    public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        super((Query)hybridQuery);
        this.weights = hybridQuery.getSubQueries().stream().map(q -> {
            try {
                return searcher.createWeight(searcher.rewrite(q), scoreMode, boost);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }).collect(Collectors.toList());
        this.scoreMode = scoreMode;
    }

    public Matches matches(LeafReaderContext context, int doc) throws IOException {
        List mis = this.weights.stream().map(weight -> {
            try {
                return weight.matches(context, doc);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }).filter(Objects::nonNull).collect(Collectors.toList());
        return MatchesUtils.fromSubMatches(mis);
    }

    public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
        HybridQueryScoreSupplierCollectorManager manager = new HybridQueryScoreSupplierCollectorManager(context);
        ArrayList<Callable<Void>> scoreSupplierTasks = new ArrayList<Callable<Void>>();
        ArrayList<HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier>> collectors = new ArrayList<HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier>>();
        for (Weight weight : this.weights) {
            HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier> collector = manager.newCollector();
            collectors.add(collector);
            scoreSupplierTasks.add(() -> this.addScoreSupplier(weight, collector));
        }
        HybridQueryExecutor.getExecutor().invokeAll(scoreSupplierTasks);
        List<ScorerSupplier> scorerSuppliers = manager.mergeScoreSuppliers(collectors);
        if (scorerSuppliers.isEmpty() || scorerSuppliers.stream().allMatch(Objects::isNull)) {
            return null;
        }
        return new HybridScorerSupplier(scorerSuppliers, this, this.scoreMode, context);
    }

    private Void addScoreSupplier(Weight weight, HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier> collector) {
        collector.collect(leafReaderContext -> {
            try {
                return weight.scorerSupplier(leafReaderContext);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        });
        return null;
    }

    public boolean isCacheable(LeafReaderContext ctx) {
        if (this.weights.size() > 5) {
            return false;
        }
        return this.weights.stream().allMatch(w -> w.isCacheable(ctx));
    }

    public Explanation explain(LeafReaderContext context, int doc) throws IOException {
        boolean match = false;
        double max = 0.0;
        ArrayList<Explanation> subsOnNoMatch = new ArrayList<Explanation>();
        ArrayList<Explanation> subsOnMatch = new ArrayList<Explanation>();
        for (Weight wt : this.weights) {
            Explanation e = wt.explain(context, doc);
            if (e.isMatch()) {
                match = true;
                double score = e.getValue().doubleValue();
                max = Math.max(max, score);
                subsOnMatch.add(e);
                continue;
            }
            if (!match) {
                subsOnNoMatch.add(e);
            }
            subsOnMatch.add(e);
        }
        if (match) {
            String desc = "combined score of:";
            return Explanation.match((Number)max, (String)"combined score of:", subsOnMatch);
        }
        return Explanation.noMatch((String)"no matching clause", subsOnNoMatch);
    }

    @Generated
    List<Weight> getWeights() {
        return this.weights;
    }
}

