Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent crash COM_STMT_EXECUTE and invalid stmt_id #4481

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions lib/MySQL_Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3790,10 +3790,13 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
CurrentQuery.stmt_client_id=client_stmt_id;
stmt_global_id=client_myds->myconn->local_stmts->find_global_stmt_id_from_client(client_stmt_id);
if (stmt_global_id == 0) {
// FIXME: add error handling
// LCOV_EXCL_START
assert(0);
// LCOV_EXCL_STOP
l_free(pkt.size,pkt.ptr);
client_myds->setDSS_STATE_QUERY_SENT_NET();
string err_msg = "Unknown prepared statement handler (" + to_string(client_stmt_id) + ") given to mysql_stmt_precheck";
client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,1243,(char *)"HY000", err_msg.c_str());
client_myds->DSS=STATE_SLEEP;
status=WAITING_CLIENT_DATA;
return;
}
CurrentQuery.stmt_global_id=stmt_global_id;
// now we get the statement information
Expand All @@ -3803,7 +3806,8 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
// we couldn't find it
l_free(pkt.size,pkt.ptr);
client_myds->setDSS_STATE_QUERY_SENT_NET();
client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,1045,(char *)"28000",(char *)"Prepared statement doesn't exist", true);
string err_msg = "Unknown prepared statement handler (" + to_string(client_stmt_id) + ") given to mysql_stmt_precheck";
client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,1243,(char *)"HY000", err_msg.c_str());
client_myds->DSS=STATE_SLEEP;
status=WAITING_CLIENT_DATA;
return;
Expand Down
89 changes: 89 additions & 0 deletions test/tap/tests/reg_test_3371_prepared_statement_crash-t.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/**
* @file reg_test_3371_prepared_statement_crash-t.cpp
* @brief Tries to execute prepared statements with a not existing stmt_id.
* This used to crash ProxySQL , so this tap test verifies that ProxySQL
* doesn't crash
*/


#include <string>
#include <stdio.h>
#include <cstring>
#include <unistd.h>

#include "mysql.h"

#include "tap.h"
#include "command_line.h"
#include "utils.h"

using std::string;

const int NUM_LOOPS = 100; ///< Number of loops for statement execution.

int main(int argc, char** argv) {
CommandLine cl;

// Checking for required environmental variables
if (cl.getEnv()) {
diag("Failed to get the required environmental variables.");
return -1;
}

plan(1+NUM_LOOPS*2); // Plan for testing purposes

MYSQL* mysql = mysql_init(NULL); ///< MySQL connection object
if (!mysql) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql));
return exit_status();
}

// Connecting to ProxySQL
diag("Connecting to '%s@%s:%d'", cl.mysql_username, cl.mysql_host, cl.port);
if (!mysql_real_connect(mysql, cl.mysql_host, cl.mysql_username, cl.mysql_password, NULL, cl.port, NULL, 0)) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql));
return exit_status();
}

// Initialize and prepare all the statements
MYSQL_STMT* stmt = mysql_stmt_init(mysql);
if (!stmt) {
fprintf(stderr, "mysql_stmt_init(), out of memory\n");
return exit_status();
}

std::string select_query = "SELECT 1";
diag("select_query: %s", select_query.c_str());
if (mysql_stmt_prepare(stmt, select_query.c_str(), strlen(select_query.c_str()))) {
fprintf(stderr, "mysql_stmt_prepare at line %d failed: %s\n", __LINE__ , mysql_error(mysql));
mysql_close(mysql);
mysql_library_end();
return exit_status();
}

diag("Increasing stmt_id by 1, so that mysql_stmt_execute() must fail");
int rc = 0;
for (int i = 0; i < NUM_LOOPS ; i++) {
stmt->stmt_id += 1;
rc = mysql_stmt_execute(stmt);
ok (rc , "mysql_stmt_execute() must fail");
if (rc) {
unsigned int psrc = mysql_stmt_errno(stmt);
ok( psrc == 1243 , "mysql_stmt_execute at line %d failed: %d , %s", __LINE__ , psrc , mysql_stmt_error(stmt));
}
}

diag("Decreasing stmt_id by 1, so that mysql_stmt_execute() must succeed");
stmt->stmt_id -= NUM_LOOPS;
rc = mysql_stmt_execute(stmt);
ok (rc == 0 , "mysql_stmt_execute() succeeded");
if (rc) {
fprintf(stderr, "mysql_stmt_execute at line %d failed: %d , %s\n", __LINE__ , rc , mysql_stmt_error(stmt));
mysql_close(mysql);
mysql_library_end();
return exit_status();
}
mysql_close(mysql);

return exit_status();
}
Loading