Skip to content

Commit

Permalink
sensai.util: Minimise required dependencies for all modules in this p…
Browse files Browse the repository at this point in the history
…ackage

  in preparation of the release of sensAI-utils
  • Loading branch information
opcode81 committed Aug 10, 2024
1 parent 2691c10 commit b55df4f
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Improvements/Changes

* `util`
* Minimise required dependencies for all modules in this package in preparation of the release of *sensAI-utils*
* `util.logging`:
* Fix type annotations of `run_main` and `run_cli`

Expand Down
9 changes: 5 additions & 4 deletions src/sensai/util/datastruct.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Sequence, Optional, TypeVar, Generic, Tuple, Dict, Any

import pandas as pd
from typing import Sequence, Optional, TypeVar, Generic, Tuple, Dict, Any, TYPE_CHECKING

from . import sequences as array_util
from .string import ToStringMixin, dict_string

if TYPE_CHECKING:
import pandas as pd

TKey = TypeVar("TKey")
TValue = TypeVar("TValue")
TSortedKeyValueStructure = TypeVar("TSortedKeyValueStructure", bound="SortedKeyValueStructure")
Expand Down Expand Up @@ -303,7 +304,7 @@ def __len__(self):
return len(self.keys)

@classmethod
def from_series(cls, s: pd.Series):
def from_series(cls, s: "pd.Series"):
"""
Creates an instance from a pandas Series, using the series' index as the keys and its values as the values
Expand Down
17 changes: 9 additions & 8 deletions src/sensai/util/io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import io
import logging
import os
from typing import Sequence, Optional, Tuple, List, Any
from typing import Sequence, Optional, Tuple, List, Any, TYPE_CHECKING

import matplotlib.figure
from matplotlib import pyplot as plt
import pandas as pd
if TYPE_CHECKING:
from matplotlib import pyplot as plt
import pandas as pd

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,28 +81,29 @@ def write_text_file_lines(self, filename_suffix: str, lines: List[str]):
write_text_file_lines(lines, p)
return p

def write_data_frame_text_file(self, filename_suffix: str, df: pd.DataFrame):
def write_data_frame_text_file(self, filename_suffix: str, df: "pd.DataFrame"):
p = self.path(filename_suffix, extension_to_add="df.txt", valid_other_extensions="txt")
if self.enabled:
self.log.info(f"Saving data frame text file {p}")
with open(p, "w") as f:
f.write(df.to_string())
return p

def write_data_frame_csv_file(self, filename_suffix: str, df: pd.DataFrame, index=True, header=True):
def write_data_frame_csv_file(self, filename_suffix: str, df: "pd.DataFrame", index=True, header=True):
p = self.path(filename_suffix, extension_to_add="csv")
if self.enabled:
self.log.info(f"Saving data frame CSV file {p}")
df.to_csv(p, index=index, header=header)
return p

def write_figure(self, filename_suffix: str, fig: plt.Figure, close_figure: Optional[bool] = None):
def write_figure(self, filename_suffix: str, fig: "plt.Figure", close_figure: Optional[bool] = None):
"""
:param filename_suffix: the filename suffix, which may or may not include a file extension, valid extensions being {"png", "jpg"}
:param fig: the figure to save
:param close_figure: whether to close the figure after having saved it; if None, use default passed at construction
:return: the path to the file that was written (or would have been written if the writer was enabled)
"""
from matplotlib import pyplot as plt
p = self.path(filename_suffix, extension_to_add="png", valid_other_extensions=("jpg",))
if self.enabled:
self.log.info(f"Saving figure {p}")
Expand All @@ -112,7 +113,7 @@ def write_figure(self, filename_suffix: str, fig: plt.Figure, close_figure: Opti
plt.close(fig)
return p

def write_figures(self, figures: Sequence[Tuple[str, matplotlib.figure.Figure]], close_figures=False):
def write_figures(self, figures: Sequence[Tuple[str, "plt.Figure"]], close_figures=False):
for name, fig in figures:
self.write_figure(name, fig, close_figure=close_figures)

Expand Down
22 changes: 17 additions & 5 deletions src/sensai/util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
from datetime import datetime
from io import StringIO
from logging import *
from typing import List, Callable, Any, Optional, TypeVar
from typing import List, Callable, Optional, TypeVar, TYPE_CHECKING

from .time import format_duration

if TYPE_CHECKING:
import pandas as pd

import pandas as pd

log = getLogger(__name__)

Expand Down Expand Up @@ -77,7 +81,11 @@ def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
getLogger("matplotlib").setLevel(lg.INFO)
getLogger("urllib3").setLevel(lg.INFO)
getLogger("msal").setLevel(lg.INFO)
pd.set_option('display.max_colwidth', 255)
try:
import pandas as pd
pd.set_option('display.max_colwidth', 255)
except ImportError:
pass
for callback in _configureCallbacks:
callback()

Expand Down Expand Up @@ -230,7 +238,7 @@ def get_elapsed_time_secs(self) -> float:
"""
return self._elapsed_secs + self._get_elapsed_time_since_last_start()

