Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions awscrt/mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,8 +1228,20 @@ class PublishReceivedData:

Args:
publish_packet (PublishPacket): Data model of an `MQTT5 PUBLISH <https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901100>`_ packet.
acquire_puback_control (Callable): For QoS 1 messages only: call this function within the
on_publish_callback_fn callback to take manual control of the PUBACK for this message, preventing
the client from automatically sending a PUBACK. Returns an opaque handle that can be passed to
Client.invoke_puback() to send the PUBACK to the broker.

Important: This function must be called within the on_publish_callback_fn callback. Calling it after the
callback returns will raise a RuntimeError. This function may only be called once per received PUBLISH;
calling it a second time will also raise a RuntimeError. If this function is not called, the client will
automatically send a PUBACK for QoS 1 messages when the callback returns.

For QoS 0 messages, this field is None.
"""
publish_packet: PublishPacket = None
acquire_puback_control: Callable = None


@dataclass
Expand Down Expand Up @@ -1434,7 +1446,8 @@ def _on_publish(
correlation_data,
subscription_identifiers_tuples,
content_type,
user_properties_tuples):
user_properties_tuples,
acquire_puback_control_fn):
if self._on_publish_cb is None:
return

Expand Down Expand Up @@ -1468,9 +1481,13 @@ def _on_publish(
publish_packet.content_type = content_type
publish_packet.user_properties = _init_user_properties(user_properties_tuples)

self._on_publish_cb(PublishReceivedData(publish_packet=publish_packet))
# Create PublishReceivedData with the manual control callback
publish_data = PublishReceivedData(
publish_packet=publish_packet,
acquire_puback_control=acquire_puback_control_fn
)

return
self._on_publish_cb(publish_data)

def _on_lifecycle_stopped(self):
if self._on_lifecycle_stopped_cb:
Expand Down Expand Up @@ -1957,6 +1974,25 @@ def get_stats(self):
result = _awscrt.mqtt5_client_get_stats(self._binding)
return OperationStatisticsData(result[0], result[1], result[2], result[3])

def invoke_puback(self, puback_control_handle):
"""Sends a PUBACK packet for a QoS 1 PUBLISH that was previously acquired for manual control.

To use manual PUBACK control, call acquire_puback_control() within the on_publish_callback_fn
callback to obtain a handle. Then call this method to send the PUBACK.

Args:
puback_control_handle: An opaque handle obtained from acquire_puback_control() within
PublishReceivedData. This handle cannot be created manually.

