Skip to content

Commit

Permalink
Update log_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aymann121 authored Nov 11, 2024
1 parent 5c0e73e commit 18d076c
Showing 1 changed file with 0 additions and 78 deletions.
78 changes: 0 additions & 78 deletions src/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import json
import networkx as nx
from networkx import Graph
import matplotlib.pyplot as plt
import imageio

def deprocess(img: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -154,82 +152,6 @@ def log_nx_graph(self, graph: Graph, iteration: int, directory: str|None = None)
nx.write_adjlist(graph, f"{directory}/graph_{iteration}.adjlist", comments='#', delimiter=' ', encoding='utf-8') # type: ignore
else:
nx.write_adjlist(graph, f"{self.log_dir}/graph_{iteration}.adjlist", comments='#', delimiter=' ', encoding='utf-8') # type: ignore


def log_nx_graph_image(self, graph: Graph, iteration: int, directory: str | None = None):
"""
Log the networkx directed graph as an image with non-overlapping edges.
"""
# Generate a layout with more spacing
if self.nx_layout is None:
self.nx_layout = nx.shell_layout(graph)

# Draw nodes with larger size and smaller font
nx.draw_networkx_nodes(graph, self.nx_layout, node_size=700, node_color="skyblue")

# Draw each edge with curved lines for side-by-side display
edges = list(graph.edges())
for i, (u, v) in enumerate(edges):
rad = 0.2 if i % 2 == 0 else -0.2 # Increase rad for more curvature
nx.draw_networkx_edges(
graph,
self.nx_layout,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
arrows=True,
# arrowstyle="-|>", # Customize arrow style for better separation
arrowsize=20 # Increase arrow size for visibility
)

# Draw labels with smaller font
nx.draw_networkx_labels(graph, self.nx_layout, font_size=8, font_weight="bold")

# Save the plot as an image
if directory:
plt.savefig(f"{directory}/graph_{iteration}.png", format="png")
else:
plt.savefig(f"{self.log_dir}/graph_{iteration}.png", format="png")

# Close the plot to free up memory
plt.close()



def log_nx_graph_edge_weights(self, graph: Graph, iteration: int, directory: str | None = None):
"""
Log the directed graph with edge weights as an image with non-overlapping edges.
"""

if self.nx_layout is None:
self.nx_layout = nx.shell_layout(graph)
pos = self.nx_layout
edge_weights = nx.get_edge_attributes(graph, "weight")

for i, (u, v) in enumerate(graph.edges()):
rad = 0.2 if i % 2 == 0 else -0.2
nx.draw_networkx_edges(
graph,
pos,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
arrows=True,
width=edge_weights.get((u, v), 1.0), # Set width based on weight
arrowsize=20
)

# Draw nodes and labels
nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue")
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")

if directory:
plt.savefig(f"{directory}/weighted_graph_{iteration}.png", format="png")
else:
plt.savefig(f"{self.log_dir}/weighted_graph_{iteration}.png", format="png")

plt.close()




def log_config(self, config: ConfigType):
"""
Expand Down

0 comments on commit 18d076c

Please sign in to comment.