Skip to content

Commit

Permalink
Merge pull request #1904 from ClickHouse/v2_jwt_auth
Browse files Browse the repository at this point in the history
[client-v2] Added implementation for Bearer token auth
  • Loading branch information
chernser authored Dec 18, 2024
2 parents 19172f6 + da03749 commit d458d60
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 13 deletions.
36 changes: 29 additions & 7 deletions client-v2/src/main/java/com/clickhouse/client/api/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@
* ...
* }
* }
*
* }
*
*
Expand All @@ -132,6 +131,9 @@ public class Client implements AutoCloseable {

private final Set<String> endpoints;
private final Map<String, String> configuration;

private final Map<String, String> readOnlyConfig;

private final List<ClickHouseNode> serverNodes = new ArrayList<>();

// POJO serializer mapping (class -> (schema -> (format -> serializer)))
Expand All @@ -158,6 +160,7 @@ private Client(Set<String> endpoints, Map<String,String> configuration, boolean
ExecutorService sharedOperationExecutor, ColumnToMethodMatchingStrategy columnToMethodMatchingStrategy) {
this.endpoints = endpoints;
this.configuration = configuration;
this.readOnlyConfig = Collections.unmodifiableMap(this.configuration);
this.endpoints.forEach(endpoint -> {
this.serverNodes.add(ClickHouseNode.of(endpoint, this.configuration));
});
Expand Down Expand Up @@ -853,7 +856,7 @@ public Builder allowBinaryReaderToReuseBuffers(boolean reuse) {
* @return same instance of the builder
*/
public Builder httpHeader(String key, String value) {
this.configuration.put(ClientConfigProperties.HTTP_HEADER_PREFIX + key.toUpperCase(Locale.US), value);
this.configuration.put(ClientConfigProperties.httpHeader(key), value);
return this;
}

Expand All @@ -864,7 +867,7 @@ public Builder httpHeader(String key, String value) {
* @return same instance of the builder
*/
public Builder httpHeader(String key, Collection<String> values) {
this.configuration.put(ClientConfigProperties.HTTP_HEADER_PREFIX + key.toUpperCase(Locale.US), ClientConfigProperties.commaSeparated(values));
this.configuration.put(ClientConfigProperties.httpHeader(key), ClientConfigProperties.commaSeparated(values));
return this;
}

Expand Down Expand Up @@ -955,6 +958,19 @@ public Builder setOptions(Map<String, String> options) {
return this;
}

/**
* Specifies whether to use Bearer Authentication and what token to use.
* The token will be sent as is, so it should be encoded before passing to this method.
*
* @param bearerToken - token to use
* @return same instance of the builder
*/
public Builder useBearerTokenAuth(String bearerToken) {
// Most JWT libraries (https://jwt.io/libraries?language=Java) compact tokens in proper way
this.httpHeader(HttpHeaders.AUTHORIZATION, "Bearer " + bearerToken);
return this;
}

public Client build() {
setDefaults();

Expand All @@ -965,8 +981,9 @@ public Client build() {
// check if username and password are empty. so can not initiate client?
if (!this.configuration.containsKey("access_token") &&
(!this.configuration.containsKey("user") || !this.configuration.containsKey("password")) &&
!MapUtils.getFlag(this.configuration, "ssl_authentication", false)) {
throw new IllegalArgumentException("Username and password (or access token, or SSL authentication) are required");
!MapUtils.getFlag(this.configuration, "ssl_authentication", false) &&
!this.configuration.containsKey(ClientConfigProperties.httpHeader(HttpHeaders.AUTHORIZATION))) {
throw new IllegalArgumentException("Username and password (or access token or SSL authentication or pre-define Authorization header) are required");
}

if (this.configuration.containsKey("ssl_authentication") &&
Expand Down Expand Up @@ -1012,7 +1029,8 @@ public Client build() {
throw new IllegalArgumentException("Nor server timezone nor specific timezone is set");
}

return new Client(this.endpoints, this.configuration, this.useNewImplementation, this.sharedOperationExecutor, this.columnToMethodMatchingStrategy);
return new Client(this.endpoints, this.configuration, this.useNewImplementation, this.sharedOperationExecutor,
this.columnToMethodMatchingStrategy);
}

private static final int DEFAULT_NETWORK_BUFFER_SIZE = 300_000;
Expand Down Expand Up @@ -2104,7 +2122,7 @@ public String toString() {
* @return - configuration options
*/
public Map<String, String> getConfiguration() {
return Collections.unmodifiableMap(configuration);
return readOnlyConfig;
}

/** Returns operation timeout in seconds */
Expand Down Expand Up @@ -2151,6 +2169,10 @@ public Collection<String> getDBRoles() {
return unmodifiableDbRolesView;
}

public void updateBearerToken(String bearer) {
this.configuration.put(ClientConfigProperties.httpHeader(HttpHeaders.AUTHORIZATION), "Bearer " + bearer);
}

private ClickHouseNode getNextAliveNode() {
return serverNodes.get(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -157,6 +158,10 @@ public static String serverSetting(String key) {
return SERVER_SETTING_PREFIX + key;
}

public static String httpHeader(String key) {
return HTTP_HEADER_PREFIX + key.toUpperCase(Locale.US);
}

public static String commaSeparated(Collection<?> values) {
StringBuilder sb = new StringBuilder();
for (Object value : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,34 @@ public class ServerException extends RuntimeException {
public static final int TABLE_NOT_FOUND = 60;

private final int code;

private final int transportProtocolCode;

public ServerException(int code, String message) {
this(code, message, 500);
}

public ServerException(int code, String message, int transportProtocolCode) {
super(message);
this.code = code;
this.transportProtocolCode = transportProtocolCode;
}

/**
* Returns CH server error code. May return 0 if code is unknown.
* @return - error code from server response
*/
public int getCode() {
return code;
}

/**
* Returns error code of underlying transport protocol. For example, HTTP status.
* By default, will return {@code 500 } what is derived from HTTP Server Internal Error.
*
* @return - transport status code
*/
public int getTransportProtocolCode() {
return transportProtocolCode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.clickhouse.client.api.ClientConfigProperties.SOCKET_TCP_NO_DELAY_OPT;

Expand Down Expand Up @@ -335,10 +336,13 @@ public Exception readError(ClassicHttpResponse httpResponse) {

String msg = msgBuilder.toString().replaceAll("\\s+", " ").replaceAll("\\\\n", " ")
.replaceAll("\\\\/", "/");
return new ServerException(serverCode, msg);
if (msg.trim().isEmpty()) {
msg = String.format(ERROR_CODE_PREFIX_PATTERN, serverCode) + " <Unreadable error message> (transport error: " + httpResponse.getCode() + ")";
}
return new ServerException(serverCode, msg, httpResponse.getCode());
} catch (Exception e) {
LOG.error("Failed to read error message", e);
return new ServerException(serverCode, String.format(ERROR_CODE_PREFIX_PATTERN, serverCode) + " <Unreadable error message>");
return new ServerException(serverCode, String.format(ERROR_CODE_PREFIX_PATTERN, serverCode) + " <Unreadable error message> (transport error: " + httpResponse.getCode() + ")", httpResponse.getCode());
}
}

Expand Down Expand Up @@ -450,12 +454,12 @@ private void addHeaders(HttpPost req, Map<String, String> chConfig, Map<String,

for (Map.Entry<String, String> entry : chConfig.entrySet()) {
if (entry.getKey().startsWith(ClientConfigProperties.HTTP_HEADER_PREFIX)) {
req.addHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue());
req.setHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue());
}
}
for (Map.Entry<String, Object> entry : requestConfig.entrySet()) {
if (entry.getKey().startsWith(ClientConfigProperties.HTTP_HEADER_PREFIX)) {
req.addHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue().toString());
req.setHeader(entry.getKey().substring(ClientConfigProperties.HTTP_HEADER_PREFIX.length()), entry.getValue().toString());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.io.ByteArrayInputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -51,8 +52,11 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
import java.util.function.Supplier;

import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;

public class HttpTransportTests extends BaseIntegrationTest {
Expand All @@ -66,7 +70,6 @@ public void testConnectionTTL(Long connectionTtl, Long keepAlive, int openSocket
ClickHouseNode server = getServer(ClickHouseProtocol.HTTP);

int proxyPort = new Random().nextInt(1000) + 10000;
System.out.println("proxyPort: " + proxyPort);
ConnectionCounterListener connectionCounter = new ConnectionCounterListener();
WireMockServer proxy = new WireMockServer(WireMockConfiguration
.options().port(proxyPort)
Expand Down Expand Up @@ -154,7 +157,6 @@ public void closed(Socket socket) {
public void testConnectionRequestTimeout() {

int serverPort = new Random().nextInt(1000) + 10000;
System.out.println("proxyPort: " + serverPort);
ConnectionCounterListener connectionCounter = new ConnectionCounterListener();
WireMockServer proxy = new WireMockServer(WireMockConfiguration
.options().port(serverPort)
Expand Down Expand Up @@ -794,4 +796,75 @@ public static Object[][] testUserAgentHasCompleteProductName_dataProvider() {
{ "test-client/1.0", Pattern.compile("test-client/1.0 clickhouse-java-v2\\/.+ \\(.+\\) Apache HttpClient\\/[\\d\\.]+$")},
{ "test-client/", Pattern.compile("test-client/ clickhouse-java-v2\\/.+ \\(.+\\) Apache HttpClient\\/[\\d\\.]+$")}};
}

@Test(groups = { "integration" })
public void testBearerTokenAuth() throws Exception {
WireMockServer mockServer = new WireMockServer( WireMockConfiguration
.options().port(9090).notifier(new ConsoleNotifier(false)));
mockServer.start();

try {
String jwtToken1 = Arrays.stream(
new String[]{"header", "payload", "signature"})
.map(s -> Base64.getEncoder().encodeToString(s.getBytes(StandardCharsets.UTF_8)))
.reduce((s1, s2) -> s1 + "." + s2).get();
try (Client client = new Client.Builder().addEndpoint(Protocol.HTTP, "localhost", mockServer.port(), false)
.useBearerTokenAuth(jwtToken1)
.compressServerResponse(false)
.build()) {

mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken1))
.willReturn(WireMock.aResponse()
.withHeader("X-ClickHouse-Summary",
"{ \"read_bytes\": \"10\", \"read_rows\": \"1\"}")).build());

try (QueryResponse response = client.query("SELECT 1").get(1, TimeUnit.SECONDS)) {
Assert.assertEquals(response.getReadBytes(), 10);
} catch (Exception e) {
Assert.fail("Unexpected exception", e);
}
}

String jwtToken2 = Arrays.stream(
new String[]{"header2", "payload2", "signature2"})
.map(s -> Base64.getEncoder().encodeToString(s.getBytes(StandardCharsets.UTF_8)))
.reduce((s1, s2) -> s1 + "." + s2).get();

mockServer.resetAll();
mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken1))
.willReturn(WireMock.aResponse()
.withStatus(HttpStatus.SC_UNAUTHORIZED))
.build());

try (Client client = new Client.Builder().addEndpoint(Protocol.HTTP, "localhost", mockServer.port(), false)
.useBearerTokenAuth(jwtToken1)
.compressServerResponse(false)
.build()) {

try {
client.execute("SELECT 1").get();
fail("Exception expected");
} catch (ServerException e) {
Assert.assertEquals(e.getTransportProtocolCode(), HttpStatus.SC_UNAUTHORIZED);
}

mockServer.resetAll();
mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken2))
.willReturn(WireMock.aResponse()
.withHeader("X-ClickHouse-Summary",
"{ \"read_bytes\": \"10\", \"read_rows\": \"1\"}"))

.build());

client.updateBearerToken(jwtToken2);

client.execute("SELECT 1").get();
}
} finally {
mockServer.stop();
}
}
}

0 comments on commit d458d60

Please sign in to comment.