Skip to content

Commit

Permalink
Restore GC list before calling serial version of deduce_unreachable
Browse files Browse the repository at this point in the history
Summary:
In debug mode the serial version of deduce_unreachable verifies that
the GC list has not been tampered with. We call `update_refs` beforehand,
which copies refcounts into the previous pointer, effectively making
it a singly linked list.

Reviewed By: DinoV

Differential Revision: D50996948

fbshipit-source-id: fa529158a50331178ff816397f96c145a24b19cb
  • Loading branch information
mpage authored and facebook-github-bot committed Nov 7, 2023
1 parent 9e1211f commit ae834b1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 30 deletions.
15 changes: 15 additions & 0 deletions CinderX/ParallelGC/parallel_gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,19 @@ Ci_move_unreachable_parallel(PyGC_Head *base, PyGC_Head *unreachable)
unreachable->_gc_next &= ~NEXT_MASK_UNREACHABLE;
}

static void
Ci_restore_prev_ptrs(PyGC_Head *containers)
{
PyGC_Head *prev = containers;
for (PyGC_Head *gc = GC_NEXT(containers); gc != containers; gc = GC_NEXT(gc)) {
// Clear refcount saved in top bits (gc_refs)
_PyGCHead_SET_PREV(gc, prev);
// Clear the collecting bit
gc_clear_collecting(gc);
prev = gc;
}
}

/* Deduce which objects among "base" are unreachable from outside the list in
parallel and move them to 'unreachable'.
Expand Down Expand Up @@ -2241,6 +2254,8 @@ Ci_deduce_unreachable_parallel(Ci_ParGCState *par_gc, PyGC_Head *base, PyGC_Head
unsigned int num_objects = update_refs(base);
if (num_objects < par_gc->num_workers) {
CI_DLOG("Too few objects to justify parallel collection. Collecting serially.");
// Restore the prev pointer of each node that was clobbered by update_refs
Ci_restore_prev_ptrs(base);
deduce_unreachable(base, unreachable);
return;
}
Expand Down
76 changes: 46 additions & 30 deletions CinderX/test_cinderx/test_parallel_gc.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. (http://www.meta.com)

import gc
import test.test_gc
import unittest

try:
import cinder
except ImportError:
cinder = None
# TODO(T168696028): Remove this once parallel gc functions are moved to the
# cinderx module
raise unittest.SkipTest("Tests CinderX features")

# TODO(T168696028): Remove this once parallel gc functions are moved to the
# cinderx module
@unittest.skipIf(cinder is None, "Tests CinderX features")
class ParallelGCTests(unittest.TestCase):

def _restore_parallel_gc(settings):
cinder.disable_parallel_gc()
if settings is not None:
cinder.enable_parallel_gc(
settings["min_generation"],
settings["num_threads"],
)


class ParallelGCAPITests(unittest.TestCase):
def setUp(self):
self.par_gc_settings = cinder.get_parallel_gc_settings()
self.old_par_gc_settings = cinder.get_parallel_gc_settings()
cinder.disable_parallel_gc()

def tearDown(self):
if self.par_gc_settings is not None:
cinder.disable_parallel_gc()
cinder.enable_parallel_gc(
self.par_gc_settings["min_generation"],
self.par_gc_settings["num_threads"],
)
else:
cinder.disable_parallel_gc()
_restore_parallel_gc(self.old_par_gc_settings)

def test_get_settings_when_disabled(self):
cinder.disable_parallel_gc()
self.assertEqual(cinder.get_parallel_gc_settings(), None)

def test_get_settings_when_enabled(self):
cinder.disable_parallel_gc()
cinder.enable_parallel_gc(2, 8)
settings = cinder.get_parallel_gc_settings()
expected = {
Expand All @@ -47,23 +49,37 @@ def test_set_invalid_num_threads(self):
with self.assertRaisesRegex(ValueError, "invalid num_threads"):
cinder.enable_parallel_gc(2, -1)

def test_collection(self):
collected = False

class Cycle:
def __init__(self):
self.ref = self
# Run all the GC tests with parallel GC enabled


class ParallelGCTests(test.test_gc.GCTests):
pass


class ParallelGCCallbackTests(test.test_gc.GCCallbackTests):
@unittest.skip("Tests implementation details of serial collector")
def test_refcount_errors(self):
pass


class ParallelGCFinalizationTests(test.test_gc.PythonFinalizationTests):
pass


def setUpModule():
test.test_gc.setUpModule()

global old_par_gc_settings
old_par_gc_settings = cinder.get_parallel_gc_settings()
cinder.enable_parallel_gc(0, 8)


def __del__(self):
nonlocal collected
collected = True
def tearDownModule():
test.test_gc.tearDownModule()

cinder.enable_parallel_gc()
gc.collect()
c = Cycle()
del c
gc.collect()
self.assertTrue(collected)
global old_par_gc_settings
_restore_parallel_gc(old_par_gc_settings)


if __name__ == "__main__":
Expand Down

0 comments on commit ae834b1

Please sign in to comment.