-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
323 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,279 @@ | ||
import yaml | ||
import sys | ||
import base64 | ||
import re | ||
import logging | ||
from typing import Dict, Any | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class ThemeManagerError(Exception): | ||
"""Raised when loading the theme fails.""" | ||
"""Raised when loading the theme fails due to invalid style strings or configuration.""" | ||
pass | ||
|
||
class ThemeManager: | ||
""" | ||
Loads and merges theme configuration from a YAML file. | ||
ThemeManager is responsible for: | ||
- Loading a theme configuration from a YAML file. | ||
- Validating style strings for nodes and links. | ||
- Optionally modifying embedded SVG images by injecting custom CSS overrides. | ||
CSS overrides can be specified in the theme YAML and allow changing properties | ||
of classes defined in the embedded SVG <style> block. | ||
""" | ||
|
||
def __init__(self, config_path: str): | ||
""" | ||
:param config_path: Path to the theme configuration file. | ||
Initialize the ThemeManager with a path to a theme configuration file. | ||
:param config_path: Path to the YAML theme configuration file. | ||
""" | ||
self.config_path = config_path | ||
|
||
def load_theme(self) -> dict: | ||
""" | ||
Load the theme configuration and return a dictionary of styles. | ||
Load and process the theme configuration. | ||
:return: Dictionary containing styles and configuration parameters. | ||
This method: | ||
- Reads the YAML file. | ||
- Validates the 'base_style' and 'custom_styles' strings. | ||
- Applies CSS overrides to embedded SVG images in custom styles if defined. | ||
:return: A dictionary representing the fully processed theme configuration. | ||
:raises SystemExit: If the file does not exist or another IO error occurs. | ||
:raises ThemeManagerError: If invalid style strings are encountered. | ||
""" | ||
logger.debug(f"Loading theme from: {self.config_path}") | ||
try: | ||
with open(self.config_path, "r") as file: | ||
config = yaml.safe_load(file) | ||
except FileNotFoundError: | ||
error_message = ( | ||
f"Error: The specified config file '{self.config_path}' does not exist." | ||
) | ||
error_message = f"Error: The specified config file '{self.config_path}' does not exist." | ||
logger.error(error_message) | ||
sys.exit(1) | ||
raise SystemExit(error_message) | ||
except Exception as e: | ||
error_message = f"An error occurred while loading the config: {e}" | ||
logger.error(error_message) | ||
sys.exit(1) | ||
|
||
base_style_dict = { | ||
item.split("=")[0]: item.split("=")[1] | ||
for item in config.get("base_style", "").split(";") | ||
if item | ||
} | ||
|
||
styles = { | ||
"background": config.get("background", "#FFFFFF"), | ||
"shadow": config.get("shadow", "1"), | ||
"grid": config.get("grid", "1"), | ||
"pagew": config.get("pagew", "827"), | ||
"pageh": config.get("pageh", "1169"), | ||
"base_style": config.get("base_style", ""), | ||
"link_style": config.get("link_style", ""), | ||
"src_label_style": config.get("src_label_style", ""), | ||
"trgt_label_style": config.get("trgt_label_style", ""), | ||
"port_style": config.get("port_style", ""), | ||
"connector_style": config.get("connector_style", ""), | ||
"icon_to_group_mapping": config.get("icon_to_group_mapping", {}), | ||
"custom_styles": {}, | ||
} | ||
|
||
for key, custom_style in config.get("custom_styles", {}).items(): | ||
custom_style_dict = { | ||
item.split("=")[0]: item.split("=")[1] | ||
for item in custom_style.split(";") | ||
if item | ||
} | ||
merged_style_dict = {**base_style_dict, **custom_style_dict} | ||
merged_style = ";".join(f"{k}={v}" for k, v in merged_style_dict.items()) | ||
styles["custom_styles"][key] = merged_style | ||
|
||
for key, value in config.items(): | ||
if key not in styles: | ||
styles[key] = value | ||
|
||
return styles | ||
raise SystemExit(error_message) | ||
|
||
# Validate base_style | ||
base_style = config.get("base_style", "") | ||
self._validate_style_string(base_style) | ||
|
||
# Validate custom_styles | ||
custom_styles = config.get("custom_styles", {}) | ||
for name, style_str in custom_styles.items(): | ||
self._validate_style_string(style_str) | ||
|
||
# Load CSS overrides if any | ||
css_overrides = config.get("css_overrides", {}) | ||
|
||
# Apply CSS overrides to embedded SVGs | ||
for style_name, style_str in custom_styles.items(): | ||
updated_style_str = self._maybe_modify_svg_css(style_name, style_str, css_overrides) | ||
custom_styles[style_name] = updated_style_str | ||
|
||
config["custom_styles"] = custom_styles | ||
|
||
logger.debug("Theme loaded and processed successfully.") | ||
return config | ||
|
||
def _maybe_modify_svg_css(self, style_name: str, style_str: str, css_overrides: Dict[str, Dict[str, str]]) -> str: | ||
""" | ||
Check if the given style string references an SVG image. If so, and if CSS overrides | ||
exist for this style, decode the SVG, modify its <style> block, and re-encode it. | ||
:param style_name: Name of the style (e.g. 'default', 'leaf', 'spine'). | ||
:param style_str: The style string which may contain 'image=data:...' referencing an SVG. | ||
:param css_overrides: A dictionary of CSS overrides for styles. | ||
:return: The updated style string if modifications were applied; else the original. | ||
""" | ||
image_match = re.search(r'image=data:([^;]+)', style_str) | ||
if not image_match: | ||
return style_str | ||
|
||
image_data = image_match.group(1) | ||
if not image_data.startswith("image/svg+xml,"): | ||
return style_str | ||
|
||
base64_part = image_data[len("image/svg+xml,"):] | ||
try: | ||
svg_binary = base64.b64decode(base64_part) | ||
except Exception as e: | ||
logger.warning(f"Failed to decode base64 SVG for style '{style_name}': {e}") | ||
return style_str | ||
|
||
svg_str = svg_binary.decode('utf-8', errors='replace') | ||
style_overrides_for_style = css_overrides.get(style_name, {}) | ||
|
||
if not style_overrides_for_style: | ||
# No overrides for this style | ||
return style_str | ||
|
||
logger.debug(f"Applying CSS overrides to style '{style_name}'.") | ||
new_svg_str = self._modify_svg_style_block(svg_str, style_overrides_for_style) | ||
if new_svg_str == svg_str: | ||
# No changes were made | ||
return style_str | ||
|
||
# Re-encode SVG | ||
new_base64 = base64.b64encode(new_svg_str.encode('utf-8')).decode('utf-8') | ||
new_image_data = "image/svg+xml," + new_base64 | ||
new_style_str = style_str.replace(image_data, new_image_data, 1) | ||
|
||
logger.debug(f"CSS overrides applied successfully to style '{style_name}'.") | ||
return new_style_str | ||
|
||
def _modify_svg_style_block(self, svg_data: str, style_overrides: Dict[str, str]) -> str: | ||
""" | ||
Modify the <style> block of the SVG by applying given CSS overrides. | ||
If no <style> block exists, one is created before the closing </svg> tag. | ||
:param svg_data: The full SVG as a string. | ||
:param style_overrides: A dictionary mapping 'classname_property' to 'value'. | ||
For example: {'st0_fill': '#FF0000'} | ||
:return: The updated SVG string. | ||
""" | ||
style_start = svg_data.find("<style") | ||
style_end = -1 | ||
style_content = "" | ||
|
||
if style_start != -1: | ||
style_close = svg_data.find("</style>", style_start) | ||
if style_close != -1: | ||
start_tag_end = svg_data.find('>', style_start) | ||
if start_tag_end != -1 and start_tag_end < style_close: | ||
style_end = style_close + len("</style>") | ||
style_content = svg_data[start_tag_end+1:style_close] | ||
|
||
# Split by '
' to preserve formatting of original style lines | ||
style_lines = style_content.split("
") if style_content else [] | ||
|
||
# Parse existing classes from the style block | ||
class_rules = {} | ||
class_line_map = {} | ||
for i, line in enumerate(style_lines): | ||
m = re.match(r'(\s*)(\.[A-Za-z0-9_-]+)\{([^}]*)\}', line.strip()) | ||
if m: | ||
indentation = m.group(1) or "" | ||
full_cls = m.group(2) | ||
cls_name = full_cls.lstrip('.') | ||
props_str = m.group(3) | ||
props = self._parse_properties(props_str) | ||
class_rules[cls_name] = props | ||
class_line_map[cls_name] = (i, indentation) | ||
|
||
# Apply overrides | ||
changed_classes = set() | ||
for key, val in style_overrides.items(): | ||
parts = key.split('_', 1) | ||
if len(parts) != 2: | ||
logger.debug(f"Skipping invalid override key '{key}'. Expected '<class>_<property>'.") | ||
continue | ||
class_name, prop_name = parts | ||
if class_name not in class_rules: | ||
# If class doesn't exist, create it | ||
class_rules[class_name] = {} | ||
class_line_map[class_name] = (None, " ") | ||
class_rules[class_name][prop_name] = val | ||
changed_classes.add(class_name) | ||
|
||
# Rebuild changed or newly added class lines | ||
for cls_n in changed_classes: | ||
i, indent = class_line_map[cls_n] | ||
new_line = self._build_class_line(indent, cls_n, class_rules[cls_n]) | ||
if i is not None: | ||
# Modify existing line | ||
style_lines[i] = new_line | ||
else: | ||
# Append a new class line | ||
style_lines.append(new_line) | ||
|
||
new_style_content = "
".join(style_lines) | ||
if new_style_content and not new_style_content.endswith("
"): | ||
new_style_content += "
" | ||
|
||
# If there was no style block, create one before </svg> | ||
if style_start == -1: | ||
insert_pos = svg_data.rfind("</svg>") | ||
if insert_pos == -1: | ||
# No closing svg? Just append the style at the end. | ||
return svg_data + "<style>" + new_style_content + "</style>" | ||
else: | ||
return svg_data[:insert_pos] + "<style>" + new_style_content + "</style>" + svg_data[insert_pos:] | ||
else: | ||
# Replace existing style content | ||
return self._replace_style_block(svg_data, style_start, style_end, new_style_content) | ||
|
||
def _replace_style_block(self, svg_data: str, style_start: int, style_end: int, new_content: str) -> str: | ||
""" | ||
Replace the content of the existing <style> block with new_content. | ||
:param svg_data: The full SVG string. | ||
:param style_start: Index of the start of the <style> tag. | ||
:param style_end: Index of the end of the </style> tag. | ||
:param new_content: The new CSS content to insert. | ||
:return: The updated SVG string with replaced style content. | ||
""" | ||
start_tag_end = svg_data.find('>', style_start) | ||
if start_tag_end == -1 or style_end == -1: | ||
logger.debug("Could not properly find the style block boundaries; returning unchanged SVG.") | ||
return svg_data | ||
|
||
style_open_tag = svg_data[style_start:start_tag_end+1] | ||
return svg_data[:style_start] + style_open_tag + new_content + "</style>" + svg_data[style_end:] | ||
|
||
def _parse_properties(self, props_str: str) -> Dict[str, str]: | ||
""" | ||
Parse CSS properties from a string like "fill:#001135;stroke:#FFF". | ||
:param props_str: The CSS properties string inside a single class definition. | ||
:return: A dict of {property_name: property_value}. | ||
""" | ||
props = {} | ||
segments = props_str.split(';') | ||
for seg in segments: | ||
seg = seg.strip() | ||
if '=' in seg: # skip invalid or unexpected segments | ||
continue | ||
if seg: | ||
kv = seg.split(':',1) | ||
if len(kv) == 2: | ||
prop = kv[0].strip() | ||
val = kv[1].strip() | ||
props[prop] = val | ||
return props | ||
|
||
def _build_class_line(self, indent: str, cls_name: str, props: Dict[str,str]) -> str: | ||
""" | ||
Rebuild a single CSS class line for the style block. | ||
:param indent: The indentation originally used for this line. | ||
:param cls_name: The class name (e.g. "st0"). | ||
:param props: A dict of CSS properties to apply to this class. | ||
:return: A string like " .st0{fill:#FF0000;stroke:#FFFFFF;}" | ||
""" | ||
prop_segs = [f"{p}:{v}" for p, v in props.items()] | ||
prop_str = ";".join(prop_segs) + ";" if prop_segs else "" | ||
return f"{indent}.{cls_name}{{{prop_str}}}" | ||
|
||
def _validate_style_string(self, style_str: str): | ||
""" | ||
Validate that the style string follows "key=value" pairs separated by semicolons. | ||
Known exception: 'points=[]' patterns are allowed. | ||
:param style_str: The style string to validate. | ||
:raises ThemeManagerError: If invalid segments are found. | ||
""" | ||
if style_str.strip() == "": | ||
return | ||
segments = style_str.split(';') | ||
segments = [seg for seg in segments if seg.strip() != ""] | ||
|
||
for seg in segments: | ||
if '=' not in seg: | ||
if 'points=[' in seg: | ||
continue | ||
raise ThemeManagerError(f"Invalid style segment '{seg}' in style string.") | ||
parts = seg.split('=', 1) | ||
if len(parts) != 2: | ||
raise ThemeManagerError(f"Invalid style segment '{seg}' in style string.") |
Oops, something went wrong.