1678 字
8 分钟
一种1亿个正整数读取、排序和保存的Java实现

前情提要#

网上冲浪时,偶然间发现一个有意思的话题:《[赛博斗蛐蛐]1亿个整数的读取、排序和保存,极致的优化,给我吓到了》。up给的题目是这样的:

对1亿个正整数进行排序。
需求如下:
1. 程序开始后,要求输入一个文本文件的路径。(要考虑绝对路径和相对路径两种情况)
2. 从该文本文件读取数字,并进行正向排序。相同数字则顺延排下去。该文本文件共有1亿行,每行一个数字,均为正整数,最大为2147483647,也就是2^31-1
3. 排序完成后,将结果输出到当前目录的一个叫sort.txt的文档中。

然后,各路大神贡献了自己的解决方法,最快的仅需2秒多一点点。

既然是“斗蛐蛐”,我也来凑个热闹。

实现代码#

大佬们甚至都卷到指令集去了,我作为一个菜鸡只能用用最常规的方式了。

虽然要求输出到sort.txt,但是为了调试方便我输出到了“out-%timestamp.txt”了

代码#

import java.io.File;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.LongAdder;

public class SortFile {

    private static final int BUFFER_SIZE = 2 * 512 * 1024;
    private static final int WRITE_BUFFER_SIZE = 8 * 128 * 1024;
    private static final int TOTAL_NUMBERS = 100_000_000;
    private static final byte LF = '\n';
    private static final byte ZERO = '0';
    private static final byte NINE = '9';
    private static final ThreadFactory TF = Thread.ofVirtual().name("sort").factory();
    private static final LongAdder CHARACTER_COUNTER = new LongAdder();
    private static String outputFilePath;
    private static int[] allNumber;

    public static void main(String[] args) throws IOException, InterruptedException {
        if (args.length < 1) {
            System.out.println("请提供文件路径作为参数");
            return;
        }
        String filePath = args[0];
        if (!new File(filePath).exists()) {
            System.out.println("文件路径[" + filePath + "]不存在");
            return;
        }
        long start = System.currentTimeMillis();
        readFileInChunks(filePath);
        long readTime = System.currentTimeMillis();
        System.out.println("readTime = " + (readTime - start));
        Arrays.parallelSort(allNumber);
        long sortTime = System.currentTimeMillis();
        System.out.println("sortTime = " + (sortTime - readTime));
        writeNumbersToFileMultiThreaded(outputFilePath);
        long writeTime = System.currentTimeMillis();
        System.out.println("writeTime = " + (writeTime - sortTime));
        System.out.println("allTime = " + (writeTime - start));
    }

    private static void readFileInChunks(String filePath) throws IOException, InterruptedException {
        try (FileChannel fc = FileChannel.open(Path.of(filePath), StandardOpenOption.READ)) {
            long fileSize = fc.size();
            long pointer = 0;
            long threadNum = (fileSize + BUFFER_SIZE - 1) / BUFFER_SIZE;
            int[][] threadNumers = new int[(int) threadNum][];
            List<Thread> threadList = new ArrayList<>((int) threadNum);

            for (int i = 0; i < threadNum; i++) {
                long startPointer = pointer;
                long endPointer = startPointer + BUFFER_SIZE;
                long chunkSize;
                if (endPointer < (fileSize - 11L)) {
                    long lastLineSize = findLastLineSize(fc, endPointer);
                    chunkSize = BUFFER_SIZE + lastLineSize;
                } else {
                    chunkSize = fileSize - startPointer;
                }
                pointer = startPointer + chunkSize;
                MappedByteBuffer buffer = fc.map(FileChannel.MapMode.READ_ONLY, startPointer, chunkSize);
                int finalI = i;
                Thread thread = TF.newThread(() -> readFile(buffer, threadNumers, finalI));
                thread.start();
                threadList.add(thread);
            }

            allNumber = new int[TOTAL_NUMBERS];
            int index = 0;
            outputFilePath = "." + File.separator + "out-" + System.currentTimeMillis() + ".txt";

            for (int i = 0, threadListSize = threadList.size(); i < threadListSize; i++) {
                threadList.get(i).join();
                int[] array = threadNumers[i];
                System.arraycopy(array, 0, allNumber, index, array.length);
                index += array.length;
            }
        }
    }

    private static void readFile(MappedByteBuffer buffer, int[][] ints, int i) {
        int number = 0;
        long counter = 0;
        List<Integer> integerList = new ArrayList<>();
        boolean inNumber = false;

        while (buffer.hasRemaining()) {
            byte b = buffer.get();
            if (b == LF) {
                if (inNumber) {
                    integerList.add(number);
                    number = 0;
                    inNumber = false;
                    counter++;
                }
            } else if (b >= ZERO && b <= NINE) {
                number = number * 10 + (b - ZERO);
                inNumber = true;
                counter++;
            }
        }

        buffer.clear();
        if (inNumber) {
            integerList.add(number);
        }
        if (!integerList.isEmpty()) {
            ints[i] = integerList.stream().mapToInt(Integer::valueOf).toArray();
        }
        CHARACTER_COUNTER.add(counter);
    }

    private static long findLastLineSize(FileChannel fc, long endNum) throws IOException {
        MappedByteBuffer buffer = fc.map(FileChannel.MapMode.READ_ONLY, endNum - 1, 11L);
        long num = 0;
        while (buffer.hasRemaining()) {
            byte b = buffer.get();
            if (b == LF) {
                return num;
            } else {
                num++;
            }
        }
        return num;
    }

