Skip to content

Target session attr (2) #987

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 26 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,93 @@ def tearDown(self):
self.con = None
finally:
super().tearDown()


class HotStandbyTestCase(ClusterTestCase):

@classmethod
def setup_cluster(cls):
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
cls.start_cluster(
cls.master_cluster,
server_settings={
'max_wal_senders': 10,
'wal_level': 'hot_standby'
}
)

con = None

try:
con = cls.loop.run_until_complete(
cls.master_cluster.connect(
database='postgres', user='postgres', loop=cls.loop))

cls.loop.run_until_complete(
con.execute('''
CREATE ROLE replication WITH LOGIN REPLICATION
'''))

cls.master_cluster.trust_local_replication_by('replication')

conn_spec = cls.master_cluster.get_connection_spec()

cls.standby_cluster = cls.new_cluster(
pg_cluster.HotStandbyCluster,
cluster_kwargs={
'master': conn_spec,
'replication_user': 'replication'
}
)
cls.start_cluster(
cls.standby_cluster,
server_settings={
'hot_standby': True
}
)

finally:
if con is not None:
cls.loop.run_until_complete(con.close())

@classmethod
def get_cluster_connection_spec(cls, cluster, kwargs={}):
conn_spec = cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec

@classmethod
def get_connection_spec(cls, kwargs={}):
primary_spec = cls.get_cluster_connection_spec(
cls.master_cluster, kwargs
)
standby_spec = cls.get_cluster_connection_spec(
cls.standby_cluster, kwargs
)
return {
'host': [primary_spec['host'], standby_spec['host']],
'port': [primary_spec['port'], standby_spec['port']],
'database': primary_spec['database'],
'user': primary_spec['user'],
**kwargs
}

@classmethod
def connect_primary(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)

@classmethod
def connect_standby(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(
cls.standby_cluster,
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)
2 changes: 1 addition & 1 deletion asyncpg/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def init(self, **settings):
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))

if self._pg_version <= (11, 0):
if self._pg_version < (12, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
Expand Down
122 changes: 114 additions & 8 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import pathlib
import platform
import random
import re
import socket
import ssl as ssl_module
Expand Down Expand Up @@ -56,6 +57,7 @@ def parse(cls, sslmode):
'direct_tls',
'connect_timeout',
'server_settings',
'target_session_attrs',
])


Expand Down Expand Up @@ -260,7 +262,8 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:

def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, connect_timeout, server_settings):
direct_tls, connect_timeout, server_settings,
target_session_attrs):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -607,10 +610,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
'server_settings is expected to be None or '
'a Dict[str, str]')

if target_session_attrs is None:

target_session_attrs = os.getenv(
"PGTARGETSESSIONATTRS", SessionAttribute.any
)
try:

target_session_attrs = SessionAttribute(target_session_attrs)
except ValueError as exc:
raise exceptions.InterfaceError(
"target_session_attrs is expected to be one of "
"{!r}"
", got {!r}".format(
SessionAttribute.__members__.values, target_session_attrs
)
) from exc

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
connect_timeout=connect_timeout, server_settings=server_settings)
connect_timeout=connect_timeout, server_settings=server_settings,
target_session_attrs=target_session_attrs)

return addrs, params

Expand All @@ -620,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings):

