Source code for ucsschool.kelvin.client.session

# -*- coding: utf-8 -*-
#
# Copyright 2020 Univention GmbH
#
# http://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <http://www.gnu.org/licenses/>.

import asyncio
import datetime
import logging
import uuid
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import httpx
import jwt
from async_property import async_property

from .exceptions import InvalidRequest, InvalidToken, NoObject, ServerError

DN = str

API_VERSION = "v1"
TOKEN_HASH_ALGORITHM = "HS256"  # nosec
URL_BASE = "https://{host}/ucsschool/kelvin"
URL_TOKEN = f"{URL_BASE}/token"
URL_RESOURCE_CLASS = f"{URL_BASE}/{API_VERSION}/classes/"
URL_RESOURCE_ROLE = f"{URL_BASE}/{API_VERSION}/roles/"
URL_RESOURCE_SCHOOL = f"{URL_BASE}/{API_VERSION}/schools/"
URL_RESOURCE_USER = f"{URL_BASE}/{API_VERSION}/users/"
URL_RESOURCE_WORKGROUP = f"{URL_BASE}/{API_VERSION}/workgroups/"
logger = logging.getLogger(__name__)


[docs]class KelvinClientWarning(Warning): ...
[docs]class BadSettingsWarning(KelvinClientWarning): ...
[docs]@dataclass class Token: expiry: datetime.datetime value: str
[docs] @classmethod def from_str(cls, token_str: str) -> "Token": try: payload = jwt.decode( token_str, algorithms=[TOKEN_HASH_ALGORITHM], options={"verify_signature": False} ) except jwt.PyJWTError as exc: raise InvalidToken(f"Error decoding token ({token_str!r}): {exc!s}") if not isinstance(payload, dict) or "exp" not in payload: raise InvalidToken(f"Payload in token not a dict or missing 'exp' entry ({token_str!r}).") try: expiry = datetime.datetime.utcfromtimestamp(payload["exp"]) except ValueError as exc: raise InvalidToken(f"Error parsing date in token ({token_str!r}): {exc!s}") return cls(expiry=expiry, value=token_str)
[docs] def is_valid(self) -> bool: if not self.expiry or not self.value: return False return datetime.datetime.utcnow() <= self.expiry
[docs]class Session: def __init__( self, username: str, password: str, host: str, max_client_tasks: int = 10, request_id: str = None, request_id_header: str = "X-Request-ID", language: str = None, **kwargs, ): if max_client_tasks < 4: txt = "Raising value of 'max_client_tasks' to its minimum of 4." warnings.warn(txt, BadSettingsWarning) logger.warning(txt) max_client_tasks = 4 self.max_client_tasks = max_client_tasks self._client: Optional[httpx.AsyncClient] = None self._client_task_limiter = asyncio.Semaphore(max_client_tasks) self.username = username self.password = password self.host = host self.request_id = request_id or uuid.uuid4().hex self.request_id_header = request_id_header self.language = language self.kwargs = kwargs self.urls = { "token": URL_TOKEN.format(host=host), "class": URL_RESOURCE_CLASS.format(host=host), "role": URL_RESOURCE_ROLE.format(host=host), "school": URL_RESOURCE_SCHOOL.format(host=host), "user": URL_RESOURCE_USER.format(host=host), "workgroup": URL_RESOURCE_WORKGROUP.format(host=host), } self._token: Optional[Token] = None async def __aenter__(self): self.open() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close()
[docs] def open(self) -> httpx.AsyncClient: if not self._client: self.kwargs["headers"] = self.kwargs.get("headers", {}) self.kwargs["headers"]["Access-Control-Expose-Headers"] = self.request_id_header self.kwargs["headers"][self.request_id_header] = self.request_id self._client = httpx.AsyncClient(**self.kwargs) return self._client
[docs] async def close(self) -> None: if self._client: await self._client.aclose() self._client = None
@property def client(self) -> httpx.AsyncClient: if not self._client: raise RuntimeError("Session is closed.") return self._client
[docs] @async_property async def token(self) -> str: if not self._token or not self._token.is_valid(): resp_json = await self.post( self.urls["token"], headers={"Content-Type": "application/x-www-form-urlencoded"}, data={"username": self.username, "password": self.password}, ) self._token = Token.from_str(resp_json["access_token"]) return self._token.value
[docs] @async_property async def json_headers(self) -> Dict[str, str]: headers = { "accept": "application/json", "Authorization": f"Bearer {await self.token}", "Content-Type": "application/json", } if self.language: headers["Accept-Language"] = self.language return headers
[docs] async def request( self, async_request_method: Any, url: str, return_json: bool = True, **kwargs ) -> Union[str, int, Dict[str, Any]]: if "headers" not in kwargs: kwargs["headers"] = await self.json_headers if "timeout" not in kwargs: kwargs["timeout"] = self.kwargs.get("timeout", 10.0) response: httpx.Response = await async_request_method(url, **kwargs) try: resp_json = response.json() detail = resp_json["detail"] if "detail" in resp_json else "" except ValueError: detail = "" resp_json = {} if "Authorization" in kwargs["headers"]: kwargs["headers"]["Authorization"] = 10 * "*" if "data" in kwargs and "password" in kwargs["data"]: kwargs["data"]["password"] = 10 * "*" logger.debug( "[%s] %s %r (**%r) -> %r %r%s", self.request_id[:10], async_request_method.__name__.upper(), url, kwargs, response.status_code, response.reason_phrase, f" ({detail})" if detail else "", ) if async_request_method == self.client.head: return response.status_code if 200 <= response.status_code <= 299: return resp_json if return_json else response.text elif response.status_code == 404: raise NoObject( f"Object not found ({async_request_method.__name__.upper()} {url!r}).", reason=detail if detail else response.reason_phrase, status=response.status_code, url=url, ) elif 400 <= response.status_code <= 499: raise InvalidRequest( f"Kelvin REST API returned status {response.status_code}, reason " f"{response.reason_phrase!r}{f' ({detail})' if detail else ''} for " f"{async_request_method.__name__.upper()} {url!r}.", reason=detail if detail else response.reason_phrase, status=response.status_code, url=url, ) else: raise ServerError( reason=response.reason_phrase, status=response.status_code, url=url ) # pragma: no cover
[docs] async def delete(self, url: str, **kwargs) -> None: await self.request(self.client.delete, url, return_json=False, **kwargs)
[docs] async def get(self, url: str, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]: return await self.request(self.client.get, url, **kwargs)
[docs] async def head(self, url: str, **kwargs) -> bool: return await self.request(self.client.head, url, **kwargs)
# async def patch(self, url: str, **kwargs,) -> Dict[str, Any]: # return await self.request(self.client.patch, url, **kwargs)
[docs] async def post(self, url: str, **kwargs) -> Dict[str, Any]: return await self.request(self.client.post, url, **kwargs)
[docs] async def put(self, url: str, **kwargs) -> Dict[str, Any]: return await self.request(self.client.put, url, **kwargs)