diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 567a2e147..4a51205a4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -42,6 +42,7 @@ "Input is too long for requested model", "input length and `max_tokens` exceed context limit", "too many total text bytes", + "prompt is too long", ] # Models that should include tool result status (include_tool_result_status = True) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 833b14729..d15f10a59 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -19,7 +19,7 @@ DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, ) -from strands.types.exceptions import ModelThrottledException +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") @@ -1516,6 +1516,31 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model ] +@pytest.mark.parametrize( + "overflow_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + "prompt is too long: 903884 tokens > 200000 maximum", + ], +) +@pytest.mark.asyncio +async def test_stream_context_window_overflow(overflow_message, bedrock_client, model, alist, messages): + """Test that ClientError with overflow messages raises ContextWindowOverflowException.""" + error_response = { + "Error": { + "Code": "ValidationException", + "Message": f"An error occurred (ValidationException) when calling the ConverseStream operation: " + f"The model returned the following errors: {overflow_message}", + } + } + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConverseStream") + + with pytest.raises(ContextWindowOverflowException): + await alist(model.stream(messages)) + + @pytest.mark.asyncio async def test_stream_logging(bedrock_client, model, messages, caplog, alist): """Test that stream method logs debug messages at the expected stages."""