    private static void writeNumbersToFileMultiThreaded(String filePath) throws IOException, InterruptedException {
        try (FileChannel fileChannel = FileChannel.open(Path.of(filePath), StandardOpenOption.READ, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
            long counter = CHARACTER_COUNTER.longValue();
            long threadNum = (counter + WRITE_BUFFER_SIZE - 1) / WRITE_BUFFER_SIZE;
            int threadCore = Runtime.getRuntime().availableProcessors();
            int byteThreadBufferSize = (TOTAL_NUMBERS + threadCore - 1) / threadCore;
            byte[] allNumberByte = new byte[(int) counter];
            int[] byteThreadStartPointers = new int[threadCore];
            List<Thread> byteThreadList = new ArrayList<>(threadCore);
            int countForByteThreadEndPointers = 0;
            int pointerForByteThreadEndPointers = 0;

            for (int i = 0, j = allNumber.length; i < j; i++) {
                if (i > 0 && i % byteThreadBufferSize == 0) {
                    byteThreadStartPointers[++pointerForByteThreadEndPointers] = countForByteThreadEndPointers;
                }
                countForByteThreadEndPointers = countForByteThreadEndPointers + stringSize(allNumber[i]) + 1;
            }

            for (int i = 0; i < threadCore; i++) {
                int start = i * byteThreadBufferSize;
                int end = (i == threadCore - 1) ? TOTAL_NUMBERS : (i + 1) * byteThreadBufferSize;
                int finalI = i;
                Thread thread = TF.newThread(() -> {
                    int count = byteThreadStartPointers[finalI];
                    for (int j = start; j < end; j++) {
                        int number = allNumber[j];
                        int stringSize = stringSize(number);
                        count += stringSize;
                        getChars(number, count, allNumberByte);
                        allNumberByte[count++] = LF;
                    }
                });
                thread.start();
                byteThreadList.add(thread);
            }

            List<Thread> threadList = new ArrayList<>((int) threadNum);

            for (Thread thread : byteThreadList) {
                thread.join();
            }

            for (int i = 0; i < threadNum; i++) {
                long position = (long) i * WRITE_BUFFER_SIZE;
                long size = (i == threadNum - 1) ? (counter - position) : WRITE_BUFFER_SIZE;
                MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, position, size);
                Thread thread = TF.newThread(() -> buffer.put(allNumberByte, (int) position, (int) size));
                thread.start();
                threadList.add(thread);
            }

            for (Thread thread : threadList) {
                thread.join();
            }
        }
    }


    static final byte[] DigitTens = {
            '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
            '1', '1', '1', '1', '1', '1', '1', '1', '1', '1',
            '2', '2', '2', '2', '2', '2', '2', '2', '2', '2',
            '3', '3', '3', '3', '3', '3', '3', '3', '3', '3',
            '4', '4', '4', '4', '4', '4', '4', '4', '4', '4',
            '5', '5', '5', '5', '5', '5', '5', '5', '5', '5',
            '6', '6', '6', '6', '6', '6', '6', '6', '6', '6',
            '7', '7', '7', '7', '7', '7', '7', '7', '7', '7',
            '8', '8', '8', '8', '8', '8', '8', '8', '8', '8',
            '9', '9', '9', '9', '9', '9', '9', '9', '9', '9',
    };

    static final byte[] DigitOnes = {
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    };

    static void getChars(int i, int index, byte[] buf) {
        int q, r;
        int charPos = index;
        i = -i;

        while (i <= -100) {
            q = i / 100;
            r = (q * 100) - i;
            i = q;
            buf[--charPos] = DigitOnes[r];
            buf[--charPos] = DigitTens[r];
        }

        buf[--charPos] = DigitOnes[-i];
        if (i < -9) {
            buf[--charPos] = DigitTens[-i];
        }
    }

    private static int stringSize(int x) {
        int d = 0;
        x = -x;
        int p = -10;
        for (int i = 1; i < 10; i++) {
            if (x > p)
                return i + d;
            p = 10 * p;
        }
        return 10 + d;
    }
}

最后从jdk源码里偷了两个方法,省的自己写了

运行#

我自己测试使用了jdk21和jdk23,jdk23的速度要更快一点。

运行环境:#

  • CPU:5600G
  • 内存:16G
  • 硬盘:PCIe 3.0 SSD
  • 操作系统:Ubuntu 24.04.1 LTS
  • JDK版本:OpenJDK Runtime Environment Zulu23.28+85-CA (build 23+37)
time java SortFile.java 1e8.txt 
readTime = 856
sortTime = 1150
writeTime = 678
allTime = 2684

real    0m2.946s
user    0m20.174s
sys     0m3.709s

很神奇的是,用GraalVM打包为本地镜像之后,速度有了薛定谔的提升。

time ./sortfile 1e8.txt 
readTime = 1184
sortTime = 930
writeTime = 583
allTime = 2697

real    0m2.840s
user    0m17.030s
sys     0m3.841s

相比jvm的方式运行,不止何种原因,读取速度大幅变慢了,但是排序和写文件却提升不少,有时候甚至能逆风打赢jvm。

有没有懂行的大佬帮忙解答一下。

尾巴#

虽然打不赢大佬们的卷到指令集级别的速度,但是仅仅用原生提供的方法,也能达成一个还算可以的速度,至少我还算比较满意。

因为题目要求的是“正整数”、数量也仅有“一亿”,所以用了很多取巧的手段,如果题目变了代码也得变才行,就这样凑合凑合吧。

一种1亿个正整数读取、排序和保存的Java实现
https://halvamazd.lv6.fun/posts/1e8-positive-integer-sort-java/
作者
哈瓦玛玛兹
发布于
2024-11-05
许可协议
CC BY-NC-SA 4.0