Skip to content

Commit

Permalink
Merge pull request #320 from cb1kenobi/shared-buf-refs-fix
Browse files Browse the repository at this point in the history
Store copy of default buffer to shared user buffer
  • Loading branch information
kriszyp authored Jan 4, 2025
2 parents 0168261 + c4bc187 commit b69091f
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 114 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Tests

on: [pull_request]

concurrency:
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
group: ${{ github.workflow }}-${{ github.ref }}

jobs:
test:
name: Test on Node.js ${{ matrix.node }} and ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
node: [16, 18, 20, 22]
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Setup node
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node }}

- name: Install dependencies
if: |
!(matrix.node == 16 && matrix.os == 'macos-latest') || (matrix.node == 22 && matrix.os == 'windows-latest'))
run: npm install

- name: Build
if: |
!(matrix.node == 16 && matrix.os == 'macos-latest') || (matrix.node == 22 && matrix.os == 'windows-latest'))
run: npm run build

- name: Run tests
if: |
!(matrix.node == 16 && matrix.os == 'macos-latest') || (matrix.node == 22 && matrix.os == 'windows-latest'))
run: npm test
26 changes: 10 additions & 16 deletions read.js
Original file line number Diff line number Diff line change
Expand Up @@ -401,23 +401,17 @@ export function addReadMethods(
keySize = this.writeKey(id, keyBytes, 4);
}
};
let userSharedBuffers =
this._userSharedBuffers || (this._userSharedBuffers = new Map());
let sharedBuffer = userSharedBuffers.get(id.toString());
if (!sharedBuffer) {
setKeyBytes();
let sharedBuffer = getUserSharedBuffer(
env.address,
keySize,
defaultBuffer,
options?.callback,
);
sharedBuffer.notify = () => {
setKeyBytes();
sharedBuffer = getUserSharedBuffer(
env.address,
keySize,
defaultBuffer,
options?.callback,
);
userSharedBuffers.set(id.toString(), sharedBuffer);
sharedBuffer.notify = () => {
setKeyBytes();
return notifyUserCallbacks(env.address, keySize);
};
}
return notifyUserCallbacks(env.address, keySize);
};
return sharedBuffer;
},

