from http import HTTPStatus
from ssl import SSLContext
from typing import Tuple, cast
from .._exceptions import ProxyError
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger, url_to_origin
from .base import AsyncByteStream
from .connection import AsyncHTTPConnection
from .connection_pool import AsyncConnectionPool, ResponseByteStream
logger = get_logger(__name__)
def get_reason_phrase(status_code: int) -> str:
try:
return HTTPStatus(status_code).phrase
except ValueError:
return ""
def merge_headers(
default_headers: Headers = None, override_headers: Headers = None
) -> Headers:
"""
Append default_headers and override_headers, de-duplicating if a key existing in
both cases.
"""
default_headers = [] if default_headers is None else default_headers
override_headers = [] if override_headers is None else override_headers
has_override = set([key.lower() for key, value in override_headers])
default_headers = [
(key, value)
for key, value in default_headers
if key.lower() not in has_override
]
return default_headers + override_headers
[docs]class AsyncHTTPProxy(AsyncConnectionPool):
"""
A connection pool for making HTTP requests via an HTTP proxy.
Parameters
----------
proxy_url:
The URL of the proxy service as a 4-tuple of (scheme, host, port, path).
proxy_headers:
A list of proxy headers to include.
proxy_mode:
A proxy mode to operate in. May be "DEFAULT", "FORWARD_ONLY", or "TUNNEL_ONLY".
ssl_context:
An SSL context to use for verifying connections.
max_connections:
The maximum number of concurrent connections to allow.
max_keepalive_connections:
The maximum number of connections to allow before closing keep-alive
connections.
http2:
Enable HTTP/2 support.
"""
def __init__(
self,
proxy_url: URL,
proxy_headers: Headers = None,
proxy_mode: str = "DEFAULT",
ssl_context: SSLContext = None,
max_connections: int = None,
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http2: bool = False,
backend: str = "auto",
# Deprecated argument style:
max_keepalive: int = None,
):
assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY")
self.proxy_origin = url_to_origin(proxy_url)
self.proxy_headers = [] if proxy_headers is None else proxy_headers
self.proxy_mode = proxy_mode
super().__init__(
ssl_context=ssl_context,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http2=http2,
backend=backend,
max_keepalive=max_keepalive,
)
async def arequest(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: AsyncByteStream = None,
ext: dict = None,
) -> Tuple[int, Headers, AsyncByteStream, dict]:
if self._keepalive_expiry is not None:
await self._keepalive_sweep()
if (
self.proxy_mode == "DEFAULT" and url[0] == b"http"
) or self.proxy_mode == "FORWARD_ONLY":
# By default HTTP requests should be forwarded.
logger.trace(
"forward_request proxy_origin=%r proxy_headers=%r method=%r url=%r",
self.proxy_origin,
self.proxy_headers,
method,
url,
)
return await self._forward_request(
method, url, headers=headers, stream=stream, ext=ext
)
else:
# By default HTTPS should be tunnelled.
logger.trace(
"tunnel_request proxy_origin=%r proxy_headers=%r method=%r url=%r",
self.proxy_origin,
self.proxy_headers,
method,
url,
)
return await self._tunnel_request(
method, url, headers=headers, stream=stream, ext=ext
)
async def _forward_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: AsyncByteStream = None,
ext: dict = None,
) -> Tuple[int, Headers, AsyncByteStream, dict]:
"""
Forwarded proxy requests include the entire URL as the HTTP target,
rather than just the path.
"""
ext = {} if ext is None else ext
timeout = cast(TimeoutDict, ext.get("timeout", {}))
origin = self.proxy_origin
connection = await self._get_connection_from_pool(origin)
if connection is None:
connection = AsyncHTTPConnection(
origin=origin, http2=self._http2, ssl_context=self._ssl_context
)
await self._add_to_pool(connection, timeout)
# Issue a forwarded proxy request...
# GET https://www.example.org/path HTTP/1.1
# [proxy headers]
# [headers]
scheme, host, port, path = url
if port is None:
target = b"%b://%b%b" % (scheme, host, path)
else:
target = b"%b://%b:%d%b" % (scheme, host, port, path)
url = self.proxy_origin + (target,)
headers = merge_headers(self.proxy_headers, headers)
(status_code, headers, stream, ext) = await connection.arequest(
method, url, headers=headers, stream=stream, ext=ext
)
wrapped_stream = ResponseByteStream(
stream, connection=connection, callback=self._response_closed
)
return status_code, headers, wrapped_stream, ext
async def _tunnel_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: AsyncByteStream = None,
ext: dict = None,
) -> Tuple[int, Headers, AsyncByteStream, dict]:
"""
Tunnelled proxy requests require an initial CONNECT request to
establish the connection, and then send regular requests.
"""
ext = {} if ext is None else ext
timeout = cast(TimeoutDict, ext.get("timeout", {}))
origin = url_to_origin(url)
connection = await self._get_connection_from_pool(origin)
if connection is None:
scheme, host, port = origin
# First, create a connection to the proxy server
proxy_connection = AsyncHTTPConnection(
origin=self.proxy_origin,
http2=self._http2,
ssl_context=self._ssl_context,
)
# Issue a CONNECT request...
# CONNECT www.example.org:80 HTTP/1.1
# [proxy-headers]
target = b"%b:%d" % (host, port)
connect_url = self.proxy_origin + (target,)
connect_headers = [(b"Host", target), (b"Accept", b"*/*")]
connect_headers = merge_headers(connect_headers, self.proxy_headers)
try:
(
proxy_status_code,
_,
proxy_stream,
_,
) = await proxy_connection.arequest(
b"CONNECT", connect_url, headers=connect_headers, ext=ext
)
proxy_reason = get_reason_phrase(proxy_status_code)
logger.trace(
"tunnel_response proxy_status_code=%r proxy_reason=%r ",
proxy_status_code,
proxy_reason,
)
# Read the response data without closing the socket
async for _ in proxy_stream:
pass
# See if the tunnel was successfully established.
if proxy_status_code < 200 or proxy_status_code > 299:
msg = "%d %s" % (proxy_status_code, proxy_reason)
raise ProxyError(msg)
# Upgrade to TLS if required
# We assume the target speaks TLS on the specified port
if scheme == b"https":
await proxy_connection.start_tls(host, timeout)
except Exception as exc:
await proxy_connection.aclose()
raise ProxyError(exc)
# The CONNECT request is successful, so we have now SWITCHED PROTOCOLS.
# This means the proxy connection is now unusable, and we must create
# a new one for regular requests, making sure to use the same socket to
# retain the tunnel.
connection = AsyncHTTPConnection(
origin=origin,
http2=self._http2,
ssl_context=self._ssl_context,
socket=proxy_connection.socket,
)
await self._add_to_pool(connection, timeout)
# Once the connection has been established we can send requests on
# it as normal.
(status_code, headers, stream, ext) = await connection.arequest(
method,
url,
headers=headers,
stream=stream,
ext=ext,
)
wrapped_stream = ResponseByteStream(
stream, connection=connection, callback=self._response_closed
)
return status_code, headers, wrapped_stream, ext