Skip to content

Fix Trailer based Http Checksum for Async Request body with variable chunk size #3380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSSDKforJavav2-5ecdce1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Fixed issue where request used to fail while calculating Trailer based checksum for Async Request body."
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ public final class HttpChecksumConstant {

public static final String HEADER_FOR_TRAILER_REFERENCE = "x-amz-trailer";

/**
* Default chunk size for Async trailer based checksum data transfer*
*/
public static final int DEFAULT_ASYNC_CHUNK_SIZE = 16 * 1024;

private HttpChecksumConstant() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.core.internal.async;

import static software.amazon.awssdk.core.HttpChecksumConstant.DEFAULT_ASYNC_CHUNK_SIZE;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChecksumContentLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateChunkLength;
import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.createChecksumTrailer;
Expand All @@ -41,12 +42,12 @@
@SdkInternalApi
public class ChecksumCalculatingAsyncRequestBody implements AsyncRequestBody {

public static final byte[] FINAL_BYTE = new byte[0];
private static final byte[] FINAL_BYTE = new byte[0];
private final AsyncRequestBody wrapped;
private final SdkChecksum sdkChecksum;
private final Algorithm algorithm;
private final String trailerHeader;
private final AtomicLong remainingBytes;
private final long totalBytes;

private ChecksumCalculatingAsyncRequestBody(DefaultBuilder builder) {

Expand All @@ -57,8 +58,8 @@ private ChecksumCalculatingAsyncRequestBody(DefaultBuilder builder) {
this.algorithm = builder.algorithm;
this.sdkChecksum = builder.algorithm != null ? SdkChecksum.forAlgorithm(algorithm) : null;
this.trailerHeader = builder.trailerHeader;
this.remainingBytes = new AtomicLong(wrapped.contentLength()
.orElseThrow(() -> new UnsupportedOperationException("Content length must be supplied.")));
this.totalBytes = wrapped.contentLength()
.orElseThrow(() -> new UnsupportedOperationException("Content length must be supplied."));
}

/**
Expand Down Expand Up @@ -148,7 +149,10 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
if (sdkChecksum != null) {
sdkChecksum.reset();
}
wrapped.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum, trailerHeader, remainingBytes));

SynchronousChunkBuffer synchronousChunkBuffer = new SynchronousChunkBuffer(totalBytes);
wrapped.flatMapIterable(synchronousChunkBuffer::buffer)
.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum, trailerHeader, totalBytes));
}

private static final class ChecksumCalculatingSubscriber implements Subscriber<ByteBuffer> {
Expand All @@ -162,11 +166,11 @@ private static final class ChecksumCalculatingSubscriber implements Subscriber<B

ChecksumCalculatingSubscriber(Subscriber<? super ByteBuffer> wrapped,
SdkChecksum checksum,
String trailerHeader, AtomicLong remainingBytes) {
String trailerHeader, long totalBytes) {
this.wrapped = wrapped;
this.checksum = checksum;
this.trailerHeader = trailerHeader;
this.remainingBytes = remainingBytes;
this.remainingBytes = new AtomicLong(totalBytes);
}

@Override
Expand All @@ -189,7 +193,8 @@ public void onNext(ByteBuffer byteBuffer) {
ByteBuffer allocatedBuffer = getFinalChecksumAppendedChunk(byteBuffer);
wrapped.onNext(allocatedBuffer);
} else {
wrapped.onNext(byteBuffer);
ByteBuffer allocatedBuffer = createChunk(byteBuffer, false);
wrapped.onNext(allocatedBuffer);
}
} catch (SdkException sdkException) {
this.subscription.cancel();
Expand All @@ -201,7 +206,7 @@ private ByteBuffer getFinalChecksumAppendedChunk(ByteBuffer byteBuffer) {
ByteBuffer finalChunkedByteBuffer = createChunk(ByteBuffer.wrap(FINAL_BYTE), true);
ByteBuffer checksumTrailerByteBuffer = createChecksumTrailer(
BinaryUtils.toBase64(checksumBytes), trailerHeader);
ByteBuffer contentChunk = createChunk(byteBuffer, false);
ByteBuffer contentChunk = byteBuffer.hasRemaining() ? createChunk(byteBuffer, false) : byteBuffer;

ByteBuffer checksumAppendedBuffer = ByteBuffer.allocate(
contentChunk.remaining()
Expand All @@ -225,4 +230,17 @@ public void onComplete() {
wrapped.onComplete();
}
}

private static final class SynchronousChunkBuffer {
private final ChunkBuffer chunkBuffer;

SynchronousChunkBuffer(long totalBytes) {
this.chunkBuffer = ChunkBuffer.builder().bufferSize(DEFAULT_ASYNC_CHUNK_SIZE).totalBytes(totalBytes).build();
}

private Iterable<ByteBuffer> buffer(ByteBuffer bytes) {
return chunkBuffer.bufferAndCreateChunks(bytes);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.core.internal.async;

import static software.amazon.awssdk.core.HttpChecksumConstant.DEFAULT_ASYNC_CHUNK_SIZE;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.builder.SdkBuilder;

/**
* Class that will buffer incoming BufferBytes of totalBytes length to chunks of bufferSize*
*/
@SdkInternalApi
public final class ChunkBuffer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, we should probably make this thread safe

private final AtomicLong remainingBytes;
private final ByteBuffer currentBuffer;
private final int bufferSize;

private ChunkBuffer(Long totalBytes, Integer bufferSize) {
Validate.notNull(totalBytes, "The totalBytes must not be null");

int chunkSize = bufferSize != null ? bufferSize : DEFAULT_ASYNC_CHUNK_SIZE;
this.bufferSize = chunkSize;
this.currentBuffer = ByteBuffer.allocate(chunkSize);
this.remainingBytes = new AtomicLong(totalBytes);
}

public static Builder builder() {
return new DefaultBuilder();
}


// currentBuffer and bufferedList can get over written if concurrent Threads calls this method at the same time.
public synchronized Iterable<ByteBuffer> bufferAndCreateChunks(ByteBuffer buffer) {
int startPosition = 0;
List<ByteBuffer> bufferedList = new ArrayList<>();
int currentBytesRead = buffer.remaining();
do {
int bufferedBytes = currentBuffer.position();
int availableToRead = bufferSize - bufferedBytes;
int bytesToMove = Math.min(availableToRead, currentBytesRead - startPosition);

if (bufferedBytes == 0) {
currentBuffer.put(buffer.array(), startPosition, bytesToMove);
} else {
currentBuffer.put(buffer.array(), 0, bytesToMove);
}

startPosition = startPosition + bytesToMove;

// Send the data once the buffer is full
if (currentBuffer.position() == bufferSize) {
currentBuffer.position(0);
ByteBuffer bufferToSend = ByteBuffer.allocate(bufferSize);
bufferToSend.put(currentBuffer.array(), 0, bufferSize);
bufferToSend.clear();
currentBuffer.clear();
bufferedList.add(bufferToSend);
remainingBytes.addAndGet(-bufferSize);
}
} while (startPosition < currentBytesRead);

int remainingBytesInBuffer = currentBuffer.position();

// Send the remaining buffer when
// 1. remainingBytes in buffer are same as the last few bytes to be read.
// 2. If it is a zero byte and the last byte to be read.
if (remainingBytes.get() == remainingBytesInBuffer &&
(buffer.remaining() == 0 || remainingBytesInBuffer > 0)) {
currentBuffer.clear();
ByteBuffer trimmedBuffer = ByteBuffer.allocate(remainingBytesInBuffer);
trimmedBuffer.put(currentBuffer.array(), 0, remainingBytesInBuffer);
trimmedBuffer.clear();
bufferedList.add(trimmedBuffer);
remainingBytes.addAndGet(-remainingBytesInBuffer);
}
return bufferedList;
}

public interface Builder extends SdkBuilder<Builder, ChunkBuffer> {

Builder bufferSize(int bufferSize);

Builder totalBytes(long totalBytes);


}

private static final class DefaultBuilder implements Builder {

private Integer bufferSize;
private Long totalBytes;

@Override
public ChunkBuffer build() {
return new ChunkBuffer(totalBytes, bufferSize);
}

@Override
public Builder bufferSize(int bufferSize) {
this.bufferSize = bufferSize;
return this;
}

@Override
public Builder totalBytes(long totalBytes) {
this.totalBytes = totalBytes;
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

package software.amazon.awssdk.core.internal.interceptor;

import static software.amazon.awssdk.core.HttpChecksumConstant.DEFAULT_ASYNC_CHUNK_SIZE;
import static software.amazon.awssdk.core.internal.io.AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength;

import java.util.Optional;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.ClientType;
Expand Down Expand Up @@ -97,7 +100,9 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context,
private static SdkHttpRequest updateHeadersForTrailerChecksum(Context.ModifyHttpRequest context, ChecksumSpecs checksum,
long checksumContentLength, long originalContentLength) {

long chunkLength = ChunkContentUtils.calculateChunkLength(originalContentLength);
long chunkLength =
calculateStreamContentLength(originalContentLength, DEFAULT_ASYNC_CHUNK_SIZE);

return context.httpRequest().copy(r ->
r.putHeader(HttpChecksumConstant.HEADER_FOR_TRAILER_REFERENCE, checksum.headerName())
.putHeader("Content-encoding", HttpChecksumConstant.AWS_CHUNKED_HEADER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.internal.io.AwsChunkedEncodingInputStream;
import software.amazon.awssdk.core.internal.io.AwsUnsignedChunkedEncodingInputStream;
import software.amazon.awssdk.core.internal.util.HttpChecksumResolver;
import software.amazon.awssdk.core.internal.util.HttpChecksumUtils;
Expand Down Expand Up @@ -73,7 +74,7 @@ public Optional<RequestBody> modifyHttpContent(Context.ModifyHttpRequest context
RequestBody.fromContentProvider(
streamProvider,
AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength(
requestBody.optionalContentLength().orElse(0L))
requestBody.optionalContentLength().orElse(0L), AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE)
+ checksumContentLength,
requestBody.contentType()));
}
Expand Down Expand Up @@ -112,7 +113,8 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, Execu
.putHeader("x-amz-decoded-content-length", Long.toString(originalContentLength))
.putHeader(CONTENT_LENGTH,
Long.toString(AwsUnsignedChunkedEncodingInputStream.calculateStreamContentLength(
originalContentLength) + checksumContentLength)));
originalContentLength, AwsChunkedEncodingInputStream.DEFAULT_CHUNK_SIZE)
+ checksumContentLength)));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
@SdkInternalApi
public abstract class AwsChunkedEncodingInputStream extends SdkInputStream {

protected static final int DEFAULT_CHUNK_SIZE = 128 * 1024;
public static final int DEFAULT_CHUNK_SIZE = 128 * 1024;
protected static final int SKIP_BUFFER_SIZE = 256 * 1024;
protected static final String CRLF = "\r\n";
protected static final byte[] FINAL_CHUNK = new byte[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@ public static Builder builder() {
* @return Content length of the trailer that will be appended at the end.
*/
public static long calculateChecksumContentLength(Algorithm algorithm, String headerName) {
int checksumLength = algorithm.base64EncodedLength();

return (headerName.length()
+ HEADER_COLON_SEPARATOR.length()
+ checksumLength
+ CRLF.length());
return headerName.length()
+ HEADER_COLON_SEPARATOR.length()
+ algorithm.base64EncodedLength().longValue()
+ CRLF.length() + CRLF.length();
}

/**
Expand All @@ -68,17 +66,20 @@ private static long calculateChunkLength(long originalContentLength) {
+ CRLF.length();
}

public static long calculateStreamContentLength(long originalLength) {
if (originalLength < 0) {
throw new IllegalArgumentException("Non negative content length expected.");
public static long calculateStreamContentLength(long originalLength, long defaultChunkSize) {
if (originalLength < 0 || defaultChunkSize == 0) {
throw new IllegalArgumentException(originalLength + ", " + defaultChunkSize + "Args <= 0 not expected");
}

long maxSizeChunks = originalLength / DEFAULT_CHUNK_SIZE;
long remainingBytes = originalLength % DEFAULT_CHUNK_SIZE;
long maxSizeChunks = originalLength / defaultChunkSize;
long remainingBytes = originalLength % defaultChunkSize;

long allChunks = maxSizeChunks * calculateChunkLength(defaultChunkSize);
long remainingInChunk = remainingBytes > 0 ? calculateChunkLength(remainingBytes) : 0;
// last byte is composed of a "0" and "\r\n"
long lastByteSize = 1 + (long) CRLF.length();

return maxSizeChunks * calculateChunkLength(DEFAULT_CHUNK_SIZE)
+ (remainingBytes > 0 ? calculateChunkLength(remainingBytes) : 0)
+ calculateChunkLength(0);
return allChunks + remainingInChunk + lastByteSize;
}

@Override
Expand Down
Loading