Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sending messages with keys #70

Merged
merged 6 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- name: Check out the code
Expand Down
5 changes: 5 additions & 0 deletions adc/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class SASLAuth(object):
ssl_ca_location : `str`, optional
If using SSL via a self-signed cert, a path/location
to the certificate.
ssl_endpoint_identification_algorithm : `str`, optional
If using SSL, the algorithm used to verify that certificate is valid for the endpoint.
token_endpoint : `str`, optional
The OpenID Connect token endpoint URL.
Required for OAUTHBEARER / OpenID Connect, otherwise ignored.
Expand All @@ -63,6 +65,9 @@ def __init__(self, user, password, ssl=True, method=None, token_endpoint=None, *
"security.protocol": "SASL_SSL",
"ssl.ca.location": ssl_cert,
}
if "ssl_endpoint_identification_algorithm" in kwargs:
self._config["ssl.endpoint.identification.algorithm"] = \
kwargs["ssl_endpoint_identification_algorithm"]
else:
self._config = {"security.protocol": "SASL_PLAINTEXT"}

Expand Down
7 changes: 4 additions & 3 deletions adc/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self, conf: 'ConsumerConfig') -> None:
self.logger = logging.getLogger("adc-streaming.consumer")
self.conf = conf
self._consumer = confluent_kafka.Consumer(conf._to_confluent_kafka())
# Workaround for https://github.com/confluentinc/librdkafka/issues/3753#issuecomment-1058272987.
# Workaround for
# https://github.com/confluentinc/librdkafka/issues/3753#issuecomment-1058272987.
# FIXME: Remove once fixed upstream, or on removal of oauth_cb.
self._consumer.poll(0)
self._stop_event = threading.Event()
Expand Down Expand Up @@ -196,7 +197,7 @@ def _stream_forever(self,
self.mark_done(m, asynchronous=True)
yield m
else:
raise(confluent_kafka.KafkaException(err))
raise (confluent_kafka.KafkaException(err))
finally:
if autocommit:
self._consumer.commit(asynchronous=True)
Expand Down Expand Up @@ -242,7 +243,7 @@ def _stream_until_eof(self,
# Done with all partitions for the topic, remove it
del active_partitions[m.topic()]
else:
raise(confluent_kafka.KafkaException(err))
raise (confluent_kafka.KafkaException(err))
finally:
if autocommit:
self._consumer.commit(asynchronous=True)
Expand Down
10 changes: 6 additions & 4 deletions adc/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ def __init__(self, conf: 'ProducerConfig') -> None:
self.conf = conf
self.logger.debug(f"connecting to producer with config {conf._to_confluent_kafka()}")
self._producer = confluent_kafka.Producer(conf._to_confluent_kafka())
# Workaround for https://github.com/confluentinc/librdkafka/issues/3753#issuecomment-1058272987.
# Workaround for
# https://github.com/confluentinc/librdkafka/issues/3753#issuecomment-1058272987.
# FIXME: Remove once fixed upstream, or on removal of oauth_cb.
self._producer.poll(0)

def write(self,
msg: Union[bytes, 'Serializable'],
headers: Optional[Union[dict, list]] = None,
delivery_callback: Optional[DeliveryCallback] = log_delivery_errors,
topic: Optional[str] = None) -> None:
topic: Optional[str] = None,
key: Optional[Union[str, bytes]] = None) -> None:
if isinstance(msg, Serializable):
msg = msg.serialize()
if topic is None:
Expand All @@ -47,10 +49,10 @@ def write(self,
"or specify the topic argument to write()")
self.logger.debug("writing message to %s", topic)
if delivery_callback is not None:
self._producer.produce(topic, msg, headers=headers,
self._producer.produce(topic, msg, headers=headers, key=key,
on_delivery=delivery_callback)
else:
self._producer.produce(topic, msg, headers=headers)
self._producer.produce(topic, msg, headers=headers, key=key,)

def flush(self, timeout: timedelta = timedelta(seconds=10)) -> int:
"""Attempt to flush enqueued messages. Return the number of messages still
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from adc.auth import SASLAuth, SASLMethod
from adc.auth import SASLAuth


@pytest.mark.parametrize('auth,expected_config', [
Expand Down
49 changes: 39 additions & 10 deletions tests/test_kafka_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import tempfile
import time
import unittest
from datetime import datetime, timedelta, timezone
from typing import List
from datetime import datetime, timedelta
from typing import List, Optional

import docker
import pytest
Expand Down Expand Up @@ -59,6 +59,31 @@ def test_round_trip(self):
self.assertEqual(msg.topic(), topic)
self.assertEqual(msg.value(), b"can you hear me?")

def test_message_with_key(self):
"""Try writing a message into the Kafka broker, and try pulling the same
message back out.

"""
topic = "test_message_with_key"
# Push one message in...
simple_write_msg(self.kafka, topic, "can you hear me?", key="test_msg")
# ... and pull it back out.
consumer = adc.consumer.Consumer(adc.consumer.ConsumerConfig(
broker_urls=[self.kafka.address],
group_id="test_consumer",
auth=self.kafka.auth,
))
consumer.subscribe(topic)
stream = consumer.stream()

msg = next(stream)
if msg.error() is not None:
raise Exception(msg.error())

self.assertEqual(msg.topic(), topic)
self.assertEqual(msg.value(), b"can you hear me?")
self.assertEqual(msg.key(), b"test_msg")

def test_reset_to_end(self):
# Write a few messages.
topic = "test_reset_to_end"
Expand Down Expand Up @@ -283,7 +308,7 @@ def test_consume_from_datetime(self):
"message 1",
"message 2",
"message 3",
])
])
# Wait a while, write, and wait some more
time.sleep(2)
client_middle_time = datetime.now()
Expand Down Expand Up @@ -390,8 +415,8 @@ def test_multi_topic_handling(self):
topic=None,
auth=self.kafka.auth,
))
for i in range(0,8):
producer.write(str(i), topic=topics[i%2])
for i in range(0, 8):
producer.write(str(i), topic=topics[i % 2])
producer.flush()
logger.info("messages sent")

Expand All @@ -403,12 +428,12 @@ def test_multi_topic_handling(self):
))
consumer.subscribe(topics)
stream = consumer.stream()
total_messages = 0;
total_messages = 0
for msg in stream:
if msg.error() is not None:
raise Exception(msg.error())
idx = int(msg.value())
self.assertEqual(msg.topic(), topics[idx%2])
self.assertEqual(msg.topic(), topics[idx % 2])
total_messages += 1
if total_messages == 8:
break
Expand Down Expand Up @@ -451,6 +476,10 @@ def __init__(self):
self.auth = adc.auth.SASLAuth(
user="test", password="test-pass",
ssl_ca_location=self.certfile.name,
# disable endpoint verification because the docker service generates a certificate with
# a useless subject (the container ID) which can never match the hostname used to
# connect (typically 0.0.0.0)
ssl_endpoint_identification_algorithm="none",
)

def poll_for_kafka_broker_address(self, maxiter=20, sleep=timedelta(milliseconds=500)):
Expand Down Expand Up @@ -550,7 +579,7 @@ def get_or_create_container(self):
# these tests cannot run if there is already an instance of Kafka running on
# the same host.
ports={"9092/tcp": 9092},
command=["/root/runServer","--advertisedListener","SASL_SSL://localhost:9092"],
command=["/root/runServer", "--advertisedListener", "SASL_SSL://localhost:9092"],
)

def get_or_create_docker_network(self):
Expand All @@ -565,13 +594,13 @@ def get_or_create_docker_network(self):
return self.docker_client.networks.create(name="adc-integration-test")


def simple_write_msg(conn: KafkaDockerConnection, topic: str, msg: str):
def simple_write_msg(conn: KafkaDockerConnection, topic: str, msg: str, key: Optional[str] = None):
producer = adc.producer.Producer(adc.producer.ProducerConfig(
broker_urls=[conn.address],
topic=topic,
auth=conn.auth,
))
producer.write(msg)
producer.write(msg, key=key)
producer.flush()


Expand Down