From 4b68a143ce1395e85138b208fded8e773572300d Mon Sep 17 00:00:00 2001 From: Michael Graeb Date: Tue, 24 Sep 2024 16:00:04 -0700 Subject: [PATCH] Return NULL instead of Py_NONE (so it works better with Py_XDECREF) AWS_FATAL_ASSERT if these functions raise a python exception (which we never checked for before) --- source/module.c | 39 ++++++++++++++++++++++++--------- source/module.h | 18 ++++++++++++++- source/mqtt_client_connection.c | 35 +++++++++++++++-------------- 3 files changed, 64 insertions(+), 28 deletions(-) diff --git a/source/module.c b/source/module.c index 23d77cda9..7aa0502b1 100644 --- a/source/module.c +++ b/source/module.c @@ -516,18 +516,37 @@ PyObject *aws_py_memory_view_from_byte_buffer(struct aws_byte_buf *buf) { return PyMemoryView_FromMemory(mem_start, mem_size, PyBUF_WRITE); } -PyObject *aws_py_weakref_get_ref(PyObject *object) { - PyObject *self = Py_None; -#if PY_VERSION_HEX >= 0x030D0000 /* Check if Python version is 3.13 or higher */ - if (PyWeakref_GetRef(object, &self) < 0) { /* strong reference */ - return Py_None; - } +PyObject *aws_py_weakref_get_ref(PyObject *ref) { + /* If Python >= 3.13 */ +#if PY_VERSION_HEX >= 0x030D0000 + /* Use PyWeakref_GetRef() (new in Python 3.13), which gets you: + /* a new strong reference, + * or NULL because ref is dead, + * or -1 because you called it wrong */ + PyObject *obj = NULL; + if (PyWeakref_GetRef(ref, &obj) == -1) { + PyErr_WriteUnraisable(PyErr_Occurred()); + AWS_FATAL_ASSERT(0 && "expected a weakref"); + } + return obj; + #else - /* PyWeakref_GetObject is deprecated since python 3.13 */ - self = PyWeakref_GetObject(object); /* borrowed reference */ - Py_XINCREF(self); + /* Use PyWeakref_GetObject() (deprecated as of Python 3.13), which gets you: + * a borrowed reference, + * or Py_None because ref is dead, + * or NULL because you called it wrong */ + PyObject *obj = PyWeakref_GetObject(ref); /* borrowed reference */ + if (obj == NULL) { + PyErr_WriteUnraisable(PyErr_Occurred()); + AWS_FATAL_ASSERT(0 && "expected a weakref"); + } else if (obj == Py_None) { + return NULL; + } else { + /* Be like PyWeakref_GetRef() and make it new strong reference */ + Py_INCREF(obj); + return obj; + } #endif - return self; } int aws_py_gilstate_ensure(PyGILState_STATE *out_state) { diff --git a/source/module.h b/source/module.h index 7777ab8bc..ca73f2a5b 100644 --- a/source/module.h +++ b/source/module.h @@ -108,7 +108,23 @@ PyObject *aws_py_memory_view_from_byte_buffer(struct aws_byte_buf *buf); /* Python 3.13+ changed the function to get a reference from WeakRef. This function is an abstraction over two different * APIs since we support Python versions before 3.13. Returns a strong reference if non-null, which you must release. */ -PyObject *aws_py_weakref_get_ref(PyObject *object); + +/** + * Given a weak reference, returns a NEW strong reference to the referenced object, + * or NULL if the reference is dead (this function NEVER raises a python exception or AWS Error). + * + * This is a simplified version of PyWeakref_GetRef() / PyWeakref_GetObject(). + * Simpler because: + * - Python 3.13 adds PyWeakref_GetRef() and deprecates PyWeakref_GetObject(). + * This function calls the appropriate one. + * + * - This will AWS_FATAL_ASSERT if ref is not a weak reference, + * So you only need to handle 2 outcomes instead of 3 + * (the 3rd being a Python exception for calling it incorrectly). + * But this means it's only safe to call if we created the ref ourselves. + * Do not call if ref could have come from the user. + */ +PyObject *aws_py_weakref_get_ref(PyObject *ref); /* Allocator that calls into PyObject_[Malloc|Free|Realloc] */ struct aws_allocator *aws_py_get_allocator(void); diff --git a/source/mqtt_client_connection.c b/source/mqtt_client_connection.c index fe629a959..78a26057e 100644 --- a/source/mqtt_client_connection.c +++ b/source/mqtt_client_connection.c @@ -140,8 +140,8 @@ static void s_on_connection_success( return; /* Python has shut down. Nothing matters anymore, but don't crash */ } - PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); - if (self != Py_None) { + PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); /* new reference */ + if (self != NULL) { PyObject *success_result = PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present)); if (success_result) { @@ -149,9 +149,9 @@ static void s_on_connection_success( } else { PyErr_WriteUnraisable(PyErr_Occurred()); } + Py_DECREF(self); } - Py_XDECREF(self); PyGILState_Release(state); } @@ -168,17 +168,17 @@ static void s_on_connection_failure(struct aws_mqtt_client_connection *connectio return; /* Python has shut down. Nothing matters anymore, but don't crash */ } - PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); - if (self != Py_None) { + PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); /* new reference */ + if (self != NULL) { PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code); if (success_result) { Py_DECREF(success_result); } else { PyErr_WriteUnraisable(PyErr_Occurred()); } + Py_DECREF(self); } - Py_XDECREF(self); PyGILState_Release(state); } @@ -196,17 +196,17 @@ static void s_on_connection_interrupted(struct aws_mqtt_client_connection *conne } /* Ensure that python class is still alive */ - PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); - if (self != Py_None) { + PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); /* new reference */ + if (self != NULL) { PyObject *result = PyObject_CallMethod(self, "_on_connection_interrupted", "(i)", error_code); if (result) { Py_DECREF(result); } else { PyErr_WriteUnraisable(PyErr_Occurred()); } + Py_DECREF(self); } - Py_XDECREF(self); PyGILState_Release(state); } @@ -230,8 +230,8 @@ static void s_on_connection_resumed( } /* Ensure that python class is still alive */ - PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); - if (self != Py_None) { + PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); /* new reference */ + if (self != NULL) { PyObject *result = PyObject_CallMethod(self, "_on_connection_resumed", "(iN)", return_code, PyBool_FromLong(session_present)); if (result) { @@ -239,8 +239,9 @@ static void s_on_connection_resumed( } else { PyErr_WriteUnraisable(PyErr_Occurred()); } + Py_DECREF(self); } - Py_XDECREF(self); + PyGILState_Release(state); } @@ -261,17 +262,17 @@ static void s_on_connection_closed( struct mqtt_connection_binding *py_connection = userdata; /* Ensure that python class is still alive */ - PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); - if (self != Py_None) { + PyObject *self = aws_py_weakref_get_ref(py_connection->self_proxy); /* new reference */ + if (self != NULL) { PyObject *result = PyObject_CallMethod(self, "_on_connection_closed", "()"); if (result) { Py_DECREF(result); } else { PyErr_WriteUnraisable(PyErr_Occurred()); } + Py_DECREF(self); } - Py_XDECREF(self); PyGILState_Release(state); } @@ -540,8 +541,8 @@ static void s_ws_handshake_transform( /* Ensure python mqtt connection object is still alive */ - PyObject *connection_py = aws_py_weakref_get_ref(connection_binding->self_proxy); - if (connection_py == Py_None) { + PyObject *connection_py = aws_py_weakref_get_ref(connection_binding->self_proxy); /* new reference */ + if (connection_py == NULL) { aws_raise_error(AWS_ERROR_INVALID_STATE); goto done; }