/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.gpu;

import com.ibm.cuda.CudaBuffer;
import com.ibm.cuda.CudaDevice;
import com.ibm.cuda.CudaGrid;
import com.ibm.cuda.CudaKernel;
import com.ibm.cuda.CudaModule;
import com.ibm.cuda.CudaStream;
import com.ibm.cuda.Dim3;
import com.ibm.gpu.CUDAManager;
import com.ibm.gpu.GPUConfigurationException;
import com.ibm.gpu.GPUSortException;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

final class SortNetwork {
    private static final ConcurrentHashMap<CudaDevice, SortKernels> deviceMap;
    private static final Integer[] powersOf2;
    private final CudaDevice device;
    private final int maxGridDimX;
    private CudaKernel sortFirst4;
    private CudaKernel sortOther1;
    private CudaKernel sortOther2;
    private CudaKernel sortOther3;
    private CudaKernel sortOther4;
    private CudaKernel sortPhase9;

    private static void checkIndices(int length, int fromIndex, int toIndex) {
        if (fromIndex > toIndex) {
            throw new IllegalArgumentException();
        }
        if (fromIndex < 0) {
            throw new ArrayIndexOutOfBoundsException(fromIndex);
        }
        if (toIndex > length) {
            throw new ArrayIndexOutOfBoundsException(toIndex);
        }
    }

    private static int roundUp(int value, int unit) {
        assert (value > 0);
        assert (unit > 0);
        int remainder = value % unit;
        return remainder == 0 ? value : value + (unit - remainder);
    }

    private static int significantBits(int value) {
        return 32 - Integer.numberOfLeadingZeros(Math.max(1, value));
    }

    static void sortArray(int deviceId, double[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "double", fromIndex, toIndex);
        try {
            SortNetwork network = new SortNetwork(deviceId);
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "double");
    }

    static void sortArray(int deviceId, float[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "float", fromIndex, toIndex);
        try {
            SortNetwork network = new SortNetwork(deviceId);
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "float");
    }

    static void sortArray(int deviceId, int[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "int", fromIndex, toIndex);
        try {
            SortNetwork network = new SortNetwork(deviceId);
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "int");
    }

    static void sortArray(int deviceId, long[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "long", fromIndex, toIndex);
        try {
            SortNetwork network = new SortNetwork(deviceId);
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "long");
    }

    private static void traceFailure(CUDAManager manager, Exception exception) {
        manager.outputIfVerbose(exception.getLocalizedMessage());
    }

    private static CUDAManager traceStart(int deviceId, String type, int fromIndex, int toIndex) {
        CUDAManager manager = CUDAManager.instanceInternal();
        if (manager.getVerboseGPUOutput()) {
            manager.outputIfVerbose("Using device: " + deviceId + " to sort " + type + " array; elements " + fromIndex + " to " + toIndex);
        }
        return manager;
    }

    private static void traceSuccess(CUDAManager manager, int deviceId, String type) {
        if (manager.getVerboseGPUOutput()) {
            manager.outputIfVerbose("Sorted " + type + "s on device " + deviceId + " successfully");
        }
    }

    private SortNetwork(int deviceId) throws GPUConfigurationException {
        try {
            this.device = new CudaDevice(deviceId);
        }
        catch (NoClassDefFoundError e) {
            throw new GPUConfigurationException("Unsupported platform detected");
        }
        try {
            int capability = this.device.getAttribute(75);
            if (capability < 2) {
                throw new GPUConfigurationException("Compute capability 2.0 or better required");
            }
            this.maxGridDimX = this.device.getAttribute(5);
        }
        catch (Exception e) {
            throw new GPUConfigurationException(e.getLocalizedMessage(), e);
        }
    }

