Skip to content

BasicConnection #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ repos:
hooks:
- id: mypy
files: ^arangoasync/
additional_dependencies: ['types-requests', "types-setuptools"]
additional_dependencies: ["types-requests", "types-setuptools"]
4 changes: 4 additions & 0 deletions arangoasync/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
import logging

from .version import __version__

logger = logging.getLogger(__name__)
74 changes: 74 additions & 0 deletions arangoasync/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
__all__ = [
"Auth",
"JwtToken",
]

from dataclasses import dataclass

import jwt


@dataclass
class Auth:
"""Authentication details for the ArangoDB instance.

Attributes:
username (str): Username.
password (str): Password.
encoding (str): Encoding for the password (default: utf-8)
"""

username: str
password: str
encoding: str = "utf-8"


class JwtToken:
"""JWT token.

Args:
token (str | bytes): JWT token.

Raises:
TypeError: If the token type is not str or bytes.
JWTExpiredError: If the token expired.
"""

def __init__(self, token: str | bytes) -> None:
self._token = token
self._validate()

@property
def token(self) -> str | bytes:
"""Get token."""
return self._token

@token.setter
def token(self, token: str | bytes) -> None:
"""Set token.

Raises:
jwt.ExpiredSignatureError: If the token expired.
"""
self._token = token
self._validate()

def _validate(self) -> None:
"""Validate the token."""
if type(self._token) not in (str, bytes):
raise TypeError("Token must be str or bytes")

jwt_payload = jwt.decode(
self._token,
issuer="arangodb",
algorithms=["HS256"],
options={
"require_exp": True,
"require_iat": True,
"verify_iat": True,
"verify_exp": True,
"verify_signature": False,
},
)

self._token_exp = jwt_payload["exp"]
118 changes: 118 additions & 0 deletions arangoasync/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
__all__ = [
"AcceptEncoding",
"ContentEncoding",
"CompressionManager",
"DefaultCompressionManager",
]

import zlib
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Optional


class AcceptEncoding(Enum):
"""Valid accepted encodings for the Accept-Encoding header."""

DEFLATE = auto()
GZIP = auto()
IDENTITY = auto()


class ContentEncoding(Enum):
"""Valid content encodings for the Content-Encoding header."""

DEFLATE = auto()
GZIP = auto()


class CompressionManager(ABC): # pragma: no cover
"""Abstract base class for handling request/response compression."""

@abstractmethod
def needs_compression(self, data: str | bytes) -> bool:
"""Determine if the data needs to be compressed

Args:
data (str | bytes): Data to check

Returns:
bool: True if the data needs to be compressed
"""
raise NotImplementedError

@abstractmethod
def compress(self, data: str | bytes) -> bytes:
"""Compress the data

Args:
data (str | bytes): Data to compress

Returns:
bytes: Compressed data
"""
raise NotImplementedError

@abstractmethod
def content_encoding(self) -> str:
"""Return the content encoding.

This is the value of the Content-Encoding header in the HTTP request.
Must match the encoding used in the compress method.

Returns:
str: Content encoding
"""
raise NotImplementedError

@abstractmethod
def accept_encoding(self) -> str | None:
"""Return the accept encoding.

This is the value of the Accept-Encoding header in the HTTP request.
Currently, only deflate and "gzip" are supported.

Returns:
str: Accept encoding
"""
raise NotImplementedError


class DefaultCompressionManager(CompressionManager):
"""Compress requests using the deflate algorithm.

Args:
threshold (int): Will compress requests to the server if
the size of the request body (in bytes) is at least the value of this option.
Setting it to -1 will disable request compression (default).
level (int): Compression level. Defaults to 6.
accept (str | None): Accepted encoding. By default, there is
no compression of responses.
"""

def __init__(
self,
threshold: int = -1,
level: int = 6,
accept: Optional[AcceptEncoding] = None,
) -> None:
self._threshold = threshold
self._level = level
self._content_encoding = ContentEncoding.DEFLATE.name.lower()
self._accept_encoding = accept.name.lower() if accept else None

def needs_compression(self, data: str | bytes) -> bool:
return self._threshold != -1 and len(data) >= self._threshold

