Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update post_hoc_plot_utils.py #147

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions src/utils/post_hoc_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import json
import networkx as nx
import imageio
from glob import glob

# Load Logs
def load_logs(node_id: str, metric_type: str, logs_dir: str) -> pd.DataFrame:
Expand Down Expand Up @@ -407,6 +410,148 @@ def plot_metric_per_realtime(metric_df: pd.DataFrame, time_ticks: np.ndarray, me
plt.savefig(f'{output_dir}{metric_name}_per_time.png')
plt.close()



def create_weighted_images(neighbors, output_dir: str, pos):
"""
Create the images for the network visualization.

Parameters:
- neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors
"""
#create a network x graph and visualize it for each round

freq = np.zeros((neighbors.shape[1], neighbors.shape[1]))
for round in range(neighbors.shape[0]):

for node in range(neighbors.shape[1]):
for neighbor in neighbors[round][node]:
freq[node][neighbor-1] += 1

# Create the directed graph
graph = nx.DiGraph()
#add edges based on which edges in freq are greater than 0 and use that as the weight
for i in range(neighbors.shape[1]):
for j in range(neighbors.shape[1]):
if freq[i][j] > 0:
graph.add_edge(i + 1, j + 1, weight=3 * freq[i][j])




# draw nodes
nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue")
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")

#make opposite edges not overlap by adding curvature and make edges thicker based on frequency
curvatureDict = {}
for _, (u, v) in enumerate(graph.edges()):
# make sure u v and v u always have different curvature
if (u,v) not in curvatureDict:
curvatureDict[(u,v)] = 0.1
curvatureDict[(v,u)] = 0.1

rad = curvatureDict[(u,v)]
nx.draw_networkx_edges(
graph,
pos,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
width=freq[u-1][v-1]/3,
arrows=True,
arrowsize=20
)

# Create the image
plt.title(f"Round {round + 1}")
plt.savefig(f"{output_dir}/weighted_graph_{round + 1}.png")

plt.close()

def create_images(neighbors, output_dir: str, pos):
"""
Create the images for the network visualization.

Parameters:
- neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors
"""
#create a network x graph and visualize it for each round
for round in range(neighbors.shape[0]):

# Create the directed graph
graph = nx.DiGraph()
for node in range(neighbors.shape[1]):
for neighbor in neighbors[round][node]:
graph.add_edge(node + 1, neighbor)

# draw nodes
nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue")
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")


#make opposite edges not overlap by adding curvature
curvatureDict = {}
for i, (u, v) in enumerate(graph.edges()):
# make sure u v and v u always have different curvature
if (u,v) not in curvatureDict:
curvatureDict[(u,v)] = 0.1
curvatureDict[(v,u)] = 0.1

rad = curvatureDict[(u,v)]
nx.draw_networkx_edges(
graph,
pos,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
arrows=True,
arrowsize=20
)

# Create the image
plt.title(f"Round {round + 1}")
plt.savefig(f"{output_dir}/graph_{round + 1}.png")
plt.close()

def create_video(output_dir: str, image_name: str):
"""Create a gif from the images."""
images = []
for filename in sorted(glob(f"{output_dir}/{image_name}_*.png")):
images.append(imageio.imread(filename))
imageio.mimsave(f"{output_dir}/{image_name}_video.gif", images, fps = 1, loop = 0)



def create_heatmap(neighbors, output_dir: str):
"""
Create a heatmap of the edge frequency.

Parameters:
- neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors
"""

# Initialize the edge frequency matrix
edge_frequency_matrix = np.zeros((neighbors.shape[1]+1, neighbors.shape[1]+1))
# Iterate over all the rounds
for round in range(neighbors.shape[0]):
# Iterate over all the nodes
for node in range(neighbors.shape[1]):
# Iterate over all the
for neighbor in neighbors[round][node]:
edge_frequency_matrix[node+1][neighbor] += 1

edge_frequency_matrix = np.log(edge_frequency_matrix + 1) # Log scale for better visualization
# Create the heatmap
plt.figure(figsize=(10, 6))
plt.imshow(edge_frequency_matrix, cmap="hot", interpolation="nearest")
plt.title("Edge Frequency Matrix")
plt.colorbar(label="Frequency of Communication")
plt.xlabel("Node")
plt.ylabel("Node")
plt.xticks(range(1,neighbors.shape[1]+1))
plt.yticks(range(1,neighbors.shape[1]+1))
plt.savefig(f"{output_dir}/edge_frequency_heatmap.png")
plt.close()

def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = True, metrics_map: Optional[Dict[str, str]] = None, plot_avg_only: bool=False, **kwargs) -> None:
"""Generates plots for all metrics over rounds with aggregation."""
if metrics_map is None:
Expand Down Expand Up @@ -449,9 +594,34 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru
plot_avg_only=plot_avg_only,
**kwargs
)

neighbors = aggregate_neighbors_across_users(logs_dir)
# create_heatmap(neighbors, f'{os.path.dirname(logs_dir)}/plots/')
pos = nx.spring_layout(nx.DiGraph({i+1: [] for i in range(neighbors.shape[1])}))
create_images(neighbors, f'{os.path.dirname(logs_dir)}/plots/', pos)
create_weighted_images(neighbors, f'{os.path.dirname(logs_dir)}/plots/', pos)
create_video(f'{os.path.dirname(logs_dir)}/plots/', 'graph')
create_video(f'{os.path.dirname(logs_dir)}/plots/', 'weighted_graph')
create_heatmap(neighbors, f'{os.path.dirname(logs_dir)}/plots/')


print("Plots saved as PNG files.")

def aggregate_neighbors_across_users(logs_dir: str) -> np.ndarray:
"""Aggregate the neighbors of each node across all users."""
nodes = get_all_nodes(logs_dir)
nodes.sort() # Sort the nodes to ensure consistent order

all_users_neighbors = []

for node in nodes:
node_id = node.split('_')[-1]
neighbors_file = os.path.join(logs_dir, f'node_{node_id}/csv/neighbors.csv')
neighbors = pd.read_csv(neighbors_file)
np.array(all_users_neighbors.append(neighbors['neighbors'].apply(json.loads).values))

return np.array(all_users_neighbors).T

# Use if you a specific experiment folder
# if __name__ == "__main__":
# # Define the path where your experiment logs are saved
Expand Down
Loading