diff --git a/kafka/protocol/__init__.py b/kafka/protocol/__init__.py index 578cf489a..ff9c68306 100644 --- a/kafka/protocol/__init__.py +++ b/kafka/protocol/__init__.py @@ -1,12 +1,3 @@ -from . import ( - produce, fetch, list_offsets, metadata, - commit, find_coordinator, group, - sasl_handshake, api_versions, admin, - init_producer_id, offset_for_leader_epoch, - add_partitions_to_txn, add_offsets_to_txn, end_txn, - txn_offset_commit, sasl_authenticate, -) - API_KEYS = { 0: 'Produce', 1: 'Fetch', @@ -53,47 +44,3 @@ 46: 'ListPartitionReassignments', 48: 'DescribeClientQuotas', } - -# Mapping of Api_key to a tuple of (request_classes, response_classes) -REQUEST_TYPES = { - 0: (produce.ProduceRequest, produce.ProduceResponse), - 1: (fetch.FetchRequest, fetch.FetchResponse), - 2: (list_offsets.ListOffsetsRequest, list_offsets.ListOffsetsResponse), - 3: (metadata.MetadataRequest, metadata.MetadataResponse), - 8: (commit.OffsetCommitRequest, commit.OffsetCommitResponse), - 9: (commit.OffsetFetchRequest, commit.OffsetFetchResponse), - 10: (find_coordinator.FindCoordinatorRequest, find_coordinator.FindCoordinatorResponse), - 11: (group.JoinGroupRequest, group.JoinGroupResponse), - 12: (group.HeartbeatRequest, group.HeartbeatResponse), - 13: (group.LeaveGroupRequest, group.LeaveGroupResponse), - 14: (group.SyncGroupRequest, group.SyncGroupResponse), - 15: (admin.DescribeGroupsRequest, admin.DescribeGroupsResponse), - 16: (admin.ListGroupsRequest, admin.ListGroupsResponse), - 17: (sasl_handshake.SaslHandshakeRequest, sasl_handshake.SaslHandshakeResponse), - 18: (api_versions.ApiVersionsRequest, api_versions.ApiVersionsResponse), - 19: (admin.CreateTopicsRequest, admin.CreateTopicsResponse), - 20: (admin.DeleteTopicsRequest, admin.DeleteTopicsResponse), - 21: (admin.DeleteRecordsRequest, admin.DeleteRecordsResponse), - 22: (init_producer_id.InitProducerIdRequest, init_producer_id.InitProducerIdResponse), - 23: (offset_for_leader_epoch.OffsetForLeaderEpochRequest, offset_for_leader_epoch.OffsetForLeaderEpochResponse), - 24: (add_partitions_to_txn.AddPartitionsToTxnRequest, add_partitions_to_txn.AddPartitionsToTxnResponse), - 25: (add_offsets_to_txn.AddOffsetsToTxnRequest, add_offsets_to_txn.AddOffsetsToTxnResponse), - 26: (end_txn.EndTxnRequest, end_txn.EndTxnResponse), - 28: (txn_offset_commit.TxnOffsetCommitRequest, txn_offset_commit.TxnOffsetCommitResponse), - 29: (admin.DescribeAclsRequest, admin.DescribeAclsResponse), - 30: (admin.CreateAclsRequest, admin.CreateAclsResponse), - 31: (admin.DeleteAclsRequest, admin.DeleteAclsResponse), - 32: (admin.DescribeConfigsRequest, admin.DescribeConfigsResponse), - 33: (admin.AlterConfigsRequest, admin.AlterConfigsResponse), - 36: (sasl_authenticate.SaslAuthenticateRequest, sasl_authenticate.SaslAuthenticateResponse), - 37: (admin.CreatePartitionsRequest, admin.CreatePartitionsResponse), - 42: (admin.DeleteGroupsRequest, admin.DeleteGroupsResponse) -} - -def get_response_class(api_key, api_version): - request_type, response_type = REQUEST_TYPES.get(api_key, (None, None)) - if response_type: - if hasattr(response_type, '__getitem__'): - return response_type[api_version] - return response_type - return None diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 60e6c6b6d..a88f43927 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -1,11 +1,26 @@ import abc from io import BytesIO +import weakref from kafka.protocol.struct import Struct from kafka.protocol.types import Int16, Int32, String, Schema, Array, TaggedFields -class RequestHeader(Struct): +class ResponseClassRegistry: + _response_class_registry = {} + + @classmethod + def register_response_class(cls, response_class): + cls._response_class_registry[(response_class.API_KEY, response_class.API_VERSION)] = response_class + + @classmethod + def get_response_class(cls, header): + key = (header.api_key, header.api_version) + if key in cls._response_class_registry: + return cls._response_class_registry[key] + + +class RequestHeader(ResponseClassRegistry, Struct): SCHEMA = Schema( ('api_key', Int16), ('api_version', Int16), @@ -13,8 +28,11 @@ class RequestHeader(Struct): ('client_id', String('utf-8')) ) + def get_response_class(self): + return ResponseClassRegistry.get_response_class(self) + -class RequestHeaderV2(Struct): +class RequestHeaderV2(ResponseClassRegistry, Struct): # Flexible response / request headers end in field buffer SCHEMA = Schema( ('api_key', Int16), @@ -24,6 +42,11 @@ class RequestHeaderV2(Struct): ('tags', TaggedFields), ) + def get_response_class(self): + key = (self.api_key, self.api_version) # pylint: disable=E1101 + if key in ResponseClassRegistry._response_class_registry: + return ResponseClassRegistry._response_class_registry[key] + class ResponseHeader(Struct): SCHEMA = Schema( @@ -132,6 +155,10 @@ def header_class(cls): class Response(RequestResponse): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + ResponseClassRegistry.register_response_class(weakref.proxy(cls)) + @classmethod def is_request(cls): return False diff --git a/kafka/protocol/parser.py b/kafka/protocol/parser.py index 847a47dc0..586b51fa1 100644 --- a/kafka/protocol/parser.py +++ b/kafka/protocol/parser.py @@ -2,7 +2,6 @@ import logging import kafka.errors as Errors -from kafka.protocol import get_response_class from kafka.protocol.find_coordinator import FindCoordinatorResponse from kafka.protocol.frame import KafkaBytes from kafka.protocol.types import Int32 @@ -138,7 +137,7 @@ def _process_response(self, read_buffer): raise Errors.CorrelationIdError('No in-flight-request found for server response') header = self.in_flight_requests.popleft() correlation_id = header.correlation_id - response_type = get_response_class(header.api_key, header.api_version) + response_type = header.get_response_class() if response_type is None: log.error('Unable to find ResponseType for api=%d version=%d', header.api_key, header.api_version)