Source code for httpcore._async.connection_pool

import warnings
from ssl import SSLContext
from typing import (
    AsyncIterator,
    Callable,
    Dict,
    List,
    Optional,
    Set,
    Tuple,
    Union,
    cast,
)

from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore
from .._backends.base import lookup_async_backend
from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol
from .._threadlock import ThreadLock
from .._types import URL, Headers, Origin, TimeoutDict
from .._utils import get_logger, origin_to_url_string, url_to_origin
from .base import (
    AsyncByteStream,
    AsyncHTTPTransport,
    ConnectionState,
    NewConnectionRequired,
)
from .connection import AsyncHTTPConnection

logger = get_logger(__name__)


class NullSemaphore(AsyncSemaphore):
    def __init__(self) -> None:
        pass

    async def acquire(self, timeout: float = None) -> None:
        return

    async def release(self) -> None:
        return


class ResponseByteStream(AsyncByteStream):
    def __init__(
        self,
        stream: AsyncByteStream,
        connection: AsyncHTTPConnection,
        callback: Callable,
    ) -> None:
        """
        A wrapper around the response stream that we return from `.arequest()`.

        Ensures that when `stream.aclose()` is called, the connection pool
        is notified via a callback.
        """
        self.stream = stream
        self.connection = connection
        self.callback = callback

    async def __aiter__(self) -> AsyncIterator[bytes]:
        async for chunk in self.stream:
            yield chunk

    async def aclose(self) -> None:
        try:
            # Call the underlying stream close callback.
            # This will be a call to `AsyncHTTP11Connection._response_closed()`
            # or `AsyncHTTP2Stream._response_closed()`.
            await self.stream.aclose()
        finally:
            # Call the connection pool close callback.
            # This will be a call to `AsyncConnectionPool._response_closed()`.
            await self.callback(self.connection)


