diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 14ac098d2..57272c89e 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -230,6 +230,8 @@ def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: return a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED case types.TaskState.auth_required: return a2a_pb2.TaskState.TASK_STATE_AUTH_REQUIRED + case types.TaskState.rejected: + return a2a_pb2.TaskState.TASK_STATE_REJECTED case _: return a2a_pb2.TaskState.TASK_STATE_UNSPECIFIED @@ -703,6 +705,8 @@ def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: return types.TaskState.input_required case a2a_pb2.TaskState.TASK_STATE_AUTH_REQUIRED: return types.TaskState.auth_required + case a2a_pb2.TaskState.TASK_STATE_REJECTED: + return types.TaskState.rejected case _: return types.TaskState.unknown diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index 3010c3a56..9ea8c9686 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -352,5 +352,52 @@ async def mock_stream_response(): assert response.status_code == 200 +@pytest.mark.anyio +async def test_send_message_rejected_task( + client: AsyncClient, request_handler: MagicMock +) -> None: + expected_response = a2a_pb2.SendMessageResponse( + task=a2a_pb2.Task( + id='test_task_id', + context_id='test_context_id', + status=a2a_pb2.TaskStatus( + state=a2a_pb2.TaskState.TASK_STATE_REJECTED, + update=a2a_pb2.Message( + message_id='test', + role=a2a_pb2.ROLE_AGENT, + content=[ + a2a_pb2.Part(text="I don't want to work"), + ], + ), + ), + ), + ) + request_handler.on_message_send.return_value = Task( + id='test_task_id', + context_id='test_context_id', + status=TaskStatus( + state=TaskState.rejected, + message=Message( + message_id='test', + role=Role.agent, + parts=[Part(TextPart(text="I don't want to work"))], + ), + ), + ) + request = a2a_pb2.SendMessageRequest( + request=a2a_pb2.Message(), + configuration=a2a_pb2.SendMessageConfiguration(), + ) + + response = await client.post( + '/v1/message:send', json=json_format.MessageToDict(request) + ) + + response.raise_for_status() + actual_response = a2a_pb2.SendMessageResponse() + json_format.Parse(response.text, actual_response) + assert expected_response == actual_response + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index f68d5c100..7fc82aad7 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -183,9 +183,8 @@ def test_enum_conversions(self): ) for state in types.TaskState: - if state not in (types.TaskState.unknown, types.TaskState.rejected): - proto_state = proto_utils.ToProto.task_state(state) - assert proto_utils.FromProto.task_state(proto_state) == state + proto_state = proto_utils.ToProto.task_state(state) + assert proto_utils.FromProto.task_state(proto_state) == state # Test unknown state case assert (