Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
assertTrue(updateLatch.await(10, TimeUnit.SECONDS), "Timeout waiting for task update");

// Step 5: Poll for the async notification to be captured
// With the new StreamingEventKind support, we receive all event types (Task, Message, TaskArtifactUpdateEvent, etc.)
long end = System.currentTimeMillis() + 5000;
boolean notificationReceived = false;

while (System.currentTimeMillis() < end) {
if (!mockPushNotificationSender.getCapturedTasks().isEmpty()) {
if (!mockPushNotificationSender.getCapturedEvents().isEmpty()) {
notificationReceived = true;
break;
}
Expand All @@ -165,17 +166,22 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
assertTrue(notificationReceived, "Timeout waiting for push notification.");

// Step 6: Verify the captured notification
Queue<Task> capturedTasks = mockPushNotificationSender.getCapturedTasks();

// Verify the notification contains the correct task with artifacts
Task notifiedTaskWithArtifact = capturedTasks.stream()
.filter(t -> taskId.equals(t.id()) && t.artifacts() != null && t.artifacts().size() > 0)
.findFirst()
.orElse(null);

assertNotNull(notifiedTaskWithArtifact, "Notification should contain the updated task with artifacts");
assertEquals(taskId, notifiedTaskWithArtifact.id());
assertEquals(1, notifiedTaskWithArtifact.artifacts().size(), "Task should have one artifact from the update");
// Check if we received events for this task (could be Task, TaskArtifactUpdateEvent, etc.)
Queue<io.a2a.spec.StreamingEventKind> capturedEvents = mockPushNotificationSender.getCapturedEvents();

// Look for Task events with artifacts OR TaskArtifactUpdateEvent for this task
boolean hasTaskWithArtifact = capturedEvents.stream()
.filter(e -> e instanceof Task)
.map(e -> (Task) e)
.anyMatch(t -> taskId.equals(t.id()) && t.artifacts() != null && t.artifacts().size() > 0);

boolean hasArtifactUpdateEvent = capturedEvents.stream()
.filter(e -> e instanceof io.a2a.spec.TaskArtifactUpdateEvent)
.map(e -> (io.a2a.spec.TaskArtifactUpdateEvent) e)
.anyMatch(e -> taskId.equals(e.taskId()));

assertTrue(hasTaskWithArtifact || hasArtifactUpdateEvent,
"Notification should contain either Task with artifacts or TaskArtifactUpdateEvent for task " + taskId);

// Step 7: Clean up - delete the push notification configuration
client.deleteTaskPushNotificationConfigurations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jakarta.enterprise.inject.Alternative;

import io.a2a.server.tasks.PushNotificationSender;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;

/**
Expand All @@ -19,18 +20,30 @@
@Priority(100)
public class MockPushNotificationSender implements PushNotificationSender {

private final Queue<Task> capturedTasks = new ConcurrentLinkedQueue<>();
private final Queue<StreamingEventKind> capturedEvents = new ConcurrentLinkedQueue<>();

@Override
public void sendNotification(Task task) {
capturedTasks.add(task);
public void sendNotification(StreamingEventKind event) {
capturedEvents.add(event);
}

public Queue<StreamingEventKind> getCapturedEvents() {
return capturedEvents;
}

/**
* For backward compatibility - provides access to Task events only.
*/
public Queue<Task> getCapturedTasks() {
return capturedTasks;
Queue<Task> tasks = new ConcurrentLinkedQueue<>();
capturedEvents.stream()
.filter(e -> e instanceof Task)
.map(e -> (Task) e)
.forEach(tasks::add);
return tasks;
}

public void clear() {
capturedTasks.clear();
capturedEvents.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.a2a.spec.InternalError;
import io.a2a.spec.Message;
import io.a2a.spec.Task;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskStatusUpdateEvent;
import org.slf4j.Logger;
Expand Down Expand Up @@ -211,16 +212,11 @@ private void processEvent(MainEventBusContext context) {

// Step 2: Send push notification AFTER successful persistence (only from active node)
// Skip push notifications for replicated events to avoid duplicate notifications in multi-instance deployments
if (eventToDistribute == event && !isReplicated) {
// Capture task state immediately after persistence, before going async
// This ensures we send the task as it existed when THIS event was processed,
// not whatever state might exist later when the async callback executes
Task taskSnapshot = taskStore.get(taskId);
if (taskSnapshot != null) {
sendPushNotification(taskId, taskSnapshot);
} else {
LOGGER.warn("Task {} not found in TaskStore after successful persistence, skipping push notification", taskId);
}
// Push notifications are sent for all StreamingEventKind events (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent)
// per A2A spec section 4.3.3
if (eventToDistribute == event && !isReplicated && event instanceof StreamingEventKind streamingEvent) {
// Send the streaming event directly - it will be wrapped in StreamResponse format by PushNotificationSender
sendPushNotification(taskId, streamingEvent);
}

// Step 3: Then distribute to ChildQueues (clients see either event or error AFTER persistence attempt)
Expand Down Expand Up @@ -304,7 +300,7 @@ private boolean updateTaskStore(String taskId, Event event, boolean isReplicated
}

/**
* Sends push notification for the task AFTER persistence.
* Sends push notification for the streaming event AFTER persistence.
* <p>
* This is called after updateTaskStore() to ensure the notification contains
* the latest persisted state, avoiding race conditions.
Expand All @@ -315,27 +311,32 @@ private boolean updateTaskStore(String taskId, Event event, boolean isReplicated
* PushNotificationSender.sendNotification() was causing streaming delays.
* </p>
* <p>
* <b>IMPORTANT:</b> The task parameter is a snapshot captured immediately after
* persistence. This ensures we send the task state as it existed when THIS event
* was processed, not whatever state might exist in TaskStore when the async
* callback executes (subsequent events may have already updated the store).
* <b>IMPORTANT:</b> The event parameter is the actual event being processed.
* This ensures we send the event as it was when processed, not whatever state
* might exist in TaskStore when the async callback executes (subsequent events
* may have already updated the store).
* </p>
* <p>
* Supports all StreamingEventKind event types per A2A spec section 4.3.3:
* Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent.
* The event will be automatically wrapped in StreamResponse format by JsonUtil.
* </p>
* <p>
* <b>NOTE:</b> Tests can inject a synchronous executor via setPushNotificationExecutor()
* to ensure deterministic ordering of push notifications in the test environment.
* </p>
*
* @param taskId the task ID
* @param task the task snapshot to send (captured immediately after persistence)
* @param event the streaming event to send (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent)
*/
private void sendPushNotification(String taskId, Task task) {
private void sendPushNotification(String taskId, StreamingEventKind event) {
Runnable pushTask = () -> {
try {
if (task != null) {
if (event != null) {
LOGGER.debug("Sending push notification for task {}", taskId);
pushSender.sendNotification(task);
pushSender.sendNotification(event);
} else {
LOGGER.debug("Skipping push notification - task snapshot is null for task {}", taskId);
LOGGER.debug("Skipping push notification - event is null for task {}", taskId);
}
} catch (Exception e) {
LOGGER.error("Error sending push notification for task {}", taskId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN;

import io.a2a.spec.TaskPushNotificationConfig;
import jakarta.annotation.Nullable;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

Expand All @@ -20,8 +21,12 @@
import io.a2a.jsonrpc.common.json.JsonUtil;
import io.a2a.spec.ListTaskPushNotificationConfigParams;
import io.a2a.spec.ListTaskPushNotificationConfigResult;
import io.a2a.spec.Message;
import io.a2a.spec.PushNotificationConfig;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskStatusUpdateEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -62,11 +67,17 @@ public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHt
}

@Override
public void sendNotification(Task task) {
public void sendNotification(StreamingEventKind event) {
String taskId = extractTaskId(event);
if (taskId == null) {
LOGGER.warn("Cannot send push notification: event does not contain taskId");
return;
}

List<TaskPushNotificationConfig> configs = new ArrayList<>();
String nextPageToken = null;
do {
ListTaskPushNotificationConfigResult pageResult = configStore.getInfo(new ListTaskPushNotificationConfigParams(task.id(),
ListTaskPushNotificationConfigResult pageResult = configStore.getInfo(new ListTaskPushNotificationConfigParams(taskId,
DEFAULT_PAGE_SIZE, nextPageToken, ""));
if (!pageResult.configs().isEmpty()) {
configs.addAll(pageResult.configs());
Expand All @@ -76,26 +87,45 @@ public void sendNotification(Task task) {

List<CompletableFuture<Boolean>> dispatchResults = configs
.stream()
.map(pushConfig -> dispatch(task, pushConfig.pushNotificationConfig()))
.map(pushConfig -> dispatch(event, pushConfig.pushNotificationConfig()))
.toList();
CompletableFuture<Void> allFutures = CompletableFuture.allOf(dispatchResults.toArray(new CompletableFuture[0]));
CompletableFuture<Boolean> dispatchResult = allFutures.thenApply(v -> dispatchResults.stream()
.allMatch(CompletableFuture::join));
try {
boolean allSent = dispatchResult.get();
if (!allSent) {
LOGGER.warn("Some push notifications failed to send for taskId: " + task.id());
LOGGER.warn("Some push notifications failed to send for taskId: " + taskId);
}
} catch (InterruptedException | ExecutionException e) {
LOGGER.warn("Some push notifications failed to send for taskId " + task.id() + ": {}", e.getMessage(), e);
LOGGER.warn("Some push notifications failed to send for taskId " + taskId + ": {}", e.getMessage(), e);
}
}

/**
* Extracts the task ID from a StreamingEventKind event.
*
* @param event the streaming event
* @return the task ID, or null if not available
*/
private @Nullable String extractTaskId(StreamingEventKind event) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve extensibility, consider changing the visibility of extractTaskId from private to protected. This would allow custom PushNotificationSender implementations that extend BasePushNotificationSender to reuse this helpful utility method. The Javadoc example in the PushNotificationSender interface already implies the existence of such a reusable helper.

Suggested change
private @Nullable String extractTaskId(StreamingEventKind event) {
protected @Nullable String extractTaskId(StreamingEventKind event) {

if (event instanceof Task task) {
return task.id();
} else if (event instanceof Message message) {
return message.taskId();
} else if (event instanceof TaskStatusUpdateEvent statusUpdate) {
return statusUpdate.taskId();
} else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) {
return artifactUpdate.taskId();
}
return null;
}

private CompletableFuture<Boolean> dispatch(Task task, PushNotificationConfig pushInfo) {
return CompletableFuture.supplyAsync(() -> dispatchNotification(task, pushInfo));
private CompletableFuture<Boolean> dispatch(StreamingEventKind event, PushNotificationConfig pushInfo) {
return CompletableFuture.supplyAsync(() -> dispatchNotification(event, pushInfo));
}

private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) {
private boolean dispatchNotification(StreamingEventKind event, PushNotificationConfig pushInfo) {
String url = pushInfo.url();
String token = pushInfo.token();

Expand All @@ -106,9 +136,11 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo)

String body;
try {
body = JsonUtil.toJson(task);
// JsonUtil.toJson automatically wraps StreamingEventKind in StreamResponse format
// (task/message/statusUpdate/artifactUpdate) per A2A spec section 4.3.3
body = JsonUtil.toJson(event);
} catch (Throwable throwable) {
LOGGER.debug("Error writing value as string: {}", throwable.getMessage(), throwable);
LOGGER.debug("Error serializing StreamingEventKind to JSON: {}", throwable.getMessage(), throwable);
return false;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.a2a.server.tasks;

import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;

/**
Expand Down Expand Up @@ -27,7 +28,8 @@
* {@link BasePushNotificationSender} provides HTTP webhook delivery:
* <ul>
* <li>Retrieves webhook URLs from {@link PushNotificationConfigStore}</li>
* <li>Sends HTTP POST requests with task JSON payload</li>
* <li>Wraps events in StreamResponse format (per A2A spec section 4.3.3)</li>
* <li>Sends HTTP POST requests with StreamResponse JSON payload</li>
* <li>Logs errors but doesn't fail the request</li>
* </ul>
*
Expand All @@ -47,11 +49,12 @@
* @Priority(100)
* public class KafkaPushNotificationSender implements PushNotificationSender {
* @Inject
* KafkaProducer<String, Task> producer;
* KafkaProducer<String, StreamingEventKind> producer;
*
* @Override
* public void sendNotification(Task task) {
* producer.send("task-updates", task.id(), task);
* public void sendNotification(StreamingEventKind event) {
* String taskId = extractTaskId(event);
* producer.send("task-updates", taskId, event);
* }
* }
* }</pre>
Expand All @@ -78,18 +81,31 @@
public interface PushNotificationSender {

/**
* Sends a push notification containing the latest task state.
* Sends a push notification containing a streaming event.
* <p>
* Called after the task has been persisted to {@link TaskStore}. Retrieve push
* notification URLs or messaging configurations from {@link PushNotificationConfigStore}
* using {@code task.id()}.
* Called after the event has been persisted to {@link TaskStore}. The event is wrapped
* in a StreamResponse format (per A2A spec section 4.3.3) with the appropriate oneof
* field set (task, message, statusUpdate, or artifactUpdate).
* </p>
* <p>
* Retrieve push notification URLs or messaging configurations from
* {@link PushNotificationConfigStore} using the task ID extracted from the event.
* </p>
* <p>
* Supported event types:
* <ul>
* <li>{@link Task} - wrapped in StreamResponse.task</li>
* <li>{@link io.a2a.spec.Message} - wrapped in StreamResponse.message</li>
* <li>{@link io.a2a.spec.TaskStatusUpdateEvent} - wrapped in StreamResponse.statusUpdate</li>
* <li>{@link io.a2a.spec.TaskArtifactUpdateEvent} - wrapped in StreamResponse.artifactUpdate</li>
* </ul>
* </p>
* <p>
* <b>Error Handling:</b> Log errors but don't throw exceptions. Notifications are
* best-effort and should not fail the primary request.
* </p>
*
* @param task the task with current state and artifacts to send
* @param event the streaming event to send (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent)
*/
void sendNotification(Task task);
void sendNotification(StreamingEventKind event);
}
Loading