Skip to content

Commit

Permalink
auto-add ruff-inferrable types
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Apr 8, 2024
1 parent 04d368a commit e23f70d
Show file tree
Hide file tree
Showing 66 changed files with 444 additions and 440 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: ^docs

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.4
rev: v0.3.5
hooks:
- id: ruff
args: [--fix, --ignore, D]
Expand Down
2 changes: 1 addition & 1 deletion docs_rst/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,5 +309,5 @@ def skip(app, what, name, obj, skip, options):


# AJ: a hack found online to get __init__ to show up in docs
def setup(app):
def setup(app) -> None:
app.connect("autodoc-skip-member", skip)
44 changes: 22 additions & 22 deletions fireworks/core/firework.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from typing import Any, Iterator, Sequence
from typing import Any, Iterator, NoReturn, Sequence

from monty.io import reverse_readline, zopen
from monty.os.path import zpath
Expand Down Expand Up @@ -49,7 +49,7 @@ class FiretaskBase(defaultdict, FWSerializable, abc.ABC):
# if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init
optional_params = None

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
dict.__init__(self, *args, **kwargs)

required_params = self.required_params or []
Expand All @@ -68,7 +68,7 @@ def __init__(self, *args, **kwargs):
)

@abc.abstractmethod
def run_task(self, fw_spec):
def run_task(self, fw_spec) -> NoReturn:
"""
This method gets called when the Firetask is run. It can take in a
Firework spec, perform some task using that data, and then return an
Expand Down Expand Up @@ -101,7 +101,7 @@ def to_dict(self):
def from_dict(cls, m_dict):
return cls(m_dict)

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.fw_name}>:{dict(self)}"

# not strictly needed here for pickle/unpickle, but complements __setstate__
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
defuse_children=False,
defuse_workflow=False,
propagate=False,
):
) -> None:
"""
Args:
stored_data (dict): data to store from the run. Does not affect the operation of FireWorks.
Expand Down Expand Up @@ -209,7 +209,7 @@ def skip_remaining_tasks(self):
"""
return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow

def __str__(self):
def __str__(self) -> str:
return "FWAction\n" + pprint.pformat(self.to_dict())


Expand Down Expand Up @@ -241,7 +241,7 @@ def __init__(
fw_id=None,
parents=None,
updated_on=None,
):
) -> None:
"""
Args:
tasks (Firetask or [Firetask]): a list of Firetasks to run in sequence.
Expand Down Expand Up @@ -286,7 +286,7 @@ def state(self):
return self._state

@state.setter
def state(self, state):
def state(self, state) -> None:
"""
Setter for the FW state, which triggers updated_on change.
Expand Down Expand Up @@ -318,7 +318,7 @@ def to_dict(self):

return m_dict

def _rerun(self):
def _rerun(self) -> None:
"""
Moves all Launches to archived Launches and resets the state to 'WAITING'. The Firework
can thus be re-run even if it was Launched in the past. This method should be called by
Expand Down Expand Up @@ -367,7 +367,7 @@ def from_dict(cls, m_dict):
tasks, m_dict["spec"], name, launches, archived_launches, state, created_on, fw_id, updated_on=updated_on
)

def __str__(self):
def __str__(self) -> str:
return f"Firework object: (id: {int(self.fw_id)} , name: {self.fw_name})"

def __iter__(self) -> Iterator[FiretaskBase]:
Expand All @@ -385,7 +385,7 @@ class Tracker(FWSerializable):

MAX_TRACKER_LINES = 1000

def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False):
def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False) -> None:
"""
Args:
filename (str)
Expand Down Expand Up @@ -437,7 +437,7 @@ def from_dict(cls, m_dict):
m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False)
)

def __str__(self):
def __str__(self) -> str:
return f"### Filename: {self.filename}\n{self.content}"


Expand All @@ -456,7 +456,7 @@ def __init__(
state_history=None,
launch_id=None,
fw_id=None,
):
) -> None:
"""
Args:
state (str): the state of the Launch (e.g. RUNNING, COMPLETED)
Expand All @@ -483,7 +483,7 @@ def __init__(
self.launch_id = launch_id
self.fw_id = fw_id

def touch_history(self, update_time=None, checkpoint=None):
def touch_history(self, update_time=None, checkpoint=None) -> None:
"""
Updates the update_on field of the state history of a Launch. Used to ping that a Launch
is still alive.
Expand All @@ -496,7 +496,7 @@ def touch_history(self, update_time=None, checkpoint=None):
self.state_history[-1]["checkpoint"] = checkpoint
self.state_history[-1]["updated_on"] = update_time

def set_reservation_id(self, reservation_id):
def set_reservation_id(self, reservation_id) -> None:
"""
Adds the job_id to the reservation.
Expand All @@ -517,7 +517,7 @@ def state(self):
return self._state

@state.setter
def state(self, state):
def state(self, state) -> None:
"""
Setter for the Launch's state. Automatically triggers an update to state_history.
Expand Down Expand Up @@ -627,7 +627,7 @@ def from_dict(cls, m_dict):
m_dict["fw_id"],
)

def _update_state_history(self, state):
def _update_state_history(self, state) -> None:
"""
Internal method to update the state history whenever the Launch state is modified.
Expand Down Expand Up @@ -675,7 +675,7 @@ class Workflow(FWSerializable):
class Links(dict, FWSerializable):
"""An inner class for storing the DAG links between FireWorks."""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

for k, v in list(self.items()):
Expand Down Expand Up @@ -906,7 +906,7 @@ def apply_action(self, action: FWAction, fw_id: int) -> list[int]:
# Traverse whole sub-workflow down to leaves.
visited_cfid = set() # avoid double-updating for diamond deps

def recursive_update_spec(fw_id):
def recursive_update_spec(fw_id) -> None:
for cfid in self.links[fw_id]:
if cfid not in visited_cfid:
visited_cfid.add(cfid)
Expand All @@ -926,7 +926,7 @@ def recursive_update_spec(fw_id):
if action.mod_spec and action.propagate:
visited_cfid = set()

def recursive_mod_spec(fw_id):
def recursive_mod_spec(fw_id) -> None:
for cfid in self.links[fw_id]:
if cfid not in visited_cfid:
visited_cfid.add(cfid)
Expand Down Expand Up @@ -1349,10 +1349,10 @@ def from_Firework(cls, fw: Firework, name: str | None = None, metadata=None) ->
name = name if name else fw.name
return Workflow([fw], None, name=name, metadata=metadata, created_on=fw.created_on, updated_on=fw.updated_on)

def __str__(self):
def __str__(self) -> str:
return f"Workflow object: (fw_ids: {[*self.id_fw]} , name: {self.name})"

def remove_fws(self, fw_ids):
def remove_fws(self, fw_ids) -> None:
"""
Remove the fireworks corresponding to the input firework ids and update the workflow i.e the
parents of the removed fireworks become the parents of the children fireworks (only if the
Expand Down
2 changes: 1 addition & 1 deletion fireworks/core/fworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class FWorker(FWSerializable):
def __init__(self, name="Automatically generated Worker", category="", query=None, env=None):
def __init__(self, name="Automatically generated Worker", category="", query=None, env=None) -> None:
"""
Args:
name (str): the name of the resource, should be unique
Expand Down
Loading

0 comments on commit e23f70d

Please sign in to comment.