Skip to content

Commit

Permalink
Simplify API for users and also simplify internal usage (#4)
Browse files Browse the repository at this point in the history
Motivation:

How the user was suposed to use the OHttp*Codec was more complicated then needed. We can simplify things a lot by merging some classes and only expose things to the enduser that really matters.

Modifications:

- Merge serializer and parser
- Make OHttpClientCodec non abstract and just let the user pass in a Function
- Simplify both codecs

Result:

Simpler API
  • Loading branch information
normanmaurer authored Dec 7, 2023
1 parent 8e41b03 commit 6ccae52
Show file tree
Hide file tree
Showing 11 changed files with 404 additions and 446 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
* Interface that defines how an Oblivious HTTP implementation handles the framing of chunks.
* <br>
* Instances of {@link OHttpChunkFramer} are stateless. The state management and encryption is delegated to
* the {@link Decoder} and {@link Encoder} interfaces, which are typically implemented by
* {@link OHttpContentParser} and {@link OHttpContentSerializer}, respectively.
* the {@link Decoder} and {@link Encoder} interfaces.
*/
public interface OHttpChunkFramer<T> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public final class OHttpCiphersuite {

private static final int ENCODED_LENGTH = 7;

public OHttpCiphersuite(byte keyId, HybridPublicKeyEncryption.KEM kem, HybridPublicKeyEncryption.KDF kdf, HybridPublicKeyEncryption.AEAD aead) {
public OHttpCiphersuite(byte keyId, HybridPublicKeyEncryption.KEM kem, HybridPublicKeyEncryption.KDF kdf,
HybridPublicKeyEncryption.AEAD aead) {
this.keyId = keyId;
this.kem = requireNonNull(kem, "kem");
this.kdf = requireNonNull(kdf, "kdf");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.incubator.codec.hpke.HybridPublicKeyEncryption;
import io.netty.util.AsciiString;
import io.netty.util.ReferenceCountUtil;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
import java.util.function.Function;

import static io.netty.handler.codec.ByteToMessageDecoder.MERGE_CUMULATOR;
import static java.util.Objects.requireNonNull;

/**
* {@link MessageToMessageCodec} that HTTP clients can use to encrypt outgoing HTTP requests into
Expand All @@ -50,15 +55,48 @@
* <br><br>
* Both incoming and outgoing messages are {@link HttpObject}s.
*/
public abstract class OHttpClientCodec extends MessageToMessageCodec<HttpObject, HttpObject> {
public final class OHttpClientCodec extends MessageToMessageCodec<HttpObject, HttpObject> {

private OHttpContentSerializer serializer;
private OHttpContentParser parser;
private final Deque<OHttpRequestResponseContextHolder> contextHolders = new ArrayDeque<>();

private static final class OHttpRequestResponseContextHolder {

static final OHttpRequestResponseContextHolder NONE = new OHttpRequestResponseContextHolder(null);

final OHttpRequestResponseContext handler;

OHttpRequestResponseContextHolder(OHttpRequestResponseContext handler) {
this.handler = handler;
}

void destroy() {
if (handler != null) {
handler.destroy();
}
}
}

private final HybridPublicKeyEncryption encryption;
private final Function<HttpRequest, EncapsulationParameters> encapsulationFunc;

private OHttpClientContext context;
private ByteBuf cumulationBuffer = Unpooled.EMPTY_BUFFER;
private boolean destroyed;

/**
* Creates a new instance
*
* @param encryption the {@link HybridPublicKeyEncryption} to use for all the crypto.
* @param encapsulationFunc the {@link Function} that will be used to return the correct
* {@link EncapsulationParameters} for a given {@link HttpRequest}.
* If {@link Function} returns {@code null} no encapsulation will
* take place.
*/
public OHttpClientCodec(HybridPublicKeyEncryption encryption, Function<HttpRequest,
EncapsulationParameters> encapsulationFunc) {
this.encryption = requireNonNull(encryption, "encryption");
this.encapsulationFunc = requireNonNull(encapsulationFunc, "encapsulationFunc");
}

/**
* Parameters that control the OHTTP encapsulation of an HTTP request.
*/
Expand All @@ -75,25 +113,80 @@ public interface EncapsulationParameters {
String outerRequestAuthority();

/**
* Update outer HTTP request headers, if necessary.
* @param headers {@link HttpHeaders} to be updated.
* Create the headers for the other HTTP request.
* @return headers
*/
default void outerRequestUpdateHeaders(HttpHeaders headers) {
default HttpHeaders outerRequestHeaders() {
return new DefaultHttpHeaders();
}

/**
* @return {@link OHttpClientContext}.
* Return the {@link OHttpCiphersuite}s to use.
*
* @return the ciphersuites.
*/
OHttpClientContext context();
}
OHttpCiphersuite ciphersuite();

/**
* Get the parameters to encapsulate a {@link HttpRequest} into OHTTP.
* <br>
* @param request outbound {@link HttpRequest} intercepted by the handler.
* @return {@link EncapsulationParameters} object if OHTTP encapsulation is required, or null otherwise.
*/
protected abstract EncapsulationParameters encapsulationParameters(HttpRequest request);
/**
* The public key bytes of the server.
*
* @return bytes.
*/
byte[] serverPublicKeyBytes();

/**
* The {@link OHttpVersion} to use.
*
* @return the version.
*/
OHttpVersion version();

/**
* Create a simple {@link EncapsulationParameters} instance.
*
* @param version the version to use.
* @param ciphersuite the suite to use.
* @param serverPublicKeyBytes the public key to use.
* @param outerRequestUri the outer requst uri.
* @param outerRequestAuthority the authority.
* @return created params.
*/
static EncapsulationParameters newInstance(OHttpVersion version, OHttpCiphersuite ciphersuite,
byte[] serverPublicKeyBytes, String outerRequestUri,
String outerRequestAuthority) {
requireNonNull(version, "version");
requireNonNull(ciphersuite, "ciphersuite");
requireNonNull(serverPublicKeyBytes, "serverPublicKeysBytes");
requireNonNull(outerRequestUri, "outerRequestUri");
requireNonNull(outerRequestAuthority, "outerRequestAuthority");
return new EncapsulationParameters() {
@Override
public String outerRequestUri() {
return outerRequestUri;
}

@Override
public String outerRequestAuthority() {
return outerRequestAuthority;
}

@Override
public OHttpCiphersuite ciphersuite() {
return ciphersuite;
}

@Override
public byte[] serverPublicKeyBytes() {
return serverPublicKeyBytes;
}

@Override
public OHttpVersion version() {
return version;
}
};
}
}

@Override
public final boolean isSharable() {
Expand All @@ -106,30 +199,35 @@ protected final void decode(ChannelHandlerContext ctx, HttpObject msg, List<Obje
throw new IllegalStateException("Already destroyed");
}
try {
assert !contextHolders.isEmpty();
OHttpRequestResponseContext ohttpContext = contextHolders.peekFirst().handler;
if (msg instanceof HttpResponse) {
HttpResponse resp = (HttpResponse) msg;
parser = null;
if (context != null) {
if (ohttpContext != null) {
if (resp.status() != HttpResponseStatus.OK) {
throw new DecoderException("OHTTP response status is not OK");
}
String contentTypeValue = resp.headers().get(HttpHeaderNames.CONTENT_TYPE);
AsciiString expectedContentType = context.version().responseContentType();
AsciiString expectedContentType = ohttpContext.version().responseContentType();
if (!expectedContentType.contentEqualsIgnoreCase(contentTypeValue)) {
throw new DecoderException("OHTTP response has unexpected content type");
}
parser = context.newContentParser();
}
}
if (parser != null) {

boolean isLast = msg instanceof LastHttpContent;
if (ohttpContext != null) {
if (msg instanceof HttpContent) {
ByteBuf content = ((HttpContent) msg).content();
cumulationBuffer = MERGE_CUMULATOR.cumulate(content.alloc(), cumulationBuffer, content.retain());
parser.parse(cumulationBuffer, msg instanceof LastHttpContent, out);
ohttpContext.parse(cumulationBuffer, isLast, out);
}
} else {
out.add(ReferenceCountUtil.retain(msg));
}
if (isLast) {
contextHolders.poll().destroy();
}
} catch (CryptoException e) {
throw new DecoderException("failed to decrypt bytes", e);
}
Expand All @@ -140,30 +238,36 @@ protected final void encode(ChannelHandlerContext ctx, HttpObject msg, List<Obje
try {
if (msg instanceof HttpRequest) {
HttpRequest innerRequest = (HttpRequest) msg;
context = null;
serializer = null;
EncapsulationParameters encapsulation = encapsulationParameters(innerRequest);
EncapsulationParameters encapsulation = encapsulationFunc.apply(innerRequest);
if (encapsulation != null) {
context = encapsulation.context();
serializer = context.newContentSerializer();
HttpHeaders outerHeaders = new DefaultHttpHeaders();

OHttpClientRequestResponseContext oHttpContext =
new OHttpClientRequestResponseContext(encapsulation, encryption);
HttpHeaders outerHeaders = encapsulation.outerRequestHeaders();
DefaultHttpRequest outerRequest = new DefaultHttpRequest(
innerRequest.protocolVersion(),
HttpMethod.POST,
encapsulation.outerRequestUri(), outerHeaders);
encapsulation.outerRequestUpdateHeaders(outerHeaders);
outerHeaders
.set(HttpHeaderNames.HOST, encapsulation.outerRequestAuthority())
.add(HttpHeaderNames.CONTENT_TYPE, context.version().requestContentType());
.add(HttpHeaderNames.CONTENT_TYPE, oHttpContext.version().requestContentType());
HttpUtil.setTransferEncodingChunked(outerRequest, true);

contextHolders.addLast(new OHttpRequestResponseContextHolder(oHttpContext));

out.add(outerRequest);
} else {
contextHolders.addLast(OHttpRequestResponseContextHolder.NONE);
}
}
if (serializer != null) {

assert !contextHolders.isEmpty();
OHttpRequestResponseContext contentHandler = contextHolders.peekLast().handler;
if (contentHandler != null) {
ByteBuf contentBytes = ctx.alloc().buffer();
try {
boolean isLast = msg instanceof LastHttpContent;
serializer.serialize(msg, contentBytes);
contentHandler.serialize(msg, contentBytes);
// Use the correct version of HttpContent depending on if it was the last or not.
HttpContent content = isLast ? new DefaultLastHttpContent(contentBytes) :
new DefaultHttpContent(contentBytes);
Expand All @@ -190,10 +294,57 @@ public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
cumulationBuffer.release();
cumulationBuffer = Unpooled.EMPTY_BUFFER;

if (parser != null) {
parser.destroy();
for (;;) {
OHttpRequestResponseContextHolder h = contextHolders.poll();
if (h == null) {
break;
}
h.destroy();
}
}
super.handlerRemoved(ctx);
}

private static final class OHttpClientRequestResponseContext extends OHttpRequestResponseContext {

private final OHttpCryptoSender sender;

OHttpClientRequestResponseContext(EncapsulationParameters parameters, HybridPublicKeyEncryption encryption) {
super(parameters.version());
this.sender = OHttpCryptoSender.newBuilder()
.setHybridPublicKeyEncryption(encryption)
.setConfiguration(parameters.version())
.setCiphersuite(requireNonNull(parameters.ciphersuite(), "ciphersuite"))
.setReceiverPublicKeyBytes(requireNonNull(parameters.serverPublicKeyBytes(), "serverPublicKeyBytes"))
.build();
}

@Override
public boolean decodePrefix(ByteBuf in) {
if (in.readableBytes() < sender.ciphersuite().responseNonceLength()) {
return false;
}
byte[] responseNonce = new byte[sender.ciphersuite().responseNonceLength()];
in.readBytes(responseNonce);
sender.setResponseNonce(responseNonce);
return true;
}

@Override
protected void decryptChunk(ByteBuf chunk, int chunkLength, boolean isFinal, ByteBuf out)
throws CryptoException {
sender.decrypt(chunk, chunkLength, isFinal, out);
}

@Override
public void encodePrefixNow(ByteBuf out) {
out.writeBytes(sender.header());
}

@Override
protected void encryptChunk(ByteBuf chunk, int chunkLength, boolean isFinal, ByteBuf out)
throws CryptoException {
sender.encrypt(chunk, chunkLength, isFinal, out);
}
};
}
Loading

0 comments on commit 6ccae52

Please sign in to comment.