Skip to content

Commit

Permalink
feat(core): Add assign shifts evenly preference
Browse files Browse the repository at this point in the history
The `diff` and `L2` usages are directly copied from the 2023/08/20 POC.
  • Loading branch information
j3soon committed Jul 26, 2024
1 parent 43faf1e commit 1481e61
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 42 deletions.
7 changes: 6 additions & 1 deletion core/nurse_scheduling/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from ortools.sat.python import cp_model


class Context:
def __init__(self) -> None:
self.startdate = None
Expand All @@ -9,10 +12,12 @@ def __init__(self) -> None:
self.n_days = None
self.n_requirements = None
self.n_people = None
self.model = None
self.model: cp_model.CpModel = None
self.model_vars = None
self.shifts = None
self.map_dr_p = None
self.map_dp_r = None
self.map_d_rp = None
self.map_r_dp = None
self.map_p_dr = None
self.objective = None
30 changes: 28 additions & 2 deletions core/nurse_scheduling/preference_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import utils
from .context import Context

def all_requirements_fulfilled(ctx, args):
def all_requirements_fulfilled(ctx: Context, args, preference_id):
# Hard constraint
# For all shifts, the requirements (# of people) must be fulfilled.
# Note that a shift is represented as (d, r)
Expand All @@ -10,7 +11,7 @@ def all_requirements_fulfilled(ctx, args):
required_n_people = utils.required_n_people(ctx.requirements[r])
ctx.model.Add(actual_n_people == required_n_people)

def all_people_work_at_most_one_shift_per_day(ctx, args):
def all_people_work_at_most_one_shift_per_day(ctx: Context, args, preference_id):
# Hard constraint
# For all people, for all days, only work at most one shift.
# Note that a shift in day `d` can be represented as `r` instead of (d, r).
Expand All @@ -20,7 +21,32 @@ def all_people_work_at_most_one_shift_per_day(ctx, args):
maximum_n_shifts = 1
ctx.model.Add(actual_n_shifts <= maximum_n_shifts)

def assign_shifts_evenly(ctx: Context, args, preference_id):
# Soft constraint
# For all people, spread the shifts evenly.
# Note that a shift is represented as (d, r)
# i.e., max(weight * (actual_n_shifts - target_n_shifts) ** 2), for all p,
# where actual_n_shifts = sum_{(d, r)}(shifts[(d, r, p)])
for p in range(ctx.n_people):
actual_n_shifts = sum(ctx.shifts[(d, r, p)] for d, r in ctx.map_p_dr[p])
target_n_shifts = round(ctx.n_days * sum(requirement.required_people for requirement in ctx.requirements) / ctx.n_people)
unique_var_prefix = f"pref_{preference_id}_p_{p}_"

# Construct: L2 = actual_n_shifts - target_n_shifts) ** 2
L, U = -100, 100 # TODO: Calculate the actual bounds
diff_var_name = f"{unique_var_prefix}_diff"
ctx.model_vars[diff_var_name] = diff = ctx.model.NewIntVar(L, U, diff_var_name)
ctx.model.Add(diff == (actual_n_shifts - target_n_shifts))
L2_var_name = f"{unique_var_prefix}_L2"
ctx.model_vars[L2_var_name] = L2 = ctx.model.NewIntVar(0, max(L**2, U**2), L2_var_name)
ctx.model.AddMultiplicationEquality(L2, diff, diff)

# Add the objective
weight = -1
ctx.objective += weight * L2

PREFERENCE_TYPES_TO_FUNC = {
"all requirements fulfilled": all_requirements_fulfilled,
"all people work at most one shift per day": all_people_work_at_most_one_shift_per_day,
"assign shifts evenly": assign_shifts_evenly,
}
87 changes: 48 additions & 39 deletions core/nurse_scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,64 +16,69 @@ def schedule(filepath: str, validate=True, deterministic=False):
logging.info("Extracting scenario data...")
if scenario.apiVersion != "alpha":
raise NotImplementedError(f"Unsupported API version: {scenario.apiVersion}")
startdate = scenario.startdate
enddate = scenario.enddate
requirements = scenario.requirements
people = scenario.people
preferences = scenario.preferences
ctx = Context()
ctx.startdate = scenario.startdate
ctx.enddate = scenario.enddate
ctx.requirements = scenario.requirements
ctx.people = scenario.people
ctx.preferences = scenario.preferences
del scenario
n_days = (enddate - startdate).days + 1
n_requirements = len(requirements)
n_people = len(people)
dates = [startdate + timedelta(days=d) for d in range(n_days)]
ctx.n_days = (ctx.enddate - ctx.startdate).days + 1
ctx.n_requirements = len(ctx.requirements)
ctx.n_people = len(ctx.people)
ctx.dates = [ctx.startdate + timedelta(days=d) for d in range(ctx.n_days)]

logging.info("Initializing solver model...")
model = cp_model.CpModel()
shifts = {}
ctx.model = cp_model.CpModel()
ctx.model_vars = {}
ctx.shifts = {}
"""A set of indicator variables that are 1 if and only if
a person (p) is assigned to a shift (d, r)."""

logging.info("Creating shift variables...")
# Ref: https://developers.google.com/optimization/scheduling/employee_scheduling
for d in range(n_days):
for r in range(n_requirements):
for d in range(ctx.n_days):
for r in range(ctx.n_requirements):
# TODO(Optimize): Skip if no people is required in that day
for p in range(n_people):
for p in range(ctx.n_people):
# TODO(Optimize): Skip if the person does not qualify for the requirement
shifts[(d, r, p)] = model.NewBoolVar(f"shift_d{d}_r{r}_p{p}")
var_name = f"shift_d{d}_r{r}_p{p}"
ctx.model_vars[var_name] = ctx.shifts[(d, r, p)] = ctx.model.NewBoolVar(var_name)

