/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

public class Lucene99ScalarQuantizedVectorScorer
implements FlatVectorsScorer {
    private final FlatVectorsScorer nonQuantizedDelegate;

    public Lucene99ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
        this.nonQuantizedDelegate = flatVectorsScorer;
    }

    @Override
    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) throws IOException {
        if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) {
            return new ScalarQuantizedRandomVectorScorerSupplier((RandomAccessQuantizedByteVectorValues)vectorValues, similarityFunction);
        }
        return this.nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues, float[] target) throws IOException {
        if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) {
            RandomAccessQuantizedByteVectorValues quantizedByteVectorValues = (RandomAccessQuantizedByteVectorValues)vectorValues;
            ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
            byte[] targetBytes = new byte[target.length];
            float offsetCorrection = ScalarQuantizedVectorScorer.quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer);
            return Lucene99ScalarQuantizedVectorScorer.fromVectorSimilarity(targetBytes, offsetCorrection, similarityFunction, scalarQuantizer.getConstantMultiplier(), quantizedByteVectorValues);
        }
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues, byte[] target) throws IOException {
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    public String toString() {
        return "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + String.valueOf(this.nonQuantizedDelegate) + ")";
    }

    static RandomVectorScorer fromVectorSimilarity(byte[] targetBytes, float offsetCorrection, VectorSimilarityFunction sim, float constMultiplier, RandomAccessQuantizedByteVectorValues values) {
        switch (sim) {
            case EUCLIDEAN: {
                return new Euclidean(values, constMultiplier, targetBytes);
            }
            case COSINE: 
            case DOT_PRODUCT: {
                return Lucene99ScalarQuantizedVectorScorer.dotProductFactory(targetBytes, offsetCorrection, constMultiplier, values, f -> Math.max((1.0f + f) / 2.0f, 0.0f));
            }
            case MAXIMUM_INNER_PRODUCT: {
                return Lucene99ScalarQuantizedVectorScorer.dotProductFactory(targetBytes, offsetCorrection, constMultiplier, values, VectorUtil::scaleMaxInnerProductScore);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function: " + String.valueOf((Object)sim));
    }

    private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory(byte[] targetBytes, float offsetCorrection, float constMultiplier, RandomAccessQuantizedByteVectorValues values, FloatToFloatFunction scoreAdjustmentFunction) {
        if (values.getScalarQuantizer().getBits() <= 4) {
            if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) {
                return new CompressedInt4DotProduct(values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction);
            }
            return new Int4DotProduct(values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction);
        }
        return new DotProduct(values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction);
    }

    private static final class ScalarQuantizedRandomVectorScorerSupplier
    implements RandomVectorScorerSupplier {
        private final VectorSimilarityFunction vectorSimilarityFunction;
        private final RandomAccessQuantizedByteVectorValues values;
        private final RandomAccessQuantizedByteVectorValues values1;
        private final RandomAccessQuantizedByteVectorValues values2;

        public ScalarQuantizedRandomVectorScorerSupplier(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction) throws IOException {
            this.values = values;
            this.values1 = values.copy();
            this.values2 = values.copy();
            this.vectorSimilarityFunction = vectorSimilarityFunction;
        }

        @Override
        public RandomVectorScorer scorer(int ord) throws IOException {
            byte[] vectorValue = this.values1.vectorValue(ord);
            float offsetCorrection = this.values1.getScoreCorrectionConstant(ord);
            return Lucene99ScalarQuantizedVectorScorer.fromVectorSimilarity(vectorValue, offsetCorrection, this.vectorSimilarityFunction, this.values.getScalarQuantizer().getConstantMultiplier(), this.values2);
        }

        @Override
        public ScalarQuantizedRandomVectorScorerSupplier copy() throws IOException {
            return new ScalarQuantizedRandomVectorScorerSupplier(this.values.copy(), this.vectorSimilarityFunction);
        }

        public String toString() {
            return "ScalarQuantizedRandomVectorScorerSupplier(vectorSimilarityFunction=" + String.valueOf((Object)this.vectorSimilarityFunction) + ")";
        }
    }

    @FunctionalInterface
    private static interface FloatToFloatFunction {
        public float apply(float var1);
    }

    private static class Int4DotProduct
    extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final RandomAccessQuantizedByteVectorValues values;
        private final byte[] targetBytes;
        private final float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;

        public Int4DotProduct(RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, FloatToFloatFunction scoreAdjustmentFunction) {
            super(values);
            this.constMultiplier = constMultiplier;
            this.values = values;
            this.targetBytes = targetBytes;
            this.offsetCorrection = offsetCorrection;
            this.scoreAdjustmentFunction = scoreAdjustmentFunction;
        }

        @Override
        public float score(int vectorOrdinal) throws IOException {
            byte[] storedVector = this.values.vectorValue(vectorOrdinal);
            float vectorOffset = this.values.getScoreCorrectionConstant(vectorOrdinal);
            int dotProduct = VectorUtil.int4DotProduct(storedVector, this.targetBytes);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.constMultiplier + this.offsetCorrection + vectorOffset;
            return this.scoreAdjustmentFunction.apply(adjustedDistance);
        }
    }

    private static class CompressedInt4DotProduct
    extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final RandomAccessQuantizedByteVectorValues values;
        private final byte[] compressedVector;
        private final byte[] targetBytes;
        private final float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;

        private CompressedInt4DotProduct(RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, FloatToFloatFunction scoreAdjustmentFunction) {
            super(values);
            this.constMultiplier = constMultiplier;
            this.values = values;
            this.compressedVector = new byte[values.getVectorByteLength()];
            this.targetBytes = targetBytes;
            this.offsetCorrection = offsetCorrection;
            this.scoreAdjustmentFunction = scoreAdjustmentFunction;
        }

        @Override
        public float score(int vectorOrdinal) throws IOException {
            this.values.getSlice().seek((long)vectorOrdinal * (long)(this.values.getVectorByteLength() + 4));
            this.values.getSlice().readBytes(this.compressedVector, 0, this.compressedVector.length);
            float vectorOffset = this.values.getScoreCorrectionConstant(vectorOrdinal);
            int dotProduct = VectorUtil.int4DotProductPacked(this.targetBytes, this.compressedVector);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.constMultiplier + this.offsetCorrection + vectorOffset;
            return this.scoreAdjustmentFunction.apply(adjustedDistance);
        }
    }

    private static class DotProduct
    extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final RandomAccessQuantizedByteVectorValues values;
        private final byte[] targetBytes;
        private final float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;

        public DotProduct(RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, FloatToFloatFunction scoreAdjustmentFunction) {
            super(values);
            this.constMultiplier = constMultiplier;
            this.values = values;
            this.targetBytes = targetBytes;
            this.offsetCorrection = offsetCorrection;
            this.scoreAdjustmentFunction = scoreAdjustmentFunction;
        }

        @Override
        public float score(int vectorOrdinal) throws IOException {
            byte[] storedVector = this.values.vectorValue(vectorOrdinal);
            float vectorOffset = this.values.getScoreCorrectionConstant(vectorOrdinal);
            int dotProduct = VectorUtil.dotProduct(storedVector, this.targetBytes);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.constMultiplier + this.offsetCorrection + vectorOffset;
            return this.scoreAdjustmentFunction.apply(adjustedDistance);
        }
    }

    private static class Euclidean
    extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final byte[] targetBytes;
        private final RandomAccessQuantizedByteVectorValues values;

        private Euclidean(RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) {
            super(values);
            this.values = values;
            this.constMultiplier = constMultiplier;
            this.targetBytes = targetBytes;
        }

        @Override
        public float score(int node) throws IOException {
            byte[] nodeVector = this.values.vectorValue(node);
            int squareDistance = VectorUtil.squareDistance(nodeVector, this.targetBytes);
            float adjustedDistance = (float)squareDistance * this.constMultiplier;
            return 1.0f / (1.0f + adjustedDistance);
        }
    }
}

