@@ -40,11 +40,12 @@ class KafkaDatasetTest(test.TestCase):
40
40
# To setup the Kafka server:
41
41
# $ bash kafka_test.sh start kafka
42
42
#
43
- # To team down the Kafka server:
43
+ # To tear down the Kafka server:
44
44
# $ bash kafka_test.sh stop kafka
45
45
46
46
def test_kafka_dataset (self ):
47
- """Tests for KafkaDataset."""
47
+ """Tests for KafkaDataset when reading non-keyed messages
48
+ from a single-partitioned topic"""
48
49
topics = tf .compat .v1 .placeholder (dtypes .string , shape = [None ])
49
50
num_epochs = tf .compat .v1 .placeholder (dtypes .int64 , shape = [])
50
51
batch_size = tf .compat .v1 .placeholder (dtypes .int64 , shape = [])
@@ -60,21 +61,21 @@ def test_kafka_dataset(self):
60
61
get_next = iterator .get_next ()
61
62
62
63
with self .cached_session () as sess :
63
- # Basic test: read from topic 0 .
64
+ # Basic test: read a limited number of messages from the topic .
64
65
sess .run (init_op , feed_dict = {topics : ["test:0:0:4" ], num_epochs : 1 })
65
66
for i in range (5 ):
66
67
self .assertEqual (("D" + str (i )).encode (), sess .run (get_next ))
67
68
with self .assertRaises (errors .OutOfRangeError ):
68
69
sess .run (get_next )
69
70
70
- # Basic test: read from topic 1 .
71
+ # Basic test: read all the messages from the topic from offset 5 .
71
72
sess .run (init_op , feed_dict = {topics : ["test:0:5:-1" ], num_epochs : 1 })
72
73
for i in range (5 ):
73
74
self .assertEqual (("D" + str (i + 5 )).encode (), sess .run (get_next ))
74
75
with self .assertRaises (errors .OutOfRangeError ):
75
76
sess .run (get_next )
76
77
77
- # Basic test: read from both topics .
78
+ # Basic test: read from different subscriptions of the same topic .
78
79
sess .run (
79
80
init_op ,
80
81
feed_dict = {topics : ["test:0:0:4" , "test:0:5:-1" ], num_epochs : 1 },
@@ -87,7 +88,7 @@ def test_kafka_dataset(self):
87
88
with self .assertRaises (errors .OutOfRangeError ):
88
89
sess .run (get_next )
89
90
90
- # Test repeated iteration through both files .
91
+ # Test repeated iteration through both subscriptions .
91
92
sess .run (
92
93
init_op ,
93
94
feed_dict = {topics : ["test:0:0:4" , "test:0:5:-1" ], num_epochs : 10 },
@@ -101,7 +102,7 @@ def test_kafka_dataset(self):
101
102
with self .assertRaises (errors .OutOfRangeError ):
102
103
sess .run (get_next )
103
104
104
- # Test batched and repeated iteration through both files .
105
+ # Test batched and repeated iteration through both subscriptions .
105
106
sess .run (
106
107
init_batch_op ,
107
108
feed_dict = {
@@ -276,7 +277,8 @@ def test_write_kafka(self):
276
277
sess .run (get_next )
277
278
278
279
def test_kafka_dataset_with_key (self ):
279
- """Tests for KafkaDataset."""
280
+ """Tests for KafkaDataset when reading keyed-messages
281
+ from a single-partitioned topic"""
280
282
topics = tf .compat .v1 .placeholder (dtypes .string , shape = [None ])
281
283
num_epochs = tf .compat .v1 .placeholder (dtypes .int64 , shape = [])
282
284
batch_size = tf .compat .v1 .placeholder (dtypes .int64 , shape = [])
@@ -288,10 +290,11 @@ def test_kafka_dataset_with_key(self):
288
290
289
291
iterator = data .Iterator .from_structure (batch_dataset .output_types )
290
292
init_op = iterator .make_initializer (repeat_dataset )
293
+ init_batch_op = iterator .make_initializer (batch_dataset )
291
294
get_next = iterator .get_next ()
292
295
293
296
with self .cached_session () as sess :
294
- # Basic test: read from topic 0 .
297
+ # Basic test: read a limited number of keyed messages from the topic .
295
298
sess .run (init_op , feed_dict = {topics : ["key-test:0:0:4" ], num_epochs : 1 })
296
299
for i in range (5 ):
297
300
self .assertEqual (
@@ -301,6 +304,181 @@ def test_kafka_dataset_with_key(self):
301
304
with self .assertRaises (errors .OutOfRangeError ):
302
305
sess .run (get_next )
303
306
307
+ # Basic test: read all the keyed messages from the topic from offset 5.
308
+ sess .run (init_op , feed_dict = {topics : ["key-test:0:5:-1" ], num_epochs : 1 })
309
+ for i in range (5 ):
310
+ self .assertEqual (
311
+ (("D" + str (i + 5 )).encode (), ("K" + str ((i + 5 ) % 2 )).encode ()),
312
+ sess .run (get_next ),
313
+ )
314
+ with self .assertRaises (errors .OutOfRangeError ):
315
+ sess .run (get_next )
316
+
317
+ # Basic test: read from different subscriptions of the same topic.
318
+ sess .run (
319
+ init_op ,
320
+ feed_dict = {
321
+ topics : ["key-test:0:0:4" , "key-test:0:5:-1" ],
322
+ num_epochs : 1 ,
323
+ },
324
+ )
325
+ for j in range (2 ):
326
+ for i in range (5 ):
327
+ self .assertEqual (
328
+ (
329
+ ("D" + str (i + j * 5 )).encode (),
330
+ ("K" + str ((i + j * 5 ) % 2 )).encode (),
331
+ ),
332
+ sess .run (get_next ),
333
+ )
334
+ with self .assertRaises (errors .OutOfRangeError ):
335
+ sess .run (get_next )
336
+
337
+ # Test repeated iteration through both subscriptions.
338
+ sess .run (
339
+ init_op ,
340
+ feed_dict = {
341
+ topics : ["key-test:0:0:4" , "key-test:0:5:-1" ],
342
+ num_epochs : 10 ,
343
+ },
344
+ )
345
+ for _ in range (10 ):
346
+ for j in range (2 ):
347
+ for i in range (5 ):
348
+ self .assertEqual (
349
+ (
350
+ ("D" + str (i + j * 5 )).encode (),
351
+ ("K" + str ((i + j * 5 ) % 2 )).encode (),
352
+ ),
353
+ sess .run (get_next ),
354
+ )
355
+ with self .assertRaises (errors .OutOfRangeError ):
356
+ sess .run (get_next )
357
+
358
+ # Test batched and repeated iteration through both subscriptions.
359
+ sess .run (
360
+ init_batch_op ,
361
+ feed_dict = {
362
+ topics : ["key-test:0:0:4" , "key-test:0:5:-1" ],
363
+ num_epochs : 10 ,
364
+ batch_size : 5 ,
365
+ },
366
+ )
367
+ for _ in range (10 ):
368
+ self .assertAllEqual (
369
+ [
370
+ [("D" + str (i )).encode () for i in range (5 )],
371
+ [("K" + str (i % 2 )).encode () for i in range (5 )],
372
+ ],
373
+ sess .run (get_next ),
374
+ )
375
+ self .assertAllEqual (
376
+ [
377
+ [("D" + str (i + 5 )).encode () for i in range (5 )],
378
+ [("K" + str ((i + 5 ) % 2 )).encode () for i in range (5 )],
379
+ ],
380
+ sess .run (get_next ),
381
+ )
382
+
383
+ def test_kafka_dataset_with_partitioned_key (self ):
384
+ """Tests for KafkaDataset when reading keyed-messages
385
+ from a multi-partitioned topic"""
386
+ topics = tf .compat .v1 .placeholder (dtypes .string , shape = [None ])
387
+ num_epochs = tf .compat .v1 .placeholder (dtypes .int64 , shape = [])
388
+ batch_size = tf .compat .v1 .placeholder (dtypes .int64 , shape = [])
389
+
390
+ repeat_dataset = kafka_io .KafkaDataset (
391
+ topics , group = "test" , eof = True , message_key = True
392
+ ).repeat (num_epochs )
393
+ batch_dataset = repeat_dataset .batch (batch_size )
394
+
395
+ iterator = data .Iterator .from_structure (batch_dataset .output_types )
396
+ init_op = iterator .make_initializer (repeat_dataset )
397
+ init_batch_op = iterator .make_initializer (batch_dataset )
398
+ get_next = iterator .get_next ()
399
+
400
+ with self .cached_session () as sess :
401
+ # Basic test: read first 5 messages from the first partition of the topic.
402
+ # NOTE: The key-partition mapping occurs based on the order in which the data
403
+ # is being stored in kafka. Please check kafka_test.sh for the sample data.
404
+
405
+ sess .run (
406
+ init_op ,
407
+ feed_dict = {topics : ["key-partition-test:0:0:5" ], num_epochs : 1 },
408
+ )
409
+ for i in range (5 ):
410
+ self .assertEqual (
411
+ (("D" + str (i * 2 )).encode (), (b"K0" )), sess .run (get_next ),
412
+ )
413
+ with self .assertRaises (errors .OutOfRangeError ):
414
+ sess .run (get_next )
415
+
416
+ # Basic test: read first 5 messages from the second partition of the topic.
417
+ sess .run (
418
+ init_op ,
419
+ feed_dict = {topics : ["key-partition-test:1:0:5" ], num_epochs : 1 },
420
+ )
421
+ for i in range (5 ):
422
+ self .assertEqual (
423
+ (("D" + str (i * 2 + 1 )).encode (), (b"K1" )), sess .run (get_next ),
424
+ )
425
+ with self .assertRaises (errors .OutOfRangeError ):
426
+ sess .run (get_next )
427
+
428
+ # Basic test: read from different subscriptions to the same topic.
429
+ sess .run (
430
+ init_op ,
431
+ feed_dict = {
432
+ topics : ["key-partition-test:0:0:5" , "key-partition-test:1:0:5" ],
433
+ num_epochs : 1 ,
434
+ },
435
+ )
436
+ for j in range (2 ):
437
+ for i in range (5 ):
438
+ self .assertEqual (
439
+ (("D" + str (i * 2 + j )).encode (), ("K" + str (j )).encode ()),
440
+ sess .run (get_next ),
441
+ )
442
+ with self .assertRaises (errors .OutOfRangeError ):
443
+ sess .run (get_next )
444
+
445
+ # Test repeated iteration through both subscriptions.
446
+ sess .run (
447
+ init_op ,
448
+ feed_dict = {
449
+ topics : ["key-partition-test:0:0:5" , "key-partition-test:1:0:5" ],
450
+ num_epochs : 10 ,
451
+ },
452
+ )
453
+ for _ in range (10 ):
454
+ for j in range (2 ):
455
+ for i in range (5 ):
456
+ self .assertEqual (
457
+ (("D" + str (i * 2 + j )).encode (), ("K" + str (j )).encode ()),
458
+ sess .run (get_next ),
459
+ )
460
+ with self .assertRaises (errors .OutOfRangeError ):
461
+ sess .run (get_next )
462
+
463
+ # Test batched and repeated iteration through both subscriptions.
464
+ sess .run (
465
+ init_batch_op ,
466
+ feed_dict = {
467
+ topics : ["key-partition-test:0:0:5" , "key-partition-test:1:0:5" ],
468
+ num_epochs : 10 ,
469
+ batch_size : 5 ,
470
+ },
471
+ )
472
+ for _ in range (10 ):
473
+ for j in range (2 ):
474
+ self .assertAllEqual (
475
+ [
476
+ [("D" + str (i * 2 + j )).encode () for i in range (5 )],
477
+ [("K" + str (j )).encode () for i in range (5 )],
478
+ ],
479
+ sess .run (get_next ),
480
+ )
481
+
304
482
305
483
if __name__ == "__main__" :
306
484
test .main ()
0 commit comments