From e96d62a44268c1f970070766a7c2c19af165086c Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Wed, 4 Feb 2026 16:05:45 +0000 Subject: [PATCH 1/2] feat: Implement PushNotifications as per the 1.0 spec --- ...otificationConfigStoreIntegrationTest.java | 30 +-- .../jpa/MockPushNotificationSender.java | 23 +- .../server/events/MainEventBusProcessor.java | 41 ++-- .../tasks/BasePushNotificationSender.java | 52 ++++- .../server/tasks/PushNotificationSender.java | 36 +++- .../AbstractA2ARequestHandlerTest.java | 27 ++- .../tasks/PushNotificationSenderTest.java | 201 +++++++++++++++--- .../grpc/handler/GrpcHandlerTest.java | 48 +++-- .../jsonrpc/handler/JSONRPCHandlerTest.java | 48 +++-- 9 files changed, 364 insertions(+), 142 deletions(-) diff --git a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStoreIntegrationTest.java b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStoreIntegrationTest.java index ab06747de..eb0c35014 100644 --- a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStoreIntegrationTest.java +++ b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStoreIntegrationTest.java @@ -151,11 +151,12 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep assertTrue(updateLatch.await(10, TimeUnit.SECONDS), "Timeout waiting for task update"); // Step 5: Poll for the async notification to be captured + // With the new StreamingEventKind support, we receive all event types (Task, Message, TaskArtifactUpdateEvent, etc.) long end = System.currentTimeMillis() + 5000; boolean notificationReceived = false; while (System.currentTimeMillis() < end) { - if (!mockPushNotificationSender.getCapturedTasks().isEmpty()) { + if (!mockPushNotificationSender.getCapturedEvents().isEmpty()) { notificationReceived = true; break; } @@ -165,17 +166,22 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep assertTrue(notificationReceived, "Timeout waiting for push notification."); // Step 6: Verify the captured notification - Queue capturedTasks = mockPushNotificationSender.getCapturedTasks(); - - // Verify the notification contains the correct task with artifacts - Task notifiedTaskWithArtifact = capturedTasks.stream() - .filter(t -> taskId.equals(t.id()) && t.artifacts() != null && t.artifacts().size() > 0) - .findFirst() - .orElse(null); - - assertNotNull(notifiedTaskWithArtifact, "Notification should contain the updated task with artifacts"); - assertEquals(taskId, notifiedTaskWithArtifact.id()); - assertEquals(1, notifiedTaskWithArtifact.artifacts().size(), "Task should have one artifact from the update"); + // Check if we received events for this task (could be Task, TaskArtifactUpdateEvent, etc.) + Queue capturedEvents = mockPushNotificationSender.getCapturedEvents(); + + // Look for Task events with artifacts OR TaskArtifactUpdateEvent for this task + boolean hasTaskWithArtifact = capturedEvents.stream() + .filter(e -> e instanceof Task) + .map(e -> (Task) e) + .anyMatch(t -> taskId.equals(t.id()) && t.artifacts() != null && t.artifacts().size() > 0); + + boolean hasArtifactUpdateEvent = capturedEvents.stream() + .filter(e -> e instanceof io.a2a.spec.TaskArtifactUpdateEvent) + .map(e -> (io.a2a.spec.TaskArtifactUpdateEvent) e) + .anyMatch(e -> taskId.equals(e.taskId())); + + assertTrue(hasTaskWithArtifact || hasArtifactUpdateEvent, + "Notification should contain either Task with artifacts or TaskArtifactUpdateEvent for task " + taskId); // Step 7: Clean up - delete the push notification configuration client.deleteTaskPushNotificationConfigurations( diff --git a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/MockPushNotificationSender.java b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/MockPushNotificationSender.java index 0a6bba415..2275388a9 100644 --- a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/MockPushNotificationSender.java +++ b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/MockPushNotificationSender.java @@ -8,6 +8,7 @@ import jakarta.enterprise.inject.Alternative; import io.a2a.server.tasks.PushNotificationSender; +import io.a2a.spec.StreamingEventKind; import io.a2a.spec.Task; /** @@ -19,18 +20,30 @@ @Priority(100) public class MockPushNotificationSender implements PushNotificationSender { - private final Queue capturedTasks = new ConcurrentLinkedQueue<>(); + private final Queue capturedEvents = new ConcurrentLinkedQueue<>(); @Override - public void sendNotification(Task task) { - capturedTasks.add(task); + public void sendNotification(StreamingEventKind event) { + capturedEvents.add(event); } + public Queue getCapturedEvents() { + return capturedEvents; + } + + /** + * For backward compatibility - provides access to Task events only. + */ public Queue getCapturedTasks() { - return capturedTasks; + Queue tasks = new ConcurrentLinkedQueue<>(); + capturedEvents.stream() + .filter(e -> e instanceof Task) + .map(e -> (Task) e) + .forEach(tasks::add); + return tasks; } public void clear() { - capturedTasks.clear(); + capturedEvents.clear(); } } diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java index 8b3dc6fa3..19d11b819 100644 --- a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java @@ -16,6 +16,7 @@ import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; +import io.a2a.spec.StreamingEventKind; import io.a2a.spec.TaskArtifactUpdateEvent; import io.a2a.spec.TaskStatusUpdateEvent; import org.slf4j.Logger; @@ -211,16 +212,11 @@ private void processEvent(MainEventBusContext context) { // Step 2: Send push notification AFTER successful persistence (only from active node) // Skip push notifications for replicated events to avoid duplicate notifications in multi-instance deployments - if (eventToDistribute == event && !isReplicated) { - // Capture task state immediately after persistence, before going async - // This ensures we send the task as it existed when THIS event was processed, - // not whatever state might exist later when the async callback executes - Task taskSnapshot = taskStore.get(taskId); - if (taskSnapshot != null) { - sendPushNotification(taskId, taskSnapshot); - } else { - LOGGER.warn("Task {} not found in TaskStore after successful persistence, skipping push notification", taskId); - } + // Push notifications are sent for all StreamingEventKind events (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) + // per A2A spec section 4.3.3 + if (eventToDistribute == event && !isReplicated && event instanceof StreamingEventKind streamingEvent) { + // Send the streaming event directly - it will be wrapped in StreamResponse format by PushNotificationSender + sendPushNotification(taskId, streamingEvent); } // Step 3: Then distribute to ChildQueues (clients see either event or error AFTER persistence attempt) @@ -304,7 +300,7 @@ private boolean updateTaskStore(String taskId, Event event, boolean isReplicated } /** - * Sends push notification for the task AFTER persistence. + * Sends push notification for the streaming event AFTER persistence. *

