Skip to content

Commit d09bed0

Browse files
committed
fix: UB around SpringSinkRow
1 parent 98bc2ce commit d09bed0

File tree

3 files changed

+46
-51
lines changed

3 files changed

+46
-51
lines changed

springql.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ typedef struct SpringConfig SpringConfig;
4949
*/
5050
typedef struct SpringPipeline SpringPipeline;
5151

52+
/**
53+
* Row object to pop from an in memory queue.
54+
*/
55+
typedef struct SpringSinkRow SpringSinkRow;
56+
5257
/**
5358
* Row object to push into an in memory queue.
5459
*/
@@ -59,11 +64,6 @@ typedef struct SpringSourceRow SpringSourceRow;
5964
*/
6065
typedef struct SpringSourceRowBuilder SpringSourceRowBuilder;
6166

62-
/**
63-
* Row object to pop from an in memory queue.
64-
*/
65-
typedef void *SpringSinkRow;
66-
6767
/**
6868
* Returns default configuration.
6969
*
@@ -156,7 +156,7 @@ enum SpringErrno spring_command(const struct SpringPipeline *pipeline, const cha
156156
*
157157
* - `Unavailable`: queue named `queue` does not exist.
158158
*/
159-
SpringSinkRow *spring_pop(const struct SpringPipeline *pipeline, const char *queue);
159+
struct SpringSinkRow *spring_pop(const struct SpringPipeline *pipeline, const char *queue);
160160

161161
/**
162162
* Pop a row from an in memory queue. This is a non-blocking function.
@@ -170,9 +170,9 @@ SpringSinkRow *spring_pop(const struct SpringPipeline *pipeline, const char *que
170170
*
171171
* - `Unavailable`: queue named `queue` does not exist.
172172
*/
173-
SpringSinkRow *spring_pop_non_blocking(const struct SpringPipeline *pipeline,
174-
const char *queue,
175-
bool *is_err);
173+
struct SpringSinkRow *spring_pop_non_blocking(const struct SpringPipeline *pipeline,
174+
const char *queue,
175+
bool *is_err);
176176

177177
/**
178178
* Push a row into an in memory queue. This is a non-blocking function.
@@ -256,7 +256,7 @@ struct SpringSourceRow *spring_source_row_build(struct SpringSourceRowBuilder *b
256256
* - `Ok`: on success.
257257
* - `CNull`: `pipeline` is a NULL pointer.
258258
*/
259-
enum SpringErrno spring_sink_row_close(SpringSinkRow *row);
259+
enum SpringErrno spring_sink_row_close(struct SpringSinkRow *row);
260260

261261
/**
262262
* Get a 2-byte integer column.
@@ -275,7 +275,7 @@ enum SpringErrno spring_sink_row_close(SpringSinkRow *row);
275275
* - `i_col` is out of range.
276276
* - `CNull`: Column value is NULL.
277277
*/
278-
enum SpringErrno spring_column_short(const SpringSinkRow *row, uint16_t i_col, short *out);
278+
enum SpringErrno spring_column_short(const struct SpringSinkRow *row, uint16_t i_col, short *out);
279279

280280
/**
281281
* Get a 4-byte integer column.
@@ -294,7 +294,7 @@ enum SpringErrno spring_column_short(const SpringSinkRow *row, uint16_t i_col, s
294294
* - `i_col` is out of range.
295295
* - `CNull`: Column value is NULL.
296296
*/
297-
enum SpringErrno spring_column_int(const SpringSinkRow *row, uint16_t i_col, int *out);
297+
enum SpringErrno spring_column_int(const struct SpringSinkRow *row, uint16_t i_col, int *out);
298298

299299
/**
300300
* Get an 8-byte integer column.
@@ -313,7 +313,7 @@ enum SpringErrno spring_column_int(const SpringSinkRow *row, uint16_t i_col, int
313313
* - `i_col` is out of range.
314314
* - `CNull`: Column value is NULL.
315315
*/
316-
enum SpringErrno spring_column_long(const SpringSinkRow *row, uint16_t i_col, long *out);
316+
enum SpringErrno spring_column_long(const struct SpringSinkRow *row, uint16_t i_col, long *out);
317317

