Source code for SlyAPI.oauth2

'''
Implementation for OAuth2.0 with PKCE as the `Auth` interface
https://datatracker.ietf.org/doc/html/rfc7636
'''
import asyncio
import base64
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from hashlib import sha256
import json
import secrets
from copy import copy
from typing import Any, Callable, ParamSpec, TypeVar, cast

from warnings import warn

from .auth import Auth
from .web import ApiError, JsonMap, Request, TomlMap, serve_once


import aiohttp
from aiohttp import ClientSession as Client

import urllib.parse

F_Params = ParamSpec('F_Params')
F_Return = TypeVar('F_Return')

# currently just a marker and has no effect
[docs]def requires_scopes(*_scopes: str) -> Callable[[Callable[F_Params, F_Return]], Callable[F_Params, F_Return]]: 'Mark a endpoint as requiring specific scopes to be used' def wrap(func: Callable[F_Params, F_Return]) -> Callable[F_Params, F_Return]: return func return wrap
[docs]@dataclass class OAuth2User: token: str refresh_token: str expires_at: datetime token_type: str = 'Bearer' scopes: list[str] = field(default_factory=list)
[docs] @classmethod def from_json_obj(cls, obj: TomlMap) -> 'OAuth2User': '''Read an app from a JSON object''' match obj: case { # JSON / self.to_dict() 'token': str(token), 'refresh_token': str(refresh_token), 'expires_at': str(expires_at_str), 'token_type': str(token_type), 'scopes': scopes }: try: expires_at = datetime.strptime(expires_at_str, '%Y-%m-%dT%H:%M:%S.%fZ') except ValueError: expires_at = datetime.strptime(expires_at_str, '%Y-%m-%dT%H:%M:%SZ') return cls(token, refresh_token, expires_at, token_type, cast(list[str], scopes)) case { # asdict(self) # TODO: eliminate this case? 'token': str(token), 'refresh_token': str(refresh_token), 'expires_at': datetime() as expires_at, 'token_type': str(token_type), 'scopes': scopes }: return cls(token, refresh_token, expires_at, token_type, cast(list[str], scopes)) case { 'access_token': str(token), 'expires_in': str(expires_in) | int(expires_in), 'token_type': str(token_type), **others }: # OAuth 2 grant response expires_in = int(expires_in) expires_at = datetime.utcnow() + timedelta(seconds=expires_in) refresh_token = others.get('refresh_token', '') scopes = cast(str, others.get('scope', '')).split(' ') if refresh_token is None: warn( "Google doesn't re-issue refresh tokens when you authorize a new token from the same application for the same user. That might be the case here, since a token grant was recieved without `refresh_token`! Refreshing these credentials will fail, consider revoking access at https://myaccount.google.com/permissions and re-authorizing." ) return cls(token, cast(str, refresh_token), expires_at, token_type, scopes) case { # mastodon 'access_token': str(token), 'token_type': str(token_type), 'scope': str(scopes), 'created_at': int(_stamp) }: expires_at = datetime(2400, 1, 1) # TODO: does this never expire? return cls(token, '', expires_at, token_type, scopes.split(' ')) case _: raise ValueError(F"Unknown format for OAuth2User: {obj}")
[docs] @classmethod def from_json_file(cls, path: str) -> 'OAuth2User': '''Read an app from a JSON file path''' with open(path, 'rb') as f: return cls.from_json_obj(json.load(f))
[docs] def to_dict(self) -> dict[str, Any]: return { 'token': self.token, 'refresh_token': self.refresh_token, 'expires_at': self.expires_at.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), 'token_type': self.token_type, 'scopes': self.scopes }
[docs]@dataclass class OAuth2App: id: str secret: str auth_uri: str # flow step 1 token_uri: str # flow step 3
[docs] @classmethod def from_json_obj(cls, obj: JsonMap) -> 'OAuth2App': '''Read an app from a JSON object, either from the Google Console JSON format or a set of kwargs''' match obj: case { # to_dict(self) 'id': str(id), 'secret': str(secret), 'auth_uri': str(auth_uri), 'token_uri': str(token_uri) } | { # Google JSON 'web': { 'client_id': str(id), 'client_secret': str(secret), 'auth_uri': str(auth_uri), 'token_uri': str(token_uri) } }: return cls(id, secret, auth_uri, token_uri) case _: raise ValueError(F"Unknown format for OAuth2App: {obj}")
[docs] @classmethod def from_json_file(cls, path: str) -> 'OAuth2App': '''Read an app from a JSON file path''' with open(path, 'rb') as f: return cls.from_json_obj(json.load(f))
[docs] def auth_url_with_pkce(self, redirect_uri: str, state: str, scopes: str) -> tuple[str, str, str]: state_challenge = secrets.token_urlsafe(54) code_verifier = secrets.token_urlsafe(54) verifier_hash = sha256(code_verifier.encode('utf-8')).digest() code_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=') params = { 'client_id': self.id, 'redirect_uri': redirect_uri, 'response_type': 'code', 'state': state+state_challenge, 'code_challenge': code_challenge, 'code_challenge_method': 'S256', 'scope': scopes, 'access_type': 'offline', 'prompt': 'consent', # re-issue refresh tokens! } return F"{self.auth_uri}?{urllib.parse.urlencode(params)}", code_verifier, state_challenge
[docs] async def refresh(self, client: Client, user: OAuth2User): data = { 'grant_type': 'refresh_token', 'refresh_token': user.refresh_token, 'client_id': self.id, 'client_secret': self.secret, } headers = {'Content-Type': 'application/x-www-form-urlencoded'} async with client.post(self.token_uri, data=data, headers=headers) as resp: if resp.status != 200: raise await ApiError.from_resposnse(resp) result = await resp.json() match result: case { 'access_token': token, 'expires_in': expires_str, 'token_type': token_type, **_others }: # OAuth 2 refresh response expiry = datetime.utcnow() + timedelta(seconds=int(expires_str)) new_user = copy(user) new_user.token = token new_user.expires_at = expiry new_user.token_type = token_type case _: raise ValueError(F"Invalid OAuth2 refresh response: {result}") return new_user
[docs] async def exchange_code(self, code: str, verifier: str, scopes: list[str], redirect_uri: str, client: Client|None=None) -> OAuth2User: grant_data = { 'grant_type': 'authorization_code', 'code': code, 'client_id': self.id, 'redirect_uri': redirect_uri, 'scope': scopes, 'code_verifier': verifier, } grant_headers = { 'Authorization': F"Basic {base64.b64encode(F'{self.id}:{self.secret}'.encode('utf-8')).decode('utf-8')}", 'Content-Type': 'application/x-www-form-urlencoded' } if client is None: do_req = aiohttp.request else: do_req = client.request async with do_req('POST', self.token_uri, data=grant_data, headers=grant_headers) as resp: if resp.status != 200: raise await ApiError.from_resposnse(resp) result = await resp.json() user = OAuth2User.from_json_obj(result) return user
[docs]@dataclass class OAuth2(Auth): """Provides the Auth interface implementation for OAuth2""" app: OAuth2App user: OAuth2User _refreshed: asyncio.Semaphore _refresh_callback: Callable[[OAuth2User], None] | None = None def __init__(self, app: OAuth2App|str, user: OAuth2User|str, on_refresh: Callable[[OAuth2User], None] | None = None): """Load an OAuth2 app and user from JSON files or existing objects.""" if isinstance(app, str): app = OAuth2App.from_json_file(app) if isinstance(user, str): user = OAuth2User.from_json_file(user) self.app = app self.user = user self._refresh_callback = on_refresh self._refreshed = asyncio.Semaphore()
[docs] async def sign(self, client: Client, request: Request) -> Request: await self._refreshed.acquire() if datetime.utcnow() > self.user.expires_at: # TODO: log refresh self.user = await self.app.refresh(client, self.user) self._refreshed.release() request.headers['Authorization'] = F"{self.user.token_type} {self.user.token}" return request
[docs]async def command_line_oauth2( app: OAuth2App, redirect_host: str, redirect_port: int, usePin: bool, scopes: list[str] ) -> OAuth2User: import webbrowser redirect_uri = F'http://{redirect_host}:{redirect_port}' # step 1: get the user to authorize the application grant_link, verifier, state = app.auth_url_with_pkce(redirect_uri, '', ' '.join(scopes)) if usePin or not webbrowser.open(grant_link, new=1, autoraise=True): print("Please open the following link in your browser:") print(grant_link) print("Then enter the code below:") code = input("code") else: # step 2: wait for the user to be redirected with the code query = await serve_once(redirect_host, redirect_port, 'step2.html') # challenge is state[-54:], but state is explicitly '' # BUT 54 is LENGTH IN BYTES OF RAW CHALLENGE, *NOT* the length of the base64-encoded challenge if 'state' not in query: raise PermissionError("Redirect did not return any state parameter.") elif query['state'] != state: raise PermissionError("Redirect did not return the correct state parameter.") elif 'code' not in query: raise PermissionError("Redirect did authorize grant.") code = query['code'] # step 3: exchange the code for access token return await app.exchange_code(code, verifier, scopes, redirect_uri)