Skip to content

Commit 38b5c1a

Browse files
authored
[TST] Add another test for log failover. (#4613)
## Description of changes I'm adding a test. ## Test plan CI; new test passes locally. - [X] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes N/A
1 parent 55f0c34 commit 38b5c1a

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

chromadb/test/distributed/test_log_failover.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,99 @@ def test_log_failover_with_compaction(
189189
assert all([math.fabs(x - y) < 0.001 for (x, y) in zip(result["embeddings"][0], embeddings[i])])
190190
else:
191191
assert False, "missing a result"
192+
193+
@skip_if_not_cluster()
194+
def test_log_failover_with_query_operations(
195+
client: ClientAPI,
196+
) -> None:
197+
seed = time.time()
198+
random.seed(seed)
199+
print("Generating data with seed ", seed)
200+
reset(client)
201+
collection = client.create_collection(
202+
name="test",
203+
metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128},
204+
)
205+
206+
time.sleep(1)
207+
208+
# Add initial RECORDS records with known embeddings for querying
209+
ids = []
210+
embeddings = []
211+
for i in range(RECORDS):
212+
ids.append(str(i))
213+
embeddings.append(np.random.rand(1, 3)[0])
214+
collection.add(
215+
ids=[str(i)],
216+
embeddings=[embeddings[-1]],
217+
)
218+
219+
# Perform baseline similarity queries before failover
220+
query_embeddings = [embeddings[0], embeddings[RECORDS//2], embeddings[-1]]
221+
baseline_results = []
222+
for query_embedding in query_embeddings:
223+
result = collection.query(
224+
query_embeddings=[query_embedding],
225+
n_results=5,
226+
include=["embeddings", "distances"]
227+
)
228+
baseline_results.append(result)
229+
230+
print('failing over for', collection.id)
231+
channel = grpc.insecure_channel('localhost:50052')
232+
log_service_stub = LogServiceStub(channel)
233+
234+
# Trigger log failover
235+
request = SealLogRequest(collection_id=str(collection.id))
236+
response = log_service_stub.SealLog(request, timeout=60)
237+
238+
# Re-run the same queries after failover and verify results consistency
239+
post_failover_results = []
240+
for query_embedding in query_embeddings:
241+
result = collection.query(
242+
query_embeddings=[query_embedding],
243+
n_results=5,
244+
include=["embeddings", "distances"]
245+
)
246+
post_failover_results.append(result)
247+
248+
# Verify that query results are consistent before and after failover
249+
for i, (baseline, post_failover) in enumerate(zip(baseline_results, post_failover_results)):
250+
assert len(baseline["ids"][0]) == len(post_failover["ids"][0]), f"Query {i} returned different number of results"
251+
assert baseline["ids"][0] == post_failover["ids"][0], f"Query {i} returned different IDs"
252+
# Verify embeddings match (allowing for small floating point differences)
253+
for j, (base_emb, post_emb) in enumerate(zip(baseline["embeddings"][0], post_failover["embeddings"][0])):
254+
assert all([math.fabs(x - y) < 0.001 for (x, y) in zip(base_emb, post_emb)]), f"Query {i} result {j} embeddings differ"
255+
256+
# Add more data post-failover
257+
post_failover_start = RECORDS
258+
for i in range(post_failover_start, post_failover_start + RECORDS):
259+
ids.append(str(i))
260+
embeddings.append(np.random.rand(1, 3)[0])
261+
collection.add(
262+
ids=[str(i)],
263+
embeddings=[embeddings[-1]],
264+
)
265+
266+
# Query for both old and new data to ensure full functionality
267+
# Test that we can find old data
268+
old_data_query = collection.query(
269+
query_embeddings=[embeddings[0]],
270+
n_results=3,
271+
include=["embeddings"]
272+
)
273+
assert len(old_data_query["ids"][0]) == 3, "Failed to query old data after failover"
274+
275+
# Test that we can find new data
276+
new_data_query = collection.query(
277+
query_embeddings=[embeddings[-1]],
278+
n_results=3,
279+
include=["embeddings"]
280+
)
281+
assert len(new_data_query["ids"][0]) == 3, "Failed to query new data after failover"
282+
283+
# Verify all data is still accessible by ID
284+
for i in range(len(ids)):
285+
result = collection.get(ids=[str(i)], include=["embeddings"])
286+
assert len(result["embeddings"]) > 0, f"Missing result for ID {i} after failover with new data"
287+
assert all([math.fabs(x - y) < 0.001 for (x, y) in zip(result["embeddings"][0], embeddings[i])]), f"Embedding mismatch for ID {i}"

0 commit comments

Comments
 (0)