Raises:
Exception: If the native client returns an error when invoking the PUBACK.
"""

_awscrt.mqtt5_client_invoke_puback(
self._binding,
puback_control_handle
)

def new_connection(self, on_connection_interrupted=None, on_connection_resumed=None,
on_connection_success=None, on_connection_failure=None, on_connection_closed=None):
from awscrt.mqtt import Connection
Expand Down
1 change: 1 addition & 0 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ static PyMethodDef s_module_methods[] = {
AWS_PY_METHOD_DEF(mqtt5_client_subscribe, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_client_unsubscribe, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_client_get_stats, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_client_invoke_puback, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_ws_handshake_transform_complete, METH_VARARGS),

/* MQTT Request Response Client */
Expand Down
162 changes: 160 additions & 2 deletions source/mqtt5_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,73 @@ static PyObject *s_aws_set_user_properties_to_PyObject(
* Publish Handler
******************************************************************************/

static const char *s_capsule_name_puback_control_handle = "aws_puback_control_handle";

struct puback_control_handle {
uint64_t control_id;
};

static void s_puback_control_handle_destructor(PyObject *capsule) {
struct puback_control_handle *handle = PyCapsule_GetPointer(capsule, s_capsule_name_puback_control_handle);
if (handle) {
aws_mem_release(aws_py_get_allocator(), handle);
}
}

/* Callback context for manual PUBACK control */
struct manual_puback_control_context {
struct aws_mqtt5_client *client;
struct aws_mqtt5_packet_publish_view *publish_packet;
};

static void s_manual_puback_control_context_destructor(PyObject *capsule) {
struct manual_puback_control_context *context = PyCapsule_GetPointer(capsule, "manual_puback_control_context");
if (context) {
aws_mem_release(aws_py_get_allocator(), context);
}
}

/* Function called from Python to set manual PUBACK control and return puback_control_id */
PyObject *aws_py_mqtt5_client_acquire_puback(PyObject *self, PyObject *args) {
(void)args;

struct manual_puback_control_context *context = PyCapsule_GetPointer(self, "manual_puback_control_context");
if (!context) {
PyErr_SetString(PyExc_ValueError, "Invalid manual PUBACK control context");
return NULL;
}

/* If the publish_packet pointer has been zeroed out, the callback has already returned (post-callback call)
* or this function was already called once (double-call). Both are usage errors. */
if (!context->publish_packet) {
PyErr_SetString(
PyExc_RuntimeError,
"acquire_puback_control() must be called within the on_publish_callback_fn callback "
"and may only be called once.");
return NULL;
}

uint64_t puback_control_id = aws_mqtt5_client_acquire_puback(context->client, context->publish_packet);

/* Zero out the publish_packet pointer to prevent double-calls. */
context->publish_packet = NULL;

/* Create handle struct */
struct puback_control_handle *handle =
aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct puback_control_handle));

handle->control_id = puback_control_id;

/* Wrap in capsule */
PyObject *capsule = PyCapsule_New(handle, s_capsule_name_puback_control_handle, s_puback_control_handle_destructor);
if (!capsule) {
aws_mem_release(aws_py_get_allocator(), handle);
return NULL;
}

return capsule;
}

static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *publish_packet, void *user_data) {

if (!user_data) {
Expand All @@ -234,10 +301,50 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
PyObject *result = NULL;
PyObject *subscription_identifier_list = NULL;
PyObject *user_properties_list = NULL;
PyObject *manual_control_callback = NULL;
PyObject *control_context_capsule = NULL;

size_t subscription_identifier_count = publish_packet->subscription_identifier_count;
size_t user_property_count = publish_packet->user_property_count;

/* Create manual PUBACK control context */
struct manual_puback_control_context *control_context =
aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct manual_puback_control_context));
if (!control_context) {
PyErr_WriteUnraisable(PyErr_Occurred());
goto cleanup;
}

/* Set up the context with both client and publish packet */
control_context->client = client->native;
control_context->publish_packet = (struct aws_mqtt5_packet_publish_view *)publish_packet;

control_context_capsule =
PyCapsule_New(control_context, "manual_puback_control_context", s_manual_puback_control_context_destructor);
if (!control_context_capsule) {
aws_mem_release(aws_py_get_allocator(), control_context);
PyErr_WriteUnraisable(PyErr_Occurred());
goto cleanup;
}

/* Method definition for the manual control callback */
static PyMethodDef method_def = {
"acquire_puback_control",
aws_py_mqtt5_client_acquire_puback,
METH_NOARGS,
"Take manual control of PUBACK for this message"};

/* Only create the manual control callback for QoS 1 messages.
* For QoS 0, acquire_puback_control is passed as None.
* acquirePubackControl is only set for QoS 1). */
if (publish_packet->qos == AWS_MQTT5_QOS_AT_LEAST_ONCE) {
manual_control_callback = PyCFunction_New(&method_def, control_context_capsule);
if (!manual_control_callback) {
PyErr_WriteUnraisable(PyErr_Occurred());
goto cleanup;
}
}

/* Create list of uint32_t subscription identifier tuples */
subscription_identifier_list = PyList_New(subscription_identifier_count);
if (!subscription_identifier_list) {
Expand All @@ -261,7 +368,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
result = PyObject_CallMethod(
client->client_core,
"_on_publish",
"(y#iOs#OiOIOHs#y#Os#O)",
"(y#iOs#OiOIOHs#y#Os#OO)",
/* y */ publish_packet->payload.ptr,
/* # */ publish_packet->payload.len,
/* i */ (int)publish_packet->qos,
Expand All @@ -284,15 +391,23 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
/* O */ subscription_identifier_count > 0 ? subscription_identifier_list : Py_None,
/* s */ publish_packet->content_type ? publish_packet->content_type->ptr : NULL,
/* # */ publish_packet->content_type ? publish_packet->content_type->len : 0,
/* O */ user_property_count > 0 ? user_properties_list : Py_None);
/* O */ user_property_count > 0 ? user_properties_list : Py_None,
/* O */ manual_control_callback ? manual_control_callback : Py_None);

if (!result) {
PyErr_WriteUnraisable(PyErr_Occurred());
}

/* Invalidate the publish_packet pointer now that the callback has returned.
* This prevents use-after-free if acquire_puback_control() is called after the callback. */
control_context->publish_packet = NULL;

cleanup:
Py_XDECREF(result);
Py_XDECREF(subscription_identifier_list);
Py_XDECREF(user_properties_list);
Py_XDECREF(manual_control_callback);
Py_XDECREF(control_context_capsule);
PyGILState_Release(state);
}

Expand Down Expand Up @@ -1683,6 +1798,49 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args) {
return NULL;
}

/*******************************************************************************
* Invoke Puback
******************************************************************************/

PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args) {
(void)self;
bool success = true;

PyObject *impl_capsule;
PyObject *puback_handle_capsule;

if (!PyArg_ParseTuple(
args,
"OO",
/* O */ &impl_capsule,
/* O */ &puback_handle_capsule)) {
return NULL;
}

struct mqtt5_client_binding *client = PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt5_client);
if (!client) {
return NULL;
}

/* Extract handle from capsule */
struct puback_control_handle *handle =
PyCapsule_GetPointer(puback_handle_capsule, s_capsule_name_puback_control_handle);
if (!handle) {
PyErr_SetString(PyExc_TypeError, "Invalid PUBACK control handle");
return NULL;
}

if (aws_mqtt5_client_invoke_puback(client->native, handle->control_id, NULL)) {
PyErr_SetAwsLastError();
success = false;
}

if (success) {
Py_RETURN_NONE;
}
return NULL;
}

/*******************************************************************************
* Subscribe
******************************************************************************/
Expand Down
1 change: 1 addition & 0 deletions source/mqtt5_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_subscribe(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_unsubscribe(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_get_stats(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args);

PyObject *aws_py_mqtt5_ws_handshake_transform_complete(PyObject *self, PyObject *args);

Expand Down
Loading
Loading