Skip to content

Fix Trailer based Http Checksum for Async Request body with variable chunk size #3380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSSDKforJavav2-5ecdce1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Fixed issue where request used to fail while calculating Trailer based checksum for Async File Request body."
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.createChunk;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Subscriber;
Expand All @@ -41,12 +43,16 @@
@SdkInternalApi
public class ChecksumCalculatingAsyncRequestBody implements AsyncRequestBody {

public static final int DEFAULT_CHUNK_SIZE = 16 * 1024;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making it a public member in this class, can we create a constant class to put all checksum related chunk size constants?


public static final byte[] FINAL_BYTE = new byte[0];
private final AsyncRequestBody wrapped;
private final SdkChecksum sdkChecksum;
private final Algorithm algorithm;
private final String trailerHeader;
private final AtomicLong remainingBytes;
private final long totalBytes;
private final ByteBuffer currentBuffer;

private ChecksumCalculatingAsyncRequestBody(DefaultBuilder builder) {

Expand All @@ -57,8 +63,10 @@ private ChecksumCalculatingAsyncRequestBody(DefaultBuilder builder) {
this.algorithm = builder.algorithm;
this.sdkChecksum = builder.algorithm != null ? SdkChecksum.forAlgorithm(algorithm) : null;
this.trailerHeader = builder.trailerHeader;
this.remainingBytes = new AtomicLong(wrapped.contentLength()
.orElseThrow(() -> new UnsupportedOperationException("Content length must be supplied.")));
this.totalBytes = wrapped.contentLength()
.orElseThrow(() -> new UnsupportedOperationException("Content length must be supplied."));
this.remainingBytes = new AtomicLong();
this.currentBuffer = ByteBuffer.allocate(DEFAULT_CHUNK_SIZE);
}

/**
Expand Down Expand Up @@ -148,7 +156,11 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
if (sdkChecksum != null) {
sdkChecksum.reset();
}
wrapped.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum, trailerHeader, remainingBytes));

this.remainingBytes.set(totalBytes);

wrapped.flatMapIterable(this::bufferAndCreateChunks)
.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum, trailerHeader, totalBytes));
}

private static final class ChecksumCalculatingSubscriber implements Subscriber<ByteBuffer> {
Expand All @@ -162,11 +174,11 @@ private static final class ChecksumCalculatingSubscriber implements Subscriber<B

ChecksumCalculatingSubscriber(Subscriber<? super ByteBuffer> wrapped,
SdkChecksum checksum,
String trailerHeader, AtomicLong remainingBytes) {
String trailerHeader, long totalBytes) {
this.wrapped = wrapped;
this.checksum = checksum;
this.trailerHeader = trailerHeader;
this.remainingBytes = remainingBytes;
this.remainingBytes = new AtomicLong(totalBytes);
}

@Override
Expand All @@ -189,7 +201,8 @@ public void onNext(ByteBuffer byteBuffer) {
ByteBuffer allocatedBuffer = getFinalChecksumAppendedChunk(byteBuffer);
wrapped.onNext(allocatedBuffer);
} else {
wrapped.onNext(byteBuffer);
ByteBuffer allocatedBuffer = createChunk(byteBuffer, false);
wrapped.onNext(allocatedBuffer);
}
} catch (SdkException sdkException) {
this.subscription.cancel();
Expand Down Expand Up @@ -225,4 +238,48 @@ public void onComplete() {
wrapped.onComplete();
}
}

private Iterable<ByteBuffer> bufferAndCreateChunks(ByteBuffer buffer) {
int startPosition = 0;
int currentBytesRead = buffer.remaining();

List<ByteBuffer> resultBufferedList = new ArrayList<>();
do {
int bufferedBytes = currentBuffer.position();
int availableToRead = DEFAULT_CHUNK_SIZE - bufferedBytes;
int bytesToMove = Math.min(availableToRead, currentBytesRead - startPosition);

if (bufferedBytes == 0) {
currentBuffer.put(buffer.array(), startPosition, bytesToMove);
} else {
currentBuffer.put(buffer.array(), 0, bytesToMove);
}

startPosition = startPosition + bytesToMove;

// Send the data once the buffer is full
if (currentBuffer.position() == DEFAULT_CHUNK_SIZE) {
currentBuffer.position(0);
ByteBuffer bufferToSend = ByteBuffer.allocate(DEFAULT_CHUNK_SIZE);
bufferToSend.put(currentBuffer.array(), 0, DEFAULT_CHUNK_SIZE);
bufferToSend.clear();
currentBuffer.clear();
resultBufferedList.add(bufferToSend);
remainingBytes.addAndGet(-DEFAULT_CHUNK_SIZE);
}

} while (startPosition < currentBytesRead);

int bufferedBytes = currentBuffer.position();
// Send the remainder buffered bytes at the end when there no more bytes
if (bufferedBytes > 0 && remainingBytes.get() == bufferedBytes) {
currentBuffer.clear();
ByteBuffer trimmedBuffer = ByteBuffer.allocate(bufferedBytes);
trimmedBuffer.put(currentBuffer.array(), 0, bufferedBytes);
trimmedBuffer.clear();
resultBufferedList.add(trimmedBuffer);
remainingBytes.addAndGet(-bufferedBytes);
}
return resultBufferedList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package software.amazon.awssdk.core.internal.interceptor;

import static software.amazon.awssdk.core.internal.io.AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength;

import java.util.Optional;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.ClientType;
Expand Down Expand Up @@ -97,7 +99,9 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context,
private static SdkHttpRequest updateHeadersForTrailerChecksum(Context.ModifyHttpRequest context, ChecksumSpecs checksum,
long checksumContentLength, long originalContentLength) {

long chunkLength = ChunkContentUtils.calculateChunkLength(originalContentLength);
long chunkLength =
calculateStreamContentLength(originalContentLength, ChecksumCalculatingAsyncRequestBody.DEFAULT_CHUNK_SIZE);

return context.httpRequest().copy(r ->
r.putHeader(HttpChecksumConstant.HEADER_FOR_TRAILER_REFERENCE, checksum.headerName())
.putHeader("Content-encoding", HttpChecksumConstant.AWS_CHUNKED_HEADER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.internal.io.AwsChunkedEncodingInputStream;
import software.amazon.awssdk.core.internal.io.AwsUnsignedChunkedEncodingInputStream;
import software.amazon.awssdk.core.internal.util.HttpChecksumResolver;
import software.amazon.awssdk.core.internal.util.HttpChecksumUtils;
Expand Down Expand Up @@ -73,7 +74,7 @@ public Optional<RequestBody> modifyHttpContent(Context.ModifyHttpRequest context
RequestBody.fromContentProvider(
streamProvider,
AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength(
requestBody.optionalContentLength().orElse(0L))
requestBody.optionalContentLength().orElse(0L), AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE)
+ checksumContentLength,
requestBody.contentType()));
}
Expand Down Expand Up @@ -112,7 +113,8 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, Execu
.putHeader("x-amz-decoded-content-length", Long.toString(originalContentLength))
.putHeader(CONTENT_LENGTH,
Long.toString(AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength(
originalContentLength) + checksumContentLength)));
originalContentLength, AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE)
+ checksumContentLength)));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
@SdkInternalApi
public abstract class AwsChunkedEncodingInputStream extends SdkInputStream {

protected static final int DEFAULT_CHUNK_SIZE = 128 * 1024;
public static final int DEFAULT_CHUNK_SIZE = 128 * 1024;
protected static final int SKIP_BUFFER_SIZE = 256 * 1024;
protected static final String CRLF = "\r\n";
protected static final byte[] FINAL_CHUNK = new byte[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@ public static Builder builder() {
* @return Content length of the trailer that will be appended at the end.
*/
public static long calculateChecksumContentLength(Algorithm algorithm, String headerName) {
int checksumLength = algorithm.base64EncodedLength();

return (headerName.length()
+ HEADER_COLON_SEPARATOR.length()
+ checksumLength
+ CRLF.length());
return headerName.length()
+ HEADER_COLON_SEPARATOR.length()
+ algorithm.base64EncodedLength().longValue()
+ CRLF.length() + CRLF.length();
}

/**
Expand All @@ -68,17 +66,20 @@ private static long calculateChunkLength(long originalContentLength) {
+ CRLF.length();
}

public static long calculateStreamContentLength(long originalLength) {
if (originalLength < 0) {
throw new IllegalArgumentException("Non negative content length expected.");
public static long calculateStreamContentLength(long originalLength, long defaultChunkSize) {
if (originalLength < 0 || defaultChunkSize == 0) {
throw new IllegalArgumentException(originalLength + ", " + defaultChunkSize + "Args <= 0 not expected");
}

long maxSizeChunks = originalLength / DEFAULT_CHUNK_SIZE;
long remainingBytes = originalLength % DEFAULT_CHUNK_SIZE;
long maxSizeChunks = originalLength / defaultChunkSize;
long remainingBytes = originalLength % defaultChunkSize;

long allChunks = maxSizeChunks * calculateChunkLength(defaultChunkSize);
long remainingInChunk = remainingBytes > 0 ? calculateChunkLength(remainingBytes) : 0;
// last byte is composed of a "0" and "\r\n"
long lastByteSize = 1 + (long) CRLF.length();

return maxSizeChunks * calculateChunkLength(DEFAULT_CHUNK_SIZE)
+ (remainingBytes > 0 ? calculateChunkLength(remainingBytes) : 0)
+ calculateChunkLength(0);
return allChunks + remainingInChunk + lastByteSize;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ public void readAwsUnsignedChunkedEncodingInputStream() throws IOException {
public void lengthsOfCalculateByChecksumCalculatingInputStream(){

String initialString = "Hello world";
long calculateChunkLength = AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength(initialString.length());
long calculateChunkLength = AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength(initialString.length(),
AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE);
long checksumContentLength = AwsUnsignedChunkedEncodingInputStream.calculateChecksumContentLength(
SHA256_ALGORITHM, SHA256_HEADER_NAME);
assertThat(calculateChunkLength).isEqualTo(21);
assertThat(checksumContentLength).isEqualTo(69);
assertThat(calculateChunkLength).isEqualTo(19);
assertThat(checksumContentLength).isEqualTo(71);
}

@Test
Expand Down
Loading