Skip to content

Commit

Permalink
Use extracted exchange_cluster_information in distributed_training_ud…
Browse files Browse the repository at this point in the history
…f.py
  • Loading branch information
tkilias committed Dec 10, 2023
1 parent 149d6fa commit 10209bc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,52 +1,18 @@
import contextlib
import socket
from typing import Iterator
from typing import List

import structlog
import tensorflow as tf
from pydantic import BaseModel
from structlog.typing import FilteringBoundLogger
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver, ClusterResolver
from tensorflow.python.keras import activations
from tensorflow.python.training.server_lib import ClusterSpec

from exasol_advanced_analytics_framework.udf_communication.distributed_udf import DistributedUDF, \
exchange_cluster_information, UDFCommunicatorFactory
from exasol_advanced_analytics_framework.udf_communication.ip_address import Port, IPAddress
UDFCommunicatorFactory
from exasol_advanced_analytics_framework.udf_communication.exchange_cluster_information import \
exchange_cluster_information, reserve_port, WorkerAddress, ClusterInformation

LOGGER: FilteringBoundLogger = structlog.get_logger()


class WorkerAddress(BaseModel):
ip_address: IPAddress
port: Port


class ClusterInformation(BaseModel):
workers: List[WorkerAddress]


@contextlib.contextmanager
def reserve_port(ip: IPAddress) -> Iterator[Port]:
def new_socket():
return socket.socket(socket.AF_INET, socket.SOCK_STREAM)

def bind(sock: socket.socket, ip: IPAddress, port: int):
sock.bind((ip.ip_address, port))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

def acquire_port_number(sock: socket.socket, ip: IPAddress) -> int:
bind(sock, ip, 0)
return sock.getsockname()[1]

with new_socket() as sock:
port_number = acquire_port_number(sock, ip)
port = Port(port=port_number)
LOGGER.info("reserve_port", ip=ip, port=port)
yield port


class DistributedTrainingUDF(DistributedUDF):

def run(self, ctx, exa, udf_communicator_factory: UDFCommunicatorFactory):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")),
processors=[
structlog.contextvars.merge_contextvars,
#ConditionalMethodDropper(method_name="debug"),
#ConditionalMethodDropper(method_name="info"),
ConditionalMethodDropper(method_name="debug"),
ConditionalMethodDropper(method_name="info"),
structlog.processors.add_log_level,
structlog.processors.TimeStamper(),
structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)),
Expand Down

0 comments on commit 10209bc

Please sign in to comment.