def get_elapsed_timedelta(self) -> pd.Timedelta:
def get_elapsed_timedelta(self) -> "pd.Timedelta":
"""
:return: the elapsed time as a pandas.Timedelta object
"""
Expand All @@ -244,7 +252,11 @@ def get_elapsed_time_string(self) -> str:
if secs < 60:
return f"{secs:.3f} seconds"
else:
return str(pd.Timedelta(secs, unit="s"))
try:
import pandas as pd
return str(pd.Timedelta(secs, unit="s"))
except ImportError:
return format_duration(secs)


class StopWatchManager:
Expand Down
3 changes: 1 addition & 2 deletions src/sensai/util/math.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import math
from typing import List

import scipy.stats

from .string import object_repr, ToStringMixin


Expand All @@ -13,6 +11,7 @@ def __init__(self, mean=0, std=1, unit_max=False):
:param std: the standard deviation
:param unit_max: if True, scales the distribution's pdf such that its maximum value becomes 1
"""
import scipy.stats
self.unitMax = unit_max
self.mean = mean
self.std = std
Expand Down
4 changes: 2 additions & 2 deletions src/sensai/util/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Union

import joblib

from .io import S3Object, is_s3_path

log = logging.getLogger(__name__)
Expand All @@ -29,6 +27,7 @@ def _load_with_error_log(loader: Callable):
import cloudpickle
return _load_with_error_log(cloudpickle.load)
elif backend == "joblib":
import joblib
return joblib.load(f)
else:
raise ValueError(f"Unknown backend '{backend}'. Supported backends are 'pickle', 'joblib' and 'cloudpickle'")
Expand Down Expand Up @@ -60,6 +59,7 @@ def open_file():
failing_paths = PickleFailureDebugger.debug_failure(obj)
raise AttributeError(f"Cannot pickle paths {failing_paths} of {obj}: {str(e)}")
elif backend == "joblib":
import joblib
joblib.dump(obj, f, protocol=protocol)
elif backend == "cloudpickle":
import cloudpickle
Expand Down
30 changes: 22 additions & 8 deletions src/sensai/util/time.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from datetime import time
from typing import TYPE_CHECKING

import pandas as pd
if TYPE_CHECKING:
import pandas as pd


def ts_next_month(ts: pd.Timestamp) -> pd.Timestamp:
def ts_next_month(ts: "pd.Timestamp") -> "pd.Timestamp":
m = ts.month
if m == 12:
return ts.replace(year=ts.year+1, month=1)
else:
return ts.replace(month=m+1)


def time_of_day(ts: pd.Timestamp) -> float:
def time_of_day(ts: "pd.Timestamp") -> float:
"""
:param ts: the timestamp
:return: the time of day as a floating point number in [0, 24)
Expand All @@ -20,11 +22,11 @@ def time_of_day(ts: pd.Timestamp) -> float:


class TimeInterval:
def __init__(self, start: pd.Timestamp, end: pd.Timestamp):
def __init__(self, start: "pd.Timestamp", end: "pd.Timestamp"):
self.start = start
self.end = end

def contains(self, t: pd.Timestamp):
def contains(self, t: "pd.Timestamp"):
return self.start <= t <= self.end

def contains_time(self, t: time):
Expand All @@ -45,9 +47,21 @@ def overlaps_with(self, other: "TimeInterval") -> bool:
def intersection(self, other: "TimeInterval") -> "TimeInterval":
return TimeInterval(max(self.start, other.start), min(self.end, other.end))

def time_delta(self) -> pd.Timedelta:
def time_delta(self) -> "pd.Timedelta":
return self.end - self.start

def mid_timestamp(self) -> pd.Timestamp:
def mid_timestamp(self) -> "pd.Timestamp":
midTime: pd.Timestamp = self.start + 0.5 * self.time_delta()
return midTime
return midTime


def format_duration(seconds: float):
if seconds < 60:
return f"{seconds:.1f} seconds"
elif seconds < 3600:
minutes, secs = divmod(seconds, 60)
return f"{int(minutes)} minutes, {secs:.1f} seconds"
else:
hours, remainder = divmod(seconds, 3600)
minutes, secs = divmod(remainder, 60)
return f"{int(hours)} hours, {int(minutes)} minutes, {secs:.1f} seconds"

0 comments on commit b55df4f

Please sign in to comment.