Expand Down
40 changes: 26 additions & 14 deletions src/env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,14 +1085,13 @@ uint64_t ExtendedEnv::getNextTime() {
uint64_t ExtendedEnv::getLastTime() {
return bswap_64(lastTime);
}

NAPI_FUNCTION(getUserSharedBuffer) {
ARGS(4)
GET_INT64_ARG(0)
EnvWrap* ew = (EnvWrap*) i64;
uint32_t size;
GET_UINT32_ARG(size, 1);
MDB_val default_buffer;
napi_get_arraybuffer_info(env, args[2], &default_buffer.mv_data, &default_buffer.mv_size);
ExtendedEnv* extend_env = (ExtendedEnv*) mdb_env_get_userctx(ew->env);
std::string key(ew->keyBuffer, size);
napi_value as_bool;
Expand All @@ -1101,23 +1100,31 @@ NAPI_FUNCTION(getUserSharedBuffer) {
napi_get_value_bool(env, as_bool, &has_callback);

// get a shared buffer with the key, starting value, and convert pointer to an array buffer
MDB_val buffer = extend_env->getUserSharedBuffer(key, default_buffer, args[3], has_callback, env, ew);
if (buffer.mv_data == default_buffer.mv_data) return args[2];
napi_value return_value;
napi_create_external_arraybuffer(env, buffer.mv_data, buffer.mv_size, cleanupLMDB, buffer.mv_data, &return_value);
return return_value;
napi_value buffer = extend_env->getUserSharedBuffer(key, args[2], args[3], has_callback, env, ew);
return buffer;
}
/*napi_finalize cleanup_callback = [](napi_env env, void* data, void* buffer_info) {
// Data belongs to LMDB, we shouldn't free it here
}*/
MDB_val ExtendedEnv::getUserSharedBuffer(std::string key, MDB_val default_buffer, napi_value func, bool has_callback, napi_env env, EnvWrap* ew) {

napi_value ExtendedEnv::getUserSharedBuffer(std::string key, napi_value default_buffer, napi_value func, bool has_callback, napi_env env, EnvWrap* ew) {
pthread_mutex_lock(&userBuffersLock);

auto resolution = userSharedBuffers.find(key);
if (resolution == userSharedBuffers.end()) {
void* default_buffer_data;
size_t default_buffer_size;
napi_get_arraybuffer_info(env, default_buffer, &default_buffer_data, &default_buffer_size);

char* buffer_data = new char[default_buffer_size];
memcpy(buffer_data, default_buffer_data, default_buffer_size);

MDB_val buffer;
buffer.mv_data = (void*)buffer_data;
buffer.mv_size = default_buffer_size;

user_buffer_t user_shared_buffer;
user_shared_buffer.buffer = default_buffer;
user_shared_buffer.buffer = buffer;
resolution = userSharedBuffers.emplace(key, user_shared_buffer).first;
}

if (has_callback) {
napi_threadsafe_function callback;
napi_value resource;
Expand All @@ -1130,10 +1137,15 @@ MDB_val ExtendedEnv::getUserSharedBuffer(std::string key, MDB_val default_buffer
napi_unref_threadsafe_function(env, callback);
resolution->second.callbacks.push_back(callback);
}
MDB_val buffer = resolution->second.buffer;

napi_value buffer_value;
napi_create_external_arraybuffer(env, resolution->second.buffer.mv_data, resolution->second.buffer.mv_size, nullptr, nullptr, &buffer_value);

pthread_mutex_unlock(&userBuffersLock);
return buffer;

return buffer_value;
}

/**
* Notify the user callbacks associated with a user buffer for a given key
* @param key
Expand Down
2 changes: 1 addition & 1 deletion src/lmdb-js.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class ExtendedEnv {
pthread_mutex_t userBuffersLock;
uint64_t lastTime; // actually encoded as double
uint64_t previousTime; // actually encoded as double
MDB_val getUserSharedBuffer(std::string key, MDB_val default_buffer, napi_value func, bool has_callback, napi_env env, EnvWrap* ew);
napi_value getUserSharedBuffer(std::string key, napi_value default_buffer, napi_value func, bool has_callback, napi_env env, EnvWrap* ew);
bool notifyUserCallbacks(std::string key);
bool attemptLock(std::string key, napi_env env, napi_value func, bool has_callback, EnvWrap* ew);
bool unlock(std::string key, bool only_check);
Expand Down
12 changes: 6 additions & 6 deletions test/index.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -914,21 +914,21 @@ describe('lmdb-js', function () {
});
it('getUserSharedBuffer', function () {
let defaultIncrementer = new BigInt64Array(1);
defaultIncrementer[0] = 4n;
let incrementer = new BigInt64Array(
db.getUserSharedBuffer('incrementer-test', defaultIncrementer.buffer),
);
incrementer[0] = 4n;
should.equal(Atomics.add(incrementer, 0, 1n), 4n);
let secondDefaultIncrementer = new BigInt64Array(1); //should not get used
incrementer = new BigInt64Array( // should return same incrementer
let nextIncrementer = new BigInt64Array( // should return same incrementer
db.getUserSharedBuffer(
'incrementer-test',
secondDefaultIncrementer.buffer,
),
);
should.equal(defaultIncrementer[0], 5n);
should.equal(Atomics.add(incrementer, 0, 1n), 5n);
should.equal(defaultIncrementer[0], 6n);
should.equal(incrementer[0], 5n);
should.equal(Atomics.add(nextIncrementer, 0, 1n), 5n);
should.equal(incrementer[0], 6n);
should.equal(secondDefaultIncrementer[0], 0n);
});
it('getUserSharedBuffer with callbacks', async function () {
Expand Down Expand Up @@ -2109,4 +2109,4 @@ describe('lmdb-js', function () {

function delay(ms) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
}
181 changes: 104 additions & 77 deletions test/threads.cjs
Original file line number Diff line number Diff line change
@@ -1,95 +1,122 @@
var assert = require('assert');
const { Worker, isMainThread, parentPort, threadId } = require('worker_threads');
const {
Worker,
isMainThread,
parentPort,
threadId,
} = require('worker_threads');
var path = require('path');
var numCPUs = require('os').cpus().length;
const { setFlagsFromString } = require('v8');
const { runInNewContext } = require('vm');

setFlagsFromString('--expose_gc');
const gc = runInNewContext('gc');

const { open } = require('../dist/index.cjs');
const MAX_DB_SIZE = 256 * 1024 * 1024;
if (isMainThread) {
var inspector = require('inspector')
// inspector.open(9331, null, true);debugger
var inspector = require('inspector');
// inspector.open(9331, null, true);debugger

// The main thread
// The main thread

let db = open({
path: path.resolve(__dirname, './testdata'),
maxDbs: 10,
mapSize: MAX_DB_SIZE,
maxReaders: 126,
overlappingSync: true,
});
let db = open({
path: path.resolve(__dirname, './testdata'),
maxDbs: 10,
mapSize: MAX_DB_SIZE,
maxReaders: 126,
overlappingSync: true,
});

var workerCount = Math.min(numCPUs * 2, 20);
var value = {test: '48656c6c6f2c20776f726c6421'};
var str = 'this is supposed to be bigger than 16KB threshold for shared memory buffers';
for (let i = 0; i < 9; i++) {
str += str;
}
var bigValue = {test: str};
// This will start as many workers as there are CPUs available.
var workers = [];
for (var i = 0; i < workerCount; i++) {
var worker = new Worker(__filename);
workers.push(worker);
}
let incrementer = new BigInt64Array(1);
let incrementerBuffer = db.getUserSharedBuffer('test', incrementer.buffer);
incrementer = new BigInt64Array(incrementerBuffer);
incrementer[0] = 10000n;

var messages = [];
workers.forEach(function(worker) {
worker.on('message', function(msg) {
messages.push(msg);
// Once every worker has replied with a response for the value
// we can exit the test.
var workerCount = Math.min(numCPUs * 2, 20);
var value = { test: '48656c6c6f2c20776f726c6421' };
var str =
'this is supposed to be bigger than 16KB threshold for shared memory buffers';
for (let i = 0; i < 9; i++) {
str += str;
}
var bigValue = { test: str };
// This will start as many workers as there are CPUs available.
var workers = [];
for (var i = 0; i < workerCount; i++) {
var worker = new Worker(__filename);
workers.push(worker);
}

setTimeout(() => {
worker.terminate()
}, 100);
if (messages.length === workerCount) {
db.close();
for (var i = 0; i < messages.length; i ++) {
assert(messages[i] === value.toString('hex'));
}
console.log("done", threadId)
//setTimeout(() =>
//process.exit(0), 200);
}
});
});
var messages = [];
workers.forEach(function (worker) {
worker.on('message', function (msg) {
messages.push(msg);
// Once every worker has replied with a response for the value
// we can exit the test.

let last
for (var i = 0; i < workers.length; i++) {
last = db.put('key' + i, i % 2 === 1 ? bigValue : value);
}
setTimeout(() => {
worker.terminate();
}, 100);
if (messages.length === workerCount) {
db.close();
for (var i = 0; i < messages.length; i++) {
assert(messages[i] === value.toString('hex'));
}
assert(incrementer[0] === 10000n + BigInt(workerCount) * 10n);
console.log('done', threadId, incrementer[0]);
//setTimeout(() =>
//process.exit(0), 200);
}
});
});

last.then(() => {
for (var i = 0; i < workers.length; i++) {
var worker = workers[i];
worker.postMessage({key: 'key' + i});
};
});
let last;
for (var i = 0; i < workers.length; i++) {
last = db.put('key' + i, i % 2 === 1 ? bigValue : value);
}

last.then(() => {
for (var i = 0; i < workers.length; i++) {
var worker = workers[i];
worker.postMessage({ key: 'key' + i });
}
});
} else {
// The worker process
let db = open({
path: path.resolve(__dirname, './testdata'),
maxDbs: 10,
mapSize: MAX_DB_SIZE,
maxReaders: 126,
overlappingSync: true,
});
// The worker process
let db = open({
path: path.resolve(__dirname, './testdata'),
maxDbs: 10,
mapSize: MAX_DB_SIZE,
maxReaders: 126,
overlappingSync: true,
});

parentPort.on('message', async function (msg) {
if (msg.key) {
for (let i = 0; i < 10; i++) {
let incrementer = new BigInt64Array(1);
incrementer[0] = 1n; // should be ignored
let incrementerBuffer = db.getUserSharedBuffer(
'test',
incrementer.buffer,
);
incrementer = new BigInt64Array(incrementerBuffer);
Atomics.add(incrementer, 0, 1n);
gc();
await new Promise((resolve) => setTimeout(resolve, 100));
}

parentPort.on('message', async function(msg) {
if (msg.key) {
var value = db.get(msg.key);
if (msg.key == 'key1' || msg.key == 'key3') {
await db.put(msg.key, 'updated');
}
if (value === null) {
parentPort.postMessage("");
} else {
parentPort.postMessage(value.toString('hex'));
}

}
});
}
var value = db.get(msg.key);
if (msg.key == 'key1' || msg.key == 'key3') {
await db.put(msg.key, 'updated');
}
if (value === null) {
parentPort.postMessage('');
} else {
parentPort.postMessage(value.toString('hex'));
}
}
});
}

0 comments on commit b69091f

Please sign in to comment.