画师:不熬夜乐L
1678 字
8 分钟
一种1亿个正整数读取、排序和保存的Java实现
前情提要
网上冲浪时,偶然间发现一个有意思的话题:《[赛博斗蛐蛐]1亿个整数的读取、排序和保存,极致的优化,给我吓到了》。up给的题目是这样的:
对1亿个正整数进行排序。
需求如下:
1. 程序开始后,要求输入一个文本文件的路径。(要考虑绝对路径和相对路径两种情况)
2. 从该文本文件读取数字,并进行正向排序。相同数字则顺延排下去。该文本文件共有1亿行,每行一个数字,均为正整数,最大为2147483647,也就是2^31-1
3. 排序完成后,将结果输出到当前目录的一个叫sort.txt的文档中。
然后,各路大神贡献了自己的解决方法,最快的仅需2秒多一点点。
既然是“斗蛐蛐”,我也来凑个热闹。
实现代码
大佬们甚至都卷到指令集去了,我作为一个菜鸡只能用用最常规的方式了。
虽然要求输出到sort.txt,但是为了调试方便我输出到了“out-%timestamp.txt”了
代码
import java.io.File;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.LongAdder;
public class SortFile {
private static final int BUFFER_SIZE = 2 * 512 * 1024;
private static final int WRITE_BUFFER_SIZE = 8 * 128 * 1024;
private static final int TOTAL_NUMBERS = 100_000_000;
private static final byte LF = '\n';
private static final byte ZERO = '0';
private static final byte NINE = '9';
private static final ThreadFactory TF = Thread.ofVirtual().name("sort").factory();
private static final LongAdder CHARACTER_COUNTER = new LongAdder();
private static String outputFilePath;
private static int[] allNumber;
public static void main(String[] args) throws IOException, InterruptedException {
if (args.length < 1) {
System.out.println("请提供文件路径作为参数");
return;
}
String filePath = args[0];
if (!new File(filePath).exists()) {
System.out.println("文件路径[" + filePath + "]不存在");
return;
}
long start = System.currentTimeMillis();
readFileInChunks(filePath);
long readTime = System.currentTimeMillis();
System.out.println("readTime = " + (readTime - start));
Arrays.parallelSort(allNumber);
long sortTime = System.currentTimeMillis();
System.out.println("sortTime = " + (sortTime - readTime));
writeNumbersToFileMultiThreaded(outputFilePath);
long writeTime = System.currentTimeMillis();
System.out.println("writeTime = " + (writeTime - sortTime));
System.out.println("allTime = " + (writeTime - start));
}
private static void readFileInChunks(String filePath) throws IOException, InterruptedException {
try (FileChannel fc = FileChannel.open(Path.of(filePath), StandardOpenOption.READ)) {
long fileSize = fc.size();
long pointer = 0;
long threadNum = (fileSize + BUFFER_SIZE - 1) / BUFFER_SIZE;
int[][] threadNumers = new int[(int) threadNum][];
List<Thread> threadList = new ArrayList<>((int) threadNum);
for (int i = 0; i < threadNum; i++) {
long startPointer = pointer;
long endPointer = startPointer + BUFFER_SIZE;
long chunkSize;
if (endPointer < (fileSize - 11L)) {
long lastLineSize = findLastLineSize(fc, endPointer);
chunkSize = BUFFER_SIZE + lastLineSize;
} else {
chunkSize = fileSize - startPointer;
}
pointer = startPointer + chunkSize;
MappedByteBuffer buffer = fc.map(FileChannel.MapMode.READ_ONLY, startPointer, chunkSize);
int finalI = i;
Thread thread = TF.newThread(() -> readFile(buffer, threadNumers, finalI));
thread.start();
threadList.add(thread);
}
allNumber = new int[TOTAL_NUMBERS];
int index = 0;
outputFilePath = "." + File.separator + "out-" + System.currentTimeMillis() + ".txt";
for (int i = 0, threadListSize = threadList.size(); i < threadListSize; i++) {
threadList.get(i).join();
int[] array = threadNumers[i];
System.arraycopy(array, 0, allNumber, index, array.length);
index += array.length;
}
}
}
private static void readFile(MappedByteBuffer buffer, int[][] ints, int i) {
int number = 0;
long counter = 0;
List<Integer> integerList = new ArrayList<>();
boolean inNumber = false;
while (buffer.hasRemaining()) {
byte b = buffer.get();
if (b == LF) {
if (inNumber) {
integerList.add(number);
number = 0;
inNumber = false;
counter++;
}
} else if (b >= ZERO && b <= NINE) {
number = number * 10 + (b - ZERO);
inNumber = true;
counter++;
}
}
buffer.clear();
if (inNumber) {
integerList.add(number);
}
if (!integerList.isEmpty()) {
ints[i] = integerList.stream().mapToInt(Integer::valueOf).toArray();
}
CHARACTER_COUNTER.add(counter);
}
private static long findLastLineSize(FileChannel fc, long endNum) throws IOException {
MappedByteBuffer buffer = fc.map(FileChannel.MapMode.READ_ONLY, endNum - 1, 11L);
long num = 0;
while (buffer.hasRemaining()) {
byte b = buffer.get();
if (b == LF) {
return num;
} else {
num++;
}
}
return num;
}
private static void writeNumbersToFileMultiThreaded(String filePath) throws IOException, InterruptedException {
try (FileChannel fileChannel = FileChannel.open(Path.of(filePath), StandardOpenOption.READ, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
long counter = CHARACTER_COUNTER.longValue();
long threadNum = (counter + WRITE_BUFFER_SIZE - 1) / WRITE_BUFFER_SIZE;
int threadCore = Runtime.getRuntime().availableProcessors();
int byteThreadBufferSize = (TOTAL_NUMBERS + threadCore - 1) / threadCore;
byte[] allNumberByte = new byte[(int) counter];
int[] byteThreadStartPointers = new int[threadCore];
List<Thread> byteThreadList = new ArrayList<>(threadCore);
int countForByteThreadEndPointers = 0;
int pointerForByteThreadEndPointers = 0;
for (int i = 0, j = allNumber.length; i < j; i++) {
if (i > 0 && i % byteThreadBufferSize == 0) {
byteThreadStartPointers[++pointerForByteThreadEndPointers] = countForByteThreadEndPointers;
}
countForByteThreadEndPointers = countForByteThreadEndPointers + stringSize(allNumber[i]) + 1;
}
for (int i = 0; i < threadCore; i++) {
int start = i * byteThreadBufferSize;
int end = (i == threadCore - 1) ? TOTAL_NUMBERS : (i + 1) * byteThreadBufferSize;
int finalI = i;
Thread thread = TF.newThread(() -> {
int count = byteThreadStartPointers[finalI];
for (int j = start; j < end; j++) {
int number = allNumber[j];
int stringSize = stringSize(number);
count += stringSize;
getChars(number, count, allNumberByte);
allNumberByte[count++] = LF;
}
});
thread.start();
byteThreadList.add(thread);
}
List<Thread> threadList = new ArrayList<>((int) threadNum);
for (Thread thread : byteThreadList) {
thread.join();
}
for (int i = 0; i < threadNum; i++) {
long position = (long) i * WRITE_BUFFER_SIZE;
long size = (i == threadNum - 1) ? (counter - position) : WRITE_BUFFER_SIZE;
MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, position, size);
Thread thread = TF.newThread(() -> buffer.put(allNumberByte, (int) position, (int) size));
thread.start();
threadList.add(thread);
}
for (Thread thread : threadList) {
thread.join();
}
}
}
static final byte[] DigitTens = {
'0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
'1', '1', '1', '1', '1', '1', '1', '1', '1', '1',
'2', '2', '2', '2', '2', '2', '2', '2', '2', '2',
'3', '3', '3', '3', '3', '3', '3', '3', '3', '3',
'4', '4', '4', '4', '4', '4', '4', '4', '4', '4',
'5', '5', '5', '5', '5', '5', '5', '5', '5', '5',
'6', '6', '6', '6', '6', '6', '6', '6', '6', '6',
'7', '7', '7', '7', '7', '7', '7', '7', '7', '7',
'8', '8', '8', '8', '8', '8', '8', '8', '8', '8',
'9', '9', '9', '9', '9', '9', '9', '9', '9', '9',
};
static final byte[] DigitOnes = {
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
};
static void getChars(int i, int index, byte[] buf) {
int q, r;
int charPos = index;
i = -i;
while (i <= -100) {
q = i / 100;
r = (q * 100) - i;
i = q;
buf[--charPos] = DigitOnes[r];
buf[--charPos] = DigitTens[r];
}
buf[--charPos] = DigitOnes[-i];
if (i < -9) {
buf[--charPos] = DigitTens[-i];
}
}
private static int stringSize(int x) {
int d = 0;
x = -x;
int p = -10;
for (int i = 1; i < 10; i++) {
if (x > p)
return i + d;
p = 10 * p;
}
return 10 + d;
}
}
最后从jdk源码里偷了两个方法,省的自己写了
运行
我自己测试使用了jdk21和jdk23,jdk23的速度要更快一点。
运行环境:
- CPU:5600G
- 内存:16G
- 硬盘:PCIe 3.0 SSD
- 操作系统:Ubuntu 24.04.1 LTS
- JDK版本:OpenJDK Runtime Environment Zulu23.28+85-CA (build 23+37)
time java SortFile.java 1e8.txt
readTime = 856
sortTime = 1150
writeTime = 678
allTime = 2684
real 0m2.946s
user 0m20.174s
sys 0m3.709s
很神奇的是,用GraalVM打包为本地镜像之后,速度有了薛定谔的提升。
time ./sortfile 1e8.txt
readTime = 1184
sortTime = 930
writeTime = 583
allTime = 2697
real 0m2.840s
user 0m17.030s
sys 0m3.841s
相比jvm的方式运行,不止何种原因,读取速度大幅变慢了,但是排序和写文件却提升不少,有时候甚至能逆风打赢jvm。
有没有懂行的大佬帮忙解答一下。
尾巴
虽然打不赢大佬们的卷到指令集级别的速度,但是仅仅用原生提供的方法,也能达成一个还算可以的速度,至少我还算比较满意。
因为题目要求的是“正整数”、数量也仅有“一亿”,所以用了很多取巧的手段,如果题目变了代码也得变才行,就这样凑合凑合吧。
一种1亿个正整数读取、排序和保存的Java实现
https://halvamazd.lv6.fun/posts/1e8-positive-integer-sort-java/
