Skip to content

Commit

Permalink
Pass ByteBufAllocator to methods to allow usage for allocations (#43)
Browse files Browse the repository at this point in the history
Motivation:

We might need to do temporary allocations in our implementations, to minimize the overhead we should allow the user to pass the ByteBufAllocator that should be used.

Modifications:

- Change methods signatures to take ByteBufAllocator as well
- Use allocator to allocate the used Nonce and so reduce overhead

Result:

Less overhead
  • Loading branch information
normanmaurer authored Jan 11, 2024
1 parent be308f6 commit dbcbc74
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.incubator.codec.hpke.bouncycastle;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.incubator.codec.hpke.AEADContext;
import io.netty.incubator.codec.hpke.CryptoException;
import org.bouncycastle.crypto.InvalidCipherTextException;
Expand Down Expand Up @@ -45,13 +46,13 @@ protected byte[] execute(byte[] aad, byte[] in, int inOffset, int inLength)
}

@Override
public void seal(ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
public void seal(ByteBufAllocator alloc, ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
checkClosed();
seal.execute(aad, pt, out);
}

@Override
public void open(ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
public void open(ByteBufAllocator alloc, ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
checkClosed();
open.execute(aad, ct, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.incubator.codec.hpke.bouncycastle;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.incubator.codec.hpke.CryptoException;
import io.netty.incubator.codec.hpke.HPKERecipientContext;
import org.bouncycastle.crypto.InvalidCipherTextException;
Expand All @@ -36,7 +37,7 @@ protected byte[] execute(byte[] aad, byte[] in, int inOffset, int inLength)
}

@Override
public void open(ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
public void open(ByteBufAllocator alloc, ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
checkClosed();
open.execute(aad, ct, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.incubator.codec.hpke.bouncycastle;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.incubator.codec.hpke.CryptoException;
import io.netty.incubator.codec.hpke.HPKESenderContext;
import org.bouncycastle.crypto.InvalidCipherTextException;
Expand All @@ -41,7 +42,7 @@ public byte[] encapsulation() {
}

@Override
public void seal(ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
public void seal(ByteBufAllocator alloc, ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
checkClosed();
seal.execute(aad, pt, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
package io.netty.incubator.codec.hpke.boringssl;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.ByteBufAllocator;
import io.netty.incubator.codec.hpke.AEADContext;
import io.netty.incubator.codec.hpke.CryptoException;

/**
* BoringSSL based implementation of an {@link AEADContext}.
*/
final class BoringSSLAEADContext extends BoringSSLCryptoContext implements AEADContext {

private final Nonce nonce;
private final int aeadMaxOverhead;

Expand All @@ -35,9 +34,9 @@ int maxOutLen(long ctx, int inReadable) {
}

@Override
int execute(long ctx, long ad, int adLen, long in, int inLen, long out, int outLen) {
int execute(long ctx, ByteBufAllocator alloc, long ad, int adLen, long in, int inLen, long out, int outLen) {
int result = BoringSSL.EVP_AEAD_CTX_seal(
ctx, out, outLen, nonce.computeNext(), nonce.length(), in, inLen, ad, adLen);
ctx, out, outLen, nonce.computeNext(alloc), nonce.length(), in, inLen, ad, adLen);
if (result >= 0) {
nonce.incrementSequence();
}
Expand All @@ -52,9 +51,9 @@ int maxOutLen(long ctx, int inReadable) {
}

@Override
int execute(long ctx, long ad, int adLen, long in, int inLen, long out, int outLen) {
int execute(long ctx, ByteBufAllocator alloc, long ad, int adLen, long in, int inLen, long out, int outLen) {
int result = BoringSSL.EVP_AEAD_CTX_open(
ctx, out, outLen, nonce.computeNext(), nonce.length(), in, inLen, ad, adLen);
ctx, out, outLen, nonce.computeNext(alloc), nonce.length(), in, inLen, ad, adLen);
if (result >= 0) {
nonce.incrementSequence();
}
Expand All @@ -75,15 +74,15 @@ protected void destroyCtx(long ctx) {
}

@Override
public void open(ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
if (!open.execute(checkClosedAndReturnCtx(), aad, ct, out)) {
public void open(ByteBufAllocator alloc, ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
if (!open.execute(checkClosedAndReturnCtx(), alloc, aad, ct, out)) {
throw new CryptoException("open(...) failed");
}
}

@Override
public void seal(ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
if (!seal.execute(checkClosedAndReturnCtx(), aad, pt, out)) {
public void seal(ByteBufAllocator alloc, ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
if (!seal.execute(checkClosedAndReturnCtx(), alloc, aad, pt, out)) {
throw new CryptoException("seal(...) failed");
}
}
Expand All @@ -94,19 +93,16 @@ public boolean isDirectBufferPreferred() {
}

private static final class Nonce {
private final ByteBuf nonce;
private final long nonceAddress;
private final int nonceLen;
private final byte[] baseNonce;

private ByteBuf nonce;
private long nonceAddress;
private int seq;

Nonce(byte[] baseNonce) {
this.baseNonce = baseNonce.clone();

nonce = Unpooled.directBuffer(baseNonce.length).writeBytes(baseNonce);
this.nonceAddress = BoringSSL.memory_address(nonce);
this.nonceLen = nonce.readableBytes();
this.nonceLen = baseNonce.length;
}

int length() {
Expand All @@ -121,15 +117,23 @@ void incrementSequence() {
* <a href="https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2">Compute the nonce to use</a>
* @return memory address of the nonce buffer.
*/
long computeNext() {
long computeNext(ByteBufAllocator alloc) {
if (nonce == null) {
nonce = alloc.directBuffer(baseNonce.length).writeBytes(baseNonce);
nonceAddress = BoringSSL.memory_address(nonce);
}

for (int idx = 0, idx2 = baseNonce.length - 8 ; idx < 8; ++idx, ++idx2) {
nonce.setByte(idx2, baseNonce[idx2] ^ bigEndianByteAt(idx, seq));
}
return nonceAddress;
}

void destroy() {
nonce.release();
if (nonce != null) {
nonce.release();
nonce = null;
}
}

private static byte bigEndianByteAt(int idx, long value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.incubator.codec.hpke.boringssl;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;

/**
* Abstract base class to perform native crypto operations via BoringSSL.
Expand All @@ -28,29 +29,30 @@ abstract class BoringSSLCryptoOperation {
* accordingly.
*
* @param ctx the native {@code *_CTX} pointer.
* @param alloc {@link ByteBufAllocator} which might be used to do extra allocations.
* @param aad the AAD buffer.
* @param in the input data.
* @param out the buffer for writing into.
* @return {@code true} if successful, {@code false} otherwise.
*/
final boolean execute(long ctx, ByteBuf aad, ByteBuf in, ByteBuf out) {
final boolean execute(long ctx, ByteBufAllocator alloc, ByteBuf aad, ByteBuf in, ByteBuf out) {
ByteBuf directAad = null;
ByteBuf directIn = null;
ByteBuf directOut = null;
try {
directAad = directReadable(aad);
directIn = directReadable(in);
directAad = directReadable(alloc, aad);
directIn = directReadable(alloc, in);

int maxOutLen = maxOutLen(ctx, in.readableBytes());
directOut = directWritable(out, maxOutLen);
directOut = directWritable(alloc, out, maxOutLen);

long directAadAddress = BoringSSL.memory_address(directAad) + directAad.readerIndex();
int directAddReadableBytes = directAad.readableBytes();
long directInAddress = BoringSSL.memory_address(directIn) + directIn.readerIndex();
int directInReadableBytes = directIn.readableBytes();
long directOutAddress = BoringSSL.memory_address(directOut) + directOut.writerIndex();
int directOutWritableBytes = directOut.writableBytes();
int result = execute(ctx, directAadAddress, directAddReadableBytes,
int result = execute(ctx, alloc, directAadAddress, directAddReadableBytes,
directInAddress, directInReadableBytes,
directOutAddress, directOutWritableBytes);
if (result < 0) {
Expand All @@ -75,23 +77,24 @@ final boolean execute(long ctx, ByteBuf aad, ByteBuf in, ByteBuf out) {

abstract int maxOutLen(long ctx, int inReadable);

abstract int execute(long ctx, long ad, int adLen, long in, int inLen, long out, int outLen);
abstract int execute(long ctx, ByteBufAllocator alloc,
long ad, int adLen, long in, int inLen, long out, int outLen);

private static ByteBuf directReadable(ByteBuf in) {
private static ByteBuf directReadable(ByteBufAllocator alloc, ByteBuf in) {
if (in.isDirect()) {
return in;
}
ByteBuf directIn = in.alloc().directBuffer(in.readableBytes());
ByteBuf directIn = alloc.directBuffer(in.readableBytes());
directIn.writeBytes(in, in.readerIndex(), in.readableBytes());
return directIn;
}

private static ByteBuf directWritable(ByteBuf out, int minWritable) {
private static ByteBuf directWritable(ByteBufAllocator alloc, ByteBuf out, int minWritable) {
if (out.isDirect()) {
out.ensureWritable(minWritable);
return out;
}
return out.alloc().directBuffer(minWritable);
return alloc.directBuffer(minWritable);
}

private static void releaseIfNotTheSameInstance(ByteBuf buf, ByteBuf maybeOther) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.incubator.codec.hpke.boringssl;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.incubator.codec.hpke.CryptoException;
import io.netty.incubator.codec.hpke.HPKERecipientContext;

Expand All @@ -31,7 +32,7 @@ int maxOutLen(long ctx, int inReadable) {
}

@Override
int execute(long ctx, long ad, int adLen, long in, int inLen, long out, int outLen) {
int execute(long ctx, ByteBufAllocator alloc, long ad, int adLen, long in, int inLen, long out, int outLen) {
return BoringSSL.EVP_HPKE_CTX_open(ctx, out, outLen, in, inLen, ad, adLen);
}
};
Expand All @@ -41,8 +42,8 @@ int execute(long ctx, long ad, int adLen, long in, int inLen, long out, int outL
}

@Override
public void open(ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
if (!OPEN.execute(checkClosedAndReturnCtx(), aad, ct, out)) {
public void open(ByteBufAllocator alloc, ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException {
if (!OPEN.execute(checkClosedAndReturnCtx(), alloc, aad, ct, out)) {
throw new CryptoException("open(...) failed");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
package io.netty.incubator.codec.hpke.boringssl;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.incubator.codec.hpke.CryptoException;
import io.netty.incubator.codec.hpke.HPKERecipientContext;
import io.netty.incubator.codec.hpke.HPKESenderContext;

/**
Expand All @@ -34,7 +34,7 @@ int maxOutLen(long ctx, int inReadable) {
}

@Override
int execute(long ctx, long ad, int adLen, long in, int inLen, long out, int outLen) {
int execute(long ctx, ByteBufAllocator alloc, long ad, int adLen, long in, int inLen, long out, int outLen) {
return BoringSSL.EVP_HPKE_CTX_seal(ctx, out, outLen, in, inLen, ad, adLen);
}
};
Expand All @@ -52,8 +52,8 @@ public byte[] encapsulation() {
}

@Override
public void seal(ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
if (!SEAL.execute(checkClosedAndReturnCtx(), aad, pt, out)) {
public void seal(ByteBufAllocator alloc, ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException {
if (!SEAL.execute(checkClosedAndReturnCtx(), alloc, aad, pt, out)) {
throw new CryptoException("seal(...) failed");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.incubator.codec.hpke;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;

/**
* {@link CryptoContext} that can be used for decryption.
Expand All @@ -26,12 +27,13 @@ public interface CryptoDecryptContext extends CryptoContext {
* Authenticate and decrypt data. The {@link ByteBuf#readerIndex()} will be increased by the amount of
* data read and {@link ByteBuf#writerIndex()} by the bytes written.
*
* @param alloc {@link ByteBufAllocator} which might be used to do extra allocations.
* @param aad the AAD buffer
* @param ct the data to decrypt
* @param out the buffer for writing into.
* @throws CryptoException in case of an error.
*/
void open(ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException;
void open(ByteBufAllocator alloc, ByteBuf aad, ByteBuf ct, ByteBuf out) throws CryptoException;

/**
* Returns {@code true} if {@link ByteBuf}s that are direct should be used to avoid extra memory copies,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.incubator.codec.hpke;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;

/**
* {@link CryptoContext} that can be used for encryption.
Expand All @@ -26,12 +27,13 @@ public interface CryptoEncryptContext extends CryptoContext {
* Authenticate and encrypt data. The {@link ByteBuf#readerIndex()} will be increased by the amount of
* data read and {@link ByteBuf#writerIndex()} by the bytes written.
*
* @param alloc {@link ByteBufAllocator} which might be used to do extra allocations.
* @param aad the AAD buffer
* @param pt the data to encrypt.
* @param out the buffer for writing into
* @throws CryptoException in case of an error.
*/
void seal(ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException;
void seal(ByteBufAllocator alloc, ByteBuf aad, ByteBuf pt, ByteBuf out) throws CryptoException;

/**
* Returns {@code true} if {@link ByteBuf}s that are direct should be used to avoid extra memory copies,
Expand Down
Loading

0 comments on commit dbcbc74

Please sign in to comment.