Skip to content

Commit

Permalink
239 plot trips and routes real (#251)
Browse files Browse the repository at this point in the history
* feat: add code from other branch

* feat: add improvements from mirror branch

* fix: CSS left in readme from merge to main

* 239 hotfix refactor (#252)

* refactor: Reduce tests to common func

* refactor: Adjust plot_trips & plot_routes -> plot_service

* fix: Failing test in plot core updated params

* fix: re-organise plotting functions

---------

Co-authored-by: Browning <[email protected]>

---------

Co-authored-by: r-leyshon <[email protected]>
Co-authored-by: Richard Leyshon <[email protected]>
  • Loading branch information
3 people authored Feb 20, 2024
1 parent 2b0444e commit 7bb4863
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 0 deletions.
165 changes: 165 additions & 0 deletions src/transport_performance/gtfs/multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pandas as pd
import folium
from folium.plugins import FastMarkerCluster
import plotly.express as px
import plotly.graph_objs as go

from transport_performance.gtfs.validation import GtfsInstance
from transport_performance.utils.defence import (
Expand Down Expand Up @@ -73,6 +75,8 @@ class MultiGtfsInstance:
creates a calendar table from calendar_times.
get_dates()
Get the range of dates that the gtfs(s) span.
plot_service()
Plot a timeseries of route or trip counts.
Raises
------
Expand Down Expand Up @@ -747,3 +751,164 @@ def get_dates(self, return_range: bool = True) -> list:
if return_range:
return [min(sorted_dates), max(sorted_dates)]
return sorted_dates

def _reformat_col_names(self, col_name: str, cap_all: bool = True):
"""Convert a column name to a more readable format."""
parts = col_name.split("_")
for i, part in enumerate(parts):
part = list(part)
part[0] = part[0].upper()
parts[i] = "".join(part)
if not cap_all:
break
return " ".join(parts)

def _plot_core(
self,
df: pd.DataFrame,
count_col: str = "routes",
width: int = 1000,
height: int = 550,
title: str = None,
kwargs: dict = {},
rolling_average: Union[int, None] = None,
line_date: Union[str, None] = None,
):
"""Plot a timeseries for trip/route count."""
# defences
_type_defence(df, "df", pd.DataFrame)
_type_defence(count_col, "count_col", str)
_type_defence(width, "width", int)
_type_defence(height, "height", int)
_type_defence(title, "title", (str, type(None)))
_type_defence(kwargs, "kwargs", dict)
_type_defence(rolling_average, "rolling_average", (int, type(None)))
_type_defence(line_date, "line_date", (str, type(None)))
# preparation
LABEL_FORMAT = {
count_col: self._reformat_col_names(count_col),
"date": "Date",
}
PLOT_TITLE = {
"text": f"{self._reformat_col_names(count_col, False)} over time",
"x": 0.5,
"xanchor": "center",
}
if "route_type" in df.columns:
kwargs["color"] = "route_type"
LABEL_FORMAT["route_type"] = self._reformat_col_names("route_type")
PLOT_TITLE["text"] = PLOT_TITLE["text"] + " by route type"
kwargs["width"] = width
kwargs["height"] = height
if title:
PLOT_TITLE["text"] = title

if rolling_average:
new_count_col = f"{rolling_average} Day Rolling Average"
temp_dfs = []
# impute route type if there is none
if "route_type" not in df.columns:
df["route_type"] = 15000
for rt in df.route_type.unique():
temp = df[df.route_type == rt].copy()
# resample to account for missing dates
temp = temp.set_index("date").resample("1D").sum()
# add correct route type
temp["route_type"] = rt
# calculate rolling average over [x] days
temp[new_count_col] = (
temp[count_col]
.rolling(window=rolling_average, center=True)
.mean()
)
temp_dfs.append(temp)
df = pd.concat(temp_dfs).sort_values("date").reset_index()
count_col = new_count_col
# plotting
fig = px.line(df, x="date", y=count_col, labels=LABEL_FORMAT, **kwargs)
fig.update_layout(title=PLOT_TITLE)
if line_date:
fig.add_vline(x=line_date, line_dash="dash")

return fig

def plot_service(
self,
service_type: str = "routes",
route_type: bool = True,
width: int = 1000,
height: int = 550,
title: str = None,
plotly_kwargs: dict = None,
rolling_average: Union[int, None] = None,
line_date: Union[str, None] = None,
) -> go.Figure:
"""Create a line plot of route or trip counts over time.
Parameters
----------
service_type: str, optional
Whether to plot 'routes' or 'trips'. By default 'routes'.
route_type : bool, optional
Whether or not to draw a line for each modality, by default True
width : int, optional
Plot width, by default 1000
height : int, optional
Plot height, by default 550
title : str, optional
Plot title, by default None
plotly_kwargs : dict, optional
Kwargs to pass to plotly.express.line, by default None
rolling_average : Union[int, None], optional
How many days to calculate the rolling average over. When left as
None, rolling average is not used.
The rolling average is calculated from the center, meaning if ra=3,
the average will be calculated from the current date, previous date
and following date. Missing dates are imputed and treated as having
values of 0.
line_date : Union[str, None], optional
A data to draw a dashed vertical line on. Date should be in format:
YYYY-MM-DD, by default None
Returns
-------
go.Figure
The timerseries plot
"""
# defences
_type_defence(service_type, "service_type", str)
_type_defence(route_type, "route_type", bool)
_type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None)))
if not plotly_kwargs:
plotly_kwargs = {}
SERVICE_TYPES = ["routes", "trips"]
if service_type not in SERVICE_TYPES:
raise ValueError(
"`service_type` expects 'routes' or 'trips',"
f" found: {service_type}"
)
if service_type == "routes":
data = self.summarise_routes().copy()
count_col = "route_count"
else:
data = self.summarise_trips().copy()
count_col = "trip_count"
if not route_type:
data = (
data.drop("route_type", axis=1)
.groupby("date")
.sum()
.reset_index()
)
figure = self._plot_core(
data,
count_col=count_col,
width=width,
height=height,
title=title,
kwargs=plotly_kwargs,
rolling_average=rolling_average,
line_date=line_date,
)
return figure
73 changes: 73 additions & 0 deletions tests/gtfs/test_multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pandas as pd
import folium
from pyprojroot import here
import plotly.graph_objs as go

