Skip to content

Commit

Permalink
Follow on to #86 (#90)
Browse files Browse the repository at this point in the history
* Follow on to #86

* correct input files

* add back the input processing code for air travel

* Fixed assignment of teachers to always give the same results for the same seed when agent.fast=false

* Fixed bug in age selection for workers.
Use round for calculating teacher counts

* cast rounded values to int

---------

Co-authored-by: Tan Nguyen <[email protected]>
Co-authored-by: shofmeyr <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent f91526e commit d2aa681
Show file tree
Hide file tree
Showing 16 changed files with 153 additions and 101 deletions.
7 changes: 5 additions & 2 deletions examples/inputs.bay
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
agent.ic_type = "census"
agent.census_filename = "../../data/CensusData/BayArea.dat"
agent.workerflow_filename = "../../data/CensusData/BayArea-wf.bin"
agent.case_filename = "../../data/CaseData/July4.cases"

disease.initial_case_type = "random"
disease.num_initial_cases = 5

agent.nsteps = 10
agent.plot_int = -1
Expand All @@ -16,8 +18,9 @@ contact.pWO = 0.5
contact.pFA = 1.0
contact.pBAR = -1.

disease.case_filename = "../../data/CaseData/July4.cases"
disease.nstrain = 2
disease.p_trans = 0.20 0.30
disease.p_asymp = 0.40 0.40
disease.reduced_inf = 0.75 0.75
disease.reinfect_prob = 0.0
disease.reinfect_prob = 0.0
6 changes: 3 additions & 3 deletions examples/inputs.ca
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
agent.ic_type = "census"
agent.census_filename = "../../data/CensusData/CA.dat"
agent.workerflow_filename = "../../data/CensusData/CA-wf.bin"
agent.initial_case_type = "file"
agent.case_filename = "../../data/CaseData/July4.cases"
agent.airports_filename="../../data/CensusData/CA_airports.dat"
agent.air_traffic_filename= "../../data/CA_CY23AirTraffic.dat"

Expand All @@ -21,9 +19,11 @@ contact.pWO = 0.5
contact.pFA = 1.0
contact.pBAR = -1.

disease.initial_case_type = "file"
disease.case_filename = "../../data/CaseData/July4.cases"
disease.nstrain = 1
disease.p_trans = 0.20
disease.p_asymp = 0.40
disease.reduced_inf = 0.75

disease.incubation_length_mean = 3.0
#disease.incubation_length_mean = 3.0
9 changes: 4 additions & 5 deletions examples/inputs.ma
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ agent.ic_type = "census"
agent.census_filename = "../../data/CensusData/MA.dat"
agent.workerflow_filename = "../../data/CensusData/MA-wf.dat"

agent.initial_case_type = "file"
agent.case_filename = "../../data/CaseData/July4.cases"

#agent.initial_case_type = "random"
#agent.num_initial_cases = 5

agent.nsteps = 120
agent.plot_int = 10
Expand All @@ -28,6 +23,10 @@ contact.pWO = 0.5
contact.pFA = 1.0
contact.pBAR = -1.

#disease.initial_case_type = "random"
#disease.num_initial_cases = 5
disease.initial_case_type = "file"
disease.case_filename = "../../data/CaseData/July4.cases"
disease.nstrain = 1
disease.p_trans = 0.20
disease.p_asymp = 0.40
Expand Down
5 changes: 3 additions & 2 deletions examples/inputs.small_cases
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
agent.ic_type = "census"
agent.census_filename = "../../data/CensusData/CA.dat"
agent.workerflow_filename = "../../data/CensusData/CA-wf.bin"
agent.case_filename = "../../data/CaseData/small.cases"

agent.nsteps = 60
agent.plot_int = 10
Expand All @@ -15,7 +14,9 @@ contact.pWO = 0.5
contact.pFA = 1.0
contact.pBAR = -1.

disease.initial_case_type = "file"
disease.case_filename = "../../data/CaseData/small.cases"
disease.nstrain = 1
disease.p_trans = 0.1
disease.p_asymp = 0.4
disease.reduced_inf = 0.75
disease.reduced_inf = 0.75
6 changes: 3 additions & 3 deletions examples/inputs_random_ca
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
agent.ic_type = "census"
agent.census_filename = "../../data/CensusData/CA.dat"
agent.workerflow_filename = "../../data/CensusData/CA-wf.bin"
agent.initial_case_type = "random"
agent.num_initial_cases = 5

