From 9b24fdf4b832ea7b40b2facd2587ae6efec676cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Canna=C3=B2?= Date: Tue, 26 Mar 2024 21:47:52 +0000 Subject: [PATCH] Prevent crash COM_STMT_EXECUTE and invalid stmt_id If an invalid stmt_id is passed to COM_STMT_EXECUTE , ProxySQL used to crash due to an assert(). It seems that some buggy clients execute COM_STMT_EXECUTE with an invalid stmt_id after an auto-reconnect. This commit returns an error to the client. Closes #3371 Closes #3808 Closes #4474 --- lib/MySQL_Session.cpp | 14 +-- ...g_test_3371_prepared_statement_crash-t.cpp | 89 +++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 test/tap/tests/reg_test_3371_prepared_statement_crash-t.cpp diff --git a/lib/MySQL_Session.cpp b/lib/MySQL_Session.cpp index d84e1886ba..1dc1c19436 100644 --- a/lib/MySQL_Session.cpp +++ b/lib/MySQL_Session.cpp @@ -3790,10 +3790,13 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C CurrentQuery.stmt_client_id=client_stmt_id; stmt_global_id=client_myds->myconn->local_stmts->find_global_stmt_id_from_client(client_stmt_id); if (stmt_global_id == 0) { - // FIXME: add error handling - // LCOV_EXCL_START - assert(0); - // LCOV_EXCL_STOP + l_free(pkt.size,pkt.ptr); + client_myds->setDSS_STATE_QUERY_SENT_NET(); + string err_msg = "Unknown prepared statement handler (" + to_string(client_stmt_id) + ") given to mysql_stmt_precheck"; + client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,1243,(char *)"HY000", err_msg.c_str()); + client_myds->DSS=STATE_SLEEP; + status=WAITING_CLIENT_DATA; + return; } CurrentQuery.stmt_global_id=stmt_global_id; // now we get the statement information @@ -3803,7 +3806,8 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C // we couldn't find it l_free(pkt.size,pkt.ptr); client_myds->setDSS_STATE_QUERY_SENT_NET(); - client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,1045,(char *)"28000",(char *)"Prepared statement doesn't exist", true); + string err_msg = "Unknown prepared statement handler (" + to_string(client_stmt_id) + ") given to mysql_stmt_precheck"; + client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,1243,(char *)"HY000", err_msg.c_str()); client_myds->DSS=STATE_SLEEP; status=WAITING_CLIENT_DATA; return; diff --git a/test/tap/tests/reg_test_3371_prepared_statement_crash-t.cpp b/test/tap/tests/reg_test_3371_prepared_statement_crash-t.cpp new file mode 100644 index 0000000000..0f929cba53 --- /dev/null +++ b/test/tap/tests/reg_test_3371_prepared_statement_crash-t.cpp @@ -0,0 +1,89 @@ +/** + * @file reg_test_3371_prepared_statement_crash-t.cpp + * @brief Tries to execute prepared statements with a not existing stmt_id. + * This used to crash ProxySQL , so this tap test verifies that ProxySQL + * doesn't crash + */ + + +#include +#include +#include +#include + +#include "mysql.h" + +#include "tap.h" +#include "command_line.h" +#include "utils.h" + +using std::string; + +const int NUM_LOOPS = 100; ///< Number of loops for statement execution. + +int main(int argc, char** argv) { + CommandLine cl; + + // Checking for required environmental variables + if (cl.getEnv()) { + diag("Failed to get the required environmental variables."); + return -1; + } + + plan(1+NUM_LOOPS*2); // Plan for testing purposes + + MYSQL* mysql = mysql_init(NULL); ///< MySQL connection object + if (!mysql) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql)); + return exit_status(); + } + + // Connecting to ProxySQL + diag("Connecting to '%s@%s:%d'", cl.mysql_username, cl.mysql_host, cl.port); + if (!mysql_real_connect(mysql, cl.mysql_host, cl.mysql_username, cl.mysql_password, NULL, cl.port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql)); + return exit_status(); + } + + // Initialize and prepare all the statements + MYSQL_STMT* stmt = mysql_stmt_init(mysql); + if (!stmt) { + fprintf(stderr, "mysql_stmt_init(), out of memory\n"); + return exit_status(); + } + + std::string select_query = "SELECT 1"; + diag("select_query: %s", select_query.c_str()); + if (mysql_stmt_prepare(stmt, select_query.c_str(), strlen(select_query.c_str()))) { + fprintf(stderr, "mysql_stmt_prepare at line %d failed: %s\n", __LINE__ , mysql_error(mysql)); + mysql_close(mysql); + mysql_library_end(); + return exit_status(); + } + + diag("Increasing stmt_id by 1, so that mysql_stmt_execute() must fail"); + int rc = 0; + for (int i = 0; i < NUM_LOOPS ; i++) { + stmt->stmt_id += 1; + rc = mysql_stmt_execute(stmt); + ok (rc , "mysql_stmt_execute() must fail"); + if (rc) { + unsigned int psrc = mysql_stmt_errno(stmt); + ok( psrc == 1243 , "mysql_stmt_execute at line %d failed: %d , %s", __LINE__ , psrc , mysql_stmt_error(stmt)); + } + } + + diag("Decreasing stmt_id by 1, so that mysql_stmt_execute() must succeed"); + stmt->stmt_id -= NUM_LOOPS; + rc = mysql_stmt_execute(stmt); + ok (rc == 0 , "mysql_stmt_execute() succeeded"); + if (rc) { + fprintf(stderr, "mysql_stmt_execute at line %d failed: %d , %s\n", __LINE__ , rc , mysql_stmt_error(stmt)); + mysql_close(mysql); + mysql_library_end(); + return exit_status(); + } + mysql_close(mysql); + + return exit_status(); +}