def compress(self, data: str | bytes) -> bytes:
if data is not None:
if isinstance(data, bytes):
return zlib.compress(data, self._level)
return zlib.compress(data.encode("utf-8"), self._level)
return b""

def content_encoding(self) -> str:
return self._content_encoding

def accept_encoding(self) -> str | None:
return self._accept_encoding
174 changes: 174 additions & 0 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
__all__ = [
"BaseConnection",
"BasicConnection",
]

from abc import ABC, abstractmethod
from typing import Any, List, Optional

from arangoasync.auth import Auth
from arangoasync.compression import CompressionManager, DefaultCompressionManager
from arangoasync.exceptions import (
ClientConnectionError,
ConnectionAbortedError,
ServerConnectionError,
)
from arangoasync.http import HTTPClient
from arangoasync.request import Method, Request
from arangoasync.resolver import HostResolver
from arangoasync.response import Response


class BaseConnection(ABC):
"""Blueprint for connection to a specific ArangoDB database.

Args:
sessions (list): List of client sessions.
host_resolver (HostResolver): Host resolver.
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
"""

def __init__(
self,
sessions: List[Any],
host_resolver: HostResolver,
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
) -> None:
self._sessions = sessions
self._db_endpoint = f"/_db/{db_name}"
self._host_resolver = host_resolver
self._http_client = http_client
self._db_name = db_name
self._compression = compression or DefaultCompressionManager()

@property
def db_name(self) -> str:
"""Return the database name."""
return self._db_name

def prep_response(self, request: Request, resp: Response) -> Response:
"""Prepare response for return.

Args:
request (Request): Request object.
resp (Response): Response object.

Returns:
Response: Response object

Raises:
ServerConnectionError: If the response status code is not successful.
"""
resp.is_success = 200 <= resp.status_code < 300
if not resp.is_success:
raise ServerConnectionError(resp, request)
return resp

async def process_request(self, request: Request) -> Response:
"""Process request, potentially trying multiple hosts.

Args:
request (Request): Request object.

Returns:
Response: Response object.

Raises:
ConnectionAbortedError: If can't connect to host(s) within limit.
"""

ex_host_index = -1
host_index = self._host_resolver.get_host_index()
for tries in range(self._host_resolver.max_tries):
try:
resp = await self._http_client.send_request(
self._sessions[host_index], request
)
return self.prep_response(request, resp)
except ClientConnectionError:
ex_host_index = host_index
host_index = self._host_resolver.get_host_index()
if ex_host_index == host_index:
self._host_resolver.change_host()
host_index = self._host_resolver.get_host_index()

raise ConnectionAbortedError(
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
)

async def ping(self) -> int:
"""Ping host to check if connection is established.

Returns:
int: Response status code.

Raises:
ServerConnectionError: If the response status code is not successful.
"""
request = Request(method=Method.GET, endpoint="/_api/collection")
resp = await self.send_request(request)
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request, "Bad server response.")
return resp.status_code

@abstractmethod
async def send_request(self, request: Request) -> Response: # pragma: no cover
"""Send an HTTP request to the ArangoDB server.

Args:
request (Request): HTTP request.

Returns:
Response: HTTP response.
"""
raise NotImplementedError


class BasicConnection(BaseConnection):
"""Connection to a specific ArangoDB database.

Allows for basic authentication to be used (username and password).

Args:
sessions (list): List of client sessions.
host_resolver (HostResolver): Host resolver.
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
auth (Auth | None): Authentication information.
"""

def __init__(
self,
sessions: List[Any],
host_resolver: HostResolver,
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
auth: Optional[Auth] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._auth = auth

async def send_request(self, request: Request) -> Response:
"""Send an HTTP request to the ArangoDB server."""
if request.data is not None and self._compression.needs_compression(
request.data
):
request.data = self._compression.compress(request.data)
request.headers["content-encoding"] = self._compression.content_encoding()

accept_encoding: str | None = self._compression.accept_encoding()
if accept_encoding is not None:
request.headers["accept-encoding"] = accept_encoding

if self._auth:
request.auth = self._auth

return await self.process_request(request)
Loading
Loading