diff --git a/.gitignore b/.gitignore index 21286094..def5770d 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ docs/_build /.eggs /.vscode /.mypy_cache +.venv diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 4d1a3f7d..f36851b3 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -10,6 +10,9 @@ import enum import functools import getpass +import inspect +import logging +import random import os import pathlib import platform @@ -23,12 +26,10 @@ import typing import urllib.parse import warnings -import inspect -from . import compat -from . import exceptions -from . import protocol +from . import compat, exceptions, protocol +logger = logging.getLogger(__name__) class SSLMode(enum.IntEnum): disable = 0 @@ -882,21 +883,24 @@ async def __connect_addr( return con -async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): +async def _connect(*, loop, timeout, connection_class, record_class, connect_timeout=60, **kwargs): if loop is None: loop = asyncio.get_event_loop() - addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs) + addrs, params, config = _parse_connect_arguments(timeout=connect_timeout, **kwargs) - last_error = None + random.shuffle(addrs) + last_error = ConnectionError(f"Can't connect to all hosts {addrs}") addr = None for addr in addrs: + if timeout <= 0: + raise ConnectionError("Timeout") before = time.monotonic() try: return await _connect_addr( addr=addr, loop=loop, - timeout=timeout, + timeout=connect_timeout, params=params, config=config, connection_class=connection_class, @@ -904,6 +908,7 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): ) except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex + logger.warning("Can't connect to %s: %s", addr, ex, exc_info=True) finally: timeout -= time.monotonic() - before diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 73cb6e66..91667a28 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1784,6 +1784,7 @@ async def connect(dsn=None, *, database=None, loop=None, timeout=60, + connect_timeout=60, statement_cache_size=100, max_cached_statement_lifetime=300, max_cacheable_statement_size=1024 * 15, @@ -2104,6 +2105,7 @@ async def connect(dsn=None, *, ssl=ssl, direct_tls=direct_tls, database=database, + connect_timeout=connect_timeout, server_settings=server_settings, command_timeout=command_timeout, statement_cache_size=statement_cache_size,