[docs]class AsyncConnectionPool(AsyncHTTPTransport): """ A connection pool for making HTTP requests. Parameters ---------- ssl_context: An SSL context to use for verifying connections. max_connections: The maximum number of concurrent connections to allow. max_keepalive_connections: The maximum number of connections to allow before closing keep-alive connections. keepalive_expiry: The maximum time to allow before closing a keep-alive connection. http2: Enable HTTP/2 support. uds: Path to a Unix Domain Socket to use instead of TCP sockets. local_address: Local address to connect from. Can also be used to connect using a particular address family. Using ``local_address="0.0.0.0"`` will connect using an ``AF_INET`` address (IPv4), while using ``local_address="::"`` will connect using an ``AF_INET6`` address (IPv6). retries: The maximum number of retries when trying to establish a connection. backend: A name indicating which concurrency backend to use. """ def __init__( self, ssl_context: SSLContext = None, max_connections: int = None, max_keepalive_connections: int = None, keepalive_expiry: float = None, http2: bool = False, uds: str = None, local_address: str = None, retries: int = 0, max_keepalive: int = None, backend: Union[AsyncBackend, str] = "auto", ): if max_keepalive is not None: warnings.warn( "'max_keepalive' is deprecated. Use 'max_keepalive_connections'.", DeprecationWarning, ) max_keepalive_connections = max_keepalive if isinstance(backend, str): backend = lookup_async_backend(backend) self._ssl_context = SSLContext() if ssl_context is None else ssl_context self._max_connections = max_connections self._max_keepalive_connections = max_keepalive_connections self._keepalive_expiry = keepalive_expiry self._http2 = http2 self._uds = uds self._local_address = local_address self._retries = retries self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {} self._thread_lock = ThreadLock() self._backend = backend self._next_keepalive_check = 0.0 if http2: try: import h2 # noqa: F401 except ImportError: raise ImportError( "Attempted to use http2=True, but the 'h2' " "package is not installed. Use 'pip install httpcore[http2]'." ) @property def _connection_semaphore(self) -> AsyncSemaphore: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_internal_semaphore"): if self._max_connections is not None: self._internal_semaphore = self._backend.create_semaphore( self._max_connections, exc_class=PoolTimeout ) else: self._internal_semaphore = NullSemaphore() return self._internal_semaphore @property def _connection_acquiry_lock(self) -> AsyncLock: if not hasattr(self, "_internal_connection_acquiry_lock"): self._internal_connection_acquiry_lock = self._backend.create_lock() return self._internal_connection_acquiry_lock def _create_connection( self, origin: Tuple[bytes, bytes, int], ) -> AsyncHTTPConnection: return AsyncHTTPConnection( origin=origin, http2=self._http2, uds=self._uds, ssl_context=self._ssl_context, local_address=self._local_address, retries=self._retries, backend=self._backend, ) async def arequest( self, method: bytes, url: URL, headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, ) -> Tuple[int, Headers, AsyncByteStream, dict]: if url[0] not in (b"http", b"https"): scheme = url[0].decode("latin-1") raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") if not url[1]: raise LocalProtocolError("Missing hostname in URL.") origin = url_to_origin(url) ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) await self._keepalive_sweep() connection: Optional[AsyncHTTPConnection] = None while connection is None: async with self._connection_acquiry_lock: # We get-or-create a connection as an atomic operation, to ensure # that HTTP/2 requests issued in close concurrency will end up # on the same connection. logger.trace("get_connection_from_pool=%r", origin) connection = await self._get_connection_from_pool(origin) if connection is None: connection = self._create_connection(origin=origin) logger.trace("created connection=%r", connection) await self._add_to_pool(connection, timeout=timeout) else: logger.trace("reuse connection=%r", connection) try: response = await connection.arequest( method, url, headers=headers, stream=stream, ext=ext ) except NewConnectionRequired: connection = None except Exception: # noqa: PIE786 logger.trace("remove from pool connection=%r", connection) await self._remove_from_pool(connection) raise status_code, headers, stream, ext = response wrapped_stream = ResponseByteStream( stream, connection=connection, callback=self._response_closed ) return status_code, headers, wrapped_stream, ext async def _get_connection_from_pool( self, origin: Origin ) -> Optional[AsyncHTTPConnection]: # Determine expired keep alive connections on this origin. seen_http11 = False pending_connection = None reuse_connection = None connections_to_close = set() for connection in self._connections_for_origin(origin): if connection.is_http11: seen_http11 = True if connection.state == ConnectionState.IDLE: if connection.is_socket_readable(): # If the socket is readable while the connection is idle (meaning # we don't expect the server to send any data), then the only valid # reason is that the other end has disconnected, which means we # should drop the connection too. # (For a detailed run-through of what a "readable" socket is, and # why this is the best thing for us to do here, see: # https://github.com/encode/httpx/pull/143#issuecomment-515181778) logger.trace("removing dropped idle connection=%r", connection) # IDLE connections that have been dropped should be # removed from the pool. connections_to_close.add(connection) await self._remove_from_pool(connection) else: # IDLE connections that are still maintained may # be reused. logger.trace("reusing idle http11 connection=%r", connection) reuse_connection = connection elif connection.state == ConnectionState.ACTIVE and connection.is_http2: # HTTP/2 connections may be reused. logger.trace("reusing active http2 connection=%r", connection) reuse_connection = connection elif connection.state == ConnectionState.PENDING: # Pending connections may potentially be reused. pending_connection = connection if reuse_connection is not None: # Mark the connection as READY before we return it, to indicate # that if it is HTTP/1.1 then it should not be re-acquired. reuse_connection.mark_as_ready() reuse_connection.expires_at = None elif self._http2 and pending_connection is not None and not seen_http11: # If we have a PENDING connection, and no HTTP/1.1 connections # on this origin, then we can attempt to share the connection. logger.trace("reusing pending connection=%r", connection) reuse_connection = pending_connection # Close any dropped connections. for connection in connections_to_close: await connection.aclose() return reuse_connection async def _response_closed(self, connection: AsyncHTTPConnection) -> None: remove_from_pool = False close_connection = False if connection.state == ConnectionState.CLOSED: remove_from_pool = True elif connection.state == ConnectionState.IDLE: num_connections = len(self._get_all_connections()) if ( self._max_keepalive_connections is not None and num_connections > self._max_keepalive_connections ): remove_from_pool = True close_connection = True elif self._keepalive_expiry is not None: now = await self._backend.time() connection.expires_at = now + self._keepalive_expiry if remove_from_pool: await self._remove_from_pool(connection) if close_connection: await connection.aclose() async def _keepalive_sweep(self) -> None: """ Remove any IDLE connections that have expired past their keep-alive time. """ if self._keepalive_expiry is None: return now = await self._backend.time() if now < self._next_keepalive_check: return self._next_keepalive_check = now + min(1.0, self._keepalive_expiry) connections_to_close = set() for connection in self._get_all_connections(): if ( connection.state == ConnectionState.IDLE and connection.expires_at is not None and now >= connection.expires_at ): connections_to_close.add(connection) await self._remove_from_pool(connection) for connection in connections_to_close: await connection.aclose() async def _add_to_pool( self, connection: AsyncHTTPConnection, timeout: TimeoutDict ) -> None: logger.trace("adding connection to pool=%r", connection) await self._connection_semaphore.acquire(timeout=timeout.get("pool", None)) async with self._thread_lock: self._connections.setdefault(connection.origin, set()) self._connections[connection.origin].add(connection) async def _remove_from_pool(self, connection: AsyncHTTPConnection) -> None: logger.trace("removing connection from pool=%r", connection) async with self._thread_lock: if connection in self._connections.get(connection.origin, set()): await self._connection_semaphore.release() self._connections[connection.origin].remove(connection) if not self._connections[connection.origin]: del self._connections[connection.origin] def _connections_for_origin(self, origin: Origin) -> Set[AsyncHTTPConnection]: return set(self._connections.get(origin, set())) def _get_all_connections(self) -> Set[AsyncHTTPConnection]: connections: Set[AsyncHTTPConnection] = set() for connection_set in self._connections.values(): connections |= connection_set return connections async def aclose(self) -> None: connections = self._get_all_connections() for connection in connections: await self._remove_from_pool(connection) # Close all connections for connection in connections: await connection.aclose() async def get_connection_info(self) -> Dict[str, List[str]]: """ Returns a dict of origin URLs to a list of summary strings for each connection. """ await self._keepalive_sweep() stats = {} for origin, connections in self._connections.items(): stats[origin_to_url_string(origin)] = sorted( [connection.info() for connection in connections] ) return stats