diff --git a/src/index.test.ts b/src/index.test.ts index 5fcae94..93541f3 100644 --- a/src/index.test.ts +++ b/src/index.test.ts @@ -108,6 +108,58 @@ describe("on login", () => { }); }); + describe("on domain", () => { + const domain = "mydomain.org"; + const url = `/workos/authorize?domain=${domain}`; + + it("calls workos api with domain", async () => { + await supertest(app).get(url); + expect(getAuthorizationURL).toBeCalledTimes(1); + expect(getAuthorizationURL).toBeCalledWith( + expect.objectContaining({ + domain, + clientID, + redirectURI: callbackURL, + state: "...", + }) + ); + }); + + it("redirects to login url", async () => { + const res = await supertest(app).get(url); + expect(res.statusCode).toEqual(302); + expect(res.headers.location).toMatchInlineSnapshot( + `"https://workos.com/fake-auth-url"` + ); + }); + }); + + describe("on email", () => { + const email = "user@mydomain.org"; + const url = `/workos/authorize?email=${email}`; + + it("calls workos api with domain", async () => { + await supertest(app).get(url); + expect(getAuthorizationURL).toBeCalledTimes(1); + expect(getAuthorizationURL).toBeCalledWith( + expect.objectContaining({ + domain: email.substring(email.indexOf("@") + 1), + clientID, + redirectURI: callbackURL, + state: "...", + }) + ); + }); + + it("redirects to login url", async () => { + const res = await supertest(app).get(url); + expect(res.statusCode).toEqual(302); + expect(res.headers.location).toMatchInlineSnapshot( + `"https://workos.com/fake-auth-url"` + ); + }); + }); + describe("on provider", () => { describe("with 'GoogleOAuth'", () => { const provider = ConnectionType.GoogleOAuth; diff --git a/src/index.ts b/src/index.ts index 5340d82..5bd7e52 100644 --- a/src/index.ts +++ b/src/index.ts @@ -45,13 +45,15 @@ export class WorkOSSSOStrategy extends Strategy { private _loginAttempt(req: Request, options: AuthenticateOptions) { try { - const { connection, organization, provider } = req.query as Record< - string, - string - >; - if ([connection, organization, provider].every((a) => a === undefined)) { + const { connection, organization, domain, email, provider } = + req.query as Record; + if ( + [connection, organization, domain, email, provider].every( + (a) => a === undefined + ) + ) { throw Error( - "One of 'connection', 'organization', or 'provider' is required" + "One of 'connection', 'domain', 'organization', 'provider' and/or 'email' are required" ); } @@ -60,6 +62,7 @@ export class WorkOSSSOStrategy extends Strategy { connection, organization, provider, + domain: domain || email?.slice(email.indexOf("@") + 1), clientID: this.options.clientID, redirectURI: options.redirectURI || this.options.callbackURL, ...options,