Skip to content

Commit

Permalink
Merge pull request #206 from pph-collective/scott-ssp-rr
Browse files Browse the repository at this point in the history
Scott ssp rr
  • Loading branch information
s-bessey authored Dec 20, 2021
2 parents 55ab07c + 58773f7 commit 2b29c86
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 124 deletions.
2 changes: 2 additions & 0 deletions settings/scott/model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ syringe_services:
num_slots_start: 0
num_slots_stop: 237
risk: 0.02
dx_scalar: 1.195
ssp_on:
start_time: 4
stop_time: 121
num_slots_start: 237
num_slots_stop: 237
risk: 0.02
dx_scalar: 1.195

agent_zero:
num_partners: 4
Expand Down
1 change: 1 addition & 0 deletions tests/params/basic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ syringe_services:
num_slots_start: 100
num_slots_stop: 100
risk: 0.02
dx_scalar: 1.0

agent_zero:
bond_type: Inj
Expand Down
1 change: 1 addition & 0 deletions tests/params/simple_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ syringe_services:
num_slots_start: 500
num_slots_stop: 500
risk: 0.00
dx_scalar: 1.0

partner_tracing:
prob: 1
Expand Down
9 changes: 5 additions & 4 deletions titan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self.new_prep = AgentSet("new_prep")

self.ssp_enrolled_risk = 0.0
self.ssp_dx = 1.0

self.time = -1 * self.params.model.time.burn_steps # burn is negative time
self.id = nanoid.generate(size=8)
Expand Down Expand Up @@ -113,10 +114,6 @@ def print_stats(self, stat: Dict[str, Dict[str, int]], outdir: str):
), "Graph must be enabled to print network reports"

network_outdir = os.path.join(outdir, "network")
if self.params.outputs.network.draw_figures:
self.network_utils.visualize_network(
network_outdir, curtime=self.time, label=f"{self.id}"
)

if self.params.outputs.network.calc_component_stats:
ao.print_components(
Expand Down Expand Up @@ -899,6 +896,7 @@ def update_syringe_services(self):
for item in self.params.syringe_services.timeline.values():
if item.start_time <= self.time < item.stop_time:
self.ssp_enrolled_risk = item.risk
self.ssp_dx = item.dx_scalar

ssp_num_slots = (item.num_slots_stop - item.num_slots_start) / (
item.stop_time - item.start_time
Expand Down Expand Up @@ -1102,6 +1100,9 @@ def diagnose(
sex_type
].hiv.dx.prob

if agent.ssp:
test_prob *= self.ssp_dx

# Rescale based on calibration param
test_prob *= self.calibration.test_frequency

Expand Down
120 changes: 0 additions & 120 deletions titan/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

import networkx as nx # type: ignore
from networkx.drawing.nx_agraph import graphviz_layout # type: ignore
import matplotlib.pyplot as plt # type: ignore
import matplotlib.patches as patches # type: ignore


class NetworkGraphUtils:
Expand Down Expand Up @@ -163,121 +161,3 @@ def get_network_color(self, coloring) -> List[str]:
)

return node_color

def visualize_network(
self,
outdir: str,
coloring: str = "sex_type",
pos=None,
return_layout: bool = False,
node_size: Optional[float] = None,
curtime: int = 0,
infection_label: int = 0,
label: str = "Network",
):
"""
Visualize the network using the spring layout (default).
args:
outdir: directory the figure should be saved to
coloring: what attribute to color the nodes by
pos: a graphviz_layout
return_layout: whether to return the layout (if `False`, nothing is returned)
node_size: size of the nodes in the graph
curtime: the current timestep of the model
infection_label: number of infections to list in figure's label
label: identifier for this network
"""
if node_size is None:
node_size = 5000.0 / self.G.number_of_nodes()

print(("\tPlotting {} colored by {}...").format(label, coloring))
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
fig.clf()

# build a rectangle in axes coords
left, width = 0.0, 1.0
bottom, height = 0.0, 1.0
right = left + width
top = bottom + height

fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])

# axes coordinates are 0,0 is bottom left and 1,1 is upper right
p = patches.Rectangle(
(left, bottom),
width,
height,
fill=False,
transform=ax.transAxes,
clip_on=False,
)

ax.add_patch(p)

if not pos:
pos = graphviz_layout(self.G, prog="neato", args="")

edge_color = "k"
node_shape = "o"

# node color to by type
node_color = self.get_network_color(coloring)

# node size indicating node degree
NodeSize = []
if node_size:
for v in self.G:
NodeSize.append(node_size)
else:
for v in self.G:
NodeSize.append((10 * self.G.degree(v)) ** (1.0))

# draw:
nx.draw(
self.G,
pos,
node_size=NodeSize,
node_color=node_color,
node_shape=node_shape,
edge_color=edge_color,
with_labels=False,
linewidths=0.5,
width=0.5,
)

textstr = "\n".join(
(
r"N infection={:.2f}".format(
infection_label,
),
r"Time={:.2f}".format(
curtime,
),
)
)

# these are matplotlib.patch.Patch properties
props = dict(boxstyle="round", facecolor="wheat", alpha=0.9)

# place a text box in upper right in axes coords
ax.text(
right - 0.025,
top - 0.025,
textstr,
horizontalalignment="right",
verticalalignment="top",
transform=ax.transAxes,
bbox=props,
)

filename = os.path.join(
outdir, f"{label}_{self.G.number_of_nodes()}_{coloring}_{curtime}.png"
)

fig.savefig(filename)

if return_layout:
return pos
5 changes: 5 additions & 0 deletions titan/params/syringe_services.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ syringe_services:
description: "Risk of unsafe sharing for agents enrolled in the SSP"
min: 0.0
max: 1.0
dx_scalar:
type: float
description: "Diagnosis scalar for HIV+ agents enrolled in ssp"
min: 0.0
default:
ssp_default:
start_time: 1
stop_time: 2
num_slots_start: 0
num_slots_stop: 0
risk: 0.02
dx_scalar: 1.0

0 comments on commit 2b29c86

Please sign in to comment.