from transport_performance.gtfs.multi_validation import (
MultiGtfsInstance,
Expand Down Expand Up @@ -691,3 +692,75 @@ def test_get_dates(self, multi_gtfs_fixture, multi_gtfs_altered_fixture):
len(multi_gtfs_altered_fixture.get_dates(return_range=False)) == 6
), "Unexpected number of dates"
pass

def test__plot_core(self, multi_gtfs_fixture):
"""General tests for _plot_core()."""
# route summary
data = multi_gtfs_fixture.summarise_routes()
route_fig = multi_gtfs_fixture._plot_core(data, "route_count")
assert isinstance(route_fig, go.Figure), "Route counts not plotted"
# trip summary
data = multi_gtfs_fixture.summarise_trips()
trip_fig = multi_gtfs_fixture._plot_core(data, "trip_count")
assert isinstance(trip_fig, go.Figure), "Trip counts not plotted"
# trip summary with custom title
trip_fig = multi_gtfs_fixture._plot_core(
data, "trip_count", title="test"
)
found_title = trip_fig.layout["title"]["text"]
assert found_title == "test", "Title not as expected"
# trip summary with custom dimensions
trip_fig = multi_gtfs_fixture._plot_core(
data, "trip_count", height=100, width=150
)
found_height = trip_fig.layout["height"]
found_width = trip_fig.layout["width"]
assert found_height == 100, "Height not as expected"
assert found_width == 150, "Width not as expected"
# custom kwargs
trip_fig = multi_gtfs_fixture._plot_core(
data, "trip_count", kwargs={"markers": True}
)
assert trip_fig.data[0]["mode"] in [
"markers+lines",
"lines+markers",
], "Markers not plotted"
# rolling average
avg_fig = multi_gtfs_fixture._plot_core(
data, "trip_count", rolling_average=7
)
found_ylabel = avg_fig.layout["yaxis"]["title"]["text"]
assert (
found_ylabel == "7 Day Rolling Average"
), "Rolling average not plotted"
# draw a line on a date
avg_fig = multi_gtfs_fixture._plot_core(
data, "trip_count", rolling_average=7, line_date="2023-12-01"
)
found_line = avg_fig.layout["shapes"][0]["line"]["dash"]
assert found_line == "dash", "Date line not plotted"

def test_plot_service(self, multi_gtfs_fixture):
"""General tests for .plot_service()."""
# plot route_type
fig = multi_gtfs_fixture.plot_service(service_type="routes")
assert len(fig.data) == 2, "Not plotted by modality"
# plot without route type
fig = multi_gtfs_fixture.plot_service(
service_type="routes", route_type=False
)
assert len(fig.data) == 1, "Plot not as expected"
# rolling average + no route type
avg_fig = multi_gtfs_fixture.plot_service(
service_type="routes", rolling_average=7, route_type=False
)
leg_status = avg_fig.data[0]["showlegend"]
assert not leg_status, "Multiple route types found"
# plot trips
fig = multi_gtfs_fixture.plot_service(service_type="trips")
assert len(fig.data) == 2, "Not plotted by modality"
# plot without route type
fig = multi_gtfs_fixture.plot_service(
service_type="trips", route_type=False
)
assert len(fig.data) == 1, "Plot not as expected"

0 comments on commit 7bb4863

Please sign in to comment.