From 8ea7d84e814a73e81cba4b4233adfcd8921d9ca8 Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Wed, 11 Feb 2026 20:18:10 +0000 Subject: [PATCH 1/4] feat(spanner): add ClientContext support to options This change adds support for ClientContext in Options and ensures it is propagated to ExecuteSql, Read, Commit, and BeginTransaction requests. It aligns with go/spanner-client-scoped-session-state design. ClientContext allows passing opaque, RPC-scoped side-channel information (like application-level user context) to Spanner. This implementation supports setting ClientContext at the Client, Database, and Request levels, with request-level options taking precedence. Key changes: - Added ClientContext to types/spanner.py and exposed it. - Updated Client.__init__ to accept a default client_context. - Added helpers for merging ClientContext with correct precedence. - Updated Snapshot, Transaction, Batch, and Database wrappers to propagate the context. - Added comprehensive unit tests in tests/unit/test_client_context.py. --- google/cloud/spanner_v1/__init__.py | 2 + google/cloud/spanner_v1/_helpers.py | 70 +++++ google/cloud/spanner_v1/batch.py | 46 ++- google/cloud/spanner_v1/client.py | 13 + google/cloud/spanner_v1/database.py | 55 +++- google/cloud/spanner_v1/snapshot.py | 57 +++- google/cloud/spanner_v1/transaction.py | 35 ++- google/cloud/spanner_v1/types/__init__.py | 2 + google/cloud/spanner_v1/types/spanner.py | 26 ++ tests/unit/test_client_context.py | 327 ++++++++++++++++++++++ 10 files changed, 605 insertions(+), 28 deletions(-) create mode 100644 tests/unit/test_client_context.py diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 4f77269bb2..cd5b8ae371 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -38,6 +38,7 @@ from .types.spanner import BatchWriteRequest from .types.spanner import BatchWriteResponse from .types.spanner import BeginTransactionRequest +from .types.spanner import ClientContext from .types.spanner import CommitRequest from .types.spanner import CreateSessionRequest from .types.spanner import DeleteSessionRequest @@ -110,6 +111,7 @@ "BatchWriteRequest", "BatchWriteResponse", "BeginTransactionRequest", + "ClientContext", "CommitRequest", "CommitResponse", "CreateSessionRequest", diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index a52c24e769..f45108e4fa 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -34,6 +34,8 @@ from google.cloud._helpers import _date_from_iso8601_date from google.cloud.spanner_v1.types import ExecuteSqlRequest from google.cloud.spanner_v1.types import TransactionOptions +from google.cloud.spanner_v1.types import ClientContext +from google.cloud.spanner_v1.types import RequestOptions from google.cloud.spanner_v1.data_types import JsonObject, Interval from google.cloud.spanner_v1.request_id_header import ( with_request_id, @@ -191,6 +193,74 @@ def _merge_query_options(base, merge): return combined +def _merge_client_context(base, merge): + """Merge higher precedence ClientContext with current ClientContext. + + :type base: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` or None + :param base: The current ClientContext that is intended for use. + + :type merge: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` or None + :param merge: + The ClientContext that has a higher priority than base. These options + should overwrite the fields in base. + + :rtype: :class:`~google.cloud.spanner_v1.types.ClientContext` + or None + :returns: + ClientContext object formed by merging the two given ClientContexts. + """ + if base is None and merge is None: + return None + + combined = base or ClientContext() + if type(combined) is dict: + combined = ClientContext(combined) + + merge = merge or ClientContext() + if type(merge) is dict: + merge = ClientContext(merge) + + type(combined).pb(combined).MergeFrom(type(merge).pb(merge)) + if not combined.secure_context: + return None + return combined + + +def _merge_request_options(request_options, client_context): + """Merge RequestOptions and ClientContext. + + :type request_options: :class:`~google.cloud.spanner_v1.types.RequestOptions` + or :class:`dict` or None + :param request_options: The current RequestOptions that is intended for use. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` or None + :param client_context: + The ClientContext to merge into request_options. + + :rtype: :class:`~google.cloud.spanner_v1.types.RequestOptions` + or None + :returns: + RequestOptions object formed by merging the given ClientContext. + """ + if request_options is None and client_context is None: + return None + + if request_options is None: + request_options = RequestOptions() + elif type(request_options) is dict: + request_options = RequestOptions(request_options) + + if client_context: + request_options.client_context = _merge_client_context( + client_context, request_options.client_context + ) + + return request_options + + def _assert_numeric_precision_and_scale(value): """ Asserts that input numeric field is within Spanner supported range. diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index e70d214783..635c842613 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -27,6 +27,8 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _merge_Transaction_Options, + _merge_client_context, + _merge_request_options, AtomicCounter, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call @@ -36,6 +38,7 @@ from google.cloud.spanner_v1._helpers import _check_rst_stream_error from google.api_core.exceptions import InternalServerError from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.types import ClientContext import time DEFAULT_RETRY_TIMEOUT_SECS = 30 @@ -46,9 +49,14 @@ class _BatchBase(_SessionWrapper): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch. """ - def __init__(self, session): + def __init__(self, session, client_context=None): super(_BatchBase, self).__init__(session) self._mutations: List[Mutation] = [] @@ -58,6 +66,13 @@ def __init__(self, session): """Timestamp at which the batch was successfully committed.""" self.commit_stats: Optional[CommitResponse.CommitStats] = None + if client_context is not None: + if type(client_context) is dict: + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + def insert(self, table, columns, values): """Insert one or more new table rows. @@ -226,10 +241,14 @@ def commit( txn_options, ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag # Request tags are not supported for commit requests. @@ -316,13 +335,25 @@ class MutationGroups(_SessionWrapper): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this mutation group. """ - def __init__(self, session): + def __init__(self, session, client_context=None): super(MutationGroups, self).__init__(session) self._mutation_groups: List[MutationGroup] = [] self.committed: bool = False + if client_context is not None: + if type(client_context) is dict: + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + def group(self): """Returns a new `MutationGroup` to which mutations can be added.""" mutation_group = BatchWriteRequest.MutationGroup() @@ -364,10 +395,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) with trace_call( name="CloudSpanner.batch_write", diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index 5f72905616..49f89c807d 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -48,6 +48,7 @@ from google.cloud.spanner_v1 import __version__ from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import DefaultTransactionOptions +from google.cloud.spanner_v1.types import ClientContext from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.instance import Instance @@ -184,6 +185,10 @@ class Client(ClientWithProject): :param disable_builtin_metrics: (Optional) Default False. Set to True to disable the Spanner built-in metrics collection and exporting. + :type client_context: :class:`~google.cloud.spanner_v1.types.RequestOptions.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made by this client. + :raises: :class:`ValueError ` if both ``read_only`` and ``admin`` are :data:`True` """ @@ -210,6 +215,7 @@ def __init__( default_transaction_options: Optional[DefaultTransactionOptions] = None, experimental_host=None, disable_builtin_metrics=False, + client_context=None, ): self._emulator_host = _get_spanner_emulator_host() self._experimental_host = experimental_host @@ -247,6 +253,13 @@ def __init__( # Environment flag config has higher precedence than application config. self._query_options = _merge_query_options(query_options, env_query_options) + if client_context is not None: + if type(client_context) is dict: + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + if self._emulator_host is not None and ( "http://" in self._emulator_host or "https://" in self._emulator_host ): diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 4977a4abb9..156f99acb1 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -946,6 +946,7 @@ def snapshot(self, **kw): :param kw: Passed through to :class:`~google.cloud.spanner_v1.snapshot.Snapshot` constructor. + Now includes ``client_context``. :rtype: :class:`~google.cloud.spanner_v1.database.SnapshotCheckout` :returns: new wrapper @@ -959,6 +960,7 @@ def batch( exclude_txn_from_change_streams=False, isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, read_lock_mode=TransactionOptions.ReadWrite.ReadLockMode.READ_LOCK_MODE_UNSPECIFIED, + client_context=None, **kw, ): """Return an object which wraps a batch. @@ -996,6 +998,11 @@ def batch( :param read_lock_mode: (Optional) Sets the read lock mode for this transaction. This overrides any default read lock mode set for the client. + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch. + :rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout` :returns: new wrapper """ @@ -1007,19 +1014,25 @@ def batch( exclude_txn_from_change_streams, isolation_level, read_lock_mode, + client_context=client_context, **kw, ) - def mutation_groups(self): + def mutation_groups(self, client_context=None): """Return an object which wraps a mutation_group. The wrapper *must* be used as a context manager, with the mutation group as the value returned by the wrapper. + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this mutation group. + :rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout` :returns: new wrapper """ - return MutationGroupsCheckout(self) + return MutationGroupsCheckout(self, client_context=client_context) def batch_snapshot( self, @@ -1027,6 +1040,7 @@ def batch_snapshot( exact_staleness=None, session_id=None, transaction_id=None, + client_context=None, ): """Return an object which wraps a batch read / query. @@ -1043,6 +1057,11 @@ def batch_snapshot( :type transaction_id: str :param transaction_id: id of the transaction + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch snapshot. + :rtype: :class:`~google.cloud.spanner_v1.database.BatchSnapshot` :returns: new wrapper """ @@ -1052,6 +1071,7 @@ def batch_snapshot( exact_staleness=exact_staleness, session_id=session_id, transaction_id=transaction_id, + client_context=client_context, ) def run_in_transaction(self, func, *args, **kw): @@ -1080,6 +1100,8 @@ def run_in_transaction(self, func, *args, **kw): the DDL option `allow_txn_exclusion` being false or unset. "isolation_level" sets the isolation level for the transaction. "read_lock_mode" sets the read lock mode for the transaction. + "client_context" (Optional) Client context to use for all requests made + by this transaction. :rtype: Any :returns: The return value of ``func``. @@ -1391,6 +1413,11 @@ class BatchCheckout(object): :param max_commit_delay: (Optional) The amount of latency this request is willing to incur in order to improve throughput. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch. """ def __init__( @@ -1401,6 +1428,7 @@ def __init__( exclude_txn_from_change_streams=False, isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, read_lock_mode=TransactionOptions.ReadWrite.ReadLockMode.READ_LOCK_MODE_UNSPECIFIED, + client_context=None, **kw, ): self._database: Database = database @@ -1417,6 +1445,7 @@ def __init__( self._exclude_txn_from_change_streams = exclude_txn_from_change_streams self._isolation_level = isolation_level self._read_lock_mode = read_lock_mode + self._client_context = client_context self._kw = kw def __enter__(self): @@ -1433,7 +1462,9 @@ def __enter__(self): event_attributes={"id": self._session.session_id}, ) - batch = self._batch = Batch(session=self._session) + batch = self._batch = Batch( + session=self._session, client_context=self._client_context + ) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag @@ -1478,18 +1509,24 @@ class MutationGroupsCheckout(object): :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: database to use + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this mutation group. """ - def __init__(self, database): + def __init__(self, database, client_context=None): self._database: Database = database self._session: Optional[Session] = None + self._client_context = client_context def __enter__(self): """Begin ``with`` block.""" transaction_type = TransactionType.READ_WRITE self._session = self._database.sessions_manager.get_session(transaction_type) - return MutationGroups(session=self._session) + return MutationGroups(session=self._session, client_context=self._client_context) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" @@ -1555,6 +1592,11 @@ class BatchSnapshot(object): :type exact_staleness: :class:`datetime.timedelta` :param exact_staleness: Execute all reads at a timestamp that is ``exact_staleness`` old. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch snapshot. """ def __init__( @@ -1564,6 +1606,7 @@ def __init__( exact_staleness=None, session_id=None, transaction_id=None, + client_context=None, ): self._database: Database = database @@ -1575,6 +1618,7 @@ def __init__( self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness + self._client_context = client_context @classmethod def from_dict(cls, database, mapping): @@ -1663,6 +1707,7 @@ def _get_snapshot(self): exact_staleness=self._exact_staleness, multi_use=True, transaction_id=self._transaction_id, + client_context=self._client_context, ) if self._transaction_id is None: diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index a7abcdaaa3..85481324a9 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -41,6 +41,8 @@ from google.cloud.spanner_v1._helpers import ( _make_value_pb, _merge_query_options, + _merge_client_context, + _merge_request_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, _retry, @@ -52,6 +54,7 @@ from google.cloud.spanner_v1._opentelemetry_tracing import trace_call, add_span_event from google.cloud.spanner_v1.streamed import StreamedResultSet from google.cloud.spanner_v1 import RequestOptions +from google.cloud.spanner_v1.types import ClientContext from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture from google.cloud.spanner_v1.types import MultiplexedSessionPrecommitToken @@ -196,14 +199,26 @@ class _SnapshotBase(_SessionWrapper): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform transaction operations. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this transaction. """ _read_only: bool = True _multi_use: bool = False - def __init__(self, session): + def __init__(self, session, client_context=None): super().__init__(session) + if client_context is not None: + if type(client_context) is dict: + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + # Counts for execute SQL requests and total read requests (including # execute SQL requests). Used to provide sequence numbers for # :class:`google.cloud.spanner_v1.types.ExecuteSqlRequest` and to @@ -348,10 +363,13 @@ def read( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) if self._read_only: # Transaction tags are not supported for read only transactions. @@ -543,10 +561,14 @@ def execute_sql( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + if self._read_only: # Transaction tags are not supported for read only transactions. request_options.transaction_tag = None @@ -923,11 +945,22 @@ def _begin_transaction( "mutation_key": mutation, } - if transaction_tag: - begin_request_kwargs["request_options"] = RequestOptions( - transaction_tag=transaction_tag + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + if client_context: + begin_request_kwargs["request_options"] = _merge_request_options( + begin_request_kwargs.get("request_options"), client_context ) + if transaction_tag: + request_options = begin_request_kwargs.get("request_options") + if request_options is None: + request_options = RequestOptions(transaction_tag=transaction_tag) + else: + request_options.transaction_tag = transaction_tag + begin_request_kwargs["request_options"] = request_options + with trace_call( name=f"CloudSpanner.{type(self).__name__}.begin", session=session, @@ -1099,6 +1132,11 @@ class Snapshot(_SnapshotBase): context of a read-only transaction, used to ensure isolation / consistency. Incompatible with ``max_staleness`` and ``min_read_timestamp``. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this snapshot. """ def __init__( @@ -1110,8 +1148,9 @@ def __init__( exact_staleness=None, multi_use=False, transaction_id=None, + client_context=None, ): - super(Snapshot, self).__init__(session) + super(Snapshot, self).__init__(session, client_context=client_context) opts = [read_timestamp, min_read_timestamp, max_staleness, exact_staleness] flagged = [opt for opt in opts if opt is not None] diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 413ac0af1f..0b0dc7dd51 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -25,6 +25,8 @@ _retry, _check_rst_stream_error, _merge_Transaction_Options, + _merge_client_context, + _merge_request_options, ) from google.cloud.spanner_v1 import ( CommitRequest, @@ -54,6 +56,11 @@ class Transaction(_SnapshotBase, _BatchBase): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this transaction. + :raises ValueError: if session has an existing transaction """ @@ -69,8 +76,8 @@ class Transaction(_SnapshotBase, _BatchBase): _multi_use: bool = True _read_only: bool = False - def __init__(self, session): - super(Transaction, self).__init__(session) + def __init__(self, session, client_context=None): + super(Transaction, self).__init__(session, client_context=client_context) self.rolled_back: bool = False # If this transaction is used to retry a previous aborted transaction with a @@ -266,10 +273,14 @@ def commit( else: raise ValueError("Transaction has not begun.") + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + if self.transaction_tag is not None: request_options.transaction_tag = self.transaction_tag @@ -479,10 +490,14 @@ def execute_update( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag trace_attributes = { @@ -632,10 +647,14 @@ def batch_update( self._execute_sql_request_count + 1, ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag trace_attributes = { diff --git a/google/cloud/spanner_v1/types/__init__.py b/google/cloud/spanner_v1/types/__init__.py index 5a7ded16dd..5f1e9274b6 100644 --- a/google/cloud/spanner_v1/types/__init__.py +++ b/google/cloud/spanner_v1/types/__init__.py @@ -52,6 +52,7 @@ BatchWriteRequest, BatchWriteResponse, BeginTransactionRequest, + ClientContext, CommitRequest, CreateSessionRequest, DeleteSessionRequest, @@ -110,6 +111,7 @@ "BatchWriteRequest", "BatchWriteResponse", "BeginTransactionRequest", + "ClientContext", "CommitRequest", "CreateSessionRequest", "DeleteSessionRequest", diff --git a/google/cloud/spanner_v1/types/spanner.py b/google/cloud/spanner_v1/types/spanner.py index 6e363088de..c7085cda13 100644 --- a/google/cloud/spanner_v1/types/spanner.py +++ b/google/cloud/spanner_v1/types/spanner.py @@ -43,6 +43,7 @@ "ListSessionsResponse", "DeleteSessionRequest", "RequestOptions", + "ClientContext", "DirectedReadOptions", "ExecuteSqlRequest", "ExecuteBatchDmlRequest", @@ -395,6 +396,31 @@ class Priority(proto.Enum): proto.STRING, number=3, ) + client_context: ClientContext = proto.Field( + proto.MESSAGE, + number=4, + message="ClientContext", + ) + + +class ClientContext(proto.Message): + r"""Container for various pieces of client-owned context + attached to a request. + + Attributes: + secure_context (MutableMapping[str, google.protobuf.struct_pb2.Value]): + Optional. Map of parameter name to value for this request. + These values will be returned by any SECURE_CONTEXT() calls + invoked by this request (e.g., by queries against + Parameterized Secure Views). + """ + + secure_context: MutableMapping[str, struct_pb2.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=1, + message=struct_pb2.Value, + ) class DirectedReadOptions(proto.Message): diff --git a/tests/unit/test_client_context.py b/tests/unit/test_client_context.py new file mode 100644 index 0000000000..d11850bd2b --- /dev/null +++ b/tests/unit/test_client_context.py @@ -0,0 +1,327 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock +from google.protobuf import struct_pb2 +from google.cloud.spanner_v1.types import ClientContext, RequestOptions +from google.cloud.spanner_v1._helpers import _merge_client_context, _merge_request_options + +class TestClientContext(unittest.TestCase): + + def test__merge_client_context_both_none(self): + self.assertIsNone(_merge_client_context(None, None)) + + def test__merge_client_context_base_none(self): + merge = ClientContext(secure_context={"a": struct_pb2.Value(string_value="A")}) + result = _merge_client_context(None, merge) + self.assertEqual(result.secure_context["a"], "A") + + def test__merge_client_context_merge_none(self): + base = ClientContext(secure_context={"a": struct_pb2.Value(string_value="A")}) + result = _merge_client_context(base, None) + self.assertEqual(result.secure_context["a"], "A") + + def test__merge_client_context_both_set(self): + base = ClientContext(secure_context={ + "a": struct_pb2.Value(string_value="A"), + "b": struct_pb2.Value(string_value="B1") + }) + merge = ClientContext(secure_context={ + "b": struct_pb2.Value(string_value="B2"), + "c": struct_pb2.Value(string_value="C") + }) + result = _merge_client_context(base, merge) + self.assertEqual(result.secure_context["a"], "A") + self.assertEqual(result.secure_context["b"], "B2") + self.assertEqual(result.secure_context["c"], "C") + + def test__merge_request_options_with_client_context(self): + request_options = RequestOptions(priority=RequestOptions.Priority.PRIORITY_LOW) + client_context = ClientContext(secure_context={"a": struct_pb2.Value(string_value="A")}) + + result = _merge_request_options(request_options, client_context) + + self.assertEqual(result.priority, RequestOptions.Priority.PRIORITY_LOW) + self.assertEqual(result.client_context.secure_context["a"], "A") + + def test_client_init_with_client_context(self): + from google.cloud.spanner_v1.client import Client + + project = "PROJECT" + with mock.patch("google.cloud.spanner_v1.client._get_spanner_enable_builtin_metrics_env", return_value=False): + client_context = {"secure_context": {"a": struct_pb2.Value(string_value="A")}} + client = Client(project=project, client_context=client_context, disable_builtin_metrics=True) + + self.assertIsInstance(client._client_context, ClientContext) + self.assertEqual(client._client_context.secure_context["a"], "A") + + def test_snapshot_execute_sql_propagates_client_context(self): + from google.cloud.spanner_v1.snapshot import Snapshot + from google.cloud.spanner_v1.types import ExecuteSqlRequest + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database._directed_read_options = None + + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) + + snapshot_context = ClientContext(secure_context={"snapshot": struct_pb2.Value(string_value="from-snapshot")}) + snapshot = Snapshot(session, client_context=snapshot_context) + + with mock.patch.object(snapshot, "_get_streamed_result_set") as mocked: + snapshot.execute_sql("SELECT 1") + kwargs = mocked.call_args.kwargs + request = kwargs['request'] + self.assertIsInstance(request, ExecuteSqlRequest) + self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") + self.assertEqual(request.request_options.client_context.secure_context["snapshot"], "from-snapshot") + + def test_transaction_commit_propagates_client_context(self): + from google.cloud.spanner_v1.transaction import Transaction + from google.cloud.spanner_v1.types import CommitRequest, CommitResponse, MultiplexedSessionPrecommitToken + + session = mock.Mock(spec=["name", "_database", "is_multiplexed"]) + session.name = "session-name" + session.is_multiplexed = False + database = session._database = mock.Mock() + database.name = "projects/p/instances/i/databases/d" + database._route_to_leader_enabled = False + database.log_commit_stats = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + + client = database._instance._client = mock.Mock() + client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) + + transaction_context = ClientContext(secure_context={"txn": struct_pb2.Value(string_value="from-txn")}) + transaction = Transaction(session, client_context=transaction_context) + transaction._transaction_id = b"tx-id" + + api = database.spanner_api = mock.Mock() + + token = MultiplexedSessionPrecommitToken(seq_num=1) + response = CommitResponse(precommit_token=token) + + def side_effect(f, **kw): + return f() + + api.commit.return_value = response + + with mock.patch("google.cloud.spanner_v1.transaction._retry", side_effect=side_effect): + transaction.commit() + + args, kwargs = api.commit.call_args + request = kwargs['request'] + self.assertIsInstance(request, CommitRequest) + self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") + self.assertEqual(request.request_options.client_context.secure_context["txn"], "from-txn") + + def test_snapshot_execute_sql_request_level_override(self): + from google.cloud.spanner_v1.snapshot import Snapshot + from google.cloud.spanner_v1.types import ExecuteSqlRequest + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database._directed_read_options = None + + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext(secure_context={"a": struct_pb2.Value(string_value="from-client")}) + + snapshot_context = ClientContext(secure_context={"a": struct_pb2.Value(string_value="from-snapshot"), "b": struct_pb2.Value(string_value="B")}) + snapshot = Snapshot(session, client_context=snapshot_context) + + request_options = RequestOptions(client_context=ClientContext(secure_context={"a": struct_pb2.Value(string_value="from-request")})) + + with mock.patch.object(snapshot, "_get_streamed_result_set") as mocked: + snapshot.execute_sql("SELECT 1", request_options=request_options) + kwargs = mocked.call_args.kwargs + request = kwargs['request'] + self.assertEqual(request.request_options.client_context.secure_context["a"], "from-request") + self.assertEqual(request.request_options.client_context.secure_context["b"], "B") + + def test_batch_commit_propagates_client_context(self): + from google.cloud.spanner_v1.batch import Batch + from google.cloud.spanner_v1.types import CommitRequest, CommitResponse, TransactionOptions + from google.cloud.spanner_v1 import DefaultTransactionOptions + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.log_commit_stats = False + database.default_transaction_options = DefaultTransactionOptions() + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + client = database._instance._client = mock.Mock() + client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) + + batch_context = ClientContext(secure_context={"batch": struct_pb2.Value(string_value="from-batch")}) + batch = Batch(session, client_context=batch_context) + + api = database.spanner_api = mock.Mock() + response = CommitResponse() + api.commit.return_value = response + + batch.commit() + + args, kwargs = api.commit.call_args + request = kwargs['request'] + self.assertIsInstance(request, CommitRequest) + self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") + self.assertEqual(request.request_options.client_context.secure_context["batch"], "from-batch") + + def test_transaction_execute_update_propagates_client_context(self): + from google.cloud.spanner_v1.transaction import Transaction + from google.cloud.spanner_v1.types import ExecuteSqlRequest, ResultSet, MultiplexedSessionPrecommitToken + + session = mock.Mock(spec=["name", "_database", "_precommit_token"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) + + transaction_context = ClientContext(secure_context={"txn": struct_pb2.Value(string_value="from-txn")}) + transaction = Transaction(session, client_context=transaction_context) + transaction._transaction_id = b"tx-id" + transaction._precommit_token = MultiplexedSessionPrecommitToken(seq_num=1) + + api = database.spanner_api = mock.Mock() + response = ResultSet(precommit_token=MultiplexedSessionPrecommitToken(seq_num=2)) + + with mock.patch.object(transaction, "_execute_request", return_value=response): + transaction.execute_update("UPDATE T SET C = 1") + + args, kwargs = transaction._execute_request.call_args + request = args[1] + self.assertIsInstance(request, ExecuteSqlRequest) + self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") + self.assertEqual(request.request_options.client_context.secure_context["txn"], "from-txn") + + def test_mutation_groups_batch_write_propagates_client_context(self): + from google.cloud.spanner_v1.batch import MutationGroups + from google.cloud.spanner_v1.types import BatchWriteRequest + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database.metadata_with_request_id.return_value = [] + database._next_nth_request = 1 + + client = database._instance._client = mock.Mock() + client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) + + mg_context = ClientContext(secure_context={"mg": struct_pb2.Value(string_value="from-mg")}) + mg = MutationGroups(session, client_context=mg_context) + + api = database.spanner_api = mock.Mock() + + with mock.patch("google.cloud.spanner_v1.batch._retry", side_effect=lambda f, **kw: f()): + mg.batch_write() + + args, kwargs = api.batch_write.call_args + request = kwargs['request'] + self.assertIsInstance(request, BatchWriteRequest) + self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") + self.assertEqual(request.request_options.client_context.secure_context["mg"], "from-mg") + + def test_batch_snapshot_propagates_client_context(self): + from google.cloud.spanner_v1.database import BatchSnapshot + from google.cloud.spanner_v1.types import ExecuteSqlRequest + + database = mock.Mock() + database.name = "database-name" + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) + + batch_context = ClientContext(secure_context={"batch": struct_pb2.Value(string_value="from-batch")}) + batch_snapshot = BatchSnapshot(database, client_context=batch_context) + + session = mock.Mock(spec=["name", "_database", "session_id", "snapshot"]) + session.name = "session-name" + session.session_id = "session-id" + database.sessions_manager.get_session.return_value = session + + snapshot = mock.Mock() + session.snapshot.return_value = snapshot + + batch_snapshot.execute_sql("SELECT 1") + + session.snapshot.assert_called_once() + kwargs = session.snapshot.call_args.kwargs + self.assertEqual(kwargs["client_context"], batch_context) + + def test_database_snapshot_propagates_client_context(self): + from google.cloud.spanner_v1.database import Database, SnapshotCheckout + + instance = mock.Mock() + instance._client = mock.Mock() + instance._client._query_options = None + instance._client._client_context = None + + database = Database("db", instance) + with mock.patch("google.cloud.spanner_v1.database.SnapshotCheckout") as mocked_checkout: + client_context = {"secure_context": {"a": struct_pb2.Value(string_value="A")}} + database.snapshot(client_context=client_context) + + mocked_checkout.assert_called_once_with(database, client_context=client_context) + + def test_transaction_rollback_propagates_client_context_is_not_supported(self): + # Verify that rollback DOES NOT take client_context as it's not in RollbackRequest + from google.cloud.spanner_v1.transaction import Transaction + from google.cloud.spanner_v1.types import RollbackRequest + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + + transaction = Transaction(session) + transaction._transaction_id = b"tx-id" + + api = database.spanner_api = mock.Mock() + + transaction.rollback() + + args, kwargs = api.rollback.call_args + self.assertEqual(kwargs["session"], "session-name") + self.assertEqual(kwargs["transaction_id"], b"tx-id") + # Ensure no request_options or client_context passed to rollback + self.assertNotIn("request_options", kwargs) + +if __name__ == "__main__": + unittest.main() From 869c2520353933c464b6fb15afbb39eeab8b1536 Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Wed, 18 Feb 2026 02:51:37 +0000 Subject: [PATCH 2/4] feat: Propagate client_context in Session and update tests - Update Session.transaction to accept client_context. - Update unit tests to support client_context propagation. - Update mock objects in tests to match the expected attribute hierarchy. - Clean up nested imports in test files. --- google/cloud/spanner_v1/database.py | 8 +- google/cloud/spanner_v1/session.py | 14 +- tests/unit/spanner_dbapi/test_connection.py | 1 + tests/unit/test_backup.py | 1 + tests/unit/test_batch.py | 3 + tests/unit/test_client_context.py | 327 +++++++++++++------- tests/unit/test_database.py | 12 + tests/unit/test_instance.py | 1 + tests/unit/test_pool.py | 60 ++-- tests/unit/test_session.py | 3 + tests/unit/test_snapshot.py | 1 + tests/unit/test_spanner.py | 1 + tests/unit/test_transaction.py | 1 + 13 files changed, 280 insertions(+), 153 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 156f99acb1..8c55a1744f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1100,8 +1100,8 @@ def run_in_transaction(self, func, *args, **kw): the DDL option `allow_txn_exclusion` being false or unset. "isolation_level" sets the isolation level for the transaction. "read_lock_mode" sets the read lock mode for the transaction. - "client_context" (Optional) Client context to use for all requests made - by this transaction. + "client_context" (Optional) Client context to use for all requests + made by this transaction. :rtype: Any :returns: The return value of ``func``. @@ -1526,7 +1526,9 @@ def __enter__(self): transaction_type = TransactionType.READ_WRITE self._session = self._database.sessions_manager.get_session(transaction_type) - return MutationGroups(session=self._session, client_context=self._client_context) + return MutationGroups( + session=self._session, client_context=self._client_context + ) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index e7bc913c27..95db0f72d2 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -472,9 +472,14 @@ def batch(self): return Batch(self) - def transaction(self) -> Transaction: + def transaction(self, client_context=None) -> Transaction: """Create a transaction to perform a set of reads with shared staleness. + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this transaction. + :rtype: :class:`~google.cloud.spanner_v1.transaction.Transaction` :returns: a transaction bound to this session @@ -483,7 +488,7 @@ def transaction(self) -> Transaction: if self._session_id is None: raise ValueError("Session has not been created.") - return Transaction(self) + return Transaction(self, client_context=client_context) def run_in_transaction(self, func, *args, **kw): """Perform a unit of work in a transaction, retrying on abort. @@ -512,6 +517,8 @@ def run_in_transaction(self, func, *args, **kw): the DDL option `allow_txn_exclusion` being false or unset. "isolation_level" sets the isolation level for the transaction. "read_lock_mode" sets the read lock mode for the transaction. + "client_context" (Optional) Client context to use for all requests + made by this transaction. :rtype: Any :returns: The return value of ``func``. @@ -529,6 +536,7 @@ def run_in_transaction(self, func, *args, **kw): ) isolation_level = kw.pop("isolation_level", None) read_lock_mode = kw.pop("read_lock_mode", None) + client_context = kw.pop("client_context", None) database = self._database log_commit_stats = database.log_commit_stats @@ -554,7 +562,7 @@ def run_in_transaction(self, func, *args, **kw): previous_transaction_id: Optional[bytes] = None while True: - txn = self.transaction() + txn = self.transaction(client_context=client_context) txn.transaction_tag = transaction_tag txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams txn.isolation_level = isolation_level diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 6e8159425f..6fc844183e 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -872,6 +872,7 @@ class _Client(object): def __init__(self, project="project_id"): self.project = project self.project_name = "projects/" + self.project + self._client_context = None def instance(self, instance_id="instance_id"): return _Instance(name=instance_id, client=self) diff --git a/tests/unit/test_backup.py b/tests/unit/test_backup.py index 00621c2148..8198a283e4 100644 --- a/tests/unit/test_backup.py +++ b/tests/unit/test_backup.py @@ -679,6 +679,7 @@ class _Client(object): def __init__(self, project=TestBackup.PROJECT_ID): self.project = project self.project_name = "projects/" + self.project + self._client_context = None class _Instance(object): diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index f00a45e8a5..b4690203f6 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -806,6 +806,9 @@ class _Database(object): def __init__(self, enable_end_to_end_tracing=False): self.name = "testing" + self._instance = mock.Mock() + self._instance._client = mock.Mock() + self._instance._client._client_context = None self._route_to_leader_enabled = True if enable_end_to_end_tracing: self.observability_options = dict(enable_end_to_end_tracing=True) diff --git a/tests/unit/test_client_context.py b/tests/unit/test_client_context.py index d11850bd2b..6c95b51946 100644 --- a/tests/unit/test_client_context.py +++ b/tests/unit/test_client_context.py @@ -15,11 +15,18 @@ import unittest from unittest import mock from google.protobuf import struct_pb2 -from google.cloud.spanner_v1.types import ClientContext, RequestOptions -from google.cloud.spanner_v1._helpers import _merge_client_context, _merge_request_options +from google.cloud.spanner_v1.types import ( + ClientContext, + RequestOptions, + ExecuteSqlRequest, +) +from google.cloud.spanner_v1._helpers import ( + _merge_client_context, + _merge_request_options, +) -class TestClientContext(unittest.TestCase): +class TestClientContext(unittest.TestCase): def test__merge_client_context_both_none(self): self.assertIsNone(_merge_client_context(None, None)) @@ -34,14 +41,18 @@ def test__merge_client_context_merge_none(self): self.assertEqual(result.secure_context["a"], "A") def test__merge_client_context_both_set(self): - base = ClientContext(secure_context={ - "a": struct_pb2.Value(string_value="A"), - "b": struct_pb2.Value(string_value="B1") - }) - merge = ClientContext(secure_context={ - "b": struct_pb2.Value(string_value="B2"), - "c": struct_pb2.Value(string_value="C") - }) + base = ClientContext( + secure_context={ + "a": struct_pb2.Value(string_value="A"), + "b": struct_pb2.Value(string_value="B1"), + } + ) + merge = ClientContext( + secure_context={ + "b": struct_pb2.Value(string_value="B2"), + "c": struct_pb2.Value(string_value="C"), + } + ) result = _merge_client_context(base, merge) self.assertEqual(result.secure_context["a"], "A") self.assertEqual(result.secure_context["b"], "B2") @@ -49,27 +60,40 @@ def test__merge_client_context_both_set(self): def test__merge_request_options_with_client_context(self): request_options = RequestOptions(priority=RequestOptions.Priority.PRIORITY_LOW) - client_context = ClientContext(secure_context={"a": struct_pb2.Value(string_value="A")}) - + client_context = ClientContext( + secure_context={"a": struct_pb2.Value(string_value="A")} + ) + result = _merge_request_options(request_options, client_context) - + self.assertEqual(result.priority, RequestOptions.Priority.PRIORITY_LOW) self.assertEqual(result.client_context.secure_context["a"], "A") def test_client_init_with_client_context(self): from google.cloud.spanner_v1.client import Client - + project = "PROJECT" - with mock.patch("google.cloud.spanner_v1.client._get_spanner_enable_builtin_metrics_env", return_value=False): - client_context = {"secure_context": {"a": struct_pb2.Value(string_value="A")}} - client = Client(project=project, client_context=client_context, disable_builtin_metrics=True) - + credentials = mock.Mock(spec=["_resource_prefix__"]) + with mock.patch( + "google.auth.default", return_value=(credentials, project) + ), mock.patch( + "google.cloud.spanner_v1.client._get_spanner_enable_builtin_metrics_env", + return_value=False, + ): + client_context = { + "secure_context": {"a": struct_pb2.Value(string_value="A")} + } + client = Client( + project=project, + client_context=client_context, + disable_builtin_metrics=True, + ) + self.assertIsInstance(client._client_context, ClientContext) self.assertEqual(client._client_context.secure_context["a"], "A") def test_snapshot_execute_sql_propagates_client_context(self): from google.cloud.spanner_v1.snapshot import Snapshot - from google.cloud.spanner_v1.types import ExecuteSqlRequest session = mock.Mock(spec=["name", "_database"]) session.name = "session-name" @@ -77,25 +101,39 @@ def test_snapshot_execute_sql_propagates_client_context(self): database.name = "database-name" database._route_to_leader_enabled = False database._directed_read_options = None - + client = database._instance._client = mock.Mock() client._query_options = None - client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) - - snapshot_context = ClientContext(secure_context={"snapshot": struct_pb2.Value(string_value="from-snapshot")}) + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + snapshot_context = ClientContext( + secure_context={"snapshot": struct_pb2.Value(string_value="from-snapshot")} + ) snapshot = Snapshot(session, client_context=snapshot_context) - + with mock.patch.object(snapshot, "_get_streamed_result_set") as mocked: snapshot.execute_sql("SELECT 1") kwargs = mocked.call_args.kwargs - request = kwargs['request'] + request = kwargs["request"] self.assertIsInstance(request, ExecuteSqlRequest) - self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") - self.assertEqual(request.request_options.client_context.secure_context["snapshot"], "from-snapshot") + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["snapshot"], + "from-snapshot", + ) def test_transaction_commit_propagates_client_context(self): from google.cloud.spanner_v1.transaction import Transaction - from google.cloud.spanner_v1.types import CommitRequest, CommitResponse, MultiplexedSessionPrecommitToken + from google.cloud.spanner_v1.types import ( + CommitRequest, + CommitResponse, + MultiplexedSessionPrecommitToken, + ) session = mock.Mock(spec=["name", "_database", "is_multiplexed"]) session.name = "session-name" @@ -106,36 +144,46 @@ def test_transaction_commit_propagates_client_context(self): database.log_commit_stats = False database.with_error_augmentation.return_value = (None, mock.MagicMock()) database._next_nth_request = 1 - + client = database._instance._client = mock.Mock() - client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) - - transaction_context = ClientContext(secure_context={"txn": struct_pb2.Value(string_value="from-txn")}) + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + transaction_context = ClientContext( + secure_context={"txn": struct_pb2.Value(string_value="from-txn")} + ) transaction = Transaction(session, client_context=transaction_context) transaction._transaction_id = b"tx-id" - + api = database.spanner_api = mock.Mock() - + token = MultiplexedSessionPrecommitToken(seq_num=1) response = CommitResponse(precommit_token=token) - + def side_effect(f, **kw): return f() - + api.commit.return_value = response - - with mock.patch("google.cloud.spanner_v1.transaction._retry", side_effect=side_effect): + + with mock.patch( + "google.cloud.spanner_v1.transaction._retry", side_effect=side_effect + ): transaction.commit() - + args, kwargs = api.commit.call_args - request = kwargs['request'] + request = kwargs["request"] self.assertIsInstance(request, CommitRequest) - self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") - self.assertEqual(request.request_options.client_context.secure_context["txn"], "from-txn") + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["txn"], "from-txn" + ) def test_snapshot_execute_sql_request_level_override(self): from google.cloud.spanner_v1.snapshot import Snapshot - from google.cloud.spanner_v1.types import ExecuteSqlRequest session = mock.Mock(spec=["name", "_database"]) session.name = "session-name" @@ -143,26 +191,45 @@ def test_snapshot_execute_sql_request_level_override(self): database.name = "database-name" database._route_to_leader_enabled = False database._directed_read_options = None - + client = database._instance._client = mock.Mock() client._query_options = None - client._client_context = ClientContext(secure_context={"a": struct_pb2.Value(string_value="from-client")}) - - snapshot_context = ClientContext(secure_context={"a": struct_pb2.Value(string_value="from-snapshot"), "b": struct_pb2.Value(string_value="B")}) + client._client_context = ClientContext( + secure_context={"a": struct_pb2.Value(string_value="from-client")} + ) + + snapshot_context = ClientContext( + secure_context={ + "a": struct_pb2.Value(string_value="from-snapshot"), + "b": struct_pb2.Value(string_value="B"), + } + ) snapshot = Snapshot(session, client_context=snapshot_context) - - request_options = RequestOptions(client_context=ClientContext(secure_context={"a": struct_pb2.Value(string_value="from-request")})) - + + request_options = RequestOptions( + client_context=ClientContext( + secure_context={"a": struct_pb2.Value(string_value="from-request")} + ) + ) + with mock.patch.object(snapshot, "_get_streamed_result_set") as mocked: snapshot.execute_sql("SELECT 1", request_options=request_options) kwargs = mocked.call_args.kwargs - request = kwargs['request'] - self.assertEqual(request.request_options.client_context.secure_context["a"], "from-request") - self.assertEqual(request.request_options.client_context.secure_context["b"], "B") + request = kwargs["request"] + self.assertEqual( + request.request_options.client_context.secure_context["a"], + "from-request", + ) + self.assertEqual( + request.request_options.client_context.secure_context["b"], "B" + ) def test_batch_commit_propagates_client_context(self): from google.cloud.spanner_v1.batch import Batch - from google.cloud.spanner_v1.types import CommitRequest, CommitResponse, TransactionOptions + from google.cloud.spanner_v1.types import ( + CommitRequest, + CommitResponse, + ) from google.cloud.spanner_v1 import DefaultTransactionOptions session = mock.Mock(spec=["name", "_database"]) @@ -175,26 +242,39 @@ def test_batch_commit_propagates_client_context(self): database.with_error_augmentation.return_value = (None, mock.MagicMock()) database._next_nth_request = 1 client = database._instance._client = mock.Mock() - client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) - - batch_context = ClientContext(secure_context={"batch": struct_pb2.Value(string_value="from-batch")}) + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + batch_context = ClientContext( + secure_context={"batch": struct_pb2.Value(string_value="from-batch")} + ) batch = Batch(session, client_context=batch_context) - + api = database.spanner_api = mock.Mock() response = CommitResponse() api.commit.return_value = response - + batch.commit() - + args, kwargs = api.commit.call_args - request = kwargs['request'] + request = kwargs["request"] self.assertIsInstance(request, CommitRequest) - self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") - self.assertEqual(request.request_options.client_context.secure_context["batch"], "from-batch") + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["batch"], "from-batch" + ) def test_transaction_execute_update_propagates_client_context(self): from google.cloud.spanner_v1.transaction import Transaction - from google.cloud.spanner_v1.types import ExecuteSqlRequest, ResultSet, MultiplexedSessionPrecommitToken + from google.cloud.spanner_v1.types import ( + ExecuteSqlRequest, + ResultSet, + MultiplexedSessionPrecommitToken, + ) session = mock.Mock(spec=["name", "_database", "_precommit_token"]) session.name = "session-name" @@ -203,27 +283,38 @@ def test_transaction_execute_update_propagates_client_context(self): database._route_to_leader_enabled = False database.with_error_augmentation.return_value = (None, mock.MagicMock()) database._next_nth_request = 1 - + client = database._instance._client = mock.Mock() client._query_options = None - client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) - - transaction_context = ClientContext(secure_context={"txn": struct_pb2.Value(string_value="from-txn")}) + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + transaction_context = ClientContext( + secure_context={"txn": struct_pb2.Value(string_value="from-txn")} + ) transaction = Transaction(session, client_context=transaction_context) transaction._transaction_id = b"tx-id" transaction._precommit_token = MultiplexedSessionPrecommitToken(seq_num=1) - - api = database.spanner_api = mock.Mock() - response = ResultSet(precommit_token=MultiplexedSessionPrecommitToken(seq_num=2)) - + + database.spanner_api = mock.Mock() + response = ResultSet( + precommit_token=MultiplexedSessionPrecommitToken(seq_num=2) + ) + with mock.patch.object(transaction, "_execute_request", return_value=response): transaction.execute_update("UPDATE T SET C = 1") - + args, kwargs = transaction._execute_request.call_args - request = args[1] + request = args[1] self.assertIsInstance(request, ExecuteSqlRequest) - self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") - self.assertEqual(request.request_options.client_context.secure_context["txn"], "from-txn") + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["txn"], "from-txn" + ) def test_mutation_groups_batch_write_propagates_client_context(self): from google.cloud.spanner_v1.batch import MutationGroups @@ -237,70 +328,89 @@ def test_mutation_groups_batch_write_propagates_client_context(self): database.with_error_augmentation.return_value = (None, mock.MagicMock()) database.metadata_with_request_id.return_value = [] database._next_nth_request = 1 - + client = database._instance._client = mock.Mock() - client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) - - mg_context = ClientContext(secure_context={"mg": struct_pb2.Value(string_value="from-mg")}) + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + mg_context = ClientContext( + secure_context={"mg": struct_pb2.Value(string_value="from-mg")} + ) mg = MutationGroups(session, client_context=mg_context) - + api = database.spanner_api = mock.Mock() - - with mock.patch("google.cloud.spanner_v1.batch._retry", side_effect=lambda f, **kw: f()): + + with mock.patch( + "google.cloud.spanner_v1.batch._retry", side_effect=lambda f, **kw: f() + ): mg.batch_write() - + args, kwargs = api.batch_write.call_args - request = kwargs['request'] + request = kwargs["request"] self.assertIsInstance(request, BatchWriteRequest) - self.assertEqual(request.request_options.client_context.secure_context["client"], "from-client") - self.assertEqual(request.request_options.client_context.secure_context["mg"], "from-mg") + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["mg"], "from-mg" + ) def test_batch_snapshot_propagates_client_context(self): from google.cloud.spanner_v1.database import BatchSnapshot - from google.cloud.spanner_v1.types import ExecuteSqlRequest database = mock.Mock() database.name = "database-name" client = database._instance._client = mock.Mock() client._query_options = None - client._client_context = ClientContext(secure_context={"client": struct_pb2.Value(string_value="from-client")}) - - batch_context = ClientContext(secure_context={"batch": struct_pb2.Value(string_value="from-batch")}) + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + batch_context = ClientContext( + secure_context={"batch": struct_pb2.Value(string_value="from-batch")} + ) batch_snapshot = BatchSnapshot(database, client_context=batch_context) - + session = mock.Mock(spec=["name", "_database", "session_id", "snapshot"]) session.name = "session-name" session.session_id = "session-id" database.sessions_manager.get_session.return_value = session - + snapshot = mock.Mock() session.snapshot.return_value = snapshot - + batch_snapshot.execute_sql("SELECT 1") - + session.snapshot.assert_called_once() kwargs = session.snapshot.call_args.kwargs self.assertEqual(kwargs["client_context"], batch_context) def test_database_snapshot_propagates_client_context(self): - from google.cloud.spanner_v1.database import Database, SnapshotCheckout - + from google.cloud.spanner_v1.database import Database + instance = mock.Mock() instance._client = mock.Mock() instance._client._query_options = None instance._client._client_context = None - + database = Database("db", instance) - with mock.patch("google.cloud.spanner_v1.database.SnapshotCheckout") as mocked_checkout: - client_context = {"secure_context": {"a": struct_pb2.Value(string_value="A")}} + with mock.patch( + "google.cloud.spanner_v1.database.SnapshotCheckout" + ) as mocked_checkout: + client_context = { + "secure_context": {"a": struct_pb2.Value(string_value="A")} + } database.snapshot(client_context=client_context) - - mocked_checkout.assert_called_once_with(database, client_context=client_context) + + mocked_checkout.assert_called_once_with( + database, client_context=client_context + ) def test_transaction_rollback_propagates_client_context_is_not_supported(self): # Verify that rollback DOES NOT take client_context as it's not in RollbackRequest from google.cloud.spanner_v1.transaction import Transaction - from google.cloud.spanner_v1.types import RollbackRequest session = mock.Mock(spec=["name", "_database"]) session.name = "session-name" @@ -309,19 +419,20 @@ def test_transaction_rollback_propagates_client_context_is_not_supported(self): database._route_to_leader_enabled = False database.with_error_augmentation.return_value = (None, mock.MagicMock()) database._next_nth_request = 1 - + transaction = Transaction(session) transaction._transaction_id = b"tx-id" - + api = database.spanner_api = mock.Mock() - + transaction.rollback() - + args, kwargs = api.rollback.call_args self.assertEqual(kwargs["session"], "session-name") self.assertEqual(kwargs["transaction_id"], b"tx-id") # Ensure no request_options or client_context passed to rollback self.assertNotIn("request_options", kwargs) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 929f0c0010..6fe7dcd049 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -30,6 +30,7 @@ RequestOptions, DirectedReadOptions, DefaultTransactionOptions, + ExecuteSqlRequest, ) from google.cloud.spanner_v1._helpers import ( AtomicCounter, @@ -2599,6 +2600,7 @@ def test__get_snapshot_new_wo_staleness(self): exact_staleness=None, multi_use=True, transaction_id=None, + client_context=None, ) snapshot.begin.assert_called_once_with() @@ -2614,6 +2616,7 @@ def test__get_snapshot_w_read_timestamp(self): exact_staleness=None, multi_use=True, transaction_id=None, + client_context=None, ) snapshot.begin.assert_called_once_with() @@ -2629,6 +2632,7 @@ def test__get_snapshot_w_exact_staleness(self): exact_staleness=duration, multi_use=True, transaction_id=None, + client_context=None, ) snapshot.begin.assert_called_once_with() @@ -3540,6 +3544,7 @@ def __init__( self.directed_read_options = directed_read_options self.default_transaction_options = default_transaction_options self.observability_options = observability_options + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() @@ -3589,6 +3594,13 @@ class _Database(object): def __init__(self, name, instance=None): self.name = name self.database_id = name.rsplit("/", 1)[1] + if instance is None: + instance = mock.Mock() + instance._client = mock.Mock() + instance._client._client_context = None + instance._client._query_options = ExecuteSqlRequest.QueryOptions( + optimizer_version="1" + ) self._instance = instance from logging import Logger diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index f3bf6726c0..b76770cf13 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -1023,6 +1023,7 @@ def __init__(self, project, timeout_seconds=None): self.route_to_leader_enabled = True self.directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + self._client_context = None def copy(self): from copy import deepcopy diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index e0a236c86f..bfce743352 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -19,7 +19,21 @@ from datetime import datetime, timedelta import mock +from google.cloud.spanner_v1 import pool as MUT from google.cloud.spanner_v1 import _opentelemetry_tracing +from google.cloud.spanner_v1 import ExecuteSqlRequest +from google.cloud.spanner_v1 import BatchCreateSessionsResponse +from google.cloud.spanner_v1 import Session +from google.cloud.spanner_v1 import SpannerClient +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.pool import AbstractSessionPool +from google.cloud.spanner_v1.pool import SessionCheckout +from google.cloud.spanner_v1.pool import FixedSizePool +from google.cloud.spanner_v1.pool import BurstyPool +from google.cloud.spanner_v1.pool import PingingPool +from google.cloud.spanner_v1.transaction import Transaction +from google.cloud.exceptions import NotFound +from google.cloud._testing import _Monkey from google.cloud.spanner_v1._helpers import ( _metadata_with_request_id, _metadata_with_request_id_and_req_id, @@ -40,21 +54,15 @@ def _make_database(name="name"): - from google.cloud.spanner_v1.database import Database - return mock.create_autospec(Database, instance=True) def _make_session(): - from google.cloud.spanner_v1.database import Session - return mock.create_autospec(Session, instance=True) class TestAbstractSessionPool(unittest.TestCase): def _getTargetClass(self): - from google.cloud.spanner_v1.pool import AbstractSessionPool - return AbstractSessionPool def _make_one(self, *args, **kwargs): @@ -129,8 +137,6 @@ def test__new_session_w_database_role(self): self.assertEqual(new_session.database_role, database_role) def test_session_wo_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - pool = self._make_one() checkout = pool.session() self.assertIsInstance(checkout, SessionCheckout) @@ -139,8 +145,6 @@ def test_session_wo_kwargs(self): self.assertEqual(checkout._kwargs, {}) def test_session_w_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - pool = self._make_one() checkout = pool.session(foo="bar") self.assertIsInstance(checkout, SessionCheckout) @@ -164,8 +168,6 @@ class TestFixedSizePool(OpenTelemetryBase): enrich_with_otel_scope(BASE_ATTRIBUTES) def _getTargetClass(self): - from google.cloud.spanner_v1.pool import FixedSizePool - return FixedSizePool def _make_one(self, *args, **kwargs): @@ -559,8 +561,6 @@ class TestBurstyPool(OpenTelemetryBase): enrich_with_otel_scope(BASE_ATTRIBUTES) def _getTargetClass(self): - from google.cloud.spanner_v1.pool import BurstyPool - return BurstyPool def _make_one(self, *args, **kwargs): @@ -850,8 +850,6 @@ class TestPingingPool(OpenTelemetryBase): enrich_with_otel_scope(BASE_ATTRIBUTES) def _getTargetClass(self): - from google.cloud.spanner_v1.pool import PingingPool - return PingingPool def _make_one(self, *args, **kwargs): @@ -946,8 +944,6 @@ def test_get_hit_no_ping(self, mock_region): ) def test_get_hit_w_ping(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) database = _Database("name") @@ -974,8 +970,6 @@ def test_get_hit_w_ping(self, mock_region): ) def test_get_hit_w_ping_expired(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) database = _Database("name") @@ -1097,8 +1091,6 @@ def test_spans_put_full(self, mock_region): ) def test_put_non_full(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) session_queue = pool._sessions = _Queue() @@ -1172,8 +1164,6 @@ def test_ping_oldest_fresh(self, mock_region): ) def test_ping_oldest_stale_but_exists(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) database = _Database("name") @@ -1193,8 +1183,6 @@ def test_ping_oldest_stale_but_exists(self, mock_region): ) def test_ping_oldest_stale_and_not_exists(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) database = _Database("name") @@ -1257,8 +1245,6 @@ def test_spans_get_and_leave_empty_pool(self, mock_region): class TestSessionCheckout(unittest.TestCase): def _getTargetClass(self): - from google.cloud.spanner_v1.pool import SessionCheckout - return SessionCheckout def _make_one(self, *args, **kwargs): @@ -1314,8 +1300,6 @@ def test_context_manager_w_kwargs(self): def _make_transaction(*args, **kw): - from google.cloud.spanner_v1.transaction import Transaction - txn = mock.create_autospec(Transaction)(*args, **kw) txn.committed = None txn.rolled_back = False @@ -1352,15 +1336,11 @@ def exists(self): return self._exists def ping(self): - from google.cloud.exceptions import NotFound - self._pinged = True if not self._exists: raise NotFound("expired session") def delete(self): - from google.cloud.exceptions import NotFound - self._deleted = True if not self._exists: raise NotFound("unknown session") @@ -1391,9 +1371,6 @@ def mock_batch_create_sessions( metadata=[], labels={}, ): - from google.cloud.spanner_v1 import BatchCreateSessionsResponse - from google.cloud.spanner_v1 import Session - database_role = request.session_template.creator_role if request else None if request.session_count < 2: response = BatchCreateSessionsResponse( @@ -1408,10 +1385,15 @@ def mock_batch_create_sessions( ) return response - from google.cloud.spanner_v1 import SpannerClient - self.spanner_api = mock.create_autospec(SpannerClient, instance=True) self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions + self._instance = mock.Mock() + self._instance._client = mock.Mock() + self._instance._client._client_context = None + self._instance._client.spanner_api = self.spanner_api + self._instance._client._query_options = ExecuteSqlRequest.QueryOptions( + optimizer_version="1" + ) @property def database_role(self): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 86e4fe7e72..49a6f8297c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -194,6 +194,9 @@ def _make_database( database.database_role = database_role database._route_to_leader_enabled = True database.default_transaction_options = default_transaction_options + database._instance = mock.Mock() + database._instance._client = mock.Mock() + database._instance._client._client_context = None inject_into_mock_database(database) return database diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 81d2d01fa3..3d93488ab7 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -2182,6 +2182,7 @@ def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index ecd7d4fd86..ad1033d412 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -1280,6 +1280,7 @@ def __init__(self): self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 9afc1130b4..769dcaf703 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -1384,6 +1384,7 @@ def __init__(self): self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.directed_read_options = None + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() From 3611aec6b03b056875dbc681297dccae5f4c4e84 Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Wed, 18 Feb 2026 07:03:42 +0000 Subject: [PATCH 3/4] fix: Resolve race condition in concurrent Transaction initialization Implement a double-checked locking pattern in Transaction and _SnapshotBase methods. When multiple threads attempt to use a lazy transaction simultaneously, they race to acquire the lock. Previously, losing threads would acquire the lock and blindly send another 'begin transaction' request, ignoring that the winner had already initialized the transaction ID. This change ensures that threads re-check self._transaction_id after acquiring the lock. If the ID is present, they skip the initialization request and use the established ID. --- google/cloud/spanner_v1/snapshot.py | 4 ++++ google/cloud/spanner_v1/transaction.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 85481324a9..3d530a4d09 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -638,6 +638,10 @@ def _get_streamed_result_set( if self._transaction_id is None: is_inline_begin = True self._lock.acquire() + if self._transaction_id is not None: + is_inline_begin = False + self._lock.release() + request.transaction = TransactionSelector(id=self._transaction_id) iterator = _restart_on_unavailable( method=method, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 0b0dc7dd51..9aa2fce1ef 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -512,6 +512,9 @@ def execute_update( if self._transaction_id is None: is_inline_begin = True self._lock.acquire() + if self._transaction_id is not None: + is_inline_begin = False + self._lock.release() execute_sql_request = ExecuteSqlRequest( session=session.name, @@ -670,6 +673,9 @@ def batch_update( if self._transaction_id is None: is_inline_begin = True self._lock.acquire() + if self._transaction_id is not None: + is_inline_begin = False + self._lock.release() execute_batch_dml_request = ExecuteBatchDmlRequest( session=session.name, From d249c67806c9924cc86fd3deb48232c5c35c884d Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Wed, 18 Feb 2026 07:09:42 +0000 Subject: [PATCH 4/4] fix: Secure ClientContext merging and improve type safety - Fix critical security vulnerability in where in-place modification of the base object could lead to context leakage across requests. - Replace with throughout , , , and for better robustness and subclass support. - Simplify construction logic in for better readability. Based on suggestions from Gemini Code Assist. --- google/cloud/spanner_v1/_helpers.py | 22 ++++++++++++++++------ google/cloud/spanner_v1/batch.py | 4 ++-- google/cloud/spanner_v1/client.py | 2 +- google/cloud/spanner_v1/snapshot.py | 16 +++++++--------- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index f45108e4fa..fb6bf49b8c 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -174,7 +174,7 @@ def _merge_query_options(base, merge): If the resultant object only has empty fields, returns None. """ combined = base or ExecuteSqlRequest.QueryOptions() - if type(combined) is dict: + if isinstance(combined, dict): combined = ExecuteSqlRequest.QueryOptions( optimizer_version=combined.get("optimizer_version", ""), optimizer_statistics_package=combined.get( @@ -182,7 +182,7 @@ def _merge_query_options(base, merge): ), ) merge = merge or ExecuteSqlRequest.QueryOptions() - if type(merge) is dict: + if isinstance(merge, dict): merge = ExecuteSqlRequest.QueryOptions( optimizer_version=merge.get("optimizer_version", ""), optimizer_statistics_package=merge.get("optimizer_statistics_package", ""), @@ -215,14 +215,24 @@ def _merge_client_context(base, merge): return None combined = base or ClientContext() - if type(combined) is dict: + if isinstance(combined, dict): combined = ClientContext(combined) merge = merge or ClientContext() - if type(merge) is dict: + if isinstance(merge, dict): merge = ClientContext(merge) - type(combined).pb(combined).MergeFrom(type(merge).pb(merge)) + # Avoid in-place modification of base + combined_pb = ClientContext()._pb + if base: + base_pb = ClientContext(base)._pb if isinstance(base, dict) else base._pb + combined_pb.MergeFrom(base_pb) + if merge: + merge_pb = ClientContext(merge)._pb if isinstance(merge, dict) else merge._pb + combined_pb.MergeFrom(merge_pb) + + combined = ClientContext(combined_pb) + if not combined.secure_context: return None return combined @@ -250,7 +260,7 @@ def _merge_request_options(request_options, client_context): if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: + elif isinstance(request_options, dict): request_options = RequestOptions(request_options) if client_context: diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 7eeb83c12f..9375c80956 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -68,7 +68,7 @@ def __init__(self, session, client_context=None): self.commit_stats: Optional[CommitResponse.CommitStats] = None if client_context is not None: - if type(client_context) is dict: + if isinstance(client_context, dict): client_context = ClientContext(client_context) elif not isinstance(client_context, ClientContext): raise TypeError("client_context must be a ClientContext or a dict") @@ -349,7 +349,7 @@ def __init__(self, session, client_context=None): self.committed: bool = False if client_context is not None: - if type(client_context) is dict: + if isinstance(client_context, dict): client_context = ClientContext(client_context) elif not isinstance(client_context, ClientContext): raise TypeError("client_context must be a ClientContext or a dict") diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index cd288bc260..9f2d08059a 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -295,7 +295,7 @@ def __init__( self._query_options = _merge_query_options(query_options, env_query_options) if client_context is not None: - if type(client_context) is dict: + if isinstance(client_context, dict): client_context = ClientContext(client_context) elif not isinstance(client_context, ClientContext): raise TypeError("client_context must be a ClientContext or a dict") diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 3d530a4d09..42ae196f1b 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -213,7 +213,7 @@ def __init__(self, session, client_context=None): super().__init__(session) if client_context is not None: - if type(client_context) is dict: + if isinstance(client_context, dict): client_context = ClientContext(client_context) elif not isinstance(client_context, ClientContext): raise TypeError("client_context must be a ClientContext or a dict") @@ -949,20 +949,18 @@ def _begin_transaction( "mutation_key": mutation, } + request_options = begin_request_kwargs.get("request_options") client_context = _merge_client_context( database._instance._client._client_context, self._client_context ) - if client_context: - begin_request_kwargs["request_options"] = _merge_request_options( - begin_request_kwargs.get("request_options"), client_context - ) + request_options = _merge_request_options(request_options, client_context) if transaction_tag: - request_options = begin_request_kwargs.get("request_options") if request_options is None: - request_options = RequestOptions(transaction_tag=transaction_tag) - else: - request_options.transaction_tag = transaction_tag + request_options = RequestOptions() + request_options.transaction_tag = transaction_tag + + if request_options: begin_request_kwargs["request_options"] = request_options with trace_call(