-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Optimize String write #1651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize String write #1651
Changes from 11 commits
d0f68c4
73abcc9
577f6bf
502f84b
879ddbd
759381d
e29638d
8ba1940
3980457
d9cf649
3ff0644
43f1663
dd5fe4d
e4f6f31
db424d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
package com.mongodb.internal.connection; | ||
|
||
import org.bson.BsonSerializationException; | ||
import org.bson.ByteBuf; | ||
import org.bson.io.OutputBuffer; | ||
|
||
|
@@ -25,8 +26,10 @@ | |
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import static com.mongodb.assertions.Assertions.assertFalse; | ||
import static com.mongodb.assertions.Assertions.assertTrue; | ||
import static com.mongodb.assertions.Assertions.notNull; | ||
import static java.lang.String.format; | ||
|
||
/** | ||
* <p>This class is not part of the public API and may be removed or changed at any time</p> | ||
|
@@ -100,11 +103,17 @@ private ByteBuf getCurrentByteBuffer() { | |
return getByteBufferAtIndex(curBufferIndex); | ||
} | ||
|
||
private ByteBuf getNextByteBuffer() { | ||
assertFalse(bufferList.get(curBufferIndex).hasRemaining()); | ||
return getByteBufferAtIndex(++curBufferIndex); | ||
} | ||
|
||
private ByteBuf getByteBufferAtIndex(final int index) { | ||
if (bufferList.size() < index + 1) { | ||
bufferList.add(bufferProvider.getBuffer(index >= (MAX_SHIFT - INITIAL_SHIFT) | ||
? MAX_BUFFER_SIZE | ||
: Math.min(INITIAL_BUFFER_SIZE << index, MAX_BUFFER_SIZE))); | ||
ByteBuf buffer = bufferProvider.getBuffer(index >= (MAX_SHIFT - INITIAL_SHIFT) | ||
? MAX_BUFFER_SIZE | ||
: Math.min(INITIAL_BUFFER_SIZE << index, MAX_BUFFER_SIZE)); | ||
bufferList.add(buffer); | ||
} | ||
return bufferList.get(index); | ||
} | ||
|
@@ -147,6 +156,16 @@ public List<ByteBuf> getByteBuffers() { | |
return buffers; | ||
} | ||
|
||
public List<ByteBuf> getDuplicateByteBuffers() { | ||
ensureOpen(); | ||
|
||
List<ByteBuf> buffers = new ArrayList<>(bufferList.size()); | ||
for (final ByteBuf cur : bufferList) { | ||
buffers.add(cur.duplicate().order(ByteOrder.LITTLE_ENDIAN)); | ||
} | ||
return buffers; | ||
} | ||
|
||
|
||
@Override | ||
public int pipe(final OutputStream out) throws IOException { | ||
|
@@ -155,14 +174,18 @@ public int pipe(final OutputStream out) throws IOException { | |
byte[] tmp = new byte[INITIAL_BUFFER_SIZE]; | ||
|
||
int total = 0; | ||
for (final ByteBuf cur : getByteBuffers()) { | ||
ByteBuf dup = cur.duplicate(); | ||
while (dup.hasRemaining()) { | ||
int numBytesToCopy = Math.min(dup.remaining(), tmp.length); | ||
dup.get(tmp, 0, numBytesToCopy); | ||
out.write(tmp, 0, numBytesToCopy); | ||
List<ByteBuf> byteBuffers = getByteBuffers(); | ||
try { | ||
for (final ByteBuf cur : byteBuffers) { | ||
while (cur.hasRemaining()) { | ||
int numBytesToCopy = Math.min(cur.remaining(), tmp.length); | ||
cur.get(tmp, 0, numBytesToCopy); | ||
out.write(tmp, 0, numBytesToCopy); | ||
} | ||
total += cur.limit(); | ||
} | ||
total += dup.limit(); | ||
} finally { | ||
byteBuffers.forEach(ByteBuf::release); | ||
rozza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
return total; | ||
} | ||
|
@@ -282,4 +305,165 @@ private static final class BufferPositionPair { | |
this.position = position; | ||
} | ||
} | ||
|
||
protected int writeCharacters(final String str, final boolean checkNullTermination) { | ||
rozza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int stringLength = str.length(); | ||
int sp = 0; | ||
int prevPos = position; | ||
|
||
ByteBuf curBuffer = getCurrentByteBuffer(); | ||
int curBufferPos = curBuffer.position(); | ||
int curBufferLimit = curBuffer.limit(); | ||
int remaining = curBufferLimit - curBufferPos; | ||
|
||
if (curBuffer.hasArray()) { | ||
byte[] dst = curBuffer.array(); | ||
int arrayOffset = curBuffer.arrayOffset(); | ||
if (remaining >= str.length() + 1) { | ||
// Write ASCII characters directly to the array until we hit a non-ASCII character. | ||
sp = writeOnArrayAscii(str, dst, arrayOffset + curBufferPos, checkNullTermination); | ||
curBufferPos += sp; | ||
// If the whole string was written as ASCII, append the null terminator. | ||
if (sp == stringLength) { | ||
dst[arrayOffset + curBufferPos++] = 0; | ||
position += sp + 1; | ||
curBuffer.position(curBufferPos); | ||
return sp + 1; | ||
} | ||
// Otherwise, update the position to reflect the partial write. | ||
position += sp; | ||
curBuffer.position(curBufferPos); | ||
} | ||
} | ||
|
||
// We get here, when the buffer is not backed by an array, or when the string contains at least one non-ASCII characters. | ||
return writeOnBuffers(str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we have within this a fast PATH for ASCII too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’d expect the fast path for buffers to be in the Are you suggesting we add a fast path similar to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yep, since I see the ascii path there is already taking care to change the buffer to write against instead of performing a lookup per each byte to write. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it makes sense. Thanks for the suggestion. I implemented The implementation of writeOnBuffersAsccii
I printed the assembly to see the JIT’s behavior. It looks like the loop wasn’t unrolled by JIT - there’s only one I’ve shared a GitHub Gist with the shortened assembly (keeping the key parts) and a pseudo-Java interpretation to show how the assembly might map back to the logic: Gist. Local perf showed modest gains likely limited by dynamic buffer allocation, as you noted. I’ll run more tests on a dedicated perf instance to confirm. If I missed anything in the assembly, please let me know! I’m merging this PR for the current improvements, but I agree tighter loops or manual unrolling could be further explored, keeping in mind the maintainability trade-off. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I cannot see the assembly there , but the not decoded binary instead - did you miss the https://blogs.oracle.com/javamagazine/post/java-hotspot-hsdis-disassembler There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean something which uses compile command too as TechEmpower/FrameworkBenchmarks#9800 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You’re right - my earlier Gist showed raw hex. I recompiled on Oracle JDK 17.0.7 with The main loop (0x0000000113a12c40–0x0000000113a12d3c) seems to have no unrolling:
A second There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mmm It looks to me that the unrolling was having a factor of two, I will re read it again since I am more used to x86 asm :) Anyway, I suggest to look at the PR I sent for this same optimization: having the check for remaining buffer space in the loop would bloat the loop body, reducing the chances that C2 will unroll it many times. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another reason why the Netty version loop body is too fat is because the JIT doesn't trust final fields and since you can get a new mongo ByteBuf at each iteration, the required amount of pointer chase (mongo buf, Netty swapped big, Netty buf, Unsafe..) is way to much; this prevent massively to have unrolling. |
||
checkNullTermination, | ||
sp, | ||
stringLength, | ||
curBufferLimit, | ||
curBufferPos, | ||
curBuffer, | ||
prevPos); | ||
} | ||
|
||
private int writeOnBuffers(final String str, | ||
final boolean checkNullTermination, | ||
final int stringPointer, | ||
final int stringLength, | ||
final int bufferLimit, | ||
final int bufferPos, | ||
final ByteBuf buffer, | ||
final int prevPos) { | ||
int remaining; | ||
int sp = stringPointer; | ||
int curBufferPos = bufferPos; | ||
int curBufferLimit = bufferLimit; | ||
ByteBuf curBuffer = buffer; | ||
while (sp < stringLength) { | ||
remaining = curBufferLimit - curBufferPos; | ||
int c = str.charAt(sp); | ||
|
||
if (checkNullTermination && c == 0x0) { | ||
throw new BsonSerializationException( | ||
format("BSON cstring '%s' is not valid because it contains a null character " + "at index %d", str, sp)); | ||
} | ||
|
||
if (c < 0x80) { | ||
if (remaining == 0) { | ||
curBuffer = getNextByteBuffer(); | ||
curBufferPos = 0; | ||
curBufferLimit = curBuffer.limit(); | ||
} | ||
curBuffer.put((byte) c); | ||
curBufferPos++; | ||
position++; | ||
} else if (c < 0x800) { | ||
if (remaining < 2) { | ||
// Not enough space: use write() to handle buffer boundary | ||
write((byte) (0xc0 + (c >> 6))); | ||
write((byte) (0x80 + (c & 0x3f))); | ||
|
||
curBuffer = getCurrentByteBuffer(); | ||
curBufferPos = curBuffer.position(); | ||
curBufferLimit = curBuffer.limit(); | ||
} else { | ||
curBuffer.put((byte) (0xc0 + (c >> 6))); | ||
curBuffer.put((byte) (0x80 + (c & 0x3f))); | ||
curBufferPos += 2; | ||
position += 2; | ||
} | ||
} else { | ||
// Handle multibyte characters (may involve surrogate pairs). | ||
c = Character.codePointAt(str, sp); | ||
/* | ||
Malformed surrogate pairs are encoded as-is (3 byte code unit) without substituting any code point. | ||
This known deviation from the spec and current functionality remains for backward compatibility. | ||
Ticket: JAVA-5575 | ||
*/ | ||
if (c < 0x10000) { | ||
if (remaining < 3) { | ||
write((byte) (0xe0 + (c >> 12))); | ||
write((byte) (0x80 + ((c >> 6) & 0x3f))); | ||
write((byte) (0x80 + (c & 0x3f))); | ||
|
||
curBuffer = getCurrentByteBuffer(); | ||
curBufferPos = curBuffer.position(); | ||
curBufferLimit = curBuffer.limit(); | ||
} else { | ||
curBuffer.put((byte) (0xe0 + (c >> 12))); | ||
curBuffer.put((byte) (0x80 + ((c >> 6) & 0x3f))); | ||
curBuffer.put((byte) (0x80 + (c & 0x3f))); | ||
curBufferPos += 3; | ||
position += 3; | ||
} | ||
} else { | ||
if (remaining < 4) { | ||
write((byte) (0xf0 + (c >> 18))); | ||
write((byte) (0x80 + ((c >> 12) & 0x3f))); | ||
write((byte) (0x80 + ((c >> 6) & 0x3f))); | ||
write((byte) (0x80 + (c & 0x3f))); | ||
|
||
curBuffer = getCurrentByteBuffer(); | ||
curBufferPos = curBuffer.position(); | ||
curBufferLimit = curBuffer.limit(); | ||
} else { | ||
curBuffer.put((byte) (0xf0 + (c >> 18))); | ||
curBuffer.put((byte) (0x80 + ((c >> 12) & 0x3f))); | ||
curBuffer.put((byte) (0x80 + ((c >> 6) & 0x3f))); | ||
curBuffer.put((byte) (0x80 + (c & 0x3f))); | ||
curBufferPos += 4; | ||
position += 4; | ||
} | ||
} | ||
} | ||
sp += Character.charCount(c); | ||
} | ||
|
||
getCurrentByteBuffer().put((byte) 0); | ||
position++; | ||
return position - prevPos; | ||
} | ||
|
||
private static int writeOnArrayAscii(final String str, | ||
final byte[] dst, | ||
final int arrayPosition, | ||
final boolean checkNullTermination) { | ||
int pos = arrayPosition; | ||
int sp = 0; | ||
// Fast common path: This tight loop is JIT-friendly (simple, no calls, few branches), | ||
// It might be unrolled for performance. | ||
for (; sp < str.length(); sp++, pos++) { | ||
char c = str.charAt(sp); | ||
if (checkNullTermination && c == 0) { | ||
throw new BsonSerializationException( | ||
format("BSON cstring '%s' is not valid because it contains a null character " + "at index %d", str, sp)); | ||
} | ||
if (c >= 0x80) { | ||
break; | ||
} | ||
dst[pos] = (byte) c; | ||
} | ||
return sp; | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.