Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorboard Time #112

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ def log_metrics(self, stats: Dict[str, Any], iteration: int) -> None:
imgs=stats["images"], key="sample_images", iteration=iteration
)

#config["num_users"]
#Check if the file neighbors_{iteration}.csv exists in logs/csv
dir = os.path.dirname(self.log_utils.log_dir) + "/csv"
if os.path.exists(f"{dir}/graph_{iteration}.adjlist"):
self.plot_utils.log_nx_graph_image(iteration, f"{dir}/neighbors_{iteration}.csv")
self.plot_utils.combine_graphs_with_edge_frequency(dir, iteration)

@abstractmethod
def receive_and_aggregate(self):
"""Add docstring here"""
Expand Down
172 changes: 171 additions & 1 deletion src/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import pandas as pd
from utils.types import ConfigType
import json

import networkx as nx
from networkx import Graph
import matplotlib.pyplot as plt
import imageio

def deprocess(img: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -103,6 +106,7 @@ class LogUtils:
"""
Utility class for logging and saving experiment data.
"""
# nx_layout = None

def __init__(self, config: ConfigType) -> None:
log_dir = config["log_path"]
Expand All @@ -123,7 +127,110 @@ def __init__(self, config: ConfigType) -> None:
self.init_npy()
self.init_summary()
self.init_csv()
self.nx_layout = None
self.init_nx_graph(config)

def init_nx_graph(self, config: ConfigType):
"""
Initialize the networkx graph for the topology.

Args:
config (ConfigType): Configuration dictionary.
rank (int): Rank of the current node.
"""
if "topology" in config:
self.topology = config["topology"]
self.num_users = config["num_users"]
self.graph = nx.DiGraph()



def log_nx_graph(self, graph: Graph, iteration: int, directory: str|None = None):
"""
Log the networkx graph to a file.
"""
# print(graph)
if directory:
nx.write_adjlist(graph, f"{directory}/graph_{iteration}.adjlist", comments='#', delimiter=' ', encoding='utf-8') # type: ignore
else:
nx.write_adjlist(graph, f"{self.log_dir}/graph_{iteration}.adjlist", comments='#', delimiter=' ', encoding='utf-8') # type: ignore


def log_nx_graph_image(self, graph: Graph, iteration: int, directory: str | None = None):
"""
Log the networkx directed graph as an image with non-overlapping edges.
"""
# Generate a layout with more spacing
if self.nx_layout is None:
self.nx_layout = nx.shell_layout(graph)

# Draw nodes with larger size and smaller font
nx.draw_networkx_nodes(graph, self.nx_layout, node_size=700, node_color="skyblue")

# Draw each edge with curved lines for side-by-side display
edges = list(graph.edges())
for i, (u, v) in enumerate(edges):
rad = 0.2 if i % 2 == 0 else -0.2 # Increase rad for more curvature
nx.draw_networkx_edges(
graph,
self.nx_layout,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
arrows=True,
# arrowstyle="-|>", # Customize arrow style for better separation
arrowsize=20 # Increase arrow size for visibility
)

# Draw labels with smaller font
nx.draw_networkx_labels(graph, self.nx_layout, font_size=8, font_weight="bold")

# Save the plot as an image
if directory:
plt.savefig(f"{directory}/graph_{iteration}.png", format="png")
else:
plt.savefig(f"{self.log_dir}/graph_{iteration}.png", format="png")

# Close the plot to free up memory
plt.close()



def log_nx_graph_edge_weights(self, graph: Graph, iteration: int, directory: str | None = None):
"""
Log the directed graph with edge weights as an image with non-overlapping edges.
"""

if self.nx_layout is None:
self.nx_layout = nx.shell_layout(graph)
pos = self.nx_layout
edge_weights = nx.get_edge_attributes(graph, "weight")

for i, (u, v) in enumerate(graph.edges()):
rad = 0.2 if i % 2 == 0 else -0.2
nx.draw_networkx_edges(
graph,
pos,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
arrows=True,
width=edge_weights.get((u, v), 1.0), # Set width based on weight
arrowsize=20
)

# Draw nodes and labels
nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue")
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")

if directory:
plt.savefig(f"{directory}/weighted_graph_{iteration}.png", format="png")
else:
plt.savefig(f"{self.log_dir}/weighted_graph_{iteration}.png", format="png")

plt.close()




def log_config(self, config: ConfigType):
"""
Log the configuration to a json file.
Expand Down Expand Up @@ -165,6 +272,11 @@ def init_csv(self):
if not os.path.exists(csv_path) or not os.path.isdir(csv_path):
os.makedirs(csv_path)

parent = os.path.dirname(self.log_dir) + "/csv" # type: ignore
if not os.path.exists(parent) or not os.path.isdir(parent): # type: ignore
os.makedirs(parent) # type: ignore


def log_summary(self, text: str):
"""
Add summary text to the summary file for logging.
Expand Down Expand Up @@ -239,6 +351,64 @@ def log_csv(self, key: str, value: Any, iteration: int):
# Append the metrics to the CSV file
df.to_csv(log_file, mode='a', header=not file_exists, index=False)

#make a global file to store all the neighbors of each round
if key == "neighbors":
self.log_global_csv(iteration, key, value)

def log_global_csv(self, iteration: int, key: str, value: Any):
"""
Log a value to a CSV file.
"""
parent = os.path.dirname(self.log_dir) # type: ignore
log_file = f"{parent}/csv/neighbors_{iteration}.csv"
node = self.log_dir.split("_")[-1] # type: ignore
row = {"iteration": iteration, "node": node , key: value}
df = pd.DataFrame([row])
file_exists = os.path.isfile(log_file)
df.to_csv(log_file, mode='a', header=not file_exists, index=False)

if len(pd.read_csv(log_file)) == self.num_users:
adjacency_list = self.create_adjacency_list(log_file)
graph = nx.DiGraph(adjacency_list)
self.log_nx_graph(graph, iteration, f"{parent}/csv")




def create_adjacency_list(self, file_path: str) -> Dict[str, list]: # type: ignore
# Load the CSV file
"""
Load the CSV file, populate the adjacency list and return it.

Parameters
----------
file_path : str
The path to the CSV file

Returns
-------
adjacency_list : Dict[str, list]
The adjacency list
"""
data = pd.read_csv(str(file_path)) # type: ignore

# Initialize the adjacency list
adjacency_list : Dict[str, list] = {} # type: ignore

# Populate the adjacency list
for _, row in data.iterrows(): # type: ignore
node = row["node"] # type: ignore
# Convert string representation of list to actual list
neighbors = eval(row["neighbors"]) # type: ignore

if node not in adjacency_list:
adjacency_list[node] = neighbors
else:
adjacency_list[node].extend(neighbors) # type: ignore

return adjacency_list # type: ignore



def log_max_stats_per_client(
self, stats_per_client: np.ndarray, round_step: int, metric: str
Expand Down
Loading
Loading