agent.nsteps = 180
agent.plot_int = 10
Expand All @@ -18,9 +16,11 @@ contact.pWO = 0.5
contact.pFA = 1.0
contact.pBAR = -1.

disease.initial_case_type = "random"
disease.num_initial_cases = 5
disease.nstrain = 1
disease.p_trans = 0.20
disease.p_asymp = 0.40
disease.reduced_inf = 0.75

disease.incubation_length_mean = 3.0
disease.incubation_length_mean = 3.0
2 changes: 1 addition & 1 deletion src/AgentContainer.H
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public:

amrex::iMultiFab m_student_counts;
/// Used only for Census data. A ratio for each school type: none, college, high, middle, elem, daycare
int m_student_teacher_ratio[SchoolType::total] = {0, 15, 15, 15, 15, 15};
amrex::GpuArray<int, SchoolType::total> m_student_teacher_ratio = {0, 15, 15, 15, 15, 15};

int m_num_diseases; /*!< Number of diseases */

Expand Down
9 changes: 8 additions & 1 deletion src/AgentContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,15 @@ AgentContainer::AgentContainer (const amrex::Geometry & a_geom, /*!<
amrex::ParmParse pp("agent");
pp.query("shelter_compliance", m_shelter_compliance);
pp.query("symptomatic_withdraw_compliance", m_symptomatic_withdraw_compliance);
queryArray(pp, "student_teacher_ratio", m_student_teacher_ratio, SchoolType::total);
int stratio[SchoolType::total];
for (unsigned int i = 0; i < SchoolType::total; i++) {
stratio[i] = m_student_teacher_ratio[i];
}

queryArray(pp, "student_teacher_ratio", stratio, SchoolType::total);
for (unsigned int i = 0; i < SchoolType::total; ++i) {
m_student_teacher_ratio[i] = stratio[i];
}
}

{
Expand Down
177 changes: 105 additions & 72 deletions src/CensusData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ void CensusData::read_workerflow (AgentContainer& pc, /*!< Agent conta

int age_group = age_group_ptr[ip];
/* Check working-age population */
if ((age_group == AgeGroups::a18to29) || (age_group == AgeGroups::a30to49)) {
if (age_group >= AgeGroups::a18to29 && age_group <= AgeGroups::a50to64) {
unsigned int irnd = Random_int(nwork, engine);
int to = 0;
int comm_to = 0;
Expand Down Expand Up @@ -683,18 +683,20 @@ void CensusData::read_workerflow (AgentContainer& pc, /*!< Agent conta
assignTeachersAndWorkgroup(pc, workgroup_size);
}


void CensusData::assignTeachersAndWorkgroup (AgentContainer& pc, const int workgroup_size) {
const Box& domain = pc.Geom(0).Domain();

auto Ncommunity = demo.Ncommunity;
const int num_school_types = SchoolCensusIDType::total - 1;
Gpu::DeviceVector<int> teacher_counts_array[num_school_types];
int* teacher_counts_ptr[num_school_types];
for (int i = 0; i < num_school_types; i++) {
teacher_counts_array[i].resize(Ncommunity, 0);
teacher_counts_ptr[i] = teacher_counts_array[i].data();
}

Gpu::HostVector<int> high_teachers_array(Ncommunity, 0);
Gpu::HostVector<int> middle_teachers_array(Ncommunity, 0);
Gpu::HostVector<int> elem3_teachers_array(Ncommunity, 0);
Gpu::HostVector<int> elem4_teachers_array(Ncommunity, 0);
Gpu::HostVector<int> daycare_teachers_array(Ncommunity, 0);
auto high_teachers_ptr = high_teachers_array.data();
auto middle_teachers_ptr = middle_teachers_array.data();
auto elem3_teachers_ptr = elem3_teachers_array.data();
auto elem4_teachers_ptr = elem4_teachers_array.data();
auto daycare_teachers_ptr = daycare_teachers_array.data();
auto student_teacher_ratio = pc.m_student_teacher_ratio;

#ifdef AMREX_USE_OMP
Expand All @@ -706,16 +708,21 @@ void CensusData::assignTeachersAndWorkgroup (AgentContainer& pc, const int workg
auto bx = mfi.tilebox();
ParallelFor(bx, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept {
int comm = comm_arr(i, j, k);
for (int s = 0; s < num_school_types; s++) {
teacher_counts_ptr[s][comm] = student_counts_arr(i, j, k, s);
}
teacher_counts_ptr[SchoolCensusIDType::high_1 - 1][comm] /= student_teacher_ratio[SchoolType::high];
teacher_counts_ptr[SchoolCensusIDType::middle_2 - 1][comm] /= student_teacher_ratio[SchoolType::middle];
teacher_counts_ptr[SchoolCensusIDType::elem_3 - 1][comm] /= student_teacher_ratio[SchoolType::elem];
teacher_counts_ptr[SchoolCensusIDType::elem_4 - 1][comm] /= student_teacher_ratio[SchoolType::elem];
teacher_counts_ptr[SchoolCensusIDType::daycare_5 - 1][comm] /= student_teacher_ratio[SchoolType::daycare];
if (comm >= Ncommunity || comm < 0) return;
high_teachers_ptr[comm] = int(std::round(
double(student_counts_arr(i, j, k, SchoolCensusIDType::high_1 - 1)) / student_teacher_ratio[SchoolType::high]));
middle_teachers_ptr[comm] = int(std::round(
(double)student_counts_arr(i, j, k, SchoolCensusIDType::middle_2 - 1) / student_teacher_ratio[SchoolType::middle]));
elem3_teachers_ptr[comm] = int(std::round(
(double)student_counts_arr(i, j, k, SchoolCensusIDType::elem_3 - 1) / student_teacher_ratio[SchoolType::elem]));
elem4_teachers_ptr[comm] = int(std::round(
(double)student_counts_arr(i, j, k, SchoolCensusIDType::elem_4 - 1) / student_teacher_ratio[SchoolType::elem]));
daycare_teachers_ptr[comm] = int(std::round(
(double)student_counts_arr(i, j, k, SchoolCensusIDType::daycare_5 - 1) / student_teacher_ratio[SchoolType::daycare]));
});
Gpu::synchronize();
}

#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
Expand All @@ -725,63 +732,89 @@ void CensusData::assignTeachersAndWorkgroup (AgentContainer& pc, const int workg

auto np = soa.numParticles();

auto age_group_ptr = soa.GetIntData(IntIdx::age_group).data();
auto workgroup_ptr = soa.GetIntData(IntIdx::workgroup).data();
auto work_i_ptr = soa.GetIntData(IntIdx::work_i).data();
auto work_j_ptr = soa.GetIntData(IntIdx::work_j).data();
auto school_grade_ptr = soa.GetIntData(IntIdx::school_grade).data();
auto school_id_ptr = soa.GetIntData(IntIdx::school_id).data();
auto work_nborhood_ptr = soa.GetIntData(IntIdx::work_nborhood).data();

ParallelForRNG (np, [=] AMREX_GPU_DEVICE (int ip, RandomEngine const& engine) noexcept {

int comm_to = (int)domain.index(IntVect(AMREX_D_DECL(work_i_ptr[ip], work_j_ptr[ip], 0)));
if (comm_to >= Ncommunity) return;
Gpu::HostVector<int> age_group_h(np);
Gpu::HostVector<int> workgroup_h(np);
Gpu::HostVector<int> work_i_h(np);
Gpu::HostVector<int> work_j_h(np);
Gpu::HostVector<int> school_grade_h(np);
Gpu::HostVector<int> school_id_h(np);
Gpu::HostVector<int> work_nborhood_h(np);

Gpu::copy(Gpu::deviceToHost, soa.GetIntData(IntIdx::age_group).begin(),
soa.GetIntData(IntIdx::age_group).end(), age_group_h.begin());
Gpu::copy(Gpu::deviceToHost, soa.GetIntData(IntIdx::workgroup).begin(),
soa.GetIntData(IntIdx::workgroup).end(), workgroup_h.begin());
Gpu::copy(Gpu::deviceToHost, soa.GetIntData(IntIdx::work_i).begin(),
soa.GetIntData(IntIdx::work_i).end(), work_i_h.begin());
Gpu::copy(Gpu::deviceToHost, soa.GetIntData(IntIdx::work_j).begin(),
soa.GetIntData(IntIdx::work_j).end(), work_j_h.begin());
Gpu::copy(Gpu::deviceToHost, soa.GetIntData(IntIdx::school_grade).begin(),
soa.GetIntData(IntIdx::school_grade).end(), school_grade_h.begin());
Gpu::copy(Gpu::deviceToHost, soa.GetIntData(IntIdx::school_id).begin(),
soa.GetIntData(IntIdx::school_id).end(), school_id_h.begin());
Gpu::copy(Gpu::deviceToHost, soa.GetIntData(IntIdx::work_nborhood).begin(),
soa.GetIntData(IntIdx::work_nborhood).end(), work_nborhood_h.begin());

auto age_group_ptr = age_group_h.data();
auto workgroup_ptr = workgroup_h.data();
auto work_i_ptr = work_i_h.data();
auto work_j_ptr = work_j_h.data();
auto school_grade_ptr = school_grade_h.data();
auto school_id_ptr = school_id_h.data();
auto work_nborhood_ptr = work_nborhood_h.data();

for (int ip = 0; ip < np; ++ip) {
int comm = (int) domain.index(IntVect(AMREX_D_DECL(work_i_ptr[ip], work_j_ptr[ip], 0)));
if (comm >= Ncommunity || comm < 0) continue;
// skip non-working age
if (age_group_ptr[ip] < AgeGroups::a18to29 || age_group_ptr[ip] > AgeGroups::a50to64) return;
if (age_group_ptr[ip] < AgeGroups::a18to29 || age_group_ptr[ip] > AgeGroups::a50to64) continue;
// skip non-workers
if (workgroup_ptr[ip] == 0) return;

int choice = Random_int(num_school_types, engine);

for (int k = 0; k < num_school_types; k++) {
int pos = (k + choice) % num_school_types;
AMREX_ALWAYS_ASSERT(pos < num_school_types);
int count = Gpu::Atomic::Add(&(teacher_counts_ptr[pos][comm_to]), -1);
if (count > 0) {
// school_id of 0 is reserved for no school
school_id_ptr[ip] = pos + 1;
// workgroup is the whole school, i.e. adults interact with all other adults in the school
workgroup_ptr[ip] = school_id_ptr[ip];
// teachers are assigned the grade they teach
switch (school_id_ptr[ip]) {
case SchoolCensusIDType::high_1:
school_grade_ptr[ip] = 12; // 10th grade - generic for high school
work_nborhood_ptr[ip] = 3; // assuming the high school is located in Neighbordhood 4
break;
case SchoolCensusIDType::middle_2:
school_grade_ptr[ip] = 9; // 7th grade - generic for middle
work_nborhood_ptr[ip] = 1; // assuming the middle school is located in Neighbordhood 2
break;
case SchoolCensusIDType::elem_3:
school_grade_ptr[ip] = 5; // 3rd grade - generic for elementary
work_nborhood_ptr[ip] = 0; // assuming the first elementary school is located in Neighbordhood 1
break;
case SchoolCensusIDType::elem_4:
school_grade_ptr[ip] = 5; // 3rd grade - generic for elementary
work_nborhood_ptr[ip] = 2; // assuming the first elementary school is located in Neighbordhood 3
break;
case SchoolCensusIDType::daycare_5:
school_grade_ptr[ip] = 0; // generic for daycare
work_nborhood_ptr[ip] = Random_int(4, engine); // randomly select nborhood
school_id_ptr[ip] += work_nborhood_ptr[ip];
break;
}
return;
if (workgroup_ptr[ip] == 0) continue;

int high_teachers = high_teachers_ptr[comm];
int middle_teachers = middle_teachers_ptr[comm];
int elem3_teachers = elem3_teachers_ptr[comm];
int elem4_teachers = elem4_teachers_ptr[comm];
int daycare_teachers = daycare_teachers_ptr[comm];
int total_teachers = high_teachers + middle_teachers + elem3_teachers + elem4_teachers + daycare_teachers;
if (total_teachers > 0) {
int choice = Random_int(total_teachers);
if (choice < high_teachers) {
school_grade_ptr[ip] = 12; // 10th grade - generic for high school
school_id_ptr[ip] = SchoolCensusIDType::high_1;
work_nborhood_ptr[ip] = 3; // assuming the high school is located in Neighbordhood 3
workgroup_ptr[ip] = 1;
high_teachers_ptr[comm]--;
} else if (choice < high_teachers + middle_teachers) {
school_grade_ptr[ip] = 9; // 7th grade - generic for middle
school_id_ptr[ip] = SchoolCensusIDType::middle_2;
work_nborhood_ptr[ip] = 1; // assuming the middle school is located in Neighbordhood 2
workgroup_ptr[ip] = 2;
middle_teachers_ptr[comm]--;
} else if (choice < high_teachers + middle_teachers + elem3_teachers) {
school_grade_ptr[ip] = 5; // 3rd grade - generic for elementary
school_id_ptr[ip] = SchoolCensusIDType::elem_3;
work_nborhood_ptr[ip] = 0; // assuming the first elementary school is located in Neighbordhood 1
workgroup_ptr[ip] = 3;
elem3_teachers_ptr[comm]--;
} else if (choice < high_teachers + middle_teachers + elem3_teachers + elem4_teachers) {
school_grade_ptr[ip] = 5; // 3rd grade - generic for elementary
school_id_ptr[ip] = SchoolCensusIDType::elem_4;
work_nborhood_ptr[ip] = 2; // assuming the first elementary school is located in Neighbordhood 3
workgroup_ptr[ip] = 4;
elem4_teachers_ptr[comm]--;
} else {
school_grade_ptr[ip] = 0; // generic for daycare
work_nborhood_ptr[ip] = Random_int(4); // randomly select nborhood
school_id_ptr[ip] = SchoolCensusIDType::daycare_5 + work_nborhood_ptr[ip];
workgroup_ptr[ip] = 5;
daycare_teachers_ptr[comm]--;
}
}
});
Gpu::synchronize();
}
Gpu::copy(Gpu::hostToDevice, school_grade_h.begin(), school_grade_h.end(), soa.GetIntData(IntIdx::school_grade).begin());
Gpu::copy(Gpu::hostToDevice, school_id_h.begin(), school_id_h.end(), soa.GetIntData(IntIdx::school_id).begin());
Gpu::copy(Gpu::hostToDevice, workgroup_h.begin(), workgroup_h.end(), soa.GetIntData(IntIdx::workgroup).begin());
Gpu::copy(Gpu::hostToDevice, work_nborhood_h.begin(), work_nborhood_h.end(), soa.GetIntData(IntIdx::work_nborhood).begin());
}
}

5 changes: 4 additions & 1 deletion src/IO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ void writePlotFile (const AgentContainer& pc, /*!< Agent (particle) container */
int_varnames.push_back ("trav_i"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("trav_j"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("nborhood"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("school"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("school_grade"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("school_id"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("school_closed"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("naics"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("workgroup"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("work_nborhood"); write_int_comp.push_back(static_cast<int>(step==0));
int_varnames.push_back ("withdrawn"); write_int_comp.push_back(1);
Expand Down
6 changes: 3 additions & 3 deletions src/InteractionModHome.H
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ void InteractionModHome<PCType, PTDType, PType>::fastInteractHome (PCType& agent
auto community = getCommunityIndex(ptd, i);
AMREX_ALWAYS_ASSERT(community <= max_communities);
int family_i = community * max_family + family_ptr[i];
Gpu::Atomic::Add(&infected_family_d_ptr[family_i], 1);
Gpu::Atomic::AddNoRet(&infected_family_d_ptr[family_i], 1);
if (!ptd.m_idata[IntIdx::withdrawn][i]) {
Gpu::Atomic::Add(&infected_family_not_withdrawn_d_ptr[family_i], 1);
Gpu::Atomic::AddNoRet(&infected_family_not_withdrawn_d_ptr[family_i], 1);
int cluster = family_ptr[i] / FAMILIES_PER_CLUSTER;
int nc = (community * max_nborhood + nborhood_ptr[i]) * num_ncs + cluster;
Gpu::Atomic::Add(&infected_nc_d_ptr[nc], 1);
Gpu::Atomic::AddNoRet(&infected_nc_d_ptr[nc], 1);
}
}
});
Expand Down
Loading

0 comments on commit d2aa681

Please sign in to comment.