From ab4f816876cd4b550349633114205a3e8f5d2844 Mon Sep 17 00:00:00 2001 From: Johannes Rieke Date: Sun, 9 Mar 2025 21:18:03 +0100 Subject: [PATCH] Add screenshot generator --- python/api-screenshot-generator/.gitignore | 2 + python/api-screenshot-generator/README.md | 119 ++++ .../api-screenshot-generator/capture_app.py | 564 ++++++++++++++++++ python/api-screenshot-generator/config.yaml | 118 ++++ .../overlays/cursor.svg | 17 + .../api-screenshot-generator/preview_app.py | 104 ++++ .../api-screenshot-generator/requirements.txt | 10 + .../take_screenshots.py | 417 +++++++++++++ 8 files changed, 1351 insertions(+) create mode 100644 python/api-screenshot-generator/.gitignore create mode 100644 python/api-screenshot-generator/README.md create mode 100644 python/api-screenshot-generator/capture_app.py create mode 100644 python/api-screenshot-generator/config.yaml create mode 100644 python/api-screenshot-generator/overlays/cursor.svg create mode 100644 python/api-screenshot-generator/preview_app.py create mode 100644 python/api-screenshot-generator/requirements.txt create mode 100644 python/api-screenshot-generator/take_screenshots.py diff --git a/python/api-screenshot-generator/.gitignore b/python/api-screenshot-generator/.gitignore new file mode 100644 index 000000000..9a0bc2e97 --- /dev/null +++ b/python/api-screenshot-generator/.gitignore @@ -0,0 +1,2 @@ +.venv +screenshots diff --git a/python/api-screenshot-generator/README.md b/python/api-screenshot-generator/README.md new file mode 100644 index 000000000..e10b04f63 --- /dev/null +++ b/python/api-screenshot-generator/README.md @@ -0,0 +1,119 @@ +# Screenshot generator for the API reference + +This tool automatically generates standardized screenshots of Streamlit elements for +use in the API reference documentation. + +## Overview + +The screenshot generator consists of: + +- A Streamlit app (`capture_app.py`) that displays the Streamlit elements to be + captured. +- A screenshot capture script (`take_screenshots.py`) that automates browser + interactions and captures screenshots. +- A configuration file (`config.yaml`) that defines elements and settings. +- A preview app (`preview_app.py`) that displays all captured screenshots in a grid. + +## Installation + +```bash +pip install -r requirements.txt +playwright install +``` + +## Usage + +### Creating and previewing screenshots + +1. Run the screenshot capture script: + + ```bash + python take_screenshots.py + ``` + + This will start a Streamlit server displaying `capture_app.py`, access it with a headless + Playwright browser, take screenshots of all elements defined in `config.yaml`, and save + them to the `screenshots` directory. Note that all existing screenshots in the + `screenshots` directory will be deleted. + +2. Preview all screenshots: + + ```bash + streamlit run preview_app.py + ``` + + This will start a Streamlit app that displays all screenshots in a grid. The order + of the screenshots is the same as in `config.yaml`. + +### Command line options + +When running the screenshot script (`take_screenshots.py`), you can add the following +command line options: + +- `--headed`: Run the browser in headed mode (visible) instead of headless. This is great + for debugging. +- `--only element1 element2`: Only capture specific elements. + +Example: + +```bash +python take_screenshots.py --headed --only button text_input +``` + +### Configuration + +All settings are defined in `config.yaml`. It contains: + +- Some global settings at the top of the file, such as the dimensions and padding of the + screenshots. +- A list of elements to capture. Each element can have the following properties: + - `name`: The name of the element (required). This must map to the `key` property of + an `st.container` in `capture_app.py` (see below for details). + - `padding`: The padding around the element. If not given, the `default_padding` + setting will be used. You can use this to make smaller elements (e.g. buttons) not + take up the entire screenshot. + - `overlay`: The path to an overlay image to apply on top of the screenshot. Supports + SVG and PNG files. The overlay image must be the same size as the screenshot. The + overlay is only applied if the global setting `enable_overlays` is `true`. + +### Adding or changing elements + +All elements are displayed through the Streamlit app `capture_app.py`. For every element, +this app has an `st.container` with a `key` parameter equal to the element name defined +in `config.yaml`. The content of this container will be screenshotted. + +To edit an existing element: + +- Simply edit whatever is in the associated `st.container` in `capture_app.py`. You can + also add supporting code that should not be captured outside of the container (e.g. + CSS hacks). + +To add a new element: + +- Add the new element to `config.yaml`. +- Add a new `st.container` to `capture_app.py` with a `key` parameter equal to the + element name. + +## How It Works + +1. The script launches a Streamlit server running `capture_app.py` +2. It uses Playwright to automate a browser session +3. For each element defined in `config.yaml`: + - It locates the `st.container` with its `key` equal to the element name on the page + - Performs any special handling (e.g., clicking dropdowns) + - Takes a screenshot + - Processes the image (trims whitespace, applies padding, adds overlays) +4. All screenshots are saved to the `screenshots` directory + +## Special handling for elements + +Some elements are handled specially by `take_screenshots.py`. E.g. selectbox, +multiselect, date input, and color picker are clicked to show their dropdowns. Or data +editor is clicked multiple times to show its editing mode. For details, see the +`take_screenshots.py` code. + +## Troubleshooting + +- If elements aren't found, check the key attributes in `capture_app.py` +- For rendering issues, try running in headed mode with `--headed` +- If overlays don't appear, check that `enable_overlays` is `true` and paths are correct diff --git a/python/api-screenshot-generator/capture_app.py b/python/api-screenshot-generator/capture_app.py new file mode 100644 index 000000000..c6607400c --- /dev/null +++ b/python/api-screenshot-generator/capture_app.py @@ -0,0 +1,564 @@ +import datetime +import time + +import altair as alt +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import streamlit as st + +st.title("Capture app") +st.write( + """ + This app is being used to take screenshots for elements in the API reference. It + is started by the `take_screenshots.py` script. See `README.md` for how this works. + """ +) + +# Text Elements +st.header("Text Elements") + +with st.container(key="markdown"): + st.markdown(""" + ## st.markdown + + This component supports *italics*, **bold text**, and ***both together***. + + ## Features include: + * Bullet lists + * [Links to websites](https://streamlit.io) + * Code snippets with `inline code` + + ```python + # Or code blocks with syntax highlighting + def hello(): + return "Hello, Streamlit!" + ``` + """) + +with st.container(key="title"): + st.title("This is a title") + st.title("This is a title") + st.title("This is a title") + st.title("This is a title") + +with st.container(key="header"): + st.header("This is a header") + st.header("This is a header") + st.header("This is a header") + st.header("This is a header") + st.header("This is a header") + +with st.container(key="subheader"): + st.subheader("This is a subheader") + st.subheader("This is a subheader") + st.subheader("This is a subheader") + st.subheader("This is a subheader") + st.subheader("This is a subheader") + st.subheader("This is a subheader") + +with st.container(key="caption"): + st.caption("This is a caption") + st.caption("This is a caption") + st.caption("This is a caption") + st.caption("This is a caption") + st.caption("This is a caption") + st.caption("This is a caption") + st.caption("This is a caption") + +with st.container(key="code"): + st.code("""def hello_world(): + # A simple greeting function + return "Hello, Streamlit!" + +def square(x): + # Returns the square of a number + return x * x""", + language="python", + ) + +with st.container(key="latex"): + st.latex(r"e^{i\pi} + 1 = 0") + st.latex(r"\int_{a}^{b} f(x) \, dx = F(b) - F(a)") + st.latex(r"\quad \frac{d}{dx}[\sin(x)] = \cos(x)") + +with st.container(key="text"): + st.text("This is a text") + st.text("This is a text") + st.text("This is a text") + st.text("This is a text") + st.text("This is a text") + st.text("This is a text") + +with st.container(key="divider"): + st.divider() + +# Data Display Elements +st.header("Data Display Elements") + +with st.container(key="dataframe"): + df = pd.DataFrame( + { + "Name": ["Alice", "Bob", "Charlie"], + "Age": [24, 32, 28], + "City": ["New York", "Los Angeles", "Chicago"], + } + ) + st.dataframe(df, use_container_width=False) + + +# Hide the toolbar for the data editor; this is important because we click on it +# for the screenshot +st.markdown( + "", + unsafe_allow_html=True +) +with st.container(key="data_editor"): + df_edit = pd.DataFrame({"Name": ["Alice", "Bob", "Charlie"], "Age": [24, 32, 28]}) + st.data_editor(df_edit, use_container_width=False) + +with st.container(key="column_config"): + df = pd.DataFrame( + { + "Name": ["Alice", "Bob", "Charlie"], + "Age": [24, 32, 28], + "Rating": [4.5, 3.8, 4.9], + } + ) + st.dataframe( + df, + column_config={ + "Name": "Full Name", + "Rating": st.column_config.ProgressColumn( + "Rating", min_value=0, max_value=5 + ), + }, + use_container_width=False, + ) + +with st.container(key="table"): + st.table(df.head(3)) + +with st.container(key="metric"): + st.metric(label="Temperature", value="70 °F", delta="1.2 °F") + +with st.container(key="json"): + st.json( + { + "name": "John", + "age": 30, + "city": "New York", + "skills": ["Python", "SQL", "Streamlit"], + } + ) + +# Chart Elements +st.header("Chart Elements") + +with st.container(key="area_chart"): + np.random.seed(42) + chart_data = pd.DataFrame( + np.random.randn(20, 3).cumsum(axis=0), columns=["A", "B", "C"] + ) + st.area_chart(chart_data) + +with st.container(key="bar_chart"): + data = pd.DataFrame({"Category": ["A", "B", "C", "D"], "Values": [10, 25, 15, 30]}) + st.bar_chart(data, x="Category", y="Values") + +with st.container(key="line_chart"): + np.random.seed(42) + chart_data = pd.DataFrame( + np.random.randn(20, 3).cumsum(axis=0), columns=["X", "Y", "Z"] + ) + st.line_chart(chart_data) + +with st.container(key="scatter_chart"): + np.random.seed(42) + scatter_data = pd.DataFrame(np.random.randn(50, 3), columns=["X", "Y", "Size"]) + st.scatter_chart(scatter_data, x="X", y="Y", size="Size") + +with st.container(key="map"): + np.random.seed(42) + map_data = pd.DataFrame( + np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4], columns=["lat", "lon"] + ) + st.map(map_data, height=250) + +with st.container(key="pyplot"): + np.random.seed(42) + fig, ax = plt.subplots() + x = np.linspace(0, 10, 100) + ax.plot(x, np.sin(x)) + ax.set_title("Sine Wave") + st.pyplot(fig) + +with st.container(key="altair_chart"): + # Create a more interesting dataset with multiple series + x = list(range(10)) + source = pd.DataFrame( + { + "x": x * 3, + "y": [i**2 for i in x] + [i**1.5 for i in x] + [i**0.8 * 15 for i in x], + "category": ["Quadratic"] * 10 + ["Root"] * 10 + ["Linear-ish"] * 10, + } + ) + + # Create a colorful and interactive chart with multiple elements + chart = ( + alt.Chart(source) + .mark_circle(size=100, opacity=0.7) + .encode( + x=alt.X("x", title="X Axis", scale=alt.Scale(domain=[0, 9])), + y=alt.Y("y", title="Y Axis"), + color=alt.Color( + "category:N", + scale=alt.Scale(range=["#FF5733", "#33FF57", "#3357FF"]), + legend=alt.Legend(title="Function Type"), + ), + tooltip=["x", "y", "category"], + size=alt.Size("y", legend=None, scale=alt.Scale(range=[100, 500])), + ) + .properties(title="Beautiful Mathematical Relationships") + ) + + # Add connecting lines between points of the same category + lines = ( + alt.Chart(source) + .mark_line(opacity=0.6, strokeWidth=3) + .encode(x="x", y="y", color="category:N", strokeDash="category:N") + ) + + # Combine the visualizations + final_chart = (chart + lines).interactive() + + st.altair_chart(final_chart, use_container_width=True) + +with st.container(key="vega_lite_chart"): + spec = { + "mark": {"type": "bar"}, + "encoding": { + "x": {"field": "a", "type": "ordinal"}, + "y": {"field": "b", "type": "quantitative"}, + }, + "data": { + "values": [ + {"a": "A", "b": 28}, + {"a": "B", "b": 55}, + {"a": "C", "b": 43}, + {"a": "D", "b": 91}, + ] + }, + } + st.vega_lite_chart(spec, use_container_width=True) + +with st.container(key="plotly_chart"): + import plotly.express as px + + df = px.data.gapminder().query("continent=='Oceania'") + fig = px.line(df, x="year", y="lifeExp", color="country") + st.plotly_chart(fig, use_container_width=True) + +with st.container(key="bokeh_chart"): + from bokeh.plotting import figure + + p = figure( + title="Bokeh Chart Example", + x_axis_label="X-Axis", + y_axis_label="Y-Axis", + height=250, + ) + + # Add a line renderer + x = np.linspace(0, 10, 100) + y = np.sin(x) * np.cos(x) + p.line(x, y, legend_label="sin(x) * cos(x)", line_width=2, line_color="navy") + p.circle(x, y, fill_color="white", size=8) + + st.bokeh_chart(p, use_container_width=True) + +with st.container(key="pydeck_chart"): + import pydeck as pdk + + # Sample data for a scatter plot of points + # Create a more interesting dataset with San Francisco landmarks + data = pd.DataFrame({ + "name": ["Golden Gate Bridge", "Fisherman's Wharf", "Alcatraz Island", "Lombard Street", "Painted Ladies"], + "lat": [37.8199, 37.8080, 37.8270, 37.8021, 37.7762], + "lon": [-122.4783, -122.4177, -122.4230, -122.4187, -122.4328], + "visitors": [10000000, 8000000, 1500000, 2000000, 3500000], + "category": ["Landmark", "Tourist", "Historic", "Street", "Architecture"] + }) + + # Color mapping based on category + color_map = { + "Landmark": [255, 0, 0], # Red + "Tourist": [255, 165, 0], # Orange + "Historic": [0, 0, 255], # Blue + "Street": [0, 128, 0], # Green + "Architecture": [128, 0, 128] # Purple + } + + # Create a column with color values + data["color"] = data["category"].apply(lambda x: color_map[x]) + + # Create source and target pairs for arcs + # Connect each point to the center of San Francisco + sf_center = {"lat": 37.7749, "lon": -122.4194} + + # Create arc data - connecting each landmark to the center + arc_data = [] + for _, row in data.iterrows(): + arc_data.append({ + "from_lat": sf_center["lat"], + "from_lon": sf_center["lon"], + "to_lat": row["lat"], + "to_lon": row["lon"], + "color": row["color"], + "name": row["name"], + "category": row["category"], + "visitors": row["visitors"] + }) + + arc_data = pd.DataFrame(arc_data) + + # Create an arc layer + arc_layer = pdk.Layer( + "ArcLayer", + data=arc_data, + get_source_position=["from_lon", "from_lat"], + get_target_position=["to_lon", "to_lat"], + get_width=3, + get_source_color=[64, 64, 64], + get_target_color="color", + pickable=True, + auto_highlight=True, + ) + + # Point layer for landmarks + point_layer = pdk.Layer( + "ScatterplotLayer", + data=data, + get_position=["lon", "lat"], + get_radius=250, + get_fill_color="color", + pickable=True, + ) + + # Text layer for landmark names + text_layer = pdk.Layer( + "TextLayer", + data=data, + get_position=["lon", "lat"], + get_text="name", + get_size=16, + get_color=[0, 0, 0], + get_text_anchor="middle", + get_alignment_baseline="center", + ) + + # Set view state with increased zoom and adjusted bearing + view_state = pdk.ViewState( + latitude=37.8049, + longitude=-122.4294, + zoom=11, + pitch=45, + bearing=60, + ) + + # Render the deck + deck = pdk.Deck( + layers=[arc_layer, point_layer, text_layer], + initial_view_state=view_state, + tooltip={ + "html": "{name}
{category}
{visitors:,} annual visitors", + "style": { + "backgroundColor": "#fff", + "color": "#333" + } + }, + map_style="mapbox://styles/mapbox/light-v10", + ) + + st.pydeck_chart(deck, height=250) + +with st.container(key="graphviz_chart"): + import graphviz + + # Create a graphviz graph + graph = graphviz.Digraph() + graph.edge("A", "B") + graph.edge("B", "C") + graph.edge("C", "D") + graph.edge("D", "A") + + st.graphviz_chart(graph) + + +# Input Elements +st.header("Input Elements") + +with st.container(key="button"): + st.button("Click me") + +with st.container(key="download_button"): + st.download_button( + label="Download file", + data="A,B,C\n1,2,3", + file_name="data.csv", + mime="text/csv", + ) + +with st.container(key="form_submit_button"): + with st.form("my_form"): + st.text_input("Enter text") + st.form_submit_button("Submit") + +with st.container(key="link_button"): + st.link_button("Go to gallery", "https://streamlit.io") + +with st.container(key="page_link"): + st.page_link("https://streamlit.io", label="Home", icon="🏠") + st.page_link("https://docs.streamlit.io", label="Page 1", icon="📄") + +with st.container(key="checkbox"): + st.checkbox("Rebuild model each time", True) + +with st.container(key="color_picker"): + st.color_picker("Pick a color", "#5083e8") + for _ in range(20): + st.write("") + +with st.container(key="feedback"): + st.feedback("faces") + st.feedback("stars") + st.feedback("thumbs") + +with st.container(key="multiselect"): + st.multiselect( + "Visible in image", + ["Milk", "Bananas", "Apples", "Potatoes"], + default=["Milk", "Bananas"], + ) + # Leave some space for the dropdown to appear + for _ in range(10): + st.write("") + +with st.container(key="pills"): + selected = st.pills("Tags", ["Sports", "AI", "Politics"], default=["AI"]) + +with st.container(key="radio"): + st.radio("Classify image", ["Dog", "Cat", "Goldfish"]) + +with st.container(key="segmented_control"): + st.segmented_control("Filter", ["Open", "Closed", "All"], default="All") + +with st.container(key="selectbox"): + st.selectbox("Pick one", ["Cats", "Dogs"], index=0) + # Leave some space for the dropdown to appear + for _ in range(10): + st.write("") + +with st.container(key="select_slider"): + st.select_slider( + "Rate", options=["Poor", "Average", "Good", "Excellent"], value="Good" + ) + +with st.container(key="toggle"): + st.toggle("Activate", True) + +with st.container(key="number_input"): + st.number_input("Number of days", min_value=0, max_value=100, value=28) + +with st.container(key="slider"): + st.slider("Pick a number", 0, 100, 42) + +with st.container(key="date_input"): + st.date_input("Initial date", "2019-07-06") + # Leave some space for the dropdown to appear + for _ in range(25): + st.write("") + +with st.container(key="time_input"): + st.time_input("Set an alarm for", datetime.time(8, 45)) + +with st.container(key="chat_input"): + st.chat_input("How can I help?", accept_file=True) + +with st.container(key="text_area"): + st.text_area( + "Text to analzye", + "It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness, it was the epoch of belief.", + ) + +with st.container(key="text_input"): + st.text_input("Movie title", "Life of Brian") + +with st.container(key="audio_input"): + st.audio_input("Record a voice message") + +with st.container(key="file_uploader"): + st.file_uploader("Choose a CSV file") + +# Media Elements +st.header("Media Elements") + +with st.container(key="image"): + st.image( + "https://images.unsplash.com/photo-1548407260-da850faa41e3?ixlib=rb-1.2.1&ixid=eyJhcHBfaWQiOjEyMDd9&auto=format&fit=crop&w=1487&q=80", + caption="Sunrise by the mountains", + ) + +with st.container(key="audio"): + st.audio( + "https://www.soundhelix.com/examples/mp3/SoundHelix-Song-1.mp3", start_time=200 + ) + +with st.container(key="video"): + st.video("https://static.streamlit.io/examples/star.mp4", start_time=2) + +# Chat Elements +st.header("Chat Elements") + +with st.container(key="chat_message"): + st.chat_message("user").write("Hello, how can I help you?") + st.chat_message("assistant").write("I'm here to assist with your questions!") + +with st.container(key="status"): + with st.status("Downloading data...", expanded=True) as status: + st.write("Searching for data...") + st.write("Found URL.") + st.write("Downloading data...") + status.update(label="Download complete!", state="complete") + +# Status Elements +st.header("Status Elements") + +with st.container(key="progress"): + st.progress(0.4, "Initializing...") + st.progress(0.7, "Downloading...") + st.progress(1.0, "Complete!") + +show_spinner = st.button("Show spinner", key="show_spinner") +with st.container(key="spinner"): + if show_spinner: + with st.spinner("Please wait..."): + time.sleep(20) + +with st.container(key="success"): + st.success("Data processing completed successfully! You can now proceed.", icon=":material/check_circle:") + +with st.container(key="info"): + st.info("This visualization uses data from the last 30 days and all categories.", icon=":material/info:") + +with st.container(key="warning"): + st.warning("This action will permanently delete your data. Proceed with caution.", icon=":material/warning:") + +with st.container(key="error"): + st.error("Unable to connect to database. Please check your credentials.", icon=":material/cancel:") + +with st.container(key="exception"): + try: + 1 / 0 + except Exception as e: + st.exception(e) diff --git a/python/api-screenshot-generator/config.yaml b/python/api-screenshot-generator/config.yaml new file mode 100644 index 000000000..ac9d193d6 --- /dev/null +++ b/python/api-screenshot-generator/config.yaml @@ -0,0 +1,118 @@ +screenshot_width: 800 +screenshot_height: 600 +default_padding: 50 +browser_window_width: 400 +enable_overlays: false + +elements: + # Text elements + - name: markdown + - name: title + - name: header + - name: subheader + - name: caption + - name: code + - name: latex + - name: text + - name: divider + + # Data elements + - name: dataframe + - name: data_editor + overlay: overlays/cursor.svg + - name: column_config + - name: table + - name: metric + padding: 150 + - name: json + + # Chart elements + - name: area_chart + - name: bar_chart + - name: line_chart + - name: scatter_chart + - name: map + - name: pyplot + - name: altair_chart + - name: vega_lite_chart + - name: plotly_chart + - name: bokeh_chart + - name: pydeck_chart + - name: graphviz_chart + + # Input widgets + - name: button + padding: 200 + overlay: overlays/cursor.svg + - name: download_button + padding: 200 + overlay: overlays/cursor.svg + - name: form_submit_button + - name: link_button + padding: 200 + overlay: overlays/cursor.svg + - name: page_link + padding: 200 + overlay: overlays/cursor.svg + - name: checkbox + padding: 200 + overlay: overlays/cursor.svg + - name: color_picker + overlay: overlays/cursor.svg + - name: feedback + padding: 100 + overlay: overlays/cursor.svg + - name: multiselect + overlay: overlays/cursor.svg + - name: pills + padding: 100 + overlay: overlays/cursor.svg + - name: radio + padding: 150 + overlay: overlays/cursor.svg + - name: segmented_control + padding: 100 + overlay: overlays/cursor.svg + - name: selectbox + overlay: overlays/cursor.svg + - name: select_slider + overlay: overlays/cursor.svg + - name: toggle + padding: 200 + overlay: overlays/cursor.svg + - name: number_input + overlay: overlays/cursor.svg + - name: slider + overlay: overlays/cursor.svg + - name: date_input + overlay: overlays/cursor.svg + - name: time_input + overlay: overlays/cursor.svg + - name: chat_input + overlay: overlays/cursor.svg + - name: text_area + overlay: overlays/cursor.svg + - name: text_input + overlay: overlays/cursor.svg + - name: audio_input + overlay: overlays/cursor.svg + - name: file_uploader + overlay: overlays/cursor.svg + + # Media elements + - name: image + - name: audio + - name: video + + # Chat elements + - name: chat_message + - name: status + + # Status elements + - name: progress + #- name: spinner + - name: success + - name: info + - name: warning + - name: error + - name: exception \ No newline at end of file diff --git a/python/api-screenshot-generator/overlays/cursor.svg b/python/api-screenshot-generator/overlays/cursor.svg new file mode 100644 index 000000000..d2706febf --- /dev/null +++ b/python/api-screenshot-generator/overlays/cursor.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/python/api-screenshot-generator/preview_app.py b/python/api-screenshot-generator/preview_app.py new file mode 100644 index 000000000..2687ef064 --- /dev/null +++ b/python/api-screenshot-generator/preview_app.py @@ -0,0 +1,104 @@ +import streamlit as st +import os +import yaml +from pathlib import Path + +st.set_page_config( + page_title="Screenshot preview", + page_icon="🖼️", + layout="wide", +) + +# Add CSS hack to adjust image positioning and width +st.markdown(""" + +""", unsafe_allow_html=True) + +st.title("Screenshot preview") +st.markdown("This page shows previews of all screenshots in a similar layout to the Streamlit docs.") + +auto_refresh = st.checkbox("Auto refresh while screenshots are generated", value=True) + +# Wrap this in a fragment to auto refresh while screenshots are generated +@st.fragment(run_every=2 if auto_refresh else None) +def show_gallery(): + + # Path to screenshots directory + SCREENSHOTS_DIR = Path(__file__).parent / "screenshots" + CONFIG_FILE = Path(__file__).parent / "config.yaml" + + # Get the order of elements from config.yaml + def get_elements_order(): + with open(CONFIG_FILE, "r") as f: + yaml_data = yaml.safe_load(f) + elements = yaml_data.get("elements", []) + # Extract just the names in order + return [element["name"] for element in elements] + + # Get the ordered list of elements + try: + elements_order = get_elements_order() + except Exception as e: + st.error(f"Error reading config.yaml: {e}") + elements_order = [] + + # Get all screenshot files + all_screenshot_files = [f for f in os.listdir(SCREENSHOTS_DIR) if f.endswith(".png")] + + # Sort screenshot files according to config.yaml order + def get_element_name(filename): + return filename.replace(".png", "") + + # Create a dictionary of available screenshots + available_screenshots = {get_element_name(f): f for f in all_screenshot_files} + + # Create ordered list of screenshots based on config.yaml order + screenshot_files = [] + for element in elements_order: + if element in available_screenshots: + screenshot_files.append(available_screenshots[element]) + + # Add any screenshots that weren't in config.yaml at the end + for filename in all_screenshot_files: + element_name = get_element_name(filename) + if element_name not in elements_order: + screenshot_files.append(filename) + + # Display screenshots in rows of 4 + for i in range(0, len(screenshot_files), 4): + # Create a new row of 4 columns for every 4 elements with borders + cols = st.columns(4, border=True) + + # Process the next 4 screenshots (or fewer if we're at the end) + for j in range(4): + if i + j < len(screenshot_files): + screenshot_file = screenshot_files[i + j] + + # Get the element name without extension + element_name = screenshot_file.replace(".png", "") + + # Use the column directly with border + with cols[j]: + st.image(str(SCREENSHOTS_DIR / screenshot_file), use_container_width=True) + st.write(f"##### st.{element_name}") + st.write("This is an example text.") + + + # If no screenshots found + if not screenshot_files: + st.warning("No screenshots found in the screenshots directory. Run take_screenshots.py first to generate screenshots.") + st.info(f"Looking for screenshots in: {SCREENSHOTS_DIR}") + +show_gallery() \ No newline at end of file diff --git a/python/api-screenshot-generator/requirements.txt b/python/api-screenshot-generator/requirements.txt new file mode 100644 index 000000000..d7ec0e6fc --- /dev/null +++ b/python/api-screenshot-generator/requirements.txt @@ -0,0 +1,10 @@ +streamlit +matplotlib +pytest-playwright +requests +plotly +bokeh==2.4.3 +numpy<2 +graphviz +pyyaml +cairosvg \ No newline at end of file diff --git a/python/api-screenshot-generator/take_screenshots.py b/python/api-screenshot-generator/take_screenshots.py new file mode 100644 index 000000000..e38a4e37d --- /dev/null +++ b/python/api-screenshot-generator/take_screenshots.py @@ -0,0 +1,417 @@ +import argparse +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +import yaml # Add yaml for YAML parsing +from PIL import Image +from playwright.sync_api import sync_playwright + +# Image size configuration will be loaded from config.yaml + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Take screenshots of Streamlit elements") +parser.add_argument( + "--headed", action="store_true", help="Run browser in headed mode (not headless)" +) +parser.add_argument( + "--only", + nargs="+", + help="Only take screenshots of specific elements (e.g. --only st.multiselect st.selectbox)", +) +args = parser.parse_args() + +# Create screenshots directory if it doesn't exist +SCREENSHOTS_DIR = "screenshots" +os.makedirs(SCREENSHOTS_DIR, exist_ok=True) + +# URL of the Streamlit app (default local development server) +STREAMLIT_URL = "http://localhost:9000" + + +def is_streamlit_running(): + """Check if Streamlit is already running on the default port.""" + try: + response = requests.get(STREAMLIT_URL, timeout=2) + return response.status_code == 200 + except: + return False + + +def start_streamlit_server(): + """Start the Streamlit server as a subprocess.""" + print("Starting Streamlit server...") + + # Get the path to the capture_app.py file + app_path = Path(__file__).parent / "capture_app.py" + + # Start the Streamlit server as a subprocess + process = subprocess.Popen( + [ + sys.executable, + "-m", + "streamlit", + "run", + str(app_path), + "--server.headless", + "true", + "--server.port", + # Start on port 9000 so there's no conflict if the app is running on 8501 + "9000", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Wait for the server to start + max_wait = 30 # Maximum wait time in seconds + start_time = time.time() + + while not is_streamlit_running(): + if time.time() - start_time > max_wait: + print("Error: Streamlit server failed to start within the timeout period.") + process.terminate() + sys.exit(1) + + time.sleep(1) + + print("Streamlit server started successfully!") + return process + + +def get_elements_from_file(): + """Read the list of elements from config.yaml file.""" + config_file = Path(__file__).parent / "config.yaml" + if not config_file.exists(): + print(f"Error: {config_file} not found.") + sys.exit(1) + + with open(config_file, "r") as f: + # Parse YAML file + yaml_data = yaml.safe_load(f) + elements_data = yaml_data.get("elements", []) + + # Get screenshot dimensions, padding and browser window width from yaml file + width = yaml_data.get("screenshot_width", 800) # Default to 800 if not specified + height = yaml_data.get("screenshot_height", 600) # Default to 600 if not specified + padding = yaml_data.get("default_padding", 50) # Default to 50 if not specified + window_width = yaml_data.get("browser_window_width", 1280) # Default to 1280 if not specified + enable_overlays = yaml_data.get("enable_overlays", True) # Default to True if not specified + + print(f"Using screenshot width: {width}, screenshot height: {height}, default padding: {padding}, browser window width: {window_width}") + + # Extract element names and any custom properties + elements = [] + element_properties = {} + + for element_item in elements_data: + element_name = element_item.get("name") + if element_name: + elements.append(element_name) + # Store all properties for this element + element_properties[element_name] = element_item + + # If --only argument is provided, filter the elements list + if args.only: + filtered_elements = [e for e in elements if e in args.only] + if not filtered_elements: + print(f"Error: None of the specified elements {args.only} were found in config.yaml") + sys.exit(1) + print(f"Filtering to only capture: {', '.join(filtered_elements)}") + return filtered_elements, element_properties, width, height, padding, window_width, enable_overlays + + return elements, element_properties, width, height, padding, window_width, enable_overlays + + +def process_screenshot(image_path, element_name=None, element_properties={}, width=800, height=600, padding=50, enable_overlays=False): + """ + Process a screenshot image: + 1. Trim whitespace around the image + 2. Resize to fit with exact padding + 3. Create a rectangular image of width x height + 4. Apply an overlay if specified in element_properties and enable_overlays is True + """ + # Open the image + img = Image.open(image_path) + + # Convert to RGB if not already + if img.mode != "RGB": + img = img.convert("RGB") + + # Get the background color (assuming white background) + bg_color = (255, 255, 255) + + # Get the image data + img_width, img_height = img.size + pixels = img.load() + + # Find the bounding box of non-white pixels + left = img_width + top = img_height + right = 0 + bottom = 0 + + # Scan the image to find the bounds of non-white pixels + for y in range(img_height): + for x in range(img_width): + if pixels[x, y] != bg_color: + left = min(left, x) + top = min(top, y) + right = max(right, x) + bottom = max(bottom, y) + + # If we found non-white pixels, crop the image with a 10px margin + if left < right and top < bottom: + # Add a 1px margin on all sides to not cut off borders + left = max(0, left - 1) + top = max(0, top - 1) + right = min(img_width - 1, right + 1) + bottom = min(img_height - 1, bottom + 1) + + # Crop the image + img = img.crop((left, top, right + 1, bottom + 1)) + + # Check if this element has a custom padding + custom_padding = None + if element_name and element_name in element_properties: + custom_padding = element_properties.get(element_name, {}).get("padding") + + # Use custom padding if available, otherwise use default + padding_value = custom_padding if custom_padding is not None else padding + + # Calculate the available space for the content (dimensions minus padding on both sides) + available_width = width - (2 * padding_value) + available_height = height - (2 * padding_value) + + # Calculate the scale factor to fit the image within the available space + img_width, img_height = img.size + width_scale = available_width / img_width + height_scale = available_height / img_height + + # Use the smaller scale factor to ensure the image fits within the available space + scale_factor = min(width_scale, height_scale) + + # Calculate new dimensions + new_width = int(img_width * scale_factor) + new_height = int(img_height * scale_factor) + + # Resize the image + img = img.resize((new_width, new_height), Image.LANCZOS) + + # Create a new white image of width x height + final_img = Image.new("RGB", (width, height), bg_color) + + # Calculate position to paste the resized image (centered) + paste_x = (width - new_width) // 2 + paste_y = (height - new_height) // 2 + + # Paste the resized image onto the white background + final_img.paste(img, (paste_x, paste_y)) + + # Check if overlays are enabled and if an overlay is specified for this element + if enable_overlays and element_name and element_name in element_properties: + overlay_path = element_properties.get(element_name, {}).get("overlay") + if overlay_path: + try: + # Get the full path to the overlay file + overlay_full_path = Path(__file__).parent / overlay_path + + # Check if the overlay file exists + if not overlay_full_path.exists(): + print(f"Error: Overlay file {overlay_path} not found for element {element_name}") + else: + # Open the overlay image + if overlay_path.lower().endswith('.svg'): + # For SVG files, use cairosvg to convert to PNG first + import cairosvg + import io + + # Convert SVG to PNG in memory + png_data = cairosvg.svg2png(url=str(overlay_full_path), output_width=width, output_height=height) + overlay_img = Image.open(io.BytesIO(png_data)) + else: + # For PNG and other image formats + overlay_img = Image.open(overlay_full_path) + + # Convert to RGBA if not already + if overlay_img.mode != "RGBA": + overlay_img = overlay_img.convert("RGBA") + + # Check if overlay dimensions match the final image + if overlay_img.size != (width, height): + print(f"Error: Overlay dimensions {overlay_img.size} do not match screenshot dimensions ({width}, {height}) for element {element_name}") + else: + # Composite the overlay onto the final image + if overlay_img.mode == 'RGBA': + # If overlay has transparency, use alpha_composite + # Convert final_img to RGBA + final_img_rgba = final_img.convert("RGBA") + final_img = Image.alpha_composite(final_img_rgba, overlay_img) + else: + # If no transparency, just paste + final_img.paste(overlay_img, (0, 0), overlay_img if 'A' in overlay_img.mode else None) + except Exception as e: + print(f"Error applying overlay for element {element_name}: {e}") + + # Save the final image + final_img.save(image_path) + + +def take_screenshots(): + # Get the list of elements from config.yaml + elements, element_properties, width, height, padding, window_width, enable_overlays = get_elements_from_file() + print(f"Found {len(elements)} elements to capture.") + + # Clear the screenshots directory before generating new screenshots + for file in Path(SCREENSHOTS_DIR).glob("*.png"): + file.unlink() + print(f"Cleared existing screenshots from {SCREENSHOTS_DIR}") + + # Start the Streamlit server if it's not already running + streamlit_process = None + if not is_streamlit_running(): + streamlit_process = start_streamlit_server() + + # Launch Playwright + with sync_playwright() as playwright: + # Launch the browser + browser = playwright.chromium.launch(headless=not args.headed) + + # Create a new browser context and page + context = browser.new_context( + viewport={"width": window_width, "height": 720}, + device_scale_factor=4.0 + ) + page = context.new_page() + + # Navigate to the Streamlit app + page.goto(STREAMLIT_URL) + + # Wait for the app to load + page.wait_for_selector("h1:has-text('Streamlit Elements Gallery')") + print("Streamlit app loaded successfully") + + # Take screenshots of each element + for element in elements: + try: + # Create a key for the element (used in the HTML) + element_key = element.replace(".", "_") + + # Find the container using the key-based CSS selector + selector = f".st-key-{element_key}" + container = page.locator(selector).first + + # Scroll to the container + container.scroll_into_view_if_needed() + + # Wait a moment for any animations to complete + time.sleep(0.5) + + if container.count() == 0: + print(f"Warning: Element '{element}' not found on the page") + continue + + # Special handling for certain elements + + # Special handling for selectbox and multiselect to show dropdown + if element == "selectbox" or element == "multiselect": + # Find the select element within the container and click it + container.locator("div[data-baseweb='select']").first.click() + # Wait for dropdown to appear + time.sleep(1) + + # Special handling for date input to show dropdown + elif element == "date_input": + # Find the select element within the container and click it + container.locator("div[data-baseweb='input']").first.click() + # Wait for dropdown to appear + time.sleep(1) + + # Special handling for color picker to show the color picker panel + elif element == "color_picker": + # Find the color picker input and click it + container.locator(".stColorPicker > div").first.click() + # Wait for color picker to appear + time.sleep(1) + + # Special handling for data_editor to click on it first + elif element == "data_editor": + # Find the editor element within the container + editor_box = container.locator(".stDataFrame").bounding_box() + # Click on the editor to view cell details + page.mouse.dblclick(editor_box["x"] + 100, editor_box["y"] + 90) + time.sleep(1) + # Click again to get into editing mode + page.mouse.click(editor_box["x"] + 100, editor_box["y"] + 90) + # Wait a moment for any UI response + time.sleep(0.5) + + # Special handling for page link to hover over it + elif element == "page_link": + # Find the link element within the container + link_element = container.locator("a").last + # Hover over the link + link_element.hover() + # Wait a moment for hover effect to appear + time.sleep(0.5) + + # Special handling for spinner to show the spinner animation + elif element == "spinner": + # Find and click the "Show spinner" button using its key + show_spinner_button = page.locator(".st-key-show_spinner button").first + time.sleep(0.5) + show_spinner_button.click() + # Wait for the spinner to appear + page.wait_for_selector(".stSpinner", state="visible", timeout=5000) + + # Get the bounding box + box = container.bounding_box() + + # Take a screenshot with padding to ensure we don't cut off any content + screenshot_path = f"{SCREENSHOTS_DIR}/{element}.png" + page.screenshot( + path=screenshot_path, + clip={ + "x": box["x"] - 10, + "y": box["y"] - 10, + "width": box["width"] + 20, + "height": box["height"] + 20, + }, + ) + + # For select elements, click elsewhere to close the dropdown after screenshot + if element in [ + "selectbox", + "multiselect", + "date_input", + "color_picker", + "data_editor", + ]: + # Click on the page header to close the dropdown + page.locator( + "h1:has-text('Streamlit Elements Gallery')" + ).click() + time.sleep(0.5) + + # Trim any excess whitespace + process_screenshot(screenshot_path, element, element_properties, width, height, padding, enable_overlays) + + print(f"Saved screenshot of {element} to {screenshot_path}") + except Exception as e: + print(f"Error capturing {element}: {e}") + + # Close the browser + browser.close() + + # If we started the Streamlit server, terminate it + if streamlit_process: + streamlit_process.terminate() + print("Streamlit server terminated") + + +if __name__ == "__main__": + take_screenshots()