    private SortKernels getKernels() throws GPUSortException {
        try {
            return deviceMap.computeIfAbsent(this.device, SortKernels::create);
        }
        catch (DelayedException e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private CudaGrid makeGrid(int threadCount, int blockSize, CudaStream stream) {
        int blockCount = Math.max(1, (threadCount + blockSize - 1) / blockSize);
        return new CudaGrid(this.makeGridDim(blockCount), new Dim3(blockSize), stream);
    }

    private Dim3 makeGridDim(int blockCount) {
        int blockDimX = Math.max(1, blockCount);
        int blockDimY = 1;
        while (blockDimX > this.maxGridDimX) {
            if ((blockDimX & 1) != 0) {
                ++blockDimX;
            }
            blockDimX >>= 1;
            blockDimY <<= 1;
        }
        return new Dim3(blockDimX, blockDimY);
    }

    private void sort(double[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        SortKernels kernels = this.getKernels();
        this.sortFirst4 = kernels.doubleSortFirst4;
        this.sortOther1 = kernels.doubleSortOther1;
        this.sortOther2 = kernels.doubleSortOther2;
        this.sortOther3 = kernels.doubleSortOther3;
        this.sortOther4 = kernels.doubleSortOther4;
        this.sortPhase9 = kernels.doubleSortPhase9;
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 8L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (Exception e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sort(float[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        SortKernels kernels = this.getKernels();
        this.sortFirst4 = kernels.floatSortFirst4;
        this.sortOther1 = kernels.floatSortOther1;
        this.sortOther2 = kernels.floatSortOther2;
        this.sortOther3 = kernels.floatSortOther3;
        this.sortOther4 = kernels.floatSortOther4;
        this.sortPhase9 = kernels.floatSortPhase9;
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 4L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (Exception e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sort(int[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        SortKernels kernels = this.getKernels();
        this.sortFirst4 = kernels.intSortFirst4;
        this.sortOther1 = kernels.intSortOther1;
        this.sortOther2 = kernels.intSortOther2;
        this.sortOther3 = kernels.intSortOther3;
        this.sortOther4 = kernels.intSortOther4;
        this.sortPhase9 = kernels.intSortPhase9;
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 4L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (Exception e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sort(long[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        SortKernels kernels = this.getKernels();
        this.sortFirst4 = kernels.longSortFirst4;
        this.sortOther1 = kernels.longSortOther1;
        this.sortOther2 = kernels.longSortOther2;
        this.sortOther3 = kernels.longSortOther3;
        this.sortOther4 = kernels.longSortOther4;
        this.sortPhase9 = kernels.longSortPhase9;
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 8L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (Exception e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sortBuffer(CudaBuffer buffer, int length) throws Exception {
        try (CudaStream stream = new CudaStream(this.device);){
            Integer boxLength = length;
            int phaseCount = 9;
            int inputSize = 512;
            int blockSize = 256;
            CudaGrid grid = this.makeGrid(length >> 1, 256, stream);
            this.sortPhase9.launch(grid, new Object[]{buffer, boxLength});
            phaseCount = SortNetwork.significantBits(length - 1);
            if (phaseCount <= 9) {
                return;
            }
            int blockSize2 = 256;
            CudaGrid gridOther = this.makeGrid(length >> 1, 256, stream);
            block17: for (int phase = 9; phase < phaseCount; ++phase) {
                int granule = 1 << phase;
                int grains = SortNetwork.roundUp(length, granule);
                CudaGrid grid2 = this.makeGrid(grains >> 1, 256, stream);
                this.sortFirst4.launch(grid2, new Object[]{buffer, boxLength, powersOf2[phase]});
                int step = phase;
                while ((step -= 4) >= 3) {
                    this.sortOther4.launch(gridOther, new Object[]{buffer, boxLength, powersOf2[step]});
                }
                switch (phase & 3) {
                    case 2: {
                        this.sortOther3.launch(gridOther, new Object[]{buffer, boxLength, powersOf2[2]});
                        continue block17;
                    }
                    case 1: {
                        this.sortOther2.launch(gridOther, new Object[]{buffer, boxLength, powersOf2[1]});
                        continue block17;
                    }
                    case 0: {
                        this.sortOther1.launch(gridOther, new Object[]{buffer, boxLength, powersOf2[0]});
                        continue block17;
                    }
                }
            }
        }
    }

    static {
        int phaseCount = 31;
        Integer[] powers = new Integer[31];
        for (int i = 0; i < 31; ++i) {
            powers[i] = 1 << i;
        }
        deviceMap = new ConcurrentHashMap();
        powersOf2 = powers;
    }

    private static final class SortKernels {
        final CudaKernel doubleSortFirst4;
        final CudaKernel doubleSortOther1;
        final CudaKernel doubleSortOther2;
        final CudaKernel doubleSortOther3;
        final CudaKernel doubleSortOther4;
        final CudaKernel doubleSortPhase9;
        final CudaKernel floatSortFirst4;
        final CudaKernel floatSortOther1;
        final CudaKernel floatSortOther2;
        final CudaKernel floatSortOther3;
        final CudaKernel floatSortOther4;
        final CudaKernel floatSortPhase9;
        final CudaKernel intSortFirst4;
        final CudaKernel intSortOther1;
        final CudaKernel intSortOther2;
        final CudaKernel intSortOther3;
        final CudaKernel intSortOther4;
        final CudaKernel intSortPhase9;
        final CudaKernel longSortFirst4;
        final CudaKernel longSortOther1;
        final CudaKernel longSortOther2;
        final CudaKernel longSortOther3;
        final CudaKernel longSortOther4;
        final CudaKernel longSortPhase9;

        static SortKernels create(CudaDevice device) {
            PrivilegedAction<SortKernels> action = () -> {
                try {
                    String code = "SortKernels.fatbin";
                    CudaModule module = null;
                    try {
                        try (InputStream fatbin = CUDAManager.class.getResourceAsStream(code);){
                            if (fatbin == null) {
                                throw new FileNotFoundException(code);
                            }
                            module = new CudaModule(device, fatbin);
                            SortKernels kernels = new SortKernels(module);
                            ShutdownHook.unloadOnShutdown(module);
                            module = null;
                            SortKernels sortKernels = kernels;
                            return sortKernels;
                        }
                        {
                            catch (Throwable throwable) {
                                throw throwable;
                            }
                        }
                    }
                    finally {
                        if (module != null) {
                            module.unload();
                        }
                    }
                }
                catch (Exception e) {
                    throw new DelayedException(e);
                }
            };
            return AccessController.doPrivileged(action);
        }

        private SortKernels(CudaModule module) throws Exception {
            this.doubleSortFirst4 = new CudaKernel(module, "DFirst4");
            this.doubleSortPhase9 = new CudaKernel(module, "DPhase9");
            this.doubleSortOther1 = new CudaKernel(module, "DOther1");
            this.doubleSortOther2 = new CudaKernel(module, "DOther2");
            this.doubleSortOther3 = new CudaKernel(module, "DOther3");
            this.doubleSortOther4 = new CudaKernel(module, "DOther4");
            this.floatSortFirst4 = new CudaKernel(module, "FFirst4");
            this.floatSortPhase9 = new CudaKernel(module, "FPhase9");
            this.floatSortOther1 = new CudaKernel(module, "FOther1");
            this.floatSortOther2 = new CudaKernel(module, "FOther2");
            this.floatSortOther3 = new CudaKernel(module, "FOther3");
            this.floatSortOther4 = new CudaKernel(module, "FOther4");
            this.intSortFirst4 = new CudaKernel(module, "IFirst4");
            this.intSortPhase9 = new CudaKernel(module, "IPhase9");
            this.intSortOther1 = new CudaKernel(module, "IOther1");
            this.intSortOther2 = new CudaKernel(module, "IOther2");
            this.intSortOther3 = new CudaKernel(module, "IOther3");
            this.intSortOther4 = new CudaKernel(module, "IOther4");
            this.longSortFirst4 = new CudaKernel(module, "JFirst4");
            this.longSortPhase9 = new CudaKernel(module, "JPhase9");
            this.longSortOther1 = new CudaKernel(module, "JOther1");
            this.longSortOther2 = new CudaKernel(module, "JOther2");
            this.longSortOther3 = new CudaKernel(module, "JOther3");
            this.longSortOther4 = new CudaKernel(module, "JOther4");
        }
    }

    private static final class ShutdownHook
    extends Thread {
        private static final Queue<CudaModule> modules = new ConcurrentLinkedQueue<CudaModule>();

        public static void unloadOnShutdown(CudaModule module) {
            modules.add(module);
        }

        private ShutdownHook() {
            super("GPU sort shutdown helper");
        }

        @Override
        public void run() {
            CudaModule module;
            while ((module = modules.poll()) != null) {
                try {
                    module.unload();
                }
                catch (Exception exception) {}
            }
        }

        static {
            AccessController.doPrivileged(() -> {
                Runtime.getRuntime().addShutdownHook(new ShutdownHook());
                return null;
            });
        }
    }

    private static final class DelayedException
    extends RuntimeException {
        private static final long serialVersionUID = 6735593106826400878L;

        DelayedException(Exception exception) {
            super(exception.getLocalizedMessage(), exception);
        }
    }
}