* This is called after updateTaskStore() to ensure the notification contains * the latest persisted state, avoiding race conditions. @@ -315,10 +311,15 @@ private boolean updateTaskStore(String taskId, Event event, boolean isReplicated * PushNotificationSender.sendNotification() was causing streaming delays. *

*

- * IMPORTANT: The task parameter is a snapshot captured immediately after - * persistence. This ensures we send the task state as it existed when THIS event - * was processed, not whatever state might exist in TaskStore when the async - * callback executes (subsequent events may have already updated the store). + * IMPORTANT: The event parameter is the actual event being processed. + * This ensures we send the event as it was when processed, not whatever state + * might exist in TaskStore when the async callback executes (subsequent events + * may have already updated the store). + *

+ *

+ * Supports all StreamingEventKind event types per A2A spec section 4.3.3: + * Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent. + * The event will be automatically wrapped in StreamResponse format by JsonUtil. *

*

* NOTE: Tests can inject a synchronous executor via setPushNotificationExecutor() @@ -326,16 +327,16 @@ private boolean updateTaskStore(String taskId, Event event, boolean isReplicated *

* * @param taskId the task ID - * @param task the task snapshot to send (captured immediately after persistence) + * @param event the streaming event to send (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent) */ - private void sendPushNotification(String taskId, Task task) { + private void sendPushNotification(String taskId, StreamingEventKind event) { Runnable pushTask = () -> { try { - if (task != null) { + if (event != null) { LOGGER.debug("Sending push notification for task {}", taskId); - pushSender.sendNotification(task); + pushSender.sendNotification(event); } else { - LOGGER.debug("Skipping push notification - task snapshot is null for task {}", taskId); + LOGGER.debug("Skipping push notification - event is null for task {}", taskId); } } catch (Exception e) { LOGGER.error("Error sending push notification for task {}", taskId, e); diff --git a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java index d6d4a2369..82aa6ce68 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java +++ b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java @@ -5,6 +5,7 @@ import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN; import io.a2a.spec.TaskPushNotificationConfig; +import jakarta.annotation.Nullable; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; @@ -20,8 +21,12 @@ import io.a2a.jsonrpc.common.json.JsonUtil; import io.a2a.spec.ListTaskPushNotificationConfigParams; import io.a2a.spec.ListTaskPushNotificationConfigResult; +import io.a2a.spec.Message; import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.StreamingEventKind; import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,11 +67,17 @@ public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHt } @Override - public void sendNotification(Task task) { + public void sendNotification(StreamingEventKind event) { + String taskId = extractTaskId(event); + if (taskId == null) { + LOGGER.warn("Cannot send push notification: event does not contain taskId"); + return; + } + List configs = new ArrayList<>(); String nextPageToken = null; do { - ListTaskPushNotificationConfigResult pageResult = configStore.getInfo(new ListTaskPushNotificationConfigParams(task.id(), + ListTaskPushNotificationConfigResult pageResult = configStore.getInfo(new ListTaskPushNotificationConfigParams(taskId, DEFAULT_PAGE_SIZE, nextPageToken, "")); if (!pageResult.configs().isEmpty()) { configs.addAll(pageResult.configs()); @@ -76,7 +87,7 @@ public void sendNotification(Task task) { List> dispatchResults = configs .stream() - .map(pushConfig -> dispatch(task, pushConfig.pushNotificationConfig())) + .map(pushConfig -> dispatch(event, pushConfig.pushNotificationConfig())) .toList(); CompletableFuture allFutures = CompletableFuture.allOf(dispatchResults.toArray(new CompletableFuture[0])); CompletableFuture dispatchResult = allFutures.thenApply(v -> dispatchResults.stream() @@ -84,18 +95,37 @@ public void sendNotification(Task task) { try { boolean allSent = dispatchResult.get(); if (!allSent) { - LOGGER.warn("Some push notifications failed to send for taskId: " + task.id()); + LOGGER.warn("Some push notifications failed to send for taskId: " + taskId); } } catch (InterruptedException | ExecutionException e) { - LOGGER.warn("Some push notifications failed to send for taskId " + task.id() + ": {}", e.getMessage(), e); + LOGGER.warn("Some push notifications failed to send for taskId " + taskId + ": {}", e.getMessage(), e); + } + } + + /** + * Extracts the task ID from a StreamingEventKind event. + * + * @param event the streaming event + * @return the task ID, or null if not available + */ + private @Nullable String extractTaskId(StreamingEventKind event) { + if (event instanceof Task task) { + return task.id(); + } else if (event instanceof Message message) { + return message.taskId(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.taskId(); + } else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.taskId(); } + return null; } - private CompletableFuture dispatch(Task task, PushNotificationConfig pushInfo) { - return CompletableFuture.supplyAsync(() -> dispatchNotification(task, pushInfo)); + private CompletableFuture dispatch(StreamingEventKind event, PushNotificationConfig pushInfo) { + return CompletableFuture.supplyAsync(() -> dispatchNotification(event, pushInfo)); } - private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) { + private boolean dispatchNotification(StreamingEventKind event, PushNotificationConfig pushInfo) { String url = pushInfo.url(); String token = pushInfo.token(); @@ -106,9 +136,11 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) String body; try { - body = JsonUtil.toJson(task); + // JsonUtil.toJson automatically wraps StreamingEventKind in StreamResponse format + // (task/message/statusUpdate/artifactUpdate) per A2A spec section 4.3.3 + body = JsonUtil.toJson(event); } catch (Throwable throwable) { - LOGGER.debug("Error writing value as string: {}", throwable.getMessage(), throwable); + LOGGER.debug("Error serializing StreamingEventKind to JSON: {}", throwable.getMessage(), throwable); return false; } diff --git a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java index 2013d6a22..ef54266a5 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java +++ b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java @@ -1,5 +1,6 @@ package io.a2a.server.tasks; +import io.a2a.spec.StreamingEventKind; import io.a2a.spec.Task; /** @@ -27,7 +28,8 @@ * {@link BasePushNotificationSender} provides HTTP webhook delivery: *
    *
  • Retrieves webhook URLs from {@link PushNotificationConfigStore}
  • - *
  • Sends HTTP POST requests with task JSON payload
  • + *
  • Wraps events in StreamResponse format (per A2A spec section 4.3.3)
  • + *
  • Sends HTTP POST requests with StreamResponse JSON payload
  • *
  • Logs errors but doesn't fail the request
  • *
* @@ -47,11 +49,12 @@ * @Priority(100) * public class KafkaPushNotificationSender implements PushNotificationSender { * @Inject - * KafkaProducer producer; + * KafkaProducer producer; * * @Override - * public void sendNotification(Task task) { - * producer.send("task-updates", task.id(), task); + * public void sendNotification(StreamingEventKind event) { + * String taskId = extractTaskId(event); + * producer.send("task-updates", taskId, event); * } * } * } @@ -78,18 +81,31 @@ public interface PushNotificationSender { /** - * Sends a push notification containing the latest task state. + * Sends a push notification containing a streaming event. *

- * Called after the task has been persisted to {@link TaskStore}. Retrieve push - * notification URLs or messaging configurations from {@link PushNotificationConfigStore} - * using {@code task.id()}. + * Called after the event has been persisted to {@link TaskStore}. The event is wrapped + * in a StreamResponse format (per A2A spec section 4.3.3) with the appropriate oneof + * field set (task, message, statusUpdate, or artifactUpdate). + *

+ *

+ * Retrieve push notification URLs or messaging configurations from + * {@link PushNotificationConfigStore} using the task ID extracted from the event. + *

+ *

+ * Supported event types: + *

    + *
  • {@link Task} - wrapped in StreamResponse.task
  • + *
  • {@link io.a2a.spec.Message} - wrapped in StreamResponse.message
  • + *
  • {@link io.a2a.spec.TaskStatusUpdateEvent} - wrapped in StreamResponse.statusUpdate
  • + *
  • {@link io.a2a.spec.TaskArtifactUpdateEvent} - wrapped in StreamResponse.artifactUpdate
  • + *
*

*

* Error Handling: Log errors but don't throw exceptions. Notifications are * best-effort and should not fail the primary request. *

* - * @param task the task with current state and artifacts to send + * @param event the streaming event to send (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent) */ - void sendNotification(Task task); + void sendNotification(StreamingEventKind event); } diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index 274453fd1..1892201aa 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -42,6 +42,7 @@ import io.a2a.spec.AgentInterface; import io.a2a.spec.Event; import io.a2a.spec.Message; +import io.a2a.spec.StreamingEventKind; import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; @@ -191,7 +192,23 @@ public boolean isReplicated() { @Dependent @IfBuildProfile("test") protected static class TestHttpClient implements A2AHttpClient { - public final List tasks = Collections.synchronizedList(new ArrayList<>()); + public final List events = Collections.synchronizedList(new ArrayList<>()); + // For backward compatibility - provides access to tasks that are sent as StreamingEventKind + public final List tasks = Collections.synchronizedList(new ArrayList<>() { + @Override + public int size() { + return (int) events.stream().filter(e -> e instanceof Task).count(); + } + + @Override + public Task get(int index) { + return (Task) events.stream() + .filter(e -> e instanceof Task) + .skip(index) + .findFirst() + .orElseThrow(() -> new IndexOutOfBoundsException("Index: " + index)); + } + }); public volatile CountDownLatch latch; @Override @@ -220,8 +237,10 @@ public PostBuilder body(String body) { @Override public A2AHttpResponse post() throws IOException, InterruptedException { try { - Task task = JsonUtil.fromJson(body, Task.class); - tasks.add(task); + // Parse StreamResponse format to extract the streaming event + // The body contains a wrapper with one of: task, message, statusUpdate, artifactUpdate + StreamingEventKind event = JsonUtil.fromJson(body, StreamingEventKind.class); + events.add(event); return new A2AHttpResponse() { @Override public int status() { @@ -239,7 +258,7 @@ public String body() { } }; } catch (JsonProcessingException e) { - throw new IOException("Failed to parse task JSON", e); + throw new IOException("Failed to parse StreamingEventKind JSON", e); } finally { if (latch != null) { latch.countDown(); diff --git a/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java b/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java index 7bb67f681..e58ad384b 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java @@ -21,10 +21,17 @@ import io.a2a.common.A2AHeaders; import io.a2a.jsonrpc.common.json.JsonProcessingException; import io.a2a.jsonrpc.common.json.JsonUtil; +import io.a2a.spec.Artifact; +import io.a2a.spec.Message; +import io.a2a.spec.Part; import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.StreamingEventKind; import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -35,12 +42,14 @@ public class PushNotificationSenderTest { private BasePushNotificationSender sender; /** - * Simple test implementation of A2AHttpClient that captures HTTP calls for verification + * Simple test implementation of A2AHttpClient that captures HTTP calls for verification. + * Now captures StreamingEventKind events wrapped in StreamResponse format. */ private static class TestHttpClient implements A2AHttpClient { - final List tasks = Collections.synchronizedList(new ArrayList<>()); + final List events = Collections.synchronizedList(new ArrayList<>()); final List urls = Collections.synchronizedList(new ArrayList<>()); final List> headers = Collections.synchronizedList(new ArrayList<>()); + final List rawBodies = Collections.synchronizedList(new ArrayList<>()); volatile CountDownLatch latch; volatile boolean shouldThrowException = false; @@ -77,8 +86,13 @@ public A2AHttpResponse post() throws IOException, InterruptedException { } try { - Task task = JsonUtil.fromJson(body, Task.class); - tasks.add(task); + // Store raw body for verification + rawBodies.add(body); + + // Parse StreamResponse format to extract the event + // The body contains a wrapper with one of: task, message, statusUpdate, artifactUpdate + StreamingEventKind event = JsonUtil.fromJson(body, StreamingEventKind.class); + events.add(event); urls.add(url); headers.add(new java.util.HashMap<>(requestHeaders)); @@ -99,7 +113,7 @@ public String body() { } }; } catch (JsonProcessingException e) { - throw new IOException("Failed to parse task JSON", e); + throw new IOException("Failed to parse StreamingEventKind JSON", e); } finally { if (latch != null) { latch.countDown(); @@ -154,12 +168,14 @@ private void testSendNotificationWithInvalidToken(String token, String testName) // Wait for the async operation to complete assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds"); - - // Verify the task was sent via HTTP - assertEquals(1, testHttpClient.tasks.size()); - Task sentTask = testHttpClient.tasks.get(0); + + // Verify the task was sent via HTTP wrapped in StreamResponse format + assertEquals(1, testHttpClient.events.size()); + StreamingEventKind sentEvent = testHttpClient.events.get(0); + assertTrue(sentEvent instanceof Task, "Event should be a Task"); + Task sentTask = (Task) sentEvent; assertEquals(taskData.id(), sentTask.id()); - + // Verify that no authentication header was sent (invalid token should not add header) assertEquals(1, testHttpClient.headers.size()); Map sentHeaders = testHttpClient.headers.get(0); @@ -192,10 +208,10 @@ public void testSendNotificationSuccess() throws InterruptedException { String taskId = "task_send_success"; Task taskData = createSampleTask(taskId, TaskState.COMPLETED); PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); - + // Set up the configuration in the store configStore.setInfo(taskId, config); - + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -203,13 +219,19 @@ public void testSendNotificationSuccess() throws InterruptedException { // Wait for the async operation to complete assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds"); - - // Verify the task was sent via HTTP - assertEquals(1, testHttpClient.tasks.size()); - Task sentTask = testHttpClient.tasks.get(0); + + // Verify the task was sent via HTTP wrapped in StreamResponse format + assertEquals(1, testHttpClient.events.size()); + StreamingEventKind sentEvent = testHttpClient.events.get(0); + assertTrue(sentEvent instanceof Task, "Event should be a Task"); + Task sentTask = (Task) sentEvent; assertEquals(taskData.id(), sentTask.id()); assertEquals(taskData.contextId(), sentTask.contextId()); assertEquals(taskData.status().state(), sentTask.status().state()); + + // Verify StreamResponse wrapper is present in raw body + String rawBody = testHttpClient.rawBodies.get(0); + assertTrue(rawBody.contains("\"task\""), "Raw body should contain 'task' discriminator for StreamResponse"); } @Test @@ -217,10 +239,10 @@ public void testSendNotificationWithTokenSuccess() throws InterruptedException { String taskId = "task_send_with_token"; Task taskData = createSampleTask(taskId, TaskState.COMPLETED); PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", "unique_token"); - + // Set up the configuration in the store configStore.setInfo(taskId, config); - + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -228,12 +250,14 @@ public void testSendNotificationWithTokenSuccess() throws InterruptedException { // Wait for the async operation to complete assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds"); - - // Verify the task was sent via HTTP - assertEquals(1, testHttpClient.tasks.size()); - Task sentTask = testHttpClient.tasks.get(0); + + // Verify the task was sent via HTTP wrapped in StreamResponse format + assertEquals(1, testHttpClient.events.size()); + StreamingEventKind sentEvent = testHttpClient.events.get(0); + assertTrue(sentEvent instanceof Task, "Event should be a Task"); + Task sentTask = (Task) sentEvent; assertEquals(taskData.id(), sentTask.id()); - + // Verify that the X-A2A-Notification-Token header is sent with the correct token assertEquals(1, testHttpClient.headers.size()); Map sentHeaders = testHttpClient.headers.get(0); @@ -250,12 +274,12 @@ public void testSendNotificationWithTokenSuccess() throws InterruptedException { public void testSendNotificationNoConfig() { String taskId = "task_send_no_config"; Task taskData = createSampleTask(taskId, TaskState.COMPLETED); - + // Don't set any configuration in the store sender.sendNotification(taskData); // Verify no HTTP calls were made - assertEquals(0, testHttpClient.tasks.size()); + assertEquals(0, testHttpClient.events.size()); } @Test @@ -274,11 +298,11 @@ public void testSendNotificationMultipleConfigs() throws InterruptedException { Task taskData = createSampleTask(taskId, TaskState.COMPLETED); PushNotificationConfig config1 = createSamplePushConfig("http://notify.me/cfg1", "cfg1", null); PushNotificationConfig config2 = createSamplePushConfig("http://notify.me/cfg2", "cfg2", null); - + // Set up multiple configurations in the store configStore.setInfo(taskId, config1); configStore.setInfo(taskId, config2); - + // Set up latch to wait for async completion (2 calls expected) testHttpClient.latch = new CountDownLatch(2); @@ -286,14 +310,16 @@ public void testSendNotificationMultipleConfigs() throws InterruptedException { // Wait for the async operations to complete assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP calls should complete within 5 seconds"); - - // Verify both tasks were sent via HTTP - assertEquals(2, testHttpClient.tasks.size()); + + // Verify both events were sent via HTTP wrapped in StreamResponse format + assertEquals(2, testHttpClient.events.size()); assertEquals(2, testHttpClient.urls.size()); assertTrue(testHttpClient.urls.containsAll(java.util.List.of("http://notify.me/cfg1", "http://notify.me/cfg2"))); - - // Both tasks should be identical (same task sent to different endpoints) - for (Task sentTask : testHttpClient.tasks) { + + // Both events should be identical (same event sent to different endpoints) + for (StreamingEventKind sentEvent : testHttpClient.events) { + assertTrue(sentEvent instanceof Task, "Event should be a Task"); + Task sentTask = (Task) sentEvent; assertEquals(taskData.id(), sentTask.id()); assertEquals(taskData.contextId(), sentTask.contextId()); assertEquals(taskData.status().state(), sentTask.status().state()); @@ -315,7 +341,112 @@ public void testSendNotificationHttpError() { // This should not throw an exception - errors should be handled gracefully sender.sendNotification(taskData); - // Verify no tasks were successfully processed due to the error - assertEquals(0, testHttpClient.tasks.size()); + // Verify no events were successfully processed due to the error + assertEquals(0, testHttpClient.events.size()); + } + + @Test + public void testSendNotificationMessage() throws InterruptedException { + String taskId = "task_send_message"; + Message message = Message.builder() + .taskId(taskId) + .role(Message.Role.AGENT) + .parts(new TextPart("Hello from agent")) + .build(); + PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); + + // Set up the configuration in the store + configStore.setInfo(taskId, config); + + // Set up latch to wait for async completion + testHttpClient.latch = new CountDownLatch(1); + + sender.sendNotification(message); + + // Wait for the async operation to complete + assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds"); + + // Verify the message was sent via HTTP wrapped in StreamResponse format + assertEquals(1, testHttpClient.events.size()); + StreamingEventKind sentEvent = testHttpClient.events.get(0); + assertTrue(sentEvent instanceof Message, "Event should be a Message"); + Message sentMessage = (Message) sentEvent; + assertEquals(taskId, sentMessage.taskId()); + + // Verify StreamResponse wrapper with 'message' discriminator + String rawBody = testHttpClient.rawBodies.get(0); + assertTrue(rawBody.contains("\"message\""), "Raw body should contain 'message' discriminator for StreamResponse"); + } + + @Test + public void testSendNotificationTaskStatusUpdate() throws InterruptedException { + String taskId = "task_send_status_update"; + TaskStatusUpdateEvent statusUpdate = TaskStatusUpdateEvent.builder() + .taskId(taskId) + .contextId("ctx456") + .status(new TaskStatus(TaskState.WORKING)) + .build(); + PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); + + // Set up the configuration in the store + configStore.setInfo(taskId, config); + + // Set up latch to wait for async completion + testHttpClient.latch = new CountDownLatch(1); + + sender.sendNotification(statusUpdate); + + // Wait for the async operation to complete + assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds"); + + // Verify the status update was sent via HTTP wrapped in StreamResponse format + assertEquals(1, testHttpClient.events.size()); + StreamingEventKind sentEvent = testHttpClient.events.get(0); + assertTrue(sentEvent instanceof TaskStatusUpdateEvent, "Event should be a TaskStatusUpdateEvent"); + TaskStatusUpdateEvent sentUpdate = (TaskStatusUpdateEvent) sentEvent; + assertEquals(taskId, sentUpdate.taskId()); + assertEquals(TaskState.WORKING, sentUpdate.status().state()); + + // Verify StreamResponse wrapper with 'statusUpdate' discriminator + String rawBody = testHttpClient.rawBodies.get(0); + assertTrue(rawBody.contains("\"statusUpdate\""), "Raw body should contain 'statusUpdate' discriminator for StreamResponse"); + } + + @Test + public void testSendNotificationTaskArtifactUpdate() throws InterruptedException { + String taskId = "task_send_artifact_update"; + Artifact artifact = Artifact.builder() + .artifactId("artifact-1") + .name("test-artifact") + .parts(Collections.singletonList(new TextPart("Artifact chunk"))) + .build(); + TaskArtifactUpdateEvent artifactUpdate = TaskArtifactUpdateEvent.builder() + .taskId(taskId) + .contextId("ctx456") + .artifact(artifact) + .build(); + PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); + + // Set up the configuration in the store + configStore.setInfo(taskId, config); + + // Set up latch to wait for async completion + testHttpClient.latch = new CountDownLatch(1); + + sender.sendNotification(artifactUpdate); + + // Wait for the async operation to complete + assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds"); + + // Verify the artifact update was sent via HTTP wrapped in StreamResponse format + assertEquals(1, testHttpClient.events.size()); + StreamingEventKind sentEvent = testHttpClient.events.get(0); + assertTrue(sentEvent instanceof TaskArtifactUpdateEvent, "Event should be a TaskArtifactUpdateEvent"); + TaskArtifactUpdateEvent sentUpdate = (TaskArtifactUpdateEvent) sentEvent; + assertEquals(taskId, sentUpdate.taskId()); + + // Verify StreamResponse wrapper with 'artifactUpdate' discriminator + String rawBody = testHttpClient.rawBodies.get(0); + assertTrue(rawBody.contains("\"artifactUpdate\""), "Raw body should contain 'artifactUpdate' discriminator for StreamResponse"); } } diff --git a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java index c7cf79524..1b900105a 100644 --- a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java +++ b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java @@ -460,29 +460,31 @@ public void onCompleted() { Assertions.assertTrue(latch.await(5, TimeUnit.SECONDS)); Assertions.assertTrue(errors.isEmpty()); Assertions.assertEquals(3, results.size()); - Assertions.assertEquals(3, httpClient.tasks.size()); - - io.a2a.spec.Task curr = httpClient.tasks.get(0); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.status().state(), curr.status().state()); - Assertions.assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); - - curr = httpClient.tasks.get(1); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.status().state(), curr.status().state()); - Assertions.assertEquals(1, curr.artifacts().size()); - Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); - Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); - - curr = httpClient.tasks.get(2); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); - Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); - Assertions.assertEquals(1, curr.artifacts().size()); - Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); - Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + // Push notifications now send the actual StreamingEventKind events, not Task snapshots + Assertions.assertEquals(3, httpClient.events.size()); + + // Event 0: Task event + Assertions.assertTrue(httpClient.events.get(0) instanceof io.a2a.spec.Task, "First event should be Task"); + io.a2a.spec.Task task1 = (io.a2a.spec.Task) httpClient.events.get(0); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), task1.id()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), task1.contextId()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.status().state(), task1.status().state()); + Assertions.assertEquals(0, task1.artifacts() == null ? 0 : task1.artifacts().size()); + + // Event 1: TaskArtifactUpdateEvent + Assertions.assertTrue(httpClient.events.get(1) instanceof TaskArtifactUpdateEvent, "Second event should be TaskArtifactUpdateEvent"); + TaskArtifactUpdateEvent artifactUpdate = (TaskArtifactUpdateEvent) httpClient.events.get(1); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), artifactUpdate.taskId()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), artifactUpdate.contextId()); + Assertions.assertEquals(1, artifactUpdate.artifact().parts().size()); + Assertions.assertEquals("text", ((TextPart) artifactUpdate.artifact().parts().get(0)).text()); + + // Event 2: TaskStatusUpdateEvent + Assertions.assertTrue(httpClient.events.get(2) instanceof TaskStatusUpdateEvent, "Third event should be TaskStatusUpdateEvent"); + TaskStatusUpdateEvent statusUpdate = (TaskStatusUpdateEvent) httpClient.events.get(2); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), statusUpdate.taskId()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), statusUpdate.contextId()); + Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, statusUpdate.status().state()); } finally { mainEventBusProcessor.setPushNotificationExecutor(null); } diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index 8d2751627..4e60cc0a2 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -773,29 +773,31 @@ public void onComplete() { subscriptionRef.get().cancel(); assertEquals(3, results.size()); - assertEquals(3, httpClient.tasks.size()); - - Task curr = httpClient.tasks.get(0); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); - - curr = httpClient.tasks.get(1); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); - - curr = httpClient.tasks.get(2); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(TaskState.COMPLETED, curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + // Push notifications now send the actual StreamingEventKind events, not Task snapshots + assertEquals(3, httpClient.events.size()); + + // Event 0: Task event + assertTrue(httpClient.events.get(0) instanceof Task, "First event should be Task"); + Task task1 = (Task) httpClient.events.get(0); + assertEquals(MINIMAL_TASK.id(), task1.id()); + assertEquals(MINIMAL_TASK.contextId(), task1.contextId()); + assertEquals(MINIMAL_TASK.status().state(), task1.status().state()); + assertEquals(0, task1.artifacts() == null ? 0 : task1.artifacts().size()); + + // Event 1: TaskArtifactUpdateEvent + assertTrue(httpClient.events.get(1) instanceof TaskArtifactUpdateEvent, "Second event should be TaskArtifactUpdateEvent"); + TaskArtifactUpdateEvent artifactUpdate = (TaskArtifactUpdateEvent) httpClient.events.get(1); + assertEquals(MINIMAL_TASK.id(), artifactUpdate.taskId()); + assertEquals(MINIMAL_TASK.contextId(), artifactUpdate.contextId()); + assertEquals(1, artifactUpdate.artifact().parts().size()); + assertEquals("text", ((TextPart) artifactUpdate.artifact().parts().get(0)).text()); + + // Event 2: TaskStatusUpdateEvent + assertTrue(httpClient.events.get(2) instanceof TaskStatusUpdateEvent, "Third event should be TaskStatusUpdateEvent"); + TaskStatusUpdateEvent statusUpdate = (TaskStatusUpdateEvent) httpClient.events.get(2); + assertEquals(MINIMAL_TASK.id(), statusUpdate.taskId()); + assertEquals(MINIMAL_TASK.contextId(), statusUpdate.contextId()); + assertEquals(TaskState.COMPLETED, statusUpdate.status().state()); } finally { mainEventBusProcessor.setPushNotificationExecutor(null); } From 6cc37379971989633d534c8f329e869376dbd5ab Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 5 Feb 2026 11:41:12 +0000 Subject: [PATCH 2/2] Review and Javadoc fixes --- .../server/tasks/BasePushNotificationSender.java | 2 +- .../a2a/server/tasks/PushNotificationSender.java | 2 -- .../AbstractA2ARequestHandlerTest.java | 16 ---------------- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java index 82aa6ce68..b09422f87 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java +++ b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java @@ -108,7 +108,7 @@ public void sendNotification(StreamingEventKind event) { * @param event the streaming event * @return the task ID, or null if not available */ - private @Nullable String extractTaskId(StreamingEventKind event) { + protected @Nullable String extractTaskId(StreamingEventKind event) { if (event instanceof Task task) { return task.id(); } else if (event instanceof Message message) { diff --git a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java index ef54266a5..f8b7b018d 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java +++ b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java @@ -91,7 +91,6 @@ public interface PushNotificationSender { * Retrieve push notification URLs or messaging configurations from * {@link PushNotificationConfigStore} using the task ID extracted from the event. *

- *

* Supported event types: *

    *
  • {@link Task} - wrapped in StreamResponse.task
  • @@ -99,7 +98,6 @@ public interface PushNotificationSender { *
  • {@link io.a2a.spec.TaskStatusUpdateEvent} - wrapped in StreamResponse.statusUpdate
  • *
  • {@link io.a2a.spec.TaskArtifactUpdateEvent} - wrapped in StreamResponse.artifactUpdate
  • *
- *

*

* Error Handling: Log errors but don't throw exceptions. Notifications are * best-effort and should not fail the primary request. diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index 1892201aa..664202332 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -193,22 +193,6 @@ public boolean isReplicated() { @IfBuildProfile("test") protected static class TestHttpClient implements A2AHttpClient { public final List events = Collections.synchronizedList(new ArrayList<>()); - // For backward compatibility - provides access to tasks that are sent as StreamingEventKind - public final List tasks = Collections.synchronizedList(new ArrayList<>() { - @Override - public int size() { - return (int) events.stream().filter(e -> e instanceof Task).count(); - } - - @Override - public Task get(int index) { - return (Task) events.stream() - .filter(e -> e instanceof Task) - .skip(index) - .findFirst() - .orElseThrow(() -> new IndexOutOfBoundsException("Index: " + index)); - } - }); public volatile CountDownLatch latch; @Override