# -*- coding: utf-8 -*-
#
# Copyright 2020 Univention GmbH
#
# http://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <http://www.gnu.org/licenses/>.
import asyncio
import datetime
import logging
import uuid
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import httpx
import jwt
from async_property import async_property
from .exceptions import InvalidRequest, InvalidToken, NoObject, ServerError
DN = str
API_VERSION = "v1"
TOKEN_HASH_ALGORITHM = "HS256" # nosec
URL_BASE = "https://{host}/ucsschool/kelvin"
URL_TOKEN = f"{URL_BASE}/token"
URL_RESOURCE_CLASS = f"{URL_BASE}/{API_VERSION}/classes/"
URL_RESOURCE_ROLE = f"{URL_BASE}/{API_VERSION}/roles/"
URL_RESOURCE_SCHOOL = f"{URL_BASE}/{API_VERSION}/schools/"
URL_RESOURCE_USER = f"{URL_BASE}/{API_VERSION}/users/"
URL_RESOURCE_WORKGROUP = f"{URL_BASE}/{API_VERSION}/workgroups/"
logger = logging.getLogger(__name__)
[docs]class KelvinClientWarning(Warning):
...
[docs]class BadSettingsWarning(KelvinClientWarning):
...
[docs]@dataclass
class Token:
expiry: datetime.datetime
value: str
[docs] @classmethod
def from_str(cls, token_str: str) -> "Token":
try:
payload = jwt.decode(
token_str, algorithms=[TOKEN_HASH_ALGORITHM], options={"verify_signature": False}
)
except jwt.PyJWTError as exc:
raise InvalidToken(f"Error decoding token ({token_str!r}): {exc!s}")
if not isinstance(payload, dict) or "exp" not in payload:
raise InvalidToken(f"Payload in token not a dict or missing 'exp' entry ({token_str!r}).")
try:
expiry = datetime.datetime.utcfromtimestamp(payload["exp"])
except ValueError as exc:
raise InvalidToken(f"Error parsing date in token ({token_str!r}): {exc!s}")
return cls(expiry=expiry, value=token_str)
[docs] def is_valid(self) -> bool:
if not self.expiry or not self.value:
return False
return datetime.datetime.utcnow() <= self.expiry
[docs]class Session:
def __init__(
self,
username: str,
password: str,
host: str,
max_client_tasks: int = 10,
request_id: str = None,
request_id_header: str = "X-Request-ID",
language: str = None,
**kwargs,
):
if max_client_tasks < 4:
txt = "Raising value of 'max_client_tasks' to its minimum of 4."
warnings.warn(txt, BadSettingsWarning)
logger.warning(txt)
max_client_tasks = 4
self.max_client_tasks = max_client_tasks
self._client: Optional[httpx.AsyncClient] = None
self._client_task_limiter = asyncio.Semaphore(max_client_tasks)
self.username = username
self.password = password
self.host = host
self.request_id = request_id or uuid.uuid4().hex
self.request_id_header = request_id_header
self.language = language
self.kwargs = kwargs
self.urls = {
"token": URL_TOKEN.format(host=host),
"class": URL_RESOURCE_CLASS.format(host=host),
"role": URL_RESOURCE_ROLE.format(host=host),
"school": URL_RESOURCE_SCHOOL.format(host=host),
"user": URL_RESOURCE_USER.format(host=host),
"workgroup": URL_RESOURCE_WORKGROUP.format(host=host),
}
self._token: Optional[Token] = None
async def __aenter__(self):
self.open()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
[docs] def open(self) -> httpx.AsyncClient:
if not self._client:
self.kwargs["headers"] = self.kwargs.get("headers", {})
self.kwargs["headers"]["Access-Control-Expose-Headers"] = self.request_id_header
self.kwargs["headers"][self.request_id_header] = self.request_id
self._client = httpx.AsyncClient(**self.kwargs)
return self._client
[docs] async def close(self) -> None:
if self._client:
await self._client.aclose()
self._client = None
@property
def client(self) -> httpx.AsyncClient:
if not self._client:
raise RuntimeError("Session is closed.")
return self._client
[docs] @async_property
async def token(self) -> str:
if not self._token or not self._token.is_valid():
resp_json = await self.post(
self.urls["token"],
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={"username": self.username, "password": self.password},
)
self._token = Token.from_str(resp_json["access_token"])
return self._token.value
[docs] async def request(
self, async_request_method: Any, url: str, return_json: bool = True, **kwargs
) -> Union[str, int, Dict[str, Any]]:
if "headers" not in kwargs:
kwargs["headers"] = await self.json_headers
if "timeout" not in kwargs:
kwargs["timeout"] = self.kwargs.get("timeout", 10.0)
response: httpx.Response = await async_request_method(url, **kwargs)
try:
resp_json = response.json()
detail = resp_json["detail"] if "detail" in resp_json else ""
except ValueError:
detail = ""
resp_json = {}
if "Authorization" in kwargs["headers"]:
kwargs["headers"]["Authorization"] = 10 * "*"
if "data" in kwargs and "password" in kwargs["data"]:
kwargs["data"]["password"] = 10 * "*"
logger.debug(
"[%s] %s %r (**%r) -> %r %r%s",
self.request_id[:10],
async_request_method.__name__.upper(),
url,
kwargs,
response.status_code,
response.reason_phrase,
f" ({detail})" if detail else "",
)
if async_request_method == self.client.head:
return response.status_code
if 200 <= response.status_code <= 299:
return resp_json if return_json else response.text
elif response.status_code == 404:
raise NoObject(
f"Object not found ({async_request_method.__name__.upper()} {url!r}).",
reason=detail if detail else response.reason_phrase,
status=response.status_code,
url=url,
)
elif 400 <= response.status_code <= 499:
raise InvalidRequest(
f"Kelvin REST API returned status {response.status_code}, reason "
f"{response.reason_phrase!r}{f' ({detail})' if detail else ''} for "
f"{async_request_method.__name__.upper()} {url!r}.",
reason=detail if detail else response.reason_phrase,
status=response.status_code,
url=url,
)
else:
raise ServerError(
reason=response.reason_phrase, status=response.status_code, url=url
) # pragma: no cover
[docs] async def delete(self, url: str, **kwargs) -> None:
await self.request(self.client.delete, url, return_json=False, **kwargs)
[docs] async def get(self, url: str, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
return await self.request(self.client.get, url, **kwargs)
[docs] async def head(self, url: str, **kwargs) -> bool:
return await self.request(self.client.head, url, **kwargs)
# async def patch(self, url: str, **kwargs,) -> Dict[str, Any]:
# return await self.request(self.client.patch, url, **kwargs)
[docs] async def post(self, url: str, **kwargs) -> Dict[str, Any]:
return await self.request(self.client.post, url, **kwargs)
[docs] async def put(self, url: str, **kwargs) -> Dict[str, Any]:
return await self.request(self.client.put, url, **kwargs)