ssl, direct_tls, server_settings,
target_session_attrs):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -649,7 +670,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
connect_timeout=timeout, server_settings=server_settings)
connect_timeout=timeout, server_settings=server_settings,
target_session_attrs=target_session_attrs)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down Expand Up @@ -882,18 +904,84 @@ async def __connect_addr(
return con


class SessionAttribute(str, enum.Enum):
any = 'any'
primary = 'primary'
standby = 'standby'
prefer_standby = 'prefer-standby'
read_write = "read-write"
read_only = "read-only"


def _accept_in_hot_standby(should_be_in_hot_standby: bool):
"""
If the server didn't report "in_hot_standby" at startup, we must determine
the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
If the server allows a connection and states it is in recovery it must
be a replica/standby server.
"""
async def can_be_used(connection):
settings = connection.get_settings()
hot_standby_status = getattr(settings, 'in_hot_standby', None)
if hot_standby_status is not None:
is_in_hot_standby = hot_standby_status == 'on'
else:
is_in_hot_standby = await connection.fetchval(
"SELECT pg_catalog.pg_is_in_recovery()"
)
return is_in_hot_standby == should_be_in_hot_standby

return can_be_used


def _accept_read_only(should_be_read_only: bool):
"""
Verify the server has not set default_transaction_read_only=True
"""
async def can_be_used(connection):
settings = connection.get_settings()
is_readonly = getattr(settings, 'default_transaction_read_only', 'off')

if is_readonly == "on":
return should_be_read_only

return await _accept_in_hot_standby(should_be_read_only)(connection)
return can_be_used


async def _accept_any(_):
return True


target_attrs_check = {
SessionAttribute.any: _accept_any,
SessionAttribute.primary: _accept_in_hot_standby(False),
SessionAttribute.standby: _accept_in_hot_standby(True),
SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
SessionAttribute.read_write: _accept_read_only(False),
SessionAttribute.read_only: _accept_read_only(True),
}


async def _can_use_connection(connection, attr: SessionAttribute):
can_use = target_attrs_check[attr]
return await can_use(connection)


async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
target_attr = params.target_session_attrs

candidates = []
chosen_connection = None
last_error = None
addr = None
for addr in addrs:
before = time.monotonic()
try:
return await _connect_addr(
conn = await _connect_addr(
addr=addr,
loop=loop,
timeout=timeout,
Expand All @@ -902,12 +990,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
connection_class=connection_class,
record_class=record_class,
)
candidates.append(conn)
if await _can_use_connection(conn, target_attr):
chosen_connection = conn
break
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
finally:
timeout -= time.monotonic() - before
else:
if target_attr == SessionAttribute.prefer_standby and candidates:
chosen_connection = random.choice(candidates)

await asyncio.gather(
(c.close() for c in candidates if c is not chosen_connection),
return_exceptions=True
)

if chosen_connection:
return chosen_connection

raise last_error
raise last_error or exceptions.TargetServerAttributeNotMatched(
'None of the hosts match the target attribute requirement '
'{!r}'.format(target_attr)
)


async def _cancel(*, loop, addr, params: _ConnectionParameters,
Expand Down
20 changes: 19 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,8 @@ async def connect(dsn=None, *,
direct_tls=False,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None):
server_settings=None,
target_session_attrs=None):
r"""A coroutine to establish a connection to a PostgreSQL server.

The connection parameters may be specified either as a connection
Expand Down Expand Up @@ -2003,6 +2004,22 @@ async def connect(dsn=None, *,
this connection object. Must be a subclass of
:class:`~asyncpg.Record`.

:param SessionAttribute target_session_attrs:
If specified, check that the host has the correct attribute.
Can be one of:
"any": the first successfully connected host
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also read-write and read-only.

"primary": the host must NOT be in hot standby mode
"standby": the host must be in hot standby mode
"read-write": the host must allow writes
"read-only": the host most NOT allow writes
"prefer-standby": first try to find a standby host, but if
none of the listed hosts is a standby server,
return any of them.

If not specified will try to use PGTARGETSESSIONATTRS
from the environment.
Defaults to "any" if no value is set.

:return: A :class:`~asyncpg.connection.Connection` instance.

Example:
Expand Down Expand Up @@ -2109,6 +2126,7 @@ async def connect(dsn=None, *,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
)


Expand Down
6 changes: 5 additions & 1 deletion asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError')
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -244,6 +244,10 @@ class ProtocolError(InternalClientError):
"""Unexpected condition in the handling of PostgreSQL protocol input."""


class TargetServerAttributeNotMatched(InternalClientError):
"""Could not find a host that satisfies the target attribute requirement"""


class OutdatedSchemaCacheError(InternalClientError):
"""A value decoding error caused by a schema change before row fetching."""

Expand Down
Loading