From aeffb4d35eecfdc760fe77094f9db16193fdee79 Mon Sep 17 00:00:00 2001 From: Bernd Storath Date: Fri, 29 May 2026 15:37:33 +0200 Subject: [PATCH] support auto register --- .../config/external-authentication.md | 6 +- .../api/auth/[provider]/callback.get.ts | 40 ++-- .../database/repositories/user/service.ts | 179 ++++++++++-------- src/server/utils/config.ts | 4 + src/server/utils/oauth.ts | 20 +- 5 files changed, 136 insertions(+), 113 deletions(-) diff --git a/docs/content/advanced/config/external-authentication.md b/docs/content/advanced/config/external-authentication.md index 1243f58c..997637a7 100644 --- a/docs/content/advanced/config/external-authentication.md +++ b/docs/content/advanced/config/external-authentication.md @@ -31,12 +31,14 @@ If your provider does not support multiple redirect URIs (e.g. GitHub) but allow - `https:///api/auth//` - - ### Auto Register To automatically register users that log in with an OAuth provider, set the env var `OAUTH_AUTO_REGISTER` to `true`. +If a user logs in with an email address that is not yet registered, a new account will be created for them. + +If a user logs in with an email address that is already registered, their account will be linked to the OAuth provider (if not already linked), regardless of the value of `OAUTH_AUTO_REGISTER`. + /// warning | Security Users will be created with Admin Permissions, as the permissions system is not yet implemented. Only enable this if you trust all users that can log in with the OAuth provider. diff --git a/src/server/api/auth/[provider]/callback.get.ts b/src/server/api/auth/[provider]/callback.get.ts index 3f8b5252..7e58ba0a 100644 --- a/src/server/api/auth/[provider]/callback.get.ts +++ b/src/server/api/auth/[provider]/callback.get.ts @@ -24,7 +24,7 @@ export default defineEventHandler(async (event) => { providerConfig ); - const result = await Database.users.findOrCreateByProvider( + const result = await Database.users.loginWithOAuth( provider, userInfo.sub, userInfo.preferred_username || userInfo.email, @@ -33,22 +33,30 @@ export default defineEventHandler(async (event) => { ); if (!result.success) { - if (result.error === 'USER_DISABLED') { - throw createError({ - statusCode: 401, - statusMessage: 'User disabled', - }); + switch (result.error) { + case 'USER_DISABLED': + throw createError({ + statusCode: 401, + statusMessage: 'User disabled', + }); + case 'USER_ALREADY_LINKED': + throw createError({ + statusCode: 401, + statusMessage: + 'User already linked with different account or provider', + }); + case 'AUTO_REGISTER_DISABLED': + throw createError({ + statusCode: 401, + statusMessage: 'Auto registration is disabled', + }); + case 'UNEXPECTED_ERROR': + throw createError({ + statusCode: 500, + statusMessage: 'Unexpected error', + }); } - if (result.error === 'USER_ALREADY_LINKED') { - throw createError({ - statusCode: 401, - statusMessage: 'User already linked with different account or provider', - }); - } - throw createError({ - statusCode: 500, - statusMessage: 'Unexpected error', - }); + assertUnreachable(result.error); } // Create session diff --git a/src/server/database/repositories/user/service.ts b/src/server/database/repositories/user/service.ts index 2b218f78..38dc73c7 100644 --- a/src/server/database/repositories/user/service.ts +++ b/src/server/database/repositories/user/service.ts @@ -19,6 +19,20 @@ type LoginResult = | 'UNEXPECTED_ERROR'; }; +type LoginWithOAuthResult = + | { + success: true; + user: UserType; + } + | { + success: false; + error: + | 'USER_DISABLED' + | 'USER_ALREADY_LINKED' + | 'UNEXPECTED_ERROR' + | 'AUTO_REGISTER_DISABLED'; + }; + function createPreparedStatement(db: DBType) { return { findAll: db.query.user.findMany().prepare(), @@ -30,19 +44,6 @@ function createPreparedStatement(db: DBType) { where: eq(user.username, sql.placeholder('username')), }) .prepare(), - findByProviderId: db.query.user - .findFirst({ - where: and( - eq(user.oauthProvider, sql.placeholder('oauthProvider')), - eq(user.oauthId, sql.placeholder('oauthId')) - ), - }) - .prepare(), - findByEmail: db.query.user - .findFirst({ - where: eq(user.email, sql.placeholder('email')), - }) - .prepare(), update: db .update(user) .set({ @@ -83,74 +84,6 @@ export class UserService { return this.#statements.findByUsername.execute({ username }); } - async getByProviderId(provider: OAUTH_PROVIDER, oauthId: string) { - return this.#statements.findByProviderId.execute({ - oauthProvider: provider, - oauthId, - }); - } - - async getByEmail(email: string) { - return this.#statements.findByEmail.execute({ email }); - } - - // TODO: improve, use transaction - async findOrCreateByProvider( - provider: OAUTH_PROVIDER, - oauthId: string, - username: string, - email: string, - name: string - ) { - // Try to find by id - let existingUser = await this.getByProviderId(provider, oauthId); - if (existingUser) { - if (!existingUser.enabled) { - return { success: false as const, error: 'USER_DISABLED' as const }; - } - return { success: true as const, user: existingUser }; - } - - // Try to find by email - existingUser = await this.getByEmail(email); - if (existingUser) { - if (!existingUser.enabled) { - return { success: false as const, error: 'USER_DISABLED' as const }; - } - if (existingUser.oauthProvider && existingUser.oauthId) { - return { - success: false as const, - error: 'USER_ALREADY_LINKED' as const, - }; - } - await this.#db - .update(user) - .set({ oauthProvider: provider, oauthId: oauthId }) - .where(eq(user.id, existingUser.id)) - .execute(); - return { success: true as const, user: existingUser }; - } - - // Create new user - await this.#db.insert(user).values({ - username, - password: null, - email, - name, - role: roles.ADMIN, - totpVerified: false, - enabled: true, - oauthProvider: provider, - oauthId, - }); - - const newUser = await this.getByProviderId(provider, oauthId); - if (!newUser) { - return { success: false as const, error: 'UNEXPECTED_ERROR' as const }; - } - return { success: true as const, user: newUser }; - } - async create(username: string, password: string) { const hash = await hashPassword(password); @@ -339,6 +272,90 @@ export class UserService { }); } + /** + * Login or register user with OAuth provider. + * If user with the same email already exists, link account with OAuth provider. + * Otherwise, create new user. + */ + async loginWithOAuth( + provider: OAUTH_PROVIDER, + oauthId: string, + username: string, + email: string, + name: string + ): Promise { + return this.#db.transaction(async (tx) => { + const userById = await tx.query.user + .findFirst({ + where: and( + eq(user.oauthProvider, provider), + eq(user.oauthId, oauthId) + ), + }) + .execute(); + + if (userById) { + if (!userById.enabled) { + return { success: false, error: 'USER_DISABLED' }; + } + return { success: true, user: userById }; + } + + const userByEmail = await tx.query.user + .findFirst({ + where: eq(user.email, email), + }) + .execute(); + + if (userByEmail) { + if (!userByEmail.enabled) { + return { success: false, error: 'USER_DISABLED' }; + } + if (userByEmail.oauthProvider && userByEmail.oauthId) { + return { + success: false, + error: 'USER_ALREADY_LINKED', + }; + } + + await tx + .update(user) + .set({ oauthProvider: provider, oauthId: oauthId }) + .where(eq(user.id, userByEmail.id)) + .execute(); + + // TODO: return updated user + return { success: true, user: userByEmail }; + } + + if (!WG_ENV.OAUTH_AUTO_REGISTER) { + return { success: false, error: 'AUTO_REGISTER_DISABLED' }; + } + + // Create new user + const newUsers = await tx + .insert(user) + .values({ + username, + password: null, + email, + name, + role: roles.ADMIN, + totpVerified: false, + enabled: true, + oauthProvider: provider, + oauthId, + }) + .returning(); + const newUser = newUsers[0]; + + if (!newUser) { + return { success: false as const, error: 'UNEXPECTED_ERROR' as const }; + } + return { success: true as const, user: newUser }; + }); + } + unlinkOauth(id: ID) { return this.#db.transaction(async (tx) => { const txUser = await tx.query.user diff --git a/src/server/utils/config.ts b/src/server/utils/config.ts index ad6e80dd..e9e22234 100644 --- a/src/server/utils/config.ts +++ b/src/server/utils/config.ts @@ -38,13 +38,17 @@ export const WG_ENV = { /** If IPv6 should be disabled */ DISABLE_IPV6: process.env.DISABLE_IPV6 === 'true', WG_EXECUTABLE: await detectAwg(), + /** List of enabled OAuth providers */ OAUTH_PROVIDERS: process.env.OAUTH_PROVIDERS?.split(',') .map((v) => v.trim()) .filter((v) => isValidOauthProvider(v)) .filter((v) => isConfiguredOauthProvider(OAUTH_PROVIDERS[v])), + /** List of allowed OAuth domains */ OAUTH_ALLOWED_DOMAINS: process.env.OAUTH_ALLOWED_DOMAINS?.split(',').map( (v) => v.trim() ), + /** Automatically register users that log in with an OAuth provider */ + OAUTH_AUTO_REGISTER: process.env.OAUTH_AUTO_REGISTER === 'true', }; if (WG_ENV.OAUTH_PROVIDERS && WG_ENV.OAUTH_PROVIDERS.length > 1) { diff --git a/src/server/utils/oauth.ts b/src/server/utils/oauth.ts index 14036e19..dddbef26 100644 --- a/src/server/utils/oauth.ts +++ b/src/server/utils/oauth.ts @@ -1,5 +1,4 @@ import type { H3Event } from 'h3'; -import type { Configuration } from 'openid-client'; import * as client from 'openid-client'; type OAuthConfig = { @@ -174,7 +173,7 @@ type OauthState = { export async function getUserInfo( event: H3Event, - config: Configuration, + config: client.Configuration, state: OauthState, providerConfig: OAuthConfig ) { @@ -208,12 +207,7 @@ export async function getUserInfo( userInfo = await client.fetchUserInfo(config, tokens.access_token, subject); } - if (!hasOauthProps(userInfo)) { - throw createError({ - statusCode: 400, - statusMessage: 'Invalid user info', - }); - } + assertHasOauthProps(userInfo); if (!isAllowedDomain(userInfo.email)) { throw createError({ @@ -225,11 +219,11 @@ export async function getUserInfo( return userInfo; } -function hasOauthProps< - T extends { sub?: string; email?: string; email_verified?: boolean }, ->( +type RequireKeys = Required>; + +function assertHasOauthProps( userInfo: T -): userInfo is T & { sub: string; email: string; email_verified: boolean } { +): asserts userInfo is T & RequireKeys { if (!userInfo.sub) { throw createError({ statusCode: 400, @@ -250,8 +244,6 @@ function hasOauthProps< statusMessage: 'Email is not verified', }); } - - return true; } function isAllowedDomain(email: string) {