Skip to content

Commit

Permalink
feat(api): add support for ClientCreds style cognito tokens
Browse files Browse the repository at this point in the history
Before, only user creds were supported by the oauth2/token route.
  • Loading branch information
acwrenn committed Jul 17, 2023
1 parent 5bcf24e commit f3adf76
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 44 deletions.
1 change: 1 addition & 0 deletions src/__tests__/mockTokenGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ import { TokenGenerator } from "../services/tokenGenerator";

export const newMockTokenGenerator = (): jest.Mocked<TokenGenerator> => ({
generate: jest.fn(),
generateWithClientCreds: jest.fn(),
});
23 changes: 22 additions & 1 deletion src/server/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,28 @@ export const createServer = (
req.on("end", function () {
const target = "GetToken";
const route = router(target);
route({ logger: req.log }, rawBody).then(

const parsed = new URLSearchParams(rawBody);
const params = {
grant_type: parsed.get("grant_type"),
client_id: parsed.get("client_id"),
client_secret: parsed.get("client_secret"),
refresh_token: parsed.get("refresh_token"),
};

const auth = req.get("Authorization");
if (auth && auth.startsWith("Basic ")) {
const sliced = auth.slice("Basic ".length);
const buff = new Buffer(sliced, "base64");
const decoded = buff.toString("ascii");
const creds = decoded.split(":");
if (creds.length == 2) {
params.client_id = creds[0];
params.client_secret = creds[1];
}
}

route({ logger: req.log }, params).then(
(output) => {
res.status(200).type("json").send(JSON.stringify(output));
},
Expand Down
42 changes: 40 additions & 2 deletions src/services/tokenGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ const applyTokenOverrides = (

export interface Tokens {
readonly AccessToken: string;
readonly IdToken: string;
readonly RefreshToken: string;
readonly IdToken?: string;
readonly RefreshToken?: string;
}

export interface TokenGenerator {
Expand All @@ -104,6 +104,10 @@ export interface TokenGenerator {
| "NewPasswordChallenge"
| "RefreshTokens"
): Promise<Tokens>;
generateWithClientCreds(
ctx: Context,
userPoolClient: AppClient
): Promise<Tokens>;
}

const formatExpiration = (
Expand Down Expand Up @@ -240,4 +244,38 @@ export class JwtTokenGenerator implements TokenGenerator {
),
};
}

public async generateWithClientCreds(
ctx: Context,
userPoolClient: AppClient
): Promise<Tokens> {
const eventId = uuid.v4();
const authTime = Math.floor(this.clock.get().getTime() / 1000);

const accessToken: RawToken = {
auth_time: authTime,
client_id: userPoolClient.ClientId,
event_id: eventId,
iat: authTime,
jti: uuid.v4(),
scope: "aws.cognito.signin.user.admin", // TODO: scopes
sub: userPoolClient.ClientId,
token_use: "access",
};

const issuer = `${this.tokenConfig.IssuerDomain}/${userPoolClient.UserPoolId}`;

return await Promise.resolve({
AccessToken: jwt.sign(accessToken, PrivateKey.pem, {
algorithm: "RS256",
issuer,
expiresIn: formatExpiration(
userPoolClient.AccessTokenValidity,
userPoolClient.TokenValidityUnits?.AccessToken ?? "hours",
"24h"
),
keyid: "CognitoLocal",
}),
});
}
}
92 changes: 64 additions & 28 deletions src/targets/getToken.test.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import {newMockCognitoService} from "../__tests__/mockCognitoService";
import {newMockTokenGenerator} from "../__tests__/mockTokenGenerator";
import {newMockTriggers} from "../__tests__/mockTriggers";
import {newMockUserPoolService} from "../__tests__/mockUserPoolService";
import {TestContext} from "../__tests__/testContext";
import { newMockCognitoService } from "../__tests__/mockCognitoService";
import { newMockTokenGenerator } from "../__tests__/mockTokenGenerator";
import { newMockTriggers } from "../__tests__/mockTriggers";
import { newMockUserPoolService } from "../__tests__/mockUserPoolService";
import { TestContext } from "../__tests__/testContext";
import * as TDB from "../__tests__/testDataBuilder";
import {CognitoService, Triggers, UserPoolService} from "../services";
import {TokenGenerator} from "../services/tokenGenerator";
import { CognitoService, Triggers, UserPoolService } from "../services";
import { TokenGenerator } from "../services/tokenGenerator";

import {
GetToken,
GetTokenTarget,
} from "./getToken";
import { GetToken, GetTokenTarget } from "./getToken";

describe("GetToken target", () => {
let target: GetTokenTarget;

let getToken: GetTokenTarget;
let mockCognitoService: jest.Mocked<CognitoService>;
let mockTokenGenerator: jest.Mocked<TokenGenerator>;
let mockTriggers: jest.Mocked<Triggers>;
Expand All @@ -23,42 +19,82 @@ describe("GetToken target", () => {

beforeEach(() => {
mockUserPoolService = newMockUserPoolService({
Id : userPoolClient.UserPoolId,
Id: userPoolClient.UserPoolId,
});
mockCognitoService = newMockCognitoService(mockUserPoolService);
mockCognitoService.getAppClient.mockResolvedValue(userPoolClient);
mockTriggers = newMockTriggers();
mockTokenGenerator = newMockTokenGenerator();
getToken = GetToken({
triggers : mockTriggers,
cognito : mockCognitoService,
tokenGenerator : mockTokenGenerator,
cognito: mockCognitoService,
tokenGenerator: mockTokenGenerator,
});
});

it("issues access tokens via refresh tokens", async () => {
mockTokenGenerator.generate.mockResolvedValue({
AccessToken : "access",
IdToken : "id",
RefreshToken : "refresh",
AccessToken: "access",
IdToken: "id",
RefreshToken: "refresh",
});

const existingUser = TDB.user({
RefreshTokens : [ "refresh-orig" ],
RefreshTokens: ["refresh-orig"],
});
mockUserPoolService.getUserByRefreshToken.mockResolvedValue(existingUser);
mockUserPoolService.listUserGroupMembership.mockResolvedValue([]);

const response = await getToken(
TestContext,
new URLSearchParams(`client_id=${
userPoolClient
.ClientId}&grant_type=refresh_token&refresh_token=refresh-orig`));
expect(mockUserPoolService.getUserByRefreshToken)
.toHaveBeenCalledWith(TestContext, "refresh-orig");
const response = await getToken(TestContext, {
client_id: userPoolClient.ClientId,
grant_type: "refresh_token",
refresh_token: "refresh-orig",
});
expect(mockUserPoolService.getUserByRefreshToken).toHaveBeenCalledWith(
TestContext,
"refresh-orig"
);
expect(mockUserPoolService.storeRefreshToken).not.toHaveBeenCalled();

expect(response.access_token).toEqual("access");
expect(response.refresh_token).toEqual("refresh");
});
});

describe("GetToken target - Client Creds", () => {
let getToken: GetTokenTarget;
let mockCognitoService: jest.Mocked<CognitoService>;
let mockTokenGenerator: jest.Mocked<TokenGenerator>;
let mockUserPoolService: jest.Mocked<UserPoolService>;
const userPoolClient = TDB.appClient({
ClientSecret: "secret",
ClientId: "id",
});

beforeEach(() => {
mockUserPoolService = newMockUserPoolService({
Id: userPoolClient.UserPoolId,
});
mockCognitoService = newMockCognitoService(mockUserPoolService);
mockCognitoService.getAppClient.mockResolvedValue(userPoolClient);
mockTokenGenerator = newMockTokenGenerator();
getToken = GetToken({
cognito: mockCognitoService,
tokenGenerator: mockTokenGenerator,
});
});

it("issues access tokens via client credentials", async () => {
mockTokenGenerator.generateWithClientCreds.mockResolvedValue({
AccessToken: "access",
RefreshToken: null,
IdToken: null,
});

const response = await getToken(TestContext, {
client_id: userPoolClient.ClientId,
client_secret: userPoolClient.ClientSecret,
grant_type: "client_credentials",
});
expect(response.access_token).toEqual("access");
});
});
80 changes: 67 additions & 13 deletions src/targets/getToken.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,42 @@ import { Target } from "../targets/Target";

type HandleTokenServices = Pick<Services, "cognito" | "tokenGenerator">;

type GetTokenRequest = URLSearchParams;
export type GetTokenRequest =
| GetTokenRequestClientCreds
| GetTokenRequestRefreshToken
| GetTokenRequestAuthCode;

interface GetTokenRequestGrantType {
grant_type: "authorization_code" | "client_credentials" | "refresh_token";
client_id: string;
}

interface GetTokenRequestClientCreds extends GetTokenRequestGrantType {
client_secret: string;
}

type GetTokenRequestAuthCode = GetTokenRequestGrantType;

interface GetTokenRequestRefreshToken extends GetTokenRequestGrantType {
refresh_token: string;
}

interface GetTokenResponse {
access_token: string;
refresh_token: string;
refresh_token?: string;
}

export type GetTokenTarget = Target<GetTokenRequest, GetTokenResponse>;

async function getRefreshToken(
async function getWithRefreshToken(
ctx: Context,
services: HandleTokenServices,
params: GetTokenRequest
params: GetTokenRequestRefreshToken
) {
const clientId = params.get("client_id");
const clientId = params.client_id;
const userPool = await services.cognito.getUserPoolForClientId(ctx, clientId);
const userPoolClient = await services.cognito.getAppClient(ctx, clientId);
const user = await userPool.getUserByRefreshToken(
ctx,
params.get("refresh_token")
);
const user = await userPool.getUserByRefreshToken(ctx, params.refresh_token);
if (!user || !userPoolClient) {
throw new NotAuthorizedError();
}
Expand All @@ -51,21 +66,60 @@ async function getRefreshToken(
};
}

async function getWithClientCredentials(
ctx: Context,
services: HandleTokenServices,
params: GetTokenRequestClientCreds
) {
const clientId = params.client_id;
const clientSecret = params.client_secret;
const userPoolClient = await services.cognito.getAppClient(ctx, clientId);
if (!userPoolClient) {
throw new NotAuthorizedError();
}
if (
userPoolClient.ClientSecret &&
userPoolClient.ClientSecret != clientSecret
) {
throw new NotAuthorizedError();
}

const tokens = await services.tokenGenerator.generateWithClientCreds(
ctx,
userPoolClient
);
if (!tokens) {
throw new NotAuthorizedError();
}

return {
access_token: tokens.AccessToken,
};
}

export const GetToken =
(services: HandleTokenServices): GetTokenTarget =>
async (ctx, req) => {
const params = new URLSearchParams(req);
switch (params.get("grant_type")) {
switch (req.grant_type) {
case "authorization_code": {
throw new NotImplementedError();
}
case "client_credentials": {
throw new NotImplementedError();
return getWithClientCredentials(
ctx,
services,
req as GetTokenRequestClientCreds
);
}
case "refresh_token": {
return getRefreshToken(ctx, services, params);
return getWithRefreshToken(
ctx,
services,
req as GetTokenRequestRefreshToken
);
}
default: {
console.log("Invalid grant type passed:", req.grant_type);
throw new InvalidParameterError();
}
}
Expand Down

0 comments on commit f3adf76

Please sign in to comment.