logging.info("Creating maps for faster lookup...")
map_dr_p = {
(d, r): {p for p in range(n_people) if (d, r, p) in shifts}
for (d, r) in itertools.product(range(n_days), range(n_requirements))
ctx.map_dr_p = {
(d, r): {p for p in range(ctx.n_people) if (d, r, p) in ctx.shifts}
for (d, r) in itertools.product(range(ctx.n_days), range(ctx.n_requirements))
}
map_dp_r = {
(d, p): {r for r in range(n_requirements) if (d, r, p) in shifts}
for (d, p) in itertools.product(range(n_days), range(n_people))
ctx.map_dp_r = {
(d, p): {r for r in range(ctx.n_requirements) if (d, r, p) in ctx.shifts}
for (d, p) in itertools.product(range(ctx.n_days), range(ctx.n_people))
}
map_d_rp = {
d: {(r, p) for (r, p) in itertools.product(range(n_requirements), range(n_people)) if (d, r, p) in shifts}
for d in range(n_days)
ctx.map_d_rp = {
d: {(r, p) for (r, p) in itertools.product(range(ctx.n_requirements), range(ctx.n_people)) if (d, r, p) in ctx.shifts}
for d in range(ctx.n_days)
}
map_r_dp = {
r: {(d, p) for (d, p) in itertools.product(range(n_days), range(n_people)) if (d, r, p) in shifts}
for r in range(n_requirements)
ctx.map_r_dp = {
r: {(d, p) for (d, p) in itertools.product(range(ctx.n_days), range(ctx.n_people)) if (d, r, p) in ctx.shifts}
for r in range(ctx.n_requirements)
}
map_p_dr = {
p: {(d, r) for (d, r) in itertools.product(range(n_days), range(n_requirements)) if (d, r, p) in shifts}
for p in range(n_people)
ctx.map_p_dr = {
p: {(d, r) for (d, r) in itertools.product(range(ctx.n_days), range(ctx.n_requirements)) if (d, r, p) in ctx.shifts}
for p in range(ctx.n_people)
}

ctx = Context()
for k in vars(ctx):
setattr(ctx, k, locals()[k])
ctx.objective = 0

logging.info("Adding preferences (including constraints)...")
# TODO: Check no duplicated preferences
# TODO: Check no overlapping preferences
# TODO: Check all required preferences are present
for preference in preferences:
preference_types.PREFERENCE_TYPES_TO_FUNC[preference.type](ctx, preference.args)
for i, preference in enumerate(ctx.preferences):
preference_types.PREFERENCE_TYPES_TO_FUNC[preference.type](ctx, preference.args, i)

# Define objective (i.e., soft constraints)
print(ctx.objective)
ctx.model.Maximize(ctx.objective)

logging.info("Initializing solver...")
solver = cp_model.CpSolver()
Expand All @@ -95,7 +100,7 @@ def on_solution_callback(self):
solution_printer = PartialSolutionPrinter()

logging.info("Solving and showing partial results...")
status = solver.Solve(model, solution_printer)
status = solver.Solve(ctx.model, solution_printer)

logging.info(f"Status: {solver.StatusName(status)}")

Expand All @@ -110,7 +115,7 @@ def on_solution_callback(self):
elif status == cp_model.MODEL_INVALID:
logging.info("Model invalid!")
logging.info("Validation Info:")
logging.info(model.Validate())
logging.info(ctx.model.Validate())
else:
logging.info("No solution found!")

Expand All @@ -119,13 +124,17 @@ def on_solution_callback(self):
logging.info(f" - branches : {solver.NumBranches()}")
logging.info(f" - wall time: {solver.WallTime()}s")

logging.info("Variables:")
for k, v in ctx.model_vars.items():
logging.info(f" - {k}: {solver.Value(v)}")

logging.info(f"Done.")

if not found:
return None

df = export.get_people_versus_date_dataframe(
dates, people, requirements,
shifts, solver,
ctx.dates, ctx.people, ctx.requirements,
ctx.shifts, solver,
)
return df
13 changes: 13 additions & 0 deletions core/tests/test_or_tools_example_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import nurse_scheduling

def test_example_1():
filepath = "tests/testcases/or_tools_example_1.yaml"
df = nurse_scheduling.schedule(filepath, validate=False, deterministic=True)
assert df.values.tolist() == [
['', 18, 19, 20],
['', 'Fri', 'Sat', 'Sun'],
['Nurse 0', '', 'N', 'D'],
['Nurse 1', 'E', 'E', 'E'],
['Nurse 2', 'N', '', 'N'],
['Nurse 3', 'D', 'D', '']
]
27 changes: 27 additions & 0 deletions core/tests/testcases/or_tools_example_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
apiVersion: alpha
description: OR-Tools Example 1. From <https://developers.google.com/optimization/scheduling/employee_scheduling>.
startdate: 2023-08-18
enddate: 2023-08-20
people:
- id: 0
description: Nurse 0
- id: 1
description: Nurse 1
- id: 2
description: Nurse 2
- id: 3
description: Nurse 3
requirements:
- id: D
description: Day shift requirement
required_people: 1
- id: E
description: Evening shift requirement
required_people: 1
- id: N
description: Night shift requirement
required_people: 1
preferences:
- type: all requirements fulfilled
- type: all people work at most one shift per day
- type: assign shifts evenly

0 comments on commit 1481e61

Please sign in to comment.