diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 4f77269bb2..cd5b8ae371 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -38,6 +38,7 @@ from .types.spanner import BatchWriteRequest from .types.spanner import BatchWriteResponse from .types.spanner import BeginTransactionRequest +from .types.spanner import ClientContext from .types.spanner import CommitRequest from .types.spanner import CreateSessionRequest from .types.spanner import DeleteSessionRequest @@ -110,6 +111,7 @@ "BatchWriteRequest", "BatchWriteResponse", "BeginTransactionRequest", + "ClientContext", "CommitRequest", "CommitResponse", "CreateSessionRequest", diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index a52c24e769..fb6bf49b8c 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -34,6 +34,8 @@ from google.cloud._helpers import _date_from_iso8601_date from google.cloud.spanner_v1.types import ExecuteSqlRequest from google.cloud.spanner_v1.types import TransactionOptions +from google.cloud.spanner_v1.types import ClientContext +from google.cloud.spanner_v1.types import RequestOptions from google.cloud.spanner_v1.data_types import JsonObject, Interval from google.cloud.spanner_v1.request_id_header import ( with_request_id, @@ -172,7 +174,7 @@ def _merge_query_options(base, merge): If the resultant object only has empty fields, returns None. """ combined = base or ExecuteSqlRequest.QueryOptions() - if type(combined) is dict: + if isinstance(combined, dict): combined = ExecuteSqlRequest.QueryOptions( optimizer_version=combined.get("optimizer_version", ""), optimizer_statistics_package=combined.get( @@ -180,7 +182,7 @@ def _merge_query_options(base, merge): ), ) merge = merge or ExecuteSqlRequest.QueryOptions() - if type(merge) is dict: + if isinstance(merge, dict): merge = ExecuteSqlRequest.QueryOptions( optimizer_version=merge.get("optimizer_version", ""), optimizer_statistics_package=merge.get("optimizer_statistics_package", ""), @@ -191,6 +193,84 @@ def _merge_query_options(base, merge): return combined +def _merge_client_context(base, merge): + """Merge higher precedence ClientContext with current ClientContext. + + :type base: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` or None + :param base: The current ClientContext that is intended for use. + + :type merge: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` or None + :param merge: + The ClientContext that has a higher priority than base. These options + should overwrite the fields in base. + + :rtype: :class:`~google.cloud.spanner_v1.types.ClientContext` + or None + :returns: + ClientContext object formed by merging the two given ClientContexts. + """ + if base is None and merge is None: + return None + + combined = base or ClientContext() + if isinstance(combined, dict): + combined = ClientContext(combined) + + merge = merge or ClientContext() + if isinstance(merge, dict): + merge = ClientContext(merge) + + # Avoid in-place modification of base + combined_pb = ClientContext()._pb + if base: + base_pb = ClientContext(base)._pb if isinstance(base, dict) else base._pb + combined_pb.MergeFrom(base_pb) + if merge: + merge_pb = ClientContext(merge)._pb if isinstance(merge, dict) else merge._pb + combined_pb.MergeFrom(merge_pb) + + combined = ClientContext(combined_pb) + + if not combined.secure_context: + return None + return combined + + +def _merge_request_options(request_options, client_context): + """Merge RequestOptions and ClientContext. + + :type request_options: :class:`~google.cloud.spanner_v1.types.RequestOptions` + or :class:`dict` or None + :param request_options: The current RequestOptions that is intended for use. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` or None + :param client_context: + The ClientContext to merge into request_options. + + :rtype: :class:`~google.cloud.spanner_v1.types.RequestOptions` + or None + :returns: + RequestOptions object formed by merging the given ClientContext. + """ + if request_options is None and client_context is None: + return None + + if request_options is None: + request_options = RequestOptions() + elif isinstance(request_options, dict): + request_options = RequestOptions(request_options) + + if client_context: + request_options.client_context = _merge_client_context( + client_context, request_options.client_context + ) + + return request_options + + def _assert_numeric_precision_and_scale(value): """ Asserts that input numeric field is within Spanner supported range. diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 6f67531c1e..9375c80956 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -28,6 +28,8 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _merge_Transaction_Options, + _merge_client_context, + _merge_request_options, AtomicCounter, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call @@ -37,6 +39,7 @@ from google.cloud.spanner_v1._helpers import _check_rst_stream_error from google.api_core.exceptions import InternalServerError from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.types import ClientContext import time DEFAULT_RETRY_TIMEOUT_SECS = 30 @@ -47,9 +50,14 @@ class _BatchBase(_SessionWrapper): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch. """ - def __init__(self, session): + def __init__(self, session, client_context=None): super(_BatchBase, self).__init__(session) self._mutations: List[Mutation] = [] @@ -59,6 +67,13 @@ def __init__(self, session): """Timestamp at which the batch was successfully committed.""" self.commit_stats: Optional[CommitResponse.CommitStats] = None + if client_context is not None: + if isinstance(client_context, dict): + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + def insert(self, table, columns, values): """Insert one or more new table rows. @@ -227,10 +242,14 @@ def commit( txn_options, ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag # Request tags are not supported for commit requests. @@ -317,13 +336,25 @@ class MutationGroups(_SessionWrapper): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this mutation group. """ - def __init__(self, session): + def __init__(self, session, client_context=None): super(MutationGroups, self).__init__(session) self._mutation_groups: List[MutationGroup] = [] self.committed: bool = False + if client_context is not None: + if isinstance(client_context, dict): + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + def group(self): """Returns a new `MutationGroup` to which mutations can be added.""" mutation_group = BatchWriteRequest.MutationGroup() @@ -365,10 +396,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) with trace_call( name="CloudSpanner.batch_write", diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index 82dbe936aa..9f2d08059a 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -50,6 +50,7 @@ from google.cloud.spanner_v1 import __version__ from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import DefaultTransactionOptions +from google.cloud.spanner_v1.types import ClientContext from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.instance import Instance @@ -225,6 +226,10 @@ class Client(ClientWithProject): :param disable_builtin_metrics: (Optional) Default False. Set to True to disable the Spanner built-in metrics collection and exporting. + :type client_context: :class:`~google.cloud.spanner_v1.types.RequestOptions.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made by this client. + :raises: :class:`ValueError ` if both ``read_only`` and ``admin`` are :data:`True` """ @@ -251,6 +256,7 @@ def __init__( default_transaction_options: Optional[DefaultTransactionOptions] = None, experimental_host=None, disable_builtin_metrics=False, + client_context=None, ): self._emulator_host = _get_spanner_emulator_host() self._experimental_host = experimental_host @@ -288,6 +294,13 @@ def __init__( # Environment flag config has higher precedence than application config. self._query_options = _merge_query_options(query_options, env_query_options) + if client_context is not None: + if isinstance(client_context, dict): + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + if self._emulator_host is not None and ( "http://" in self._emulator_host or "https://" in self._emulator_host ): diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 4977a4abb9..8c55a1744f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -946,6 +946,7 @@ def snapshot(self, **kw): :param kw: Passed through to :class:`~google.cloud.spanner_v1.snapshot.Snapshot` constructor. + Now includes ``client_context``. :rtype: :class:`~google.cloud.spanner_v1.database.SnapshotCheckout` :returns: new wrapper @@ -959,6 +960,7 @@ def batch( exclude_txn_from_change_streams=False, isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, read_lock_mode=TransactionOptions.ReadWrite.ReadLockMode.READ_LOCK_MODE_UNSPECIFIED, + client_context=None, **kw, ): """Return an object which wraps a batch. @@ -996,6 +998,11 @@ def batch( :param read_lock_mode: (Optional) Sets the read lock mode for this transaction. This overrides any default read lock mode set for the client. + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch. + :rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout` :returns: new wrapper """ @@ -1007,19 +1014,25 @@ def batch( exclude_txn_from_change_streams, isolation_level, read_lock_mode, + client_context=client_context, **kw, ) - def mutation_groups(self): + def mutation_groups(self, client_context=None): """Return an object which wraps a mutation_group. The wrapper *must* be used as a context manager, with the mutation group as the value returned by the wrapper. + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this mutation group. + :rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout` :returns: new wrapper """ - return MutationGroupsCheckout(self) + return MutationGroupsCheckout(self, client_context=client_context) def batch_snapshot( self, @@ -1027,6 +1040,7 @@ def batch_snapshot( exact_staleness=None, session_id=None, transaction_id=None, + client_context=None, ): """Return an object which wraps a batch read / query. @@ -1043,6 +1057,11 @@ def batch_snapshot( :type transaction_id: str :param transaction_id: id of the transaction + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch snapshot. + :rtype: :class:`~google.cloud.spanner_v1.database.BatchSnapshot` :returns: new wrapper """ @@ -1052,6 +1071,7 @@ def batch_snapshot( exact_staleness=exact_staleness, session_id=session_id, transaction_id=transaction_id, + client_context=client_context, ) def run_in_transaction(self, func, *args, **kw): @@ -1080,6 +1100,8 @@ def run_in_transaction(self, func, *args, **kw): the DDL option `allow_txn_exclusion` being false or unset. "isolation_level" sets the isolation level for the transaction. "read_lock_mode" sets the read lock mode for the transaction. + "client_context" (Optional) Client context to use for all requests + made by this transaction. :rtype: Any :returns: The return value of ``func``. @@ -1391,6 +1413,11 @@ class BatchCheckout(object): :param max_commit_delay: (Optional) The amount of latency this request is willing to incur in order to improve throughput. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch. """ def __init__( @@ -1401,6 +1428,7 @@ def __init__( exclude_txn_from_change_streams=False, isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, read_lock_mode=TransactionOptions.ReadWrite.ReadLockMode.READ_LOCK_MODE_UNSPECIFIED, + client_context=None, **kw, ): self._database: Database = database @@ -1417,6 +1445,7 @@ def __init__( self._exclude_txn_from_change_streams = exclude_txn_from_change_streams self._isolation_level = isolation_level self._read_lock_mode = read_lock_mode + self._client_context = client_context self._kw = kw def __enter__(self): @@ -1433,7 +1462,9 @@ def __enter__(self): event_attributes={"id": self._session.session_id}, ) - batch = self._batch = Batch(session=self._session) + batch = self._batch = Batch( + session=self._session, client_context=self._client_context + ) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag @@ -1478,18 +1509,26 @@ class MutationGroupsCheckout(object): :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: database to use + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this mutation group. """ - def __init__(self, database): + def __init__(self, database, client_context=None): self._database: Database = database self._session: Optional[Session] = None + self._client_context = client_context def __enter__(self): """Begin ``with`` block.""" transaction_type = TransactionType.READ_WRITE self._session = self._database.sessions_manager.get_session(transaction_type) - return MutationGroups(session=self._session) + return MutationGroups( + session=self._session, client_context=self._client_context + ) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" @@ -1555,6 +1594,11 @@ class BatchSnapshot(object): :type exact_staleness: :class:`datetime.timedelta` :param exact_staleness: Execute all reads at a timestamp that is ``exact_staleness`` old. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this batch snapshot. """ def __init__( @@ -1564,6 +1608,7 @@ def __init__( exact_staleness=None, session_id=None, transaction_id=None, + client_context=None, ): self._database: Database = database @@ -1575,6 +1620,7 @@ def __init__( self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness + self._client_context = client_context @classmethod def from_dict(cls, database, mapping): @@ -1663,6 +1709,7 @@ def _get_snapshot(self): exact_staleness=self._exact_staleness, multi_use=True, transaction_id=self._transaction_id, + client_context=self._client_context, ) if self._transaction_id is None: diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index e7bc913c27..95db0f72d2 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -472,9 +472,14 @@ def batch(self): return Batch(self) - def transaction(self) -> Transaction: + def transaction(self, client_context=None) -> Transaction: """Create a transaction to perform a set of reads with shared staleness. + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this transaction. + :rtype: :class:`~google.cloud.spanner_v1.transaction.Transaction` :returns: a transaction bound to this session @@ -483,7 +488,7 @@ def transaction(self) -> Transaction: if self._session_id is None: raise ValueError("Session has not been created.") - return Transaction(self) + return Transaction(self, client_context=client_context) def run_in_transaction(self, func, *args, **kw): """Perform a unit of work in a transaction, retrying on abort. @@ -512,6 +517,8 @@ def run_in_transaction(self, func, *args, **kw): the DDL option `allow_txn_exclusion` being false or unset. "isolation_level" sets the isolation level for the transaction. "read_lock_mode" sets the read lock mode for the transaction. + "client_context" (Optional) Client context to use for all requests + made by this transaction. :rtype: Any :returns: The return value of ``func``. @@ -529,6 +536,7 @@ def run_in_transaction(self, func, *args, **kw): ) isolation_level = kw.pop("isolation_level", None) read_lock_mode = kw.pop("read_lock_mode", None) + client_context = kw.pop("client_context", None) database = self._database log_commit_stats = database.log_commit_stats @@ -554,7 +562,7 @@ def run_in_transaction(self, func, *args, **kw): previous_transaction_id: Optional[bytes] = None while True: - txn = self.transaction() + txn = self.transaction(client_context=client_context) txn.transaction_tag = transaction_tag txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams txn.isolation_level = isolation_level diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index a7abcdaaa3..42ae196f1b 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -41,6 +41,8 @@ from google.cloud.spanner_v1._helpers import ( _make_value_pb, _merge_query_options, + _merge_client_context, + _merge_request_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, _retry, @@ -52,6 +54,7 @@ from google.cloud.spanner_v1._opentelemetry_tracing import trace_call, add_span_event from google.cloud.spanner_v1.streamed import StreamedResultSet from google.cloud.spanner_v1 import RequestOptions +from google.cloud.spanner_v1.types import ClientContext from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture from google.cloud.spanner_v1.types import MultiplexedSessionPrecommitToken @@ -196,14 +199,26 @@ class _SnapshotBase(_SessionWrapper): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform transaction operations. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this transaction. """ _read_only: bool = True _multi_use: bool = False - def __init__(self, session): + def __init__(self, session, client_context=None): super().__init__(session) + if client_context is not None: + if isinstance(client_context, dict): + client_context = ClientContext(client_context) + elif not isinstance(client_context, ClientContext): + raise TypeError("client_context must be a ClientContext or a dict") + self._client_context = client_context + # Counts for execute SQL requests and total read requests (including # execute SQL requests). Used to provide sequence numbers for # :class:`google.cloud.spanner_v1.types.ExecuteSqlRequest` and to @@ -348,10 +363,13 @@ def read( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) if self._read_only: # Transaction tags are not supported for read only transactions. @@ -543,10 +561,14 @@ def execute_sql( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + if self._read_only: # Transaction tags are not supported for read only transactions. request_options.transaction_tag = None @@ -616,6 +638,10 @@ def _get_streamed_result_set( if self._transaction_id is None: is_inline_begin = True self._lock.acquire() + if self._transaction_id is not None: + is_inline_begin = False + self._lock.release() + request.transaction = TransactionSelector(id=self._transaction_id) iterator = _restart_on_unavailable( method=method, @@ -923,10 +949,19 @@ def _begin_transaction( "mutation_key": mutation, } + request_options = begin_request_kwargs.get("request_options") + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if transaction_tag: - begin_request_kwargs["request_options"] = RequestOptions( - transaction_tag=transaction_tag - ) + if request_options is None: + request_options = RequestOptions() + request_options.transaction_tag = transaction_tag + + if request_options: + begin_request_kwargs["request_options"] = request_options with trace_call( name=f"CloudSpanner.{type(self).__name__}.begin", @@ -1099,6 +1134,11 @@ class Snapshot(_SnapshotBase): context of a read-only transaction, used to ensure isolation / consistency. Incompatible with ``max_staleness`` and ``min_read_timestamp``. + + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this snapshot. """ def __init__( @@ -1110,8 +1150,9 @@ def __init__( exact_staleness=None, multi_use=False, transaction_id=None, + client_context=None, ): - super(Snapshot, self).__init__(session) + super(Snapshot, self).__init__(session, client_context=client_context) opts = [read_timestamp, min_read_timestamp, max_staleness, exact_staleness] flagged = [opt for opt in opts if opt is not None] diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 413ac0af1f..9aa2fce1ef 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -25,6 +25,8 @@ _retry, _check_rst_stream_error, _merge_Transaction_Options, + _merge_client_context, + _merge_request_options, ) from google.cloud.spanner_v1 import ( CommitRequest, @@ -54,6 +56,11 @@ class Transaction(_SnapshotBase, _BatchBase): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit + :type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext` + or :class:`dict` + :param client_context: (Optional) Client context to use for all requests made + by this transaction. + :raises ValueError: if session has an existing transaction """ @@ -69,8 +76,8 @@ class Transaction(_SnapshotBase, _BatchBase): _multi_use: bool = True _read_only: bool = False - def __init__(self, session): - super(Transaction, self).__init__(session) + def __init__(self, session, client_context=None): + super(Transaction, self).__init__(session, client_context=client_context) self.rolled_back: bool = False # If this transaction is used to retry a previous aborted transaction with a @@ -266,10 +273,14 @@ def commit( else: raise ValueError("Transaction has not begun.") + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + if self.transaction_tag is not None: request_options.transaction_tag = self.transaction_tag @@ -479,10 +490,14 @@ def execute_update( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag trace_attributes = { @@ -497,6 +512,9 @@ def execute_update( if self._transaction_id is None: is_inline_begin = True self._lock.acquire() + if self._transaction_id is not None: + is_inline_begin = False + self._lock.release() execute_sql_request = ExecuteSqlRequest( session=session.name, @@ -632,10 +650,14 @@ def batch_update( self._execute_sql_request_count + 1, ) + client_context = _merge_client_context( + database._instance._client._client_context, self._client_context + ) + request_options = _merge_request_options(request_options, client_context) + if request_options is None: request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag trace_attributes = { @@ -651,6 +673,9 @@ def batch_update( if self._transaction_id is None: is_inline_begin = True self._lock.acquire() + if self._transaction_id is not None: + is_inline_begin = False + self._lock.release() execute_batch_dml_request = ExecuteBatchDmlRequest( session=session.name, diff --git a/google/cloud/spanner_v1/types/__init__.py b/google/cloud/spanner_v1/types/__init__.py index 5a7ded16dd..5f1e9274b6 100644 --- a/google/cloud/spanner_v1/types/__init__.py +++ b/google/cloud/spanner_v1/types/__init__.py @@ -52,6 +52,7 @@ BatchWriteRequest, BatchWriteResponse, BeginTransactionRequest, + ClientContext, CommitRequest, CreateSessionRequest, DeleteSessionRequest, @@ -110,6 +111,7 @@ "BatchWriteRequest", "BatchWriteResponse", "BeginTransactionRequest", + "ClientContext", "CommitRequest", "CreateSessionRequest", "DeleteSessionRequest", diff --git a/google/cloud/spanner_v1/types/spanner.py b/google/cloud/spanner_v1/types/spanner.py index 6e363088de..c7085cda13 100644 --- a/google/cloud/spanner_v1/types/spanner.py +++ b/google/cloud/spanner_v1/types/spanner.py @@ -43,6 +43,7 @@ "ListSessionsResponse", "DeleteSessionRequest", "RequestOptions", + "ClientContext", "DirectedReadOptions", "ExecuteSqlRequest", "ExecuteBatchDmlRequest", @@ -395,6 +396,31 @@ class Priority(proto.Enum): proto.STRING, number=3, ) + client_context: ClientContext = proto.Field( + proto.MESSAGE, + number=4, + message="ClientContext", + ) + + +class ClientContext(proto.Message): + r"""Container for various pieces of client-owned context + attached to a request. + + Attributes: + secure_context (MutableMapping[str, google.protobuf.struct_pb2.Value]): + Optional. Map of parameter name to value for this request. + These values will be returned by any SECURE_CONTEXT() calls + invoked by this request (e.g., by queries against + Parameterized Secure Views). + """ + + secure_context: MutableMapping[str, struct_pb2.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=1, + message=struct_pb2.Value, + ) class DirectedReadOptions(proto.Message): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 6e8159425f..6fc844183e 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -872,6 +872,7 @@ class _Client(object): def __init__(self, project="project_id"): self.project = project self.project_name = "projects/" + self.project + self._client_context = None def instance(self, instance_id="instance_id"): return _Instance(name=instance_id, client=self) diff --git a/tests/unit/test_backup.py b/tests/unit/test_backup.py index 00621c2148..8198a283e4 100644 --- a/tests/unit/test_backup.py +++ b/tests/unit/test_backup.py @@ -679,6 +679,7 @@ class _Client(object): def __init__(self, project=TestBackup.PROJECT_ID): self.project = project self.project_name = "projects/" + self.project + self._client_context = None class _Instance(object): diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index f00a45e8a5..b4690203f6 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -806,6 +806,9 @@ class _Database(object): def __init__(self, enable_end_to_end_tracing=False): self.name = "testing" + self._instance = mock.Mock() + self._instance._client = mock.Mock() + self._instance._client._client_context = None self._route_to_leader_enabled = True if enable_end_to_end_tracing: self.observability_options = dict(enable_end_to_end_tracing=True) diff --git a/tests/unit/test_client_context.py b/tests/unit/test_client_context.py new file mode 100644 index 0000000000..6c95b51946 --- /dev/null +++ b/tests/unit/test_client_context.py @@ -0,0 +1,438 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock +from google.protobuf import struct_pb2 +from google.cloud.spanner_v1.types import ( + ClientContext, + RequestOptions, + ExecuteSqlRequest, +) +from google.cloud.spanner_v1._helpers import ( + _merge_client_context, + _merge_request_options, +) + + +class TestClientContext(unittest.TestCase): + def test__merge_client_context_both_none(self): + self.assertIsNone(_merge_client_context(None, None)) + + def test__merge_client_context_base_none(self): + merge = ClientContext(secure_context={"a": struct_pb2.Value(string_value="A")}) + result = _merge_client_context(None, merge) + self.assertEqual(result.secure_context["a"], "A") + + def test__merge_client_context_merge_none(self): + base = ClientContext(secure_context={"a": struct_pb2.Value(string_value="A")}) + result = _merge_client_context(base, None) + self.assertEqual(result.secure_context["a"], "A") + + def test__merge_client_context_both_set(self): + base = ClientContext( + secure_context={ + "a": struct_pb2.Value(string_value="A"), + "b": struct_pb2.Value(string_value="B1"), + } + ) + merge = ClientContext( + secure_context={ + "b": struct_pb2.Value(string_value="B2"), + "c": struct_pb2.Value(string_value="C"), + } + ) + result = _merge_client_context(base, merge) + self.assertEqual(result.secure_context["a"], "A") + self.assertEqual(result.secure_context["b"], "B2") + self.assertEqual(result.secure_context["c"], "C") + + def test__merge_request_options_with_client_context(self): + request_options = RequestOptions(priority=RequestOptions.Priority.PRIORITY_LOW) + client_context = ClientContext( + secure_context={"a": struct_pb2.Value(string_value="A")} + ) + + result = _merge_request_options(request_options, client_context) + + self.assertEqual(result.priority, RequestOptions.Priority.PRIORITY_LOW) + self.assertEqual(result.client_context.secure_context["a"], "A") + + def test_client_init_with_client_context(self): + from google.cloud.spanner_v1.client import Client + + project = "PROJECT" + credentials = mock.Mock(spec=["_resource_prefix__"]) + with mock.patch( + "google.auth.default", return_value=(credentials, project) + ), mock.patch( + "google.cloud.spanner_v1.client._get_spanner_enable_builtin_metrics_env", + return_value=False, + ): + client_context = { + "secure_context": {"a": struct_pb2.Value(string_value="A")} + } + client = Client( + project=project, + client_context=client_context, + disable_builtin_metrics=True, + ) + + self.assertIsInstance(client._client_context, ClientContext) + self.assertEqual(client._client_context.secure_context["a"], "A") + + def test_snapshot_execute_sql_propagates_client_context(self): + from google.cloud.spanner_v1.snapshot import Snapshot + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database._directed_read_options = None + + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + snapshot_context = ClientContext( + secure_context={"snapshot": struct_pb2.Value(string_value="from-snapshot")} + ) + snapshot = Snapshot(session, client_context=snapshot_context) + + with mock.patch.object(snapshot, "_get_streamed_result_set") as mocked: + snapshot.execute_sql("SELECT 1") + kwargs = mocked.call_args.kwargs + request = kwargs["request"] + self.assertIsInstance(request, ExecuteSqlRequest) + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["snapshot"], + "from-snapshot", + ) + + def test_transaction_commit_propagates_client_context(self): + from google.cloud.spanner_v1.transaction import Transaction + from google.cloud.spanner_v1.types import ( + CommitRequest, + CommitResponse, + MultiplexedSessionPrecommitToken, + ) + + session = mock.Mock(spec=["name", "_database", "is_multiplexed"]) + session.name = "session-name" + session.is_multiplexed = False + database = session._database = mock.Mock() + database.name = "projects/p/instances/i/databases/d" + database._route_to_leader_enabled = False + database.log_commit_stats = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + + client = database._instance._client = mock.Mock() + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + transaction_context = ClientContext( + secure_context={"txn": struct_pb2.Value(string_value="from-txn")} + ) + transaction = Transaction(session, client_context=transaction_context) + transaction._transaction_id = b"tx-id" + + api = database.spanner_api = mock.Mock() + + token = MultiplexedSessionPrecommitToken(seq_num=1) + response = CommitResponse(precommit_token=token) + + def side_effect(f, **kw): + return f() + + api.commit.return_value = response + + with mock.patch( + "google.cloud.spanner_v1.transaction._retry", side_effect=side_effect + ): + transaction.commit() + + args, kwargs = api.commit.call_args + request = kwargs["request"] + self.assertIsInstance(request, CommitRequest) + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["txn"], "from-txn" + ) + + def test_snapshot_execute_sql_request_level_override(self): + from google.cloud.spanner_v1.snapshot import Snapshot + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database._directed_read_options = None + + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext( + secure_context={"a": struct_pb2.Value(string_value="from-client")} + ) + + snapshot_context = ClientContext( + secure_context={ + "a": struct_pb2.Value(string_value="from-snapshot"), + "b": struct_pb2.Value(string_value="B"), + } + ) + snapshot = Snapshot(session, client_context=snapshot_context) + + request_options = RequestOptions( + client_context=ClientContext( + secure_context={"a": struct_pb2.Value(string_value="from-request")} + ) + ) + + with mock.patch.object(snapshot, "_get_streamed_result_set") as mocked: + snapshot.execute_sql("SELECT 1", request_options=request_options) + kwargs = mocked.call_args.kwargs + request = kwargs["request"] + self.assertEqual( + request.request_options.client_context.secure_context["a"], + "from-request", + ) + self.assertEqual( + request.request_options.client_context.secure_context["b"], "B" + ) + + def test_batch_commit_propagates_client_context(self): + from google.cloud.spanner_v1.batch import Batch + from google.cloud.spanner_v1.types import ( + CommitRequest, + CommitResponse, + ) + from google.cloud.spanner_v1 import DefaultTransactionOptions + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.log_commit_stats = False + database.default_transaction_options = DefaultTransactionOptions() + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + client = database._instance._client = mock.Mock() + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + batch_context = ClientContext( + secure_context={"batch": struct_pb2.Value(string_value="from-batch")} + ) + batch = Batch(session, client_context=batch_context) + + api = database.spanner_api = mock.Mock() + response = CommitResponse() + api.commit.return_value = response + + batch.commit() + + args, kwargs = api.commit.call_args + request = kwargs["request"] + self.assertIsInstance(request, CommitRequest) + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["batch"], "from-batch" + ) + + def test_transaction_execute_update_propagates_client_context(self): + from google.cloud.spanner_v1.transaction import Transaction + from google.cloud.spanner_v1.types import ( + ExecuteSqlRequest, + ResultSet, + MultiplexedSessionPrecommitToken, + ) + + session = mock.Mock(spec=["name", "_database", "_precommit_token"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + transaction_context = ClientContext( + secure_context={"txn": struct_pb2.Value(string_value="from-txn")} + ) + transaction = Transaction(session, client_context=transaction_context) + transaction._transaction_id = b"tx-id" + transaction._precommit_token = MultiplexedSessionPrecommitToken(seq_num=1) + + database.spanner_api = mock.Mock() + response = ResultSet( + precommit_token=MultiplexedSessionPrecommitToken(seq_num=2) + ) + + with mock.patch.object(transaction, "_execute_request", return_value=response): + transaction.execute_update("UPDATE T SET C = 1") + + args, kwargs = transaction._execute_request.call_args + request = args[1] + self.assertIsInstance(request, ExecuteSqlRequest) + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["txn"], "from-txn" + ) + + def test_mutation_groups_batch_write_propagates_client_context(self): + from google.cloud.spanner_v1.batch import MutationGroups + from google.cloud.spanner_v1.types import BatchWriteRequest + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database.metadata_with_request_id.return_value = [] + database._next_nth_request = 1 + + client = database._instance._client = mock.Mock() + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + mg_context = ClientContext( + secure_context={"mg": struct_pb2.Value(string_value="from-mg")} + ) + mg = MutationGroups(session, client_context=mg_context) + + api = database.spanner_api = mock.Mock() + + with mock.patch( + "google.cloud.spanner_v1.batch._retry", side_effect=lambda f, **kw: f() + ): + mg.batch_write() + + args, kwargs = api.batch_write.call_args + request = kwargs["request"] + self.assertIsInstance(request, BatchWriteRequest) + self.assertEqual( + request.request_options.client_context.secure_context["client"], + "from-client", + ) + self.assertEqual( + request.request_options.client_context.secure_context["mg"], "from-mg" + ) + + def test_batch_snapshot_propagates_client_context(self): + from google.cloud.spanner_v1.database import BatchSnapshot + + database = mock.Mock() + database.name = "database-name" + client = database._instance._client = mock.Mock() + client._query_options = None + client._client_context = ClientContext( + secure_context={"client": struct_pb2.Value(string_value="from-client")} + ) + + batch_context = ClientContext( + secure_context={"batch": struct_pb2.Value(string_value="from-batch")} + ) + batch_snapshot = BatchSnapshot(database, client_context=batch_context) + + session = mock.Mock(spec=["name", "_database", "session_id", "snapshot"]) + session.name = "session-name" + session.session_id = "session-id" + database.sessions_manager.get_session.return_value = session + + snapshot = mock.Mock() + session.snapshot.return_value = snapshot + + batch_snapshot.execute_sql("SELECT 1") + + session.snapshot.assert_called_once() + kwargs = session.snapshot.call_args.kwargs + self.assertEqual(kwargs["client_context"], batch_context) + + def test_database_snapshot_propagates_client_context(self): + from google.cloud.spanner_v1.database import Database + + instance = mock.Mock() + instance._client = mock.Mock() + instance._client._query_options = None + instance._client._client_context = None + + database = Database("db", instance) + with mock.patch( + "google.cloud.spanner_v1.database.SnapshotCheckout" + ) as mocked_checkout: + client_context = { + "secure_context": {"a": struct_pb2.Value(string_value="A")} + } + database.snapshot(client_context=client_context) + + mocked_checkout.assert_called_once_with( + database, client_context=client_context + ) + + def test_transaction_rollback_propagates_client_context_is_not_supported(self): + # Verify that rollback DOES NOT take client_context as it's not in RollbackRequest + from google.cloud.spanner_v1.transaction import Transaction + + session = mock.Mock(spec=["name", "_database"]) + session.name = "session-name" + database = session._database = mock.Mock() + database.name = "database-name" + database._route_to_leader_enabled = False + database.with_error_augmentation.return_value = (None, mock.MagicMock()) + database._next_nth_request = 1 + + transaction = Transaction(session) + transaction._transaction_id = b"tx-id" + + api = database.spanner_api = mock.Mock() + + transaction.rollback() + + args, kwargs = api.rollback.call_args + self.assertEqual(kwargs["session"], "session-name") + self.assertEqual(kwargs["transaction_id"], b"tx-id") + # Ensure no request_options or client_context passed to rollback + self.assertNotIn("request_options", kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 929f0c0010..6fe7dcd049 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -30,6 +30,7 @@ RequestOptions, DirectedReadOptions, DefaultTransactionOptions, + ExecuteSqlRequest, ) from google.cloud.spanner_v1._helpers import ( AtomicCounter, @@ -2599,6 +2600,7 @@ def test__get_snapshot_new_wo_staleness(self): exact_staleness=None, multi_use=True, transaction_id=None, + client_context=None, ) snapshot.begin.assert_called_once_with() @@ -2614,6 +2616,7 @@ def test__get_snapshot_w_read_timestamp(self): exact_staleness=None, multi_use=True, transaction_id=None, + client_context=None, ) snapshot.begin.assert_called_once_with() @@ -2629,6 +2632,7 @@ def test__get_snapshot_w_exact_staleness(self): exact_staleness=duration, multi_use=True, transaction_id=None, + client_context=None, ) snapshot.begin.assert_called_once_with() @@ -3540,6 +3544,7 @@ def __init__( self.directed_read_options = directed_read_options self.default_transaction_options = default_transaction_options self.observability_options = observability_options + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() @@ -3589,6 +3594,13 @@ class _Database(object): def __init__(self, name, instance=None): self.name = name self.database_id = name.rsplit("/", 1)[1] + if instance is None: + instance = mock.Mock() + instance._client = mock.Mock() + instance._client._client_context = None + instance._client._query_options = ExecuteSqlRequest.QueryOptions( + optimizer_version="1" + ) self._instance = instance from logging import Logger diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index f3bf6726c0..b76770cf13 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -1023,6 +1023,7 @@ def __init__(self, project, timeout_seconds=None): self.route_to_leader_enabled = True self.directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + self._client_context = None def copy(self): from copy import deepcopy diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index e0a236c86f..bfce743352 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -19,7 +19,21 @@ from datetime import datetime, timedelta import mock +from google.cloud.spanner_v1 import pool as MUT from google.cloud.spanner_v1 import _opentelemetry_tracing +from google.cloud.spanner_v1 import ExecuteSqlRequest +from google.cloud.spanner_v1 import BatchCreateSessionsResponse +from google.cloud.spanner_v1 import Session +from google.cloud.spanner_v1 import SpannerClient +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.pool import AbstractSessionPool +from google.cloud.spanner_v1.pool import SessionCheckout +from google.cloud.spanner_v1.pool import FixedSizePool +from google.cloud.spanner_v1.pool import BurstyPool +from google.cloud.spanner_v1.pool import PingingPool +from google.cloud.spanner_v1.transaction import Transaction +from google.cloud.exceptions import NotFound +from google.cloud._testing import _Monkey from google.cloud.spanner_v1._helpers import ( _metadata_with_request_id, _metadata_with_request_id_and_req_id, @@ -40,21 +54,15 @@ def _make_database(name="name"): - from google.cloud.spanner_v1.database import Database - return mock.create_autospec(Database, instance=True) def _make_session(): - from google.cloud.spanner_v1.database import Session - return mock.create_autospec(Session, instance=True) class TestAbstractSessionPool(unittest.TestCase): def _getTargetClass(self): - from google.cloud.spanner_v1.pool import AbstractSessionPool - return AbstractSessionPool def _make_one(self, *args, **kwargs): @@ -129,8 +137,6 @@ def test__new_session_w_database_role(self): self.assertEqual(new_session.database_role, database_role) def test_session_wo_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - pool = self._make_one() checkout = pool.session() self.assertIsInstance(checkout, SessionCheckout) @@ -139,8 +145,6 @@ def test_session_wo_kwargs(self): self.assertEqual(checkout._kwargs, {}) def test_session_w_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - pool = self._make_one() checkout = pool.session(foo="bar") self.assertIsInstance(checkout, SessionCheckout) @@ -164,8 +168,6 @@ class TestFixedSizePool(OpenTelemetryBase): enrich_with_otel_scope(BASE_ATTRIBUTES) def _getTargetClass(self): - from google.cloud.spanner_v1.pool import FixedSizePool - return FixedSizePool def _make_one(self, *args, **kwargs): @@ -559,8 +561,6 @@ class TestBurstyPool(OpenTelemetryBase): enrich_with_otel_scope(BASE_ATTRIBUTES) def _getTargetClass(self): - from google.cloud.spanner_v1.pool import BurstyPool - return BurstyPool def _make_one(self, *args, **kwargs): @@ -850,8 +850,6 @@ class TestPingingPool(OpenTelemetryBase): enrich_with_otel_scope(BASE_ATTRIBUTES) def _getTargetClass(self): - from google.cloud.spanner_v1.pool import PingingPool - return PingingPool def _make_one(self, *args, **kwargs): @@ -946,8 +944,6 @@ def test_get_hit_no_ping(self, mock_region): ) def test_get_hit_w_ping(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) database = _Database("name") @@ -974,8 +970,6 @@ def test_get_hit_w_ping(self, mock_region): ) def test_get_hit_w_ping_expired(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) database = _Database("name") @@ -1097,8 +1091,6 @@ def test_spans_put_full(self, mock_region): ) def test_put_non_full(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) session_queue = pool._sessions = _Queue() @@ -1172,8 +1164,6 @@ def test_ping_oldest_fresh(self, mock_region): ) def test_ping_oldest_stale_but_exists(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) database = _Database("name") @@ -1193,8 +1183,6 @@ def test_ping_oldest_stale_but_exists(self, mock_region): ) def test_ping_oldest_stale_and_not_exists(self, mock_region): import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) database = _Database("name") @@ -1257,8 +1245,6 @@ def test_spans_get_and_leave_empty_pool(self, mock_region): class TestSessionCheckout(unittest.TestCase): def _getTargetClass(self): - from google.cloud.spanner_v1.pool import SessionCheckout - return SessionCheckout def _make_one(self, *args, **kwargs): @@ -1314,8 +1300,6 @@ def test_context_manager_w_kwargs(self): def _make_transaction(*args, **kw): - from google.cloud.spanner_v1.transaction import Transaction - txn = mock.create_autospec(Transaction)(*args, **kw) txn.committed = None txn.rolled_back = False @@ -1352,15 +1336,11 @@ def exists(self): return self._exists def ping(self): - from google.cloud.exceptions import NotFound - self._pinged = True if not self._exists: raise NotFound("expired session") def delete(self): - from google.cloud.exceptions import NotFound - self._deleted = True if not self._exists: raise NotFound("unknown session") @@ -1391,9 +1371,6 @@ def mock_batch_create_sessions( metadata=[], labels={}, ): - from google.cloud.spanner_v1 import BatchCreateSessionsResponse - from google.cloud.spanner_v1 import Session - database_role = request.session_template.creator_role if request else None if request.session_count < 2: response = BatchCreateSessionsResponse( @@ -1408,10 +1385,15 @@ def mock_batch_create_sessions( ) return response - from google.cloud.spanner_v1 import SpannerClient - self.spanner_api = mock.create_autospec(SpannerClient, instance=True) self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions + self._instance = mock.Mock() + self._instance._client = mock.Mock() + self._instance._client._client_context = None + self._instance._client.spanner_api = self.spanner_api + self._instance._client._query_options = ExecuteSqlRequest.QueryOptions( + optimizer_version="1" + ) @property def database_role(self): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 86e4fe7e72..49a6f8297c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -194,6 +194,9 @@ def _make_database( database.database_role = database_role database._route_to_leader_enabled = True database.default_transaction_options = default_transaction_options + database._instance = mock.Mock() + database._instance._client = mock.Mock() + database._instance._client._client_context = None inject_into_mock_database(database) return database diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 81d2d01fa3..3d93488ab7 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -2182,6 +2182,7 @@ def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index ecd7d4fd86..ad1033d412 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -1280,6 +1280,7 @@ def __init__(self): self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 9afc1130b4..769dcaf703 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -1384,6 +1384,7 @@ def __init__(self): self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.directed_read_options = None + self._client_context = None self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter()