diff --git a/gel/_internal/_auth/_magic_link.py b/gel/_internal/_auth/_magic_link.py new file mode 100644 index 000000000..c0dda8bff --- /dev/null +++ b/gel/_internal/_auth/_magic_link.py @@ -0,0 +1,698 @@ +# SPDX-PackageName: gel-python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. + +from __future__ import annotations +from typing import Any, Optional, TypeVar, overload, Literal, TYPE_CHECKING + +import dataclasses +import http +import logging + +import httpx + +import gel +from gel import blocking_client +from gel._internal._polyfills._strenum import StrEnum + +from . import _base as base +from . import _pkce as pkce_mod +from . import _token_data as td_mod + +logger = logging.getLogger("gel.auth") +C = TypeVar("C", bound=httpx.Client | httpx.AsyncClient) + + +class VerificationMethod(StrEnum): + LINK = "Link" + CODE = "Code" + + +@dataclasses.dataclass +class MagicLinkSentResponse: + email_sent: str + verifier: str + signup: bool + + +@dataclasses.dataclass +class MagicLinkFailedResponse(base.BaseServerFailedResponse): + verifier: str + + +MagicLinkResponse = MagicLinkSentResponse | MagicLinkFailedResponse + + +@dataclasses.dataclass +class MagicCodeSentResponse: + signup: bool + email: str + + +@dataclasses.dataclass +class MagicCodeFailedResponse(base.BaseServerFailedResponse): + pass + + +MagicCodeResponse = MagicCodeSentResponse | MagicCodeFailedResponse + + +@dataclasses.dataclass +class AuthenticateLinkResultResponse: + code: str + + +@dataclasses.dataclass +class AuthenticateLinkFailedResponse(base.BaseServerFailedResponse): + pass + + +AuthenticateLinkResponse = ( + AuthenticateLinkResultResponse | AuthenticateLinkFailedResponse +) + + +@dataclasses.dataclass +class AuthenticateCodeResultResponse: + code: str + verifier: str + + +@dataclasses.dataclass +class AuthenticateCodeFailedResponse(base.BaseServerFailedResponse): + verifier: str + + +AuthenticateCodeResponse = ( + AuthenticateCodeResultResponse | AuthenticateCodeFailedResponse +) + + +class BaseMagicLink(base.BaseClient[C]): + def __init__( + self, + *, + connection_info: gel.ConnectionInfo, + **kwargs: Any, + ) -> None: + self.provider = "builtin::local_magic_link" + super().__init__(connection_info=connection_info, **kwargs) + + async def _request_link( + self, + email: str, + *, + is_sign_up: bool, + callback_url: str, + redirect_on_failure: str, + link_url: Optional[str] = None, + parse_redirect_as_error: bool = False, + ) -> MagicLinkResponse: + title = "register" if is_sign_up else "sign in" + logger.info("signing %s user: %s", "up" if is_sign_up else "in", email) + pkce = self._generate_pkce() + data = { + "provider": self.provider, + "email": email, + "challenge": pkce.challenge, + "callback_url": callback_url, + "redirect_on_failure": redirect_on_failure, + } + if link_url is not None: + data["link_url"] = link_url + register_response = await self._http_request( + "POST", + "/magic-link/register" if is_sign_up else "/magic-link/email", + json=data, + headers={"Accept": "application/json"}, + ) + if parse_redirect_as_error and register_response.has_redirect_location: + # On Gel 5.x/6.x, the /magic-link/email endpoint does not return + # JSON error responses, but instead redirects to the failure URL + # with an error query parameter. This bug is fixed in Gel 7.0+. + failure_url = httpx.URL(register_response.headers["Location"]) + error = failure_url.params.get("error", "unknown error") + logger.error("%s error: %s", title, error) + return MagicLinkFailedResponse( + verifier=pkce.verifier, + status_code=http.HTTPStatus.BAD_REQUEST, + message=error, + ) + try: + register_response.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error("%s error: %s", title, e) + return MagicLinkFailedResponse( + verifier=pkce.verifier, + status_code=e.response.status_code, + message=e.response.text, + ) + register_json = register_response.json() + if "error" in register_json: + error = register_json["error"] + logger.error("%s error: %s", title, error) + return MagicLinkFailedResponse( + verifier=pkce.verifier, + status_code=register_response.status_code, + message=error, + ) + else: + email_sent = register_json["email_sent"] + signup = ( + register_json.get("signup", str(is_sign_up)).lower() == "true" + ) + logger.info("the magic link is sent to: %r", email_sent) + logger.debug( + "Sign-up: %s, PKCE verifier: %s", signup, pkce.verifier + ) + return MagicLinkSentResponse( + email_sent=email_sent, + verifier=pkce.verifier, + signup=signup, + ) + + async def _request_code( + self, + email: str, + *, + is_sign_up: bool, + ) -> MagicCodeResponse: + title = "register" if is_sign_up else "sign in" + logger.info("signing %s user: %s", "up" if is_sign_up else "in", email) + data = { + "provider": self.provider, + "email": email, + } + register_response = await self._http_request( + "POST", + "/magic-link/register" if is_sign_up else "/magic-link/email", + json=data, + headers={"Accept": "application/json"}, + ) + try: + register_response.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error("%s error: %s", title, e) + return MagicCodeFailedResponse( + status_code=e.response.status_code, + message=e.response.text, + ) + register_json = register_response.json() + if "error" in register_json: + error = register_json["error"] + logger.error("%s error: %s", title, error) + return MagicCodeFailedResponse( + status_code=register_response.status_code, + message=error, + ) + else: + email = register_json["email"] + signup = ( + register_json.get("signup", str(is_sign_up)).lower() == "true" + ) + logger.info("the magic code is sent to: %r", email) + logger.debug("Sign-up: %s", signup) + return MagicCodeSentResponse(email=email, signup=signup) + + async def _authenticate_link(self, token: str) -> AuthenticateLinkResponse: + logger.info("authenticating magic link token") + logger.debug("token: %s", token) + response = await self._http_request( + "GET", + httpx.URL("/magic-link/authenticate").copy_add_param( + "token", token + ), + headers={"Accept": "application/json"}, + ) + if response.has_redirect_location: + # /magic-link/authenticate redirects to the callback URL, but we + # want to return the code directly. + redirect_url = httpx.URL(response.headers["Location"]) + code = redirect_url.params.get("code") + if code is not None: + logger.info("authentication succeeded") + logger.debug("code: %s", code) + return AuthenticateLinkResultResponse(code=code) + else: + logger.error("authentication failed: missing code") + return AuthenticateLinkFailedResponse( + status_code=http.HTTPStatus.BAD_GATEWAY, + message="missing code in redirect URL", + ) + elif response.is_success: + logger.error("authentication failed: expected redirect") + return AuthenticateLinkFailedResponse( + status_code=response.status_code, + message="expected redirect but got response", + ) + else: + logger.error( + "authentication failed: [%d] %s", + response.status_code, + response.text, + ) + return AuthenticateLinkFailedResponse( + status_code=response.status_code, + message=response.text, + ) + + async def _authenticate_code( + self, email: str, code: str + ) -> AuthenticateCodeResponse: + logger.info("authenticating by magic code") + logger.debug("email: %r, code: %s", email, code) + pkce = self._generate_pkce() + response = await self._http_request( + "POST", + "/magic-link/authenticate", + json={"email": email, "code": code, "challenge": pkce.challenge}, + headers={"Accept": "application/json"}, + ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error("authentication failed: %s", e) + return AuthenticateCodeFailedResponse( + verifier=pkce.verifier, + status_code=e.response.status_code, + message=e.response.text, + ) + response_json = response.json() + if "error" in response_json: + error = response_json["error"] + logger.error("authentication failed: %s", error) + return AuthenticateCodeFailedResponse( + verifier=pkce.verifier, + status_code=response.status_code, + message=error, + ) + elif "code" in response_json: + pkce_code = response_json["code"] + logger.info("authentication succeeded") + logger.debug("PKCE code: %s", pkce_code) + return AuthenticateCodeResultResponse( + code=pkce_code, verifier=pkce.verifier + ) + else: + logger.error("authentication failed: missing code") + return AuthenticateCodeFailedResponse( + verifier=pkce.verifier, + status_code=http.HTTPStatus.BAD_GATEWAY, + message="missing code in response", + ) + + async def _get_token( + self, *, verifier: Optional[str], code: str + ) -> td_mod.TokenData: + if verifier is None: + raise ValueError("verifier is required to get token") + pkce = self._pkce_from_verifier(verifier) + logger.info("exchanging code for token: %s", code) + return await pkce.internal_exchange_code_for_token(code) + + +class BaseBlockingIOMagicLink(BaseMagicLink[httpx.Client]): + def _init_http_client(self, **kwargs: Any) -> httpx.Client: + return httpx.Client(**kwargs) + + def _generate_pkce(self) -> pkce_mod.PKCE: + return pkce_mod.generate_pkce(self._client) + + def _pkce_from_verifier(self, verifier: str) -> pkce_mod.PKCE: + return pkce_mod.PKCE(self._client, verifier) + + async def _send_http_request( + self, request: httpx.Request + ) -> httpx.Response: + return self._client.send(request) + + def get_token( + self, *, verifier: Optional[str], code: str + ) -> td_mod.TokenData: + return blocking_client.iter_coroutine( + self._get_token(verifier=verifier, code=code) + ) + + +class LegacyMagicLink(BaseBlockingIOMagicLink): + def sign_up( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + ) -> MagicLinkResponse: + return blocking_client.iter_coroutine( + self._request_link( + email, + is_sign_up=True, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + ) + ) + + def sign_in( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + ) -> MagicLinkResponse: + return blocking_client.iter_coroutine( + self._request_link( + email, + is_sign_up=False, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + parse_redirect_as_error=True, + ) + ) + + +class MagicLink(BaseBlockingIOMagicLink): + def __init__( + self, + *, + parse_redirect_as_error: bool = False, + connection_info: gel.ConnectionInfo, + **kwargs: Any, + ) -> None: + self.parse_redirect_as_error = parse_redirect_as_error + super().__init__(connection_info=connection_info, **kwargs) + + def sign_up( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + link_url: Optional[str] = None, + ) -> MagicLinkResponse: + return blocking_client.iter_coroutine( + self._request_link( + email, + is_sign_up=True, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + link_url=link_url, + ) + ) + + def sign_in( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + link_url: Optional[str] = None, + ) -> MagicLinkResponse: + return blocking_client.iter_coroutine( + self._request_link( + email, + is_sign_up=False, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + link_url=link_url, + parse_redirect_as_error=self.parse_redirect_as_error, + ) + ) + + def authenticate(self, token: str) -> AuthenticateLinkResponse: + return blocking_client.iter_coroutine(self._authenticate_link(token)) + + +class MagicCode(BaseBlockingIOMagicLink): + def sign_up(self, email: str) -> MagicCodeResponse: + return blocking_client.iter_coroutine( + self._request_code(email, is_sign_up=True) + ) + + def sign_in(self, email: str) -> MagicCodeResponse: + return blocking_client.iter_coroutine( + self._request_code(email, is_sign_up=False) + ) + + def authenticate(self, email: str, code: str) -> AuthenticateCodeResponse: + return blocking_client.iter_coroutine( + self._authenticate_code(email, code) + ) + + +def _validate_server_version( + server_major_version: int, + verification_method: VerificationMethod, +) -> None: + if server_major_version < 5: + raise gel.UnsupportedFeatureError( + "Magic link is not supported on Gel < 5.0" + ) + if ( + server_major_version < 7 + and verification_method == VerificationMethod.CODE + ): + raise gel.UnsupportedFeatureError( + "Magic code verification is not supported on Gel < 7.0" + ) + + +if TYPE_CHECKING: + AnyMagicLink = LegacyMagicLink | MagicLink | MagicCode + + +@overload +def make( + client: gel.Client, *, server_major_version: Literal[5] +) -> LegacyMagicLink: ... + + +@overload +def make( + client: gel.Client, + *, + server_major_version: int, + verification_method: Literal[VerificationMethod.LINK], +) -> MagicLink: ... + + +@overload +def make( + client: gel.Client, + *, + server_major_version: int, + verification_method: Literal[VerificationMethod.CODE], +) -> MagicCode: ... + + +@overload +def make( + client: gel.Client, + *, + server_major_version: Optional[int] = None, + verification_method: VerificationMethod = VerificationMethod.LINK, + cls: Optional[type[AnyMagicLink]] = None, +) -> AnyMagicLink: ... + + +def make( + client: gel.Client, + *, + server_major_version: Optional[int] = None, + verification_method: VerificationMethod = VerificationMethod.LINK, + cls: Optional[type[AnyMagicLink]] = None, +) -> AnyMagicLink: + if server_major_version is None: + server_major_version = client.query_required_single( + "select sys::get_version().major" + ) + assert isinstance(server_major_version, int) + _validate_server_version(server_major_version, verification_method) + args = {} + if cls is None: + if server_major_version < 6: + cls = LegacyMagicLink + elif server_major_version < 7: + cls = MagicLink + args["parse_redirect_as_error"] = True + elif verification_method == VerificationMethod.LINK: + cls = MagicLink + else: + cls = MagicCode + return cls(connection_info=client.check_connection(), **args) + + +class BaseAsyncMagicLink(BaseMagicLink[httpx.AsyncClient]): + def _init_http_client(self, **kwargs: Any) -> httpx.AsyncClient: + return httpx.AsyncClient(**kwargs) + + def _generate_pkce(self) -> pkce_mod.AsyncPKCE: + return pkce_mod.generate_async_pkce(self._client) + + def _pkce_from_verifier(self, verifier: str) -> pkce_mod.AsyncPKCE: + return pkce_mod.AsyncPKCE(self._client, verifier) + + async def _send_http_request( + self, request: httpx.Request + ) -> httpx.Response: + return await self._client.send(request) + + async def get_token( + self, *, verifier: Optional[str], code: str + ) -> td_mod.TokenData: + return await self._get_token(verifier=verifier, code=code) + + +class AsyncLegacyMagicLink(BaseAsyncMagicLink): + async def sign_up( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + ) -> MagicLinkResponse: + return await self._request_link( + email, + is_sign_up=True, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + ) + + async def sign_in( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + ) -> MagicLinkResponse: + return await self._request_link( + email, + is_sign_up=False, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + parse_redirect_as_error=True, + ) + + +class AsyncMagicLink(BaseAsyncMagicLink): + def __init__( + self, + *, + parse_redirect_as_error: bool = False, + connection_info: gel.ConnectionInfo, + **kwargs: Any, + ) -> None: + self.parse_redirect_as_error = parse_redirect_as_error + super().__init__(connection_info=connection_info, **kwargs) + + async def sign_up( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + link_url: Optional[str] = None, + ) -> MagicLinkResponse: + return await self._request_link( + email, + is_sign_up=True, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + link_url=link_url, + ) + + async def sign_in( + self, + email: str, + *, + callback_url: str, + redirect_on_failure: str, + link_url: Optional[str] = None, + ) -> MagicLinkResponse: + return await self._request_link( + email, + is_sign_up=False, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + link_url=link_url, + parse_redirect_as_error=self.parse_redirect_as_error, + ) + + async def authenticate(self, token: str) -> AuthenticateLinkResponse: + return await self._authenticate_link(token) + + +class AsyncMagicCode(BaseAsyncMagicLink): + async def sign_up(self, email: str) -> MagicCodeResponse: + return await self._request_code(email, is_sign_up=True) + + async def sign_in(self, email: str) -> MagicCodeResponse: + return await self._request_code(email, is_sign_up=False) + + async def authenticate( + self, email: str, code: str + ) -> AuthenticateCodeResponse: + return await self._authenticate_code(email, code) + + +if TYPE_CHECKING: + AnyAsyncMagicLink = AsyncLegacyMagicLink | AsyncMagicLink | AsyncMagicCode + + +@overload +async def make_async( + client: gel.AsyncIOClient, *, server_major_version: Literal[5] +) -> AsyncLegacyMagicLink: ... + + +@overload +async def make_async( + client: gel.AsyncIOClient, + *, + server_major_version: int, + verification_method: Literal[VerificationMethod.LINK], +) -> AsyncMagicLink: ... + + +@overload +async def make_async( + client: gel.AsyncIOClient, + *, + server_major_version: int, + verification_method: Literal[VerificationMethod.CODE], +) -> AsyncMagicCode: ... + + +@overload +async def make_async( + client: gel.AsyncIOClient, + *, + server_major_version: Optional[int] = None, + verification_method: VerificationMethod = VerificationMethod.LINK, + cls: Optional[type[AnyAsyncMagicLink]] = None, +) -> AnyAsyncMagicLink: ... + + +async def make_async( + client: gel.AsyncIOClient, + *, + server_major_version: Optional[int] = None, + verification_method: VerificationMethod = VerificationMethod.LINK, + cls: Optional[type[AnyAsyncMagicLink]] = None, +) -> AnyAsyncMagicLink: + if server_major_version is None: + server_major_version = await client.query_required_single( + "select sys::get_version().major" + ) + assert isinstance(server_major_version, int) + _validate_server_version(server_major_version, verification_method) + args = {} + if cls is None: + if server_major_version < 6: + cls = AsyncLegacyMagicLink + elif server_major_version < 7: + cls = AsyncMagicLink + args["parse_redirect_as_error"] = True + elif verification_method == VerificationMethod.LINK: + cls = AsyncMagicLink + else: + cls = AsyncMagicCode + return cls(connection_info=await client.check_connection(), **args) diff --git a/gel/_internal/_integration/_fastapi/_auth/__init__.py b/gel/_internal/_integration/_fastapi/_auth/__init__.py index c12fe57e7..d8adc1143 100644 --- a/gel/_internal/_integration/_fastapi/_auth/__init__.py +++ b/gel/_internal/_integration/_fastapi/_auth/__init__.py @@ -29,6 +29,7 @@ from ._email_password import EmailPassword from ._builtin_ui import BuiltinUI from ._oidc import OpenIDConnect + from ._magic_link import MagicLink _BUILTIN_OAUTH2_PROVIDERS = { @@ -74,6 +75,8 @@ class GelAuth(client_mod.Extension): _auto_builtin_ui: bool = True _manual_oidc_providers: list[str] _oidc_providers: dict[str, OpenIDConnect] + _magic_link: Optional[MagicLink] = None + _auto_magic_link: bool = True _on_new_identity_path = utils.Config("/") _on_new_identity_name = utils.Config("gel.fastapi.auth.on_new_identity") @@ -348,6 +351,13 @@ async def on_startup(self, app: fastapi.FastAPI) -> None: ): _ = self.email_password + case "builtin::local_magic_link": + if ( + self._auto_magic_link + and self._magic_link is None + ): + _ = self.magic_link + if ( config.ui is not None and self._auto_builtin_ui diff --git a/gel/_internal/_integration/_fastapi/_auth/_magic_link.py b/gel/_internal/_integration/_fastapi/_auth/_magic_link.py new file mode 100644 index 000000000..3c5c7b091 --- /dev/null +++ b/gel/_internal/_integration/_fastapi/_auth/_magic_link.py @@ -0,0 +1,558 @@ +# SPDX-PackageName: gel-python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. + +from __future__ import annotations +from typing import Annotated, Optional + +import http +import logging + +import fastapi +import pydantic +from fastapi import responses +from starlette import concurrency + +from gel.auth import magic_link as core, TokenData + +from . import GelAuth, Installable +from .. import _utils as utils + + +logger = logging.getLogger("gel.auth") + + +class RequestBody(pydantic.BaseModel): + email: str + + +class AuthenticatedBody(pydantic.BaseModel): + email: str + code: str + + +class MagicLink(Installable): + sign_in_page_name = utils.Config("sign_in_page") + + _auth: GelAuth + _core: ( + core.AsyncLegacyMagicLink | core.AsyncMagicLink | core.AsyncMagicCode + ) + _blocking_io_core: core.LegacyMagicLink | core.MagicLink | core.MagicCode + + install_endpoints = utils.Config(True) # noqa: FBT003 + + # Request for magic link/code + request_body: utils.ConfigDecorator[type[RequestBody]] = ( + utils.ConfigDecorator(RequestBody) + ) + request_path = utils.Config("/magic-link") + request_name = utils.Config("gel.auth.magic_link.request") + request_summary = utils.Config("Request for magic link or code") + request_default_response_class = utils.Config(responses.RedirectResponse) + request_default_status_code = utils.Config(http.HTTPStatus.SEE_OTHER) + on_magic_link_sent: utils.Hook[RequestBody, core.MagicLinkSentResponse] = ( + utils.Hook("request") + ) + on_magic_link_failed: utils.Hook[ + RequestBody, core.MagicLinkFailedResponse + ] = utils.Hook("request") + on_magic_code_sent: utils.Hook[RequestBody, core.MagicCodeSentResponse] = ( + utils.Hook("request") + ) + on_magic_code_failed: utils.Hook[ + RequestBody, core.MagicCodeFailedResponse + ] = utils.Hook("request") + + # Callback + authenticate_body: utils.ConfigDecorator[type[AuthenticatedBody]] = ( + utils.ConfigDecorator(AuthenticatedBody) + ) + authenticate_path = utils.Config("/magic-link") + authenticate_name = utils.Config("gel.auth.magic_link.authenticate") + authenticate_summary = utils.Config("Handle the magic link authentication") + authenticate_default_response_class = utils.Config( + responses.RedirectResponse + ) + authenticate_default_status_code = utils.Config(http.HTTPStatus.SEE_OTHER) + on_authenticated: utils.Hook[TokenData] = utils.Hook("authenticate") + + def __init__(self, auth: GelAuth): + self._auth = auth + + def _redirect_success( + self, + request: fastapi.Request, + key: str, + ) -> fastapi.Response: + response_class: type[responses.RedirectResponse] = getattr( + self, f"{key}_default_response_class" + ).value + response_code = getattr(self, f"{key}_default_status_code").value + redirect_to = self._auth.redirect_to.value + redirect_to_page_name = self._auth.redirect_to_page_name.value + if redirect_to_page_name is not None: + return response_class( + url=request.url_for(redirect_to_page_name), + status_code=response_code, + ) + elif redirect_to is not None: + return response_class(url=redirect_to, status_code=response_code) + else: + raise RuntimeError( + "GelAuth should have either redirect_to or " + "redirect_to_page_name set" + ) + + def _redirect_error( + self, + request: fastapi.Request, + key: str, + **query_params: str, + ) -> fastapi.Response: + response_class: type[responses.RedirectResponse] = getattr( + self, f"{key}_default_response_class" + ).value + return response_class( + url=request.url_for( + self._auth.error_page_name.value + ).include_query_params(**query_params), + status_code=getattr(self, f"{key}_default_status_code").value, + ) + + def _redirect_sign_in( + self, + request: fastapi.Request, + key: str, + **query_params: str, + ) -> fastapi.Response: + response_class: type[responses.RedirectResponse] = getattr( + self, f"{key}_default_response_class" + ).value + return response_class( + url=request.url_for( + self.sign_in_page_name.value + ).include_query_params(**query_params), + status_code=getattr(self, f"{key}_default_status_code").value, + ) + + async def handle_request_link_result( + self, + request: fastapi.Request, + body: RequestBody, + result: core.MagicLinkSentResponse | core.MagicLinkFailedResponse, + ) -> fastapi.Response: + match result: + case core.MagicLinkSentResponse(email_sent=email, signup=signup): + if signup: + # The Gel server is not returning the identity_id on + # sign-up magic link request, so we need to fetch it here. + identity_id = ( + await self._auth.client.query_required_single( + """ + select ( + select ext::auth::MagicLinkFactor + filter .email = $email + ).identity.id; + """, + email=email, + ) + ) + id_response = await self._auth.handle_new_identity( + request, identity_id, None + ) + else: + id_response = None + if id_response is not None: + response = id_response + elif self.on_magic_link_sent.is_set(): + response = await self.on_magic_link_sent.call( + request, body, result + ) + else: + response = self._redirect_sign_in( + request, "request", incomplete="magic_link_sent" + ) + + case core.MagicLinkFailedResponse(): + logger.info( + "[%d] failed requesting for magic link: %s", + result.status_code, + result.message, + ) + logger.debug("%r", result) + + if self.on_magic_link_failed.is_set(): + response = await self.on_magic_link_failed.call( + request, body, result + ) + else: + response = self._redirect_error( + request, "request", error=result.message + ) + + case _: + raise AssertionError("request returned unknown response") + + self._auth.set_verifier_cookie(result.verifier, response) + return response + + async def handle_pkce_code_with_verifier( + self, + code: str, + *, + request: fastapi.Request, + verifier: Optional[str] = None, + ) -> fastapi.Response: + try: + token_data = await self._core.get_token( + verifier=verifier, code=code + ) + except Exception as e: + response = self._redirect_error( + request, "authenticate", error=str(e) + ) + else: + if self.on_authenticated.is_set(): + with self._auth.with_auth_token( + token_data.auth_token, request + ): + response = await self.on_authenticated.call( + request, token_data + ) + else: + response = self._redirect_success(request, "on_authenticated") + self._auth.set_auth_cookie( + token_data.auth_token, response=response + ) + return response + + def __install_legacy_request_link( + self, + router: fastapi.APIRouter, + magic_link: core.AsyncLegacyMagicLink, + ) -> None: + async def request_link( + request_body: Annotated[ + RequestBody, utils.OneOf(fastapi.Form(), fastapi.Body()) + ], + is_sign_up: Annotated[ + bool, + fastapi.Query(alias="isSignUp", default=False), + ], + request: fastapi.Request, + ) -> fastapi.Response: + callback_url = str(request.url_for(self.authenticate_name.value)) + redirect_on_failure = str( + request.url_for(self._auth.error_page_name.value) + ) + if is_sign_up: + result = await magic_link.sign_up( + request_body.email, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + ) + else: + result = await magic_link.sign_in( + request_body.email, + callback_url=callback_url, + redirect_on_failure=redirect_on_failure, + ) + return await self.handle_request_link_result( + request, request_body, result + ) + + request_link.__globals__["RequestBody"] = self.request_body.value + + router.post( + self.request_path.value, + name=self.request_name.value, + summary=self.request_summary.value, + )(request_link) + + def __install_legacy_authenticate_link( + self, router: fastapi.APIRouter + ) -> None: + @router.get( + self.authenticate_path.value, + name=self.authenticate_name.value, + summary=self.authenticate_summary.value, + ) + async def authenticate( + request: fastapi.Request, + code: str, + verifier: Optional[str] = fastapi.Depends( + self._auth.pkce_verifier + ), + ) -> fastapi.Response: + return await self.handle_pkce_code_with_verifier( + code, request=request, verifier=verifier + ) + + def __install_request_link( + self, + router: fastapi.APIRouter, + magic_link: core.AsyncMagicLink, + ) -> None: + async def request_link( + request_body: Annotated[ + RequestBody, utils.OneOf(fastapi.Form(), fastapi.Body()) + ], + is_sign_up: Annotated[ + bool, + fastapi.Query(alias="isSignUp", default=False), + ], + request: fastapi.Request, + ) -> fastapi.Response: + callback_url = str(request.url_for(self.authenticate_name.value)) + redirect_on_failure = str( + request.url_for(self._auth.error_page_name.value) + ) + if is_sign_up: + result = await magic_link.sign_up( + request_body.email, + callback_url=callback_url, # not used in favor of link_url + redirect_on_failure=redirect_on_failure, + link_url=callback_url, + ) + else: + result = await magic_link.sign_in( + request_body.email, + callback_url=callback_url, # not used in favor of link_url + redirect_on_failure=redirect_on_failure, + link_url=callback_url, + ) + return await self.handle_request_link_result( + request, request_body, result + ) + + request_link.__globals__["RequestBody"] = self.request_body.value + + router.post( + self.request_path.value, + name=self.request_name.value, + summary=self.request_summary.value, + )(request_link) + + def __install_authenticate_link( + self, router: fastapi.APIRouter, magic_link: core.AsyncMagicLink + ) -> None: + @router.get( + self.authenticate_path.value, + name=self.authenticate_name.value, + summary=self.authenticate_summary.value, + ) + async def authenticate( + request: fastapi.Request, + token: str, + verifier: Optional[str] = fastapi.Depends( + self._auth.pkce_verifier + ), + ) -> fastapi.Response: + response = await magic_link.authenticate(token) + match response: + case core.AuthenticateLinkResultResponse(code=code): + return await self.handle_pkce_code_with_verifier( + code, request=request, verifier=verifier + ) + + case core.AuthenticateLinkFailedResponse(): + logger.info( + "[%d] failed authenticating magic link: %s", + response.status_code, + response.message, + ) + return self._redirect_error( + request, "authenticate", error=response.message + ) + + case _: + raise AssertionError( + "authenticate returned unknown response" + ) + + async def _install_link( + self, router: fastapi.APIRouter, server_major_version: int + ) -> None: + self._blocking_io_core = await concurrency.run_in_threadpool( + core.make, + self._auth.blocking_io_core, + server_major_version=server_major_version, + verification_method=core.VerificationMethod.LINK, + ) + if server_major_version == 5: + self._core = legacy_magic_link = await core.make_async( + self._auth.client, + server_major_version=5, + ) + if self.install_endpoints.value: + self.__install_legacy_request_link(router, legacy_magic_link) + self.__install_legacy_authenticate_link(router) + else: + self._core = magic_link = await core.make_async( + self._auth.client, + server_major_version=server_major_version, + verification_method=core.VerificationMethod.LINK, + ) + if self.install_endpoints.value: + self.__install_request_link(router, magic_link) + self.__install_authenticate_link(router, magic_link) + + def __install_request_code( + self, + router: fastapi.APIRouter, + magic_code: core.AsyncMagicCode, + ) -> None: + async def request_code( + request_body: Annotated[ + RequestBody, utils.OneOf(fastapi.Form(), fastapi.Body()) + ], + is_sign_up: Annotated[ + bool, + fastapi.Query(alias="isSignUp", default=False), + ], + request: fastapi.Request, + ) -> fastapi.Response: + if is_sign_up: + result = await magic_code.sign_up(request_body.email) + else: + result = await magic_code.sign_in(request_body.email) + match result: + case core.MagicCodeSentResponse(signup=signup, email=email): + if signup: + # The Gel server is not returning the identity_id on + # sign-up magic code request, so we need to fetch it. + identity_id = ( + await self._auth.client.query_required_single( + """ + select ( + select ext::auth::MagicCodeFactor + filter .email = $email + ).identity.id; + """, + email=email, + ) + ) + id_response = await self._auth.handle_new_identity( + request, identity_id, None + ) + else: + id_response = None + if id_response is not None: + return id_response + elif self.on_magic_code_sent.is_set(): + return await self.on_magic_code_sent.call( + request, request_body, result + ) + else: + return self._redirect_sign_in( + request, "request", incomplete="magic_code_sent" + ) + + case core.MagicCodeFailedResponse(): + logger.info( + "[%d] failed requesting for magic code: %s", + result.status_code, + result.message, + ) + logger.debug("%r", result) + + if self.on_magic_code_failed.is_set(): + return await self.on_magic_code_failed.call( + request, request_body, result + ) + else: + return self._redirect_error( + request, "request", error=result.message + ) + + case _: + raise AssertionError("request returned unknown response") + + request_code.__globals__["RequestBody"] = self.request_body.value + + router.post( + self.request_path.value, + name=self.request_name.value, + summary=self.request_summary.value, + )(request_code) + + def __install_authenticate_code( + self, router: fastapi.APIRouter, magic_code: core.AsyncMagicCode + ) -> None: + @router.put( + self.authenticate_path.value, + name=self.authenticate_name.value, + summary=self.authenticate_summary.value, + ) + async def authenticate( + request: fastapi.Request, + body: Annotated[ + AuthenticatedBody, + utils.OneOf(fastapi.Form(), fastapi.Body()), + ], + ) -> fastapi.Response: + result = await magic_code.authenticate(body.email, body.code) + match result: + case core.AuthenticateCodeResultResponse(): + return await self.handle_pkce_code_with_verifier( + result.code, request=request, verifier=result.verifier + ) + + case core.AuthenticateCodeFailedResponse(): + logger.info( + "[%d] failed authenticating magic code: %s", + result.status_code, + result.message, + ) + logger.debug("%r", result) + return self._redirect_error( + request, "authenticate", error=result.message + ) + + case _: + raise AssertionError( + "authenticate returned unknown response" + ) + + async def _install_code( + self, router: fastapi.APIRouter, server_major_version: int + ) -> None: + self._blocking_io_core = await concurrency.run_in_threadpool( + core.make, + self._auth.blocking_io_core, + server_major_version=server_major_version, + verification_method=core.VerificationMethod.CODE, + ) + self._core = magic_code = await core.make_async( + self._auth.client, + server_major_version=server_major_version, + verification_method=core.VerificationMethod.CODE, + ) + if self.install_endpoints.value: + self.__install_request_code(router, magic_code) + self.__install_authenticate_code(router, magic_code) + + async def install(self, router: fastapi.APIRouter) -> None: + config = await self._auth.client.query_required_single( + """ + select assert_single( + cfg::Config.extensions[is ext::auth::AuthConfig] + .providers[is ext::auth::MagicLinkProviderConfig] + ) { + *, + server_major_version := (select sys::get_version()).major, + } + """ + ) + server_major_version = config.server_major_version + if hasattr(config, "verification_method"): + method = core.VerificationMethod(config.verification_method) + else: + method = core.VerificationMethod.LINK + if method == core.VerificationMethod.LINK: + await self._install_link(router, server_major_version) + else: + await self._install_code(router, server_major_version) + + await super().install(router) diff --git a/gel/auth/magic_link.py b/gel/auth/magic_link.py new file mode 100644 index 000000000..7054a5e23 --- /dev/null +++ b/gel/auth/magic_link.py @@ -0,0 +1,43 @@ +# SPDX-PackageName: gel-python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. + +from gel._internal._auth._magic_link import ( + AsyncLegacyMagicLink, + AsyncMagicLink, + AsyncMagicCode, + AuthenticateCodeResultResponse, + AuthenticateCodeFailedResponse, + AuthenticateLinkResultResponse, + AuthenticateLinkFailedResponse, + LegacyMagicLink, + MagicLink, + MagicLinkSentResponse, + MagicLinkFailedResponse, + MagicCode, + MagicCodeSentResponse, + MagicCodeFailedResponse, + make, + make_async, + VerificationMethod, +) + +__all__ = [ + "AsyncLegacyMagicLink", + "AsyncMagicLink", + "AsyncMagicCode", + "AuthenticateCodeResultResponse", + "AuthenticateCodeFailedResponse", + "AuthenticateLinkResultResponse", + "AuthenticateLinkFailedResponse", + "LegacyMagicLink", + "MagicLink", + "MagicLinkSentResponse", + "MagicLinkFailedResponse", + "MagicCode", + "MagicCodeSentResponse", + "MagicCodeFailedResponse", + "make", + "make_async", + "VerificationMethod", +] diff --git a/gel/fastapi/auth/magic_link.py b/gel/fastapi/auth/magic_link.py new file mode 100644 index 000000000..6c711526b --- /dev/null +++ b/gel/fastapi/auth/magic_link.py @@ -0,0 +1,11 @@ +# SPDX-PackageName: gel-python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. + +from gel._internal._integration._fastapi._auth._magic_link import ( + RequestBody, +) + +__all__ = [ + "RequestBody", +]