318318
/**
319319
* Get a 4-byte unsigned integer column.
@@ -332,7 +332,7 @@ enum SpringErrno spring_column_long(const SpringSinkRow *row, uint16_t i_col, lo
332332
* - `i_col` is out of range.
333333
* - `CNull`: Column value is NULL.
334334
*/
335-
enum SpringErrno spring_column_unsigned_int(const SpringSinkRow *row,
335+
enum SpringErrno spring_column_unsigned_int(const struct SpringSinkRow *row,
336336
uint16_t i_col,
337337
unsigned int *out);
338338

@@ -354,7 +354,7 @@ enum SpringErrno spring_column_unsigned_int(const SpringSinkRow *row,
354354
* - `i_col` is out of range.
355355
* - `CNull`: Column value is NULL.
356356
*/
357-
int spring_column_text(const SpringSinkRow *row, uint16_t i_col, char *out, int out_len);
357+
int spring_column_text(const struct SpringSinkRow *row, uint16_t i_col, char *out, int out_len);
358358

359359
/**
360360
* Get a BLOB column.
@@ -374,7 +374,7 @@ int spring_column_text(const SpringSinkRow *row, uint16_t i_col, char *out, int
374374
* - `i_col` is out of range.
375375
* - `CNull`: Column value is NULL.
376376
*/
377-
int spring_column_blob(const SpringSinkRow *row, uint16_t i_col, void *out, int out_len);
377+
int spring_column_blob(const struct SpringSinkRow *row, uint16_t i_col, void *out, int out_len);
378378

379379
/**
380380
* Get a bool column.
@@ -393,7 +393,7 @@ int spring_column_blob(const SpringSinkRow *row, uint16_t i_col, void *out, int
393393
* - `i_col` is out of range.
394394
* - `CNull`: Column value is NULL.
395395
*/
396-
enum SpringErrno spring_column_bool(const SpringSinkRow *row, uint16_t i_col, bool *out);
396+
enum SpringErrno spring_column_bool(const struct SpringSinkRow *row, uint16_t i_col, bool *out);
397397

398398
/**
399399
* Get a 4-byte floating point column.
@@ -412,7 +412,7 @@ enum SpringErrno spring_column_bool(const SpringSinkRow *row, uint16_t i_col, bo
412412
* - `i_col` is out of range.
413413
* - `CNull`: Column value is NULL.
414414
*/
415-
enum SpringErrno spring_column_float(const SpringSinkRow *row, uint16_t i_col, float *out);
415+
enum SpringErrno spring_column_float(const struct SpringSinkRow *row, uint16_t i_col, float *out);
416416

417417
/**
418418
* Write the most recent error number into `errno_` and message into a caller-provided buffer as a UTF-8

src/lib.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ pub unsafe extern "C" fn spring_pop(
176176
let queue = CStr::from_ptr(queue).to_string_lossy().into_owned();
177177
let result = with_catch(|| ru_pipeline.pop(&queue));
178178
match result {
179-
Ok(row) => {
180-
let row = SpringSinkRow::new(row);
179+
Ok(ru_row) => {
180+
let row = SpringSinkRow::from(ru_row);
181181
row.into_ptr()
182182
}
183183
Err(_) => ptr::null_mut(),
@@ -205,9 +205,8 @@ pub unsafe extern "C" fn spring_pop_non_blocking(
205205
let result = with_catch(|| ru_pipeline.pop_non_blocking(&queue));
206206
match result {
207207
Ok(Some(row)) => {
208-
let ptr = SpringSinkRow::new(row);
209208
*is_err = false;
210-
ptr.into_ptr()
209+
SpringSinkRow::from(row).into_ptr()
211210
}
212211
Ok(None) => {
213212
*is_err = false;
@@ -223,7 +222,7 @@ pub unsafe extern "C" fn spring_pop_non_blocking(
223222
/// Push a row into an in memory queue. This is a non-blocking function.
224223
///
225224
/// `row` is freed internally.
226-
///
225+
///
227226
/// # Returns
228227
///
229228
/// - `Ok`: on success.
@@ -338,11 +337,11 @@ pub unsafe extern "C" fn spring_source_row_build(
338337
/// - `Ok`: on success.
339338
/// - `CNull`: `pipeline` is a NULL pointer.
340339
#[no_mangle]
341-
pub extern "C" fn spring_sink_row_close(row: *mut SpringSinkRow) -> SpringErrno {
340+
pub unsafe extern "C" fn spring_sink_row_close(row: *mut SpringSinkRow) -> SpringErrno {
342341
if row.is_null() {
343342
SpringErrno::CNull
344343
} else {
345-
SpringSinkRow::drop(row);
344+
let _ = Box::from_raw(row);
346345
SpringErrno::Ok
347346
}
348347
}
@@ -368,7 +367,7 @@ pub unsafe extern "C" fn spring_column_short(
368367
i_col: u16,
369368
out: *mut c_short,
370369
) -> SpringErrno {
371-
let row = (*row).as_row();
370+
let row = &*row;
372371
let i_col = i_col as usize;
373372
let result = with_catch(|| row.get_not_null_by_index(i_col as usize));
374373
match result {
@@ -401,7 +400,7 @@ pub unsafe extern "C" fn spring_column_int(
401400
i_col: u16,
402401
out: *mut c_int,
403402
) -> SpringErrno {
404-
let row = (*row).as_row();
403+
let row = &*row;
405404
let i_col = i_col as usize;
406405
let result = with_catch(|| row.get_not_null_by_index(i_col as usize));
407406
match result {
@@ -434,7 +433,7 @@ pub unsafe extern "C" fn spring_column_long(
434433
i_col: u16,
435434
out: *mut c_long,
436435
) -> SpringErrno {
437-
let row = (*row).as_row();
436+
let row = &*row;
438437
let i_col = i_col as usize;
439438
let result = with_catch(|| row.get_not_null_by_index(i_col as usize));
440439
match result {
@@ -467,7 +466,7 @@ pub unsafe extern "C" fn spring_column_unsigned_int(
467466
i_col: u16,
468467
out: *mut c_uint,
469468
) -> SpringErrno {
470-
let row = (*row).as_row();
469+
let row = &*row;
471470
let i_col = i_col as usize;
472471
let result = with_catch(|| row.get_not_null_by_index(i_col as usize));
473472
match result {
@@ -502,7 +501,7 @@ pub unsafe extern "C" fn spring_column_text(
502501
out: *mut c_char,
503502
out_len: c_int,
504503
) -> c_int {
505-
let row = (*row).as_row();
504+
let row = &*row;
506505
let i_col = i_col as usize;
507506
let result: Result<String, SpringErrno> =
508507
with_catch(|| row.get_not_null_by_index(i_col as usize));
@@ -538,7 +537,7 @@ pub unsafe extern "C" fn spring_column_blob(
538537
out: *mut c_void,
539538
out_len: c_int,
540539
) -> c_int {
541-
let row = (*row).as_row();
540+
let row = &*row;
542541
let i_col = i_col as usize;
543542
let result: Result<Vec<u8>, SpringErrno> =
544543
with_catch(|| row.get_not_null_by_index(i_col as usize));
@@ -572,7 +571,7 @@ pub unsafe extern "C" fn spring_column_bool(
572571
i_col: u16,
573572
out: *mut bool,
574573
) -> SpringErrno {
575-
let row = (*row).as_row();
574+
let row = &*row;
576575
let i_col = i_col as usize;
577576
let result = with_catch(|| row.get_not_null_by_index(i_col as usize));
578577
match result {
@@ -605,7 +604,7 @@ pub unsafe extern "C" fn spring_column_float(
605604
i_col: u16,
606605
out: *mut c_float,
607606
) -> SpringErrno {
608-
let row = (*row).as_row();
607+
let row = &*row;
609608
let i_col = i_col as usize;
610609
let result = with_catch(|| row.get_not_null_by_index(i_col as usize));
611610
match result {

src/spring_sink_row.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,27 @@
11
// This file is part of https://github.com/SpringQL/SpringQL-client-c which is licensed under MIT OR Apache-2.0. See file LICENSE-MIT or LICENSE-APACHE for full license details.
22

3-
use ::springql::SpringSinkRow as SinkRow;
4-
5-
use std::{ffi::c_void, mem};
3+
use springql::{Result, SpringSinkRow as RuSpringSinkRow, SpringValue};
64

75
/// Row object to pop from an in memory queue.
86
#[non_exhaustive]
9-
#[repr(transparent)]
10-
pub struct SpringSinkRow(*mut c_void);
11-
12-
impl SpringSinkRow {
13-
pub fn new(sink_row: SinkRow) -> Self {
14-
SpringSinkRow(unsafe { mem::transmute(Box::new(sink_row)) })
15-
}
7+
#[derive(Debug)]
8+
pub struct SpringSinkRow(RuSpringSinkRow);
169

17-
pub fn as_row(&self) -> &SinkRow {
18-
unsafe { &*(self.0 as *const SinkRow) }
10+
impl From<RuSpringSinkRow> for SpringSinkRow {
11+
fn from(sink_row: RuSpringSinkRow) -> Self {
12+
SpringSinkRow(sink_row)
1913
}
14+
}
2015

21-
pub fn drop(ptr: *mut SpringSinkRow) {
22-
let outer = unsafe { Box::from_raw(ptr) };
23-
let inner = unsafe { Box::from_raw(outer.0) };
24-
drop(inner);
25-
drop(outer);
16+
impl SpringSinkRow {
17+
pub(crate) fn get_not_null_by_index<T>(&self, i_col: usize) -> Result<T>
18+
where
19+
T: SpringValue,
20+
{
21+
self.0.get_not_null_by_index(i_col)
2622
}
2723

28-
pub fn into_ptr(self) -> *mut SpringSinkRow {
24+
pub(crate) fn into_ptr(self) -> *mut Self {
2925
Box::into_raw(Box::new(self))
3026
}
3127
}

0 commit comments

Comments
 (0)