diff --git a/prepare_abeobk.sh b/prepare_abeobk.sh index 08a8afdcb..380e2093c 100755 --- a/prepare_abeobk.sh +++ b/prepare_abeobk.sh @@ -20,6 +20,6 @@ sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_abeobk_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -H:InlineAllBonus=10 -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -H:InlineAllBonus=10 -H:-GenLoopSafepoints --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index 2340bca79..88de5d2a9 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -34,7 +34,6 @@ import sun.misc.Unsafe; public class CalculateAverage_abeobk { - private static final boolean SHOW_ANALYSIS = false; private static final int CPU_CNT = Runtime.getRuntime().availableProcessors(); private static final String FILE = "./measurements.txt"; @@ -42,7 +41,7 @@ public class CalculateAverage_abeobk { private static final long BUCKET_MASK = BUCKET_SIZE - 1; private static final int MAX_STR_LEN = 100; private static final int MAX_STATIONS = 10000; - private static final long CHUNK_SZ = 1 << 22; // 4MB chunk + private static final long CHUNK_SZ = 1 << 22; private static final Unsafe UNSAFE = initUnsafe(); private static final long[] HASH_MASKS = new long[]{ 0x0L, @@ -60,10 +59,6 @@ public class CalculateAverage_abeobk { private static int chunk_cnt; private static long start_addr, end_addr; - private static final void debug(String s, Object... args) { - System.out.println(String.format(s, args)); - } - private static Unsafe initUnsafe() { try { Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); @@ -75,12 +70,117 @@ private static Unsafe initUnsafe() { } } - // use native type, less conversion - static class Node { + /* + * MAIN FUNCTION + */ + public static void main(String[] args) throws InterruptedException, IOException { + // thomaswue trick + if (args.length == 0 || !("--worker".equals(args[0]))) { + spawnWorker(); + return; + } + + var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); + long file_size = file.size(); + start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); + end_addr = start_addr + file_size; + + // only use all cpus on large file + int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT; + chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ); + + // spawn workers + for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) { + w.join(); + } + + // collect results + TreeMap ms = new TreeMap<>(); + for (var crr : mapref.get()) { + if (crr == null) + continue; + var prev = ms.putIfAbsent(crr.key(), crr); + if (prev != null) + prev.merge(crr); + } + // print result + System.out.println(ms); + System.out.close(); + } + + /* + * HELPER FUNCTIONS + */ + + // Get semicolon pos code + static final long getSemiCode(final long w) { + long x = w ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; + return (x - 0x0101010101010101L) & (~x & 0x8080808080808080L); + } + + // Get new line pos code + static final long getLFCode(final long w) { + long x = w ^ 0x0A0A0A0A0A0A0A0AL; // xor with \n\n\n\n\n\n\n\n + return (x - 0x0101010101010101L) & (~x & 0x8080808080808080L); + } + + // Get decimal point pos code + static final int getDotCode(final long w) { + return Long.numberOfTrailingZeros(~w & 0x10101000); + } + + // Convert semicolon pos code to position + static final int getSemiPos(final long spc) { + return Long.numberOfTrailingZeros(spc) >>> 3; + } + + // Find next line address + static final long nextLF(long addr) { + long word = UNSAFE.getLong(addr); + long lfpos_code = getLFCode(word); + while (lfpos_code == 0) { + addr += 8; + word = UNSAFE.getLong(addr); + lfpos_code = getLFCode(word); + } + return addr + (Long.numberOfTrailingZeros(lfpos_code) >>> 3) + 1; + } + + // Parse number + // great idea from merykitty (Quan Anh Mai) + static final long num(long w, int d) { + int shift = 28 - d; + long signed = (~w << 59) >> 63; + long dsmask = ~(signed & 0xFF); + long digits = ((w & dsmask) << shift) & 0x0F000F0F00L; + long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; + return ((abs_val ^ signed) - signed); + } + + // Hash mixer + static final long mix(long hash) { + long h = hash * 37; + return (h ^ (h >>> 29)); + } + + // Spawn worker (thomaswue trick + private static void spawnWorker() throws IOException { + ProcessHandle.Info info = ProcessHandle.current().info(); + ArrayList workerCommand = new ArrayList<>(); + info.command().ifPresent(workerCommand::add); + info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args))); + workerCommand.add("--worker"); + new ProcessBuilder() + .command(workerCommand) + .start() + .getInputStream() + .transferTo(System.out); + } + + final static class Node { long addr; long hash; long word0; - long tail; long sum; long min, max; int keylen; @@ -98,23 +198,36 @@ final String key() { return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8); } - Node(long a, long t, int kl, long h) { + Node(long a, long h, int kl, long v) { + addr = a; + min = max = v; + keylen = kl; + hash = h; + } + + Node(long a, long h, int kl) { addr = a; - tail = t; + hash = h; min = 999; max = -999; keylen = kl; + } + + Node(long a, long w0, long h, int kl, long v) { + addr = a; + word0 = w0; hash = h; + min = max = v; + keylen = kl; } - Node(long a, long w0, long t, int kl, long h) { + Node(long a, long w0, long h, int kl) { addr = a; word0 = w0; + hash = h; min = 999; max = -999; - tail = t; keylen = kl; - hash = h; } final void add(long val) { @@ -139,8 +252,8 @@ final void merge(Node other) { } } - final boolean contentEquals(long other_addr, long other_word0, long other_tail, long kl) { - if (word0 != other_word0 || tail != other_tail) + final boolean contentEquals(long other_addr, long other_word0, long other_hash, long kl) { + if (word0 != other_word0 || hash != other_hash) return false; // this is faster than comparision if key is short long xsum = 0; @@ -152,7 +265,7 @@ final boolean contentEquals(long other_addr, long other_word0, long other_tail, } final boolean contentEquals(Node other) { - if (tail != other.tail) + if (hash != other.hash) return false; long n = keylen & 0xF8; for (long i = 0; i < n; i += 8) { @@ -163,150 +276,13 @@ final boolean contentEquals(Node other) { } } - // idea from royvanrijn - static final long getSemiPosCode(final long word) { - long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; - return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); - } - - static final long getLFCode(final long word) { - long xor_semi = word ^ 0x0A0A0A0A0A0A0A0AL; // xor with \n\n\n\n\n\n\n\n - return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); - } - - static final long nextLine(long addr) { - long word = UNSAFE.getLong(addr); - long lfpos_code = getLFCode(word); - while (lfpos_code == 0) { - addr += 8; - word = UNSAFE.getLong(addr); - lfpos_code = getLFCode(word); - } - return addr + (Long.numberOfTrailingZeros(lfpos_code) >>> 3) + 1; - } - - // speed/collision balance - static final long xxh32(long hash) { - long h = hash * 37; - return (h ^ (h >>> 29)); - } - - static final class ChunkParser { - long addr; - long end; - Node[] map; - - ChunkParser(Node[] m, long a, long e) { - map = m; - addr = a; - end = e; - } - - final boolean ok() { - return addr < end; - } - - final long word() { - return UNSAFE.getLong(addr); - } - - final long val() { - long num_word = UNSAFE.getLong(addr); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 3; - // great idea from merykitty (Quan Anh Mai) - int shift = 28 - dot_pos; - long signed = (~num_word << 59) >> 63; - long dsmask = ~(signed & 0xFF); - long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; - long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; - return ((abs_val ^ signed) - signed); - } - - // optimize for contest - // save as much slow memory access as possible - // about 50% key < 8chars, 25% key bettween 8-10 chars - // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... - final Node key(long word0, long semipos_code) { - long row_addr = addr; - // about 50% chance key < 8 chars - if (semipos_code != 0) { - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos + 1; - long tail = word0 & HASH_MASKS[semi_pos]; - long hash = xxh32(tail); - int bucket = (int) (hash & BUCKET_MASK); - while (true) { - Node node = map[bucket]; - if (node == null) { - return (map[bucket] = new Node(row_addr, tail, semi_pos, hash)); - } - if (node.tail == tail) { - return node; - } - bucket++; - } - } - - addr += 8; - long word = UNSAFE.getLong(addr); - semipos_code = getSemiPosCode(word); - // 43% chance - if (semipos_code != 0) { - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos + 1; - long tail = (word & HASH_MASKS[semi_pos]); - long hash = xxh32(word0 ^ tail); - int bucket = (int) (hash & BUCKET_MASK); - while (true) { - Node node = map[bucket]; - if (node == null) { - return (map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash)); - } - if (node.word0 == word0 && node.tail == tail) { - return node; - } - bucket++; - } - } - - // why not going for more? tested, slower - long hash = word0; - while (semipos_code == 0) { - hash ^= word; - addr += 8; - word = UNSAFE.getLong(addr); - semipos_code = getSemiPosCode(word); - } - - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos; - long keylen = addr - row_addr; - addr++; - long tail = (word & HASH_MASKS[semi_pos]); - hash = xxh32(hash ^ tail); - int bucket = (int) (hash & BUCKET_MASK); - - while (true) { - Node node = map[bucket]; - if (node == null) { - return (map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash)); - } - if (node.contentEquals(row_addr, word0, tail, keylen)) { - return node; - } - bucket++; - } - } - } - // Thread pool worker static final class Worker extends Thread { final int thread_id; // for debug use only - int cls = 0; Worker(int i) { thread_id = i; + this.setPriority(Thread.MAX_PRIORITY); this.start(); } @@ -322,15 +298,15 @@ public void run() { // find start of line if (id > 0) { - addr = nextLine(addr); + addr = nextLF(addr); } final int num_segs = 3; long seglen = (end - addr) / num_segs; long a0 = addr; - long a1 = nextLine(addr + 1 * seglen); - long a2 = nextLine(addr + 2 * seglen); + long a1 = nextLF(addr + 1 * seglen); + long a2 = nextLF(addr + 2 * seglen); ChunkParser p0 = new ChunkParser(map, a0, a1); ChunkParser p1 = new ChunkParser(map, a1, a2); ChunkParser p2 = new ChunkParser(map, a2, end); @@ -339,9 +315,9 @@ public void run() { long w0 = p0.word(); long w1 = p1.word(); long w2 = p2.word(); - long sc0 = getSemiPosCode(w0); - long sc1 = getSemiPosCode(w1); - long sc2 = getSemiPosCode(w2); + long sc0 = getSemiCode(w0); + long sc1 = getSemiCode(w1); + long sc2 = getSemiCode(w2); Node n0 = p0.key(w0, sc0); Node n1 = p1.key(w1, sc1); Node n2 = p2.key(w2, sc2); @@ -355,21 +331,21 @@ public void run() { while (p0.ok()) { long w = p0.word(); - long sc = getSemiPosCode(w); + long sc = getSemiCode(w); Node n = p0.key(w, sc); long v = p0.val(); n.add(v); } while (p1.ok()) { long w = p1.word(); - long sc = getSemiPosCode(w); + long sc = getSemiCode(w); Node n = p1.key(w, sc); long v = p1.val(); n.add(v); } while (p2.ok()) { long w = p2.word(); - long sc = getSemiPosCode(w); + long sc = getSemiCode(w); Node n = p2.key(w, sc); long v = p2.val(); n.add(v); @@ -396,65 +372,127 @@ public void run() { break; } bucket++; - if (SHOW_ANALYSIS) - cls++; } } } } - - if (SHOW_ANALYSIS) { - debug("Thread %d collision = %d", thread_id, cls); - } } } - // thomaswue trick - private static void spawnWorker() throws IOException { - ProcessHandle.Info info = ProcessHandle.current().info(); - ArrayList workerCommand = new ArrayList<>(); - info.command().ifPresent(workerCommand::add); - info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args))); - workerCommand.add("--worker"); - new ProcessBuilder() - .command(workerCommand) - .start() - .getInputStream() - .transferTo(System.out); - } + static final class ChunkParser { + long addr; + long end; + Node[] map; - public static void main(String[] args) throws InterruptedException, IOException { - // thomaswue trick - if (args.length == 0 || !("--worker".equals(args[0]))) { - spawnWorker(); - return; + ChunkParser(Node[] m, long a, long e) { + map = m; + addr = a; + end = e; } - var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); - long file_size = file.size(); - start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); - end_addr = start_addr + file_size; + final boolean ok() { + return addr < end; + } - // only use all cpus on large file - int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT; - chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ); + final long word() { + return UNSAFE.getLong(addr); + } - // spawn workers - for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) { - w.join(); + final void skip(int n) { + addr += n; } - // collect results - TreeMap ms = new TreeMap<>(); - for (var crr : mapref.get()) { - if (crr == null) - continue; - var prev = ms.putIfAbsent(crr.key(), crr); - if (prev != null) - prev.merge(crr); + final void skip(long n) { + addr += n; + } + + final long val0() { + long w = word(); + int d = getDotCode(w); + return num(w, d); + } + + final long val() { + long w = word(); + int d = getDotCode(w); + skip((d >>> 3) + 3); + return num(w, d); + } + + // optimize for contest + // save as much slow memory access as possible + // about 50% key < 8chars, 25% key bettween 8-10 chars + // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... + final Node key(long word0, long semipos_code) { + long row_addr = addr; + // about 50% chance key < 8 chars + if (semipos_code != 0) { + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + skip(semi_pos + 1); + long tail = word0 & HASH_MASKS[semi_pos]; + long hash = mix(tail); + int bucket = (int) (hash & BUCKET_MASK); + while (true) { + Node node = map[bucket]; + if (node == null) { + return (map[bucket] = new Node(row_addr, hash, semi_pos)); + } + if (node.hash == hash) { + return node; + } + bucket++; + } + } + + skip(8); + long word = UNSAFE.getLong(addr); + semipos_code = getSemiCode(word); + // 43% chance + if (semipos_code != 0) { + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + skip(semi_pos + 1); + long tail = word0 ^ (word & HASH_MASKS[semi_pos]); + long hash = mix(tail); + int bucket = (int) (hash & BUCKET_MASK); + while (true) { + Node node = map[bucket]; + if (node == null) { + return (map[bucket] = new Node(row_addr, word0, hash, semi_pos + 8)); + } + if (node.word0 == word0 && node.hash == hash) { + return node; + } + bucket++; + } + } + + // why not going for more? tested, slower + long hash = word0; + while (semipos_code == 0) { + hash ^= word; + skip(8); + word = UNSAFE.getLong(addr); + semipos_code = getSemiCode(word); + } + + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + skip(semi_pos); + long keylen = addr - row_addr; + skip(1); + long tail = hash ^ (word & HASH_MASKS[semi_pos]); + hash = mix(tail); + int bucket = (int) (hash & BUCKET_MASK); + + while (true) { + Node node = map[bucket]; + if (node == null) { + return (map[bucket] = new Node(row_addr, word0, hash, (int) keylen)); + } + if (node.contentEquals(row_addr, word0, hash, keylen)) { + return node; + } + bucket++; + } } - // print result - System.out.println(ms); - System.out.close(); } } \ No newline at end of file