Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .types.spanner import BatchWriteRequest
from .types.spanner import BatchWriteResponse
from .types.spanner import BeginTransactionRequest
from .types.spanner import ClientContext
from .types.spanner import CommitRequest
from .types.spanner import CreateSessionRequest
from .types.spanner import DeleteSessionRequest
Expand Down Expand Up @@ -110,6 +111,7 @@
"BatchWriteRequest",
"BatchWriteResponse",
"BeginTransactionRequest",
"ClientContext",
"CommitRequest",
"CommitResponse",
"CreateSessionRequest",
Expand Down
84 changes: 82 additions & 2 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1.types import ExecuteSqlRequest
from google.cloud.spanner_v1.types import TransactionOptions
from google.cloud.spanner_v1.types import ClientContext
from google.cloud.spanner_v1.types import RequestOptions
from google.cloud.spanner_v1.data_types import JsonObject, Interval
from google.cloud.spanner_v1.request_id_header import (
with_request_id,
Expand Down Expand Up @@ -172,15 +174,15 @@ def _merge_query_options(base, merge):
If the resultant object only has empty fields, returns None.
"""
combined = base or ExecuteSqlRequest.QueryOptions()
if type(combined) is dict:
if isinstance(combined, dict):
combined = ExecuteSqlRequest.QueryOptions(
optimizer_version=combined.get("optimizer_version", ""),
optimizer_statistics_package=combined.get(
"optimizer_statistics_package", ""
),
)
merge = merge or ExecuteSqlRequest.QueryOptions()
if type(merge) is dict:
if isinstance(merge, dict):
merge = ExecuteSqlRequest.QueryOptions(
optimizer_version=merge.get("optimizer_version", ""),
optimizer_statistics_package=merge.get("optimizer_statistics_package", ""),
Expand All @@ -191,6 +193,84 @@ def _merge_query_options(base, merge):
return combined


def _merge_client_context(base, merge):
"""Merge higher precedence ClientContext with current ClientContext.

:type base: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict` or None
:param base: The current ClientContext that is intended for use.

:type merge: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict` or None
:param merge:
The ClientContext that has a higher priority than base. These options
should overwrite the fields in base.

:rtype: :class:`~google.cloud.spanner_v1.types.ClientContext`
or None
:returns:
ClientContext object formed by merging the two given ClientContexts.
"""
if base is None and merge is None:
return None

combined = base or ClientContext()
if isinstance(combined, dict):
combined = ClientContext(combined)

merge = merge or ClientContext()
if isinstance(merge, dict):
merge = ClientContext(merge)

# Avoid in-place modification of base
combined_pb = ClientContext()._pb
if base:
base_pb = ClientContext(base)._pb if isinstance(base, dict) else base._pb
combined_pb.MergeFrom(base_pb)
if merge:
merge_pb = ClientContext(merge)._pb if isinstance(merge, dict) else merge._pb
combined_pb.MergeFrom(merge_pb)

combined = ClientContext(combined_pb)

if not combined.secure_context:
return None
return combined


def _merge_request_options(request_options, client_context):
"""Merge RequestOptions and ClientContext.

:type request_options: :class:`~google.cloud.spanner_v1.types.RequestOptions`
or :class:`dict` or None
:param request_options: The current RequestOptions that is intended for use.

:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict` or None
:param client_context:
The ClientContext to merge into request_options.

:rtype: :class:`~google.cloud.spanner_v1.types.RequestOptions`
or None
:returns:
RequestOptions object formed by merging the given ClientContext.
"""
if request_options is None and client_context is None:
return None

if request_options is None:
request_options = RequestOptions()
elif isinstance(request_options, dict):
request_options = RequestOptions(request_options)

if client_context:
request_options.client_context = _merge_client_context(
client_context, request_options.client_context
)

return request_options


def _assert_numeric_precision_and_scale(value):
"""
Asserts that input numeric field is within Spanner supported range.
Expand Down
46 changes: 40 additions & 6 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_merge_Transaction_Options,
_merge_client_context,
_merge_request_options,
AtomicCounter,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
Expand All @@ -37,6 +39,7 @@
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
from google.cloud.spanner_v1.types import ClientContext
import time

DEFAULT_RETRY_TIMEOUT_SECS = 30
Expand All @@ -47,9 +50,14 @@ class _BatchBase(_SessionWrapper):

:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit

:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict`
:param client_context: (Optional) Client context to use for all requests made
by this batch.
"""

def __init__(self, session):
def __init__(self, session, client_context=None):
super(_BatchBase, self).__init__(session)

self._mutations: List[Mutation] = []
Expand All @@ -59,6 +67,13 @@ def __init__(self, session):
"""Timestamp at which the batch was successfully committed."""
self.commit_stats: Optional[CommitResponse.CommitStats] = None

if client_context is not None:
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
Comment on lines 70 to 72

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For type checking, it's generally better to use isinstance(obj, cls) rather than type(obj) is cls. isinstance correctly handles subclasses, making the code more robust.

Suggested change
if client_context is not None:
if type(client_context) is dict:
client_context = ClientContext(client_context)
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
elif not isinstance(client_context, ClientContext):

elif not isinstance(client_context, ClientContext):
raise TypeError("client_context must be a ClientContext or a dict")
self._client_context = client_context

def insert(self, table, columns, values):
"""Insert one or more new table rows.

Expand Down Expand Up @@ -227,10 +242,14 @@ def commit(
txn_options,
)

client_context = _merge_client_context(
database._instance._client._client_context, self._client_context
)
request_options = _merge_request_options(request_options, client_context)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
Expand Down Expand Up @@ -317,13 +336,25 @@ class MutationGroups(_SessionWrapper):

:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit

:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict`
:param client_context: (Optional) Client context to use for all requests made
by this mutation group.
"""

def __init__(self, session):
def __init__(self, session, client_context=None):
super(MutationGroups, self).__init__(session)
self._mutation_groups: List[MutationGroup] = []
self.committed: bool = False

if client_context is not None:
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
Comment on lines 351 to 353

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For type checking, it's generally better to use isinstance(obj, cls) rather than type(obj) is cls. isinstance correctly handles subclasses, making the code more robust.

Suggested change
if client_context is not None:
if type(client_context) is dict:
client_context = ClientContext(client_context)
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
elif not isinstance(client_context, ClientContext):

elif not isinstance(client_context, ClientContext):
raise TypeError("client_context must be a ClientContext or a dict")
self._client_context = client_context

def group(self):
"""Returns a new `MutationGroup` to which mutations can be added."""
mutation_group = BatchWriteRequest.MutationGroup()
Expand Down Expand Up @@ -365,10 +396,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

client_context = _merge_client_context(
database._instance._client._client_context, self._client_context
)
request_options = _merge_request_options(request_options, client_context)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

with trace_call(
name="CloudSpanner.batch_write",
Expand Down
13 changes: 13 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from google.cloud.spanner_v1 import __version__
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import DefaultTransactionOptions
from google.cloud.spanner_v1.types import ClientContext
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.instance import Instance
Expand Down Expand Up @@ -225,6 +226,10 @@ class Client(ClientWithProject):
:param disable_builtin_metrics: (Optional) Default False. Set to True to disable
the Spanner built-in metrics collection and exporting.

:type client_context: :class:`~google.cloud.spanner_v1.types.RequestOptions.ClientContext`
or :class:`dict`
:param client_context: (Optional) Client context to use for all requests made by this client.

:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
and ``admin`` are :data:`True`
"""
Expand All @@ -251,6 +256,7 @@ def __init__(
default_transaction_options: Optional[DefaultTransactionOptions] = None,
experimental_host=None,
disable_builtin_metrics=False,
client_context=None,
Comment on lines 257 to +259

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For type checking, it's generally better to use isinstance(obj, cls) rather than type(obj) is cls. isinstance correctly handles subclasses, making the code more robust. This pattern of using type() is dict appears in several places in this PR and should be updated for consistency.

Suggested change
experimental_host=None,
disable_builtin_metrics=False,
client_context=None,
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
elif not isinstance(client_context, ClientContext):

):
self._emulator_host = _get_spanner_emulator_host()
self._experimental_host = experimental_host
Expand Down Expand Up @@ -288,6 +294,13 @@ def __init__(
# Environment flag config has higher precedence than application config.
self._query_options = _merge_query_options(query_options, env_query_options)

if client_context is not None:
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
elif not isinstance(client_context, ClientContext):
raise TypeError("client_context must be a ClientContext or a dict")
self._client_context = client_context

if self._emulator_host is not None and (
"http://" in self._emulator_host or "https://" in self._emulator_host
):
Expand Down
Loading
Loading