Skip to content

Commit

Permalink
Merge pull request #1713 from PolicyEngine/nikhilwoodruff/issue1712
Browse files Browse the repository at this point in the history
Improve policy reproducibility block
  • Loading branch information
anth-volk authored May 14, 2024
2 parents 0d79d35 + cc549fa commit 4b234ca
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 155 deletions.
118 changes: 54 additions & 64 deletions src/__tests__/data/reformDefinitionCode.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,17 @@ import {
} from "../__setup__/sampleData";

let metadataUS = null;
let metadataUK = null;

beforeAll(async () => {
const res = await fetch("https://api.policyengine.org/us/metadata");
const metadataRaw = await res.json();
metadataUS = metadataRaw.result;
});

const numberedPolicyUS = {
baseline: {
data: {},
label: "Current law",
id: 2,
},
reform: {
data: {
"sample.reform.item.2": {
"2020.01.01": 15,
"2022.01.01": 20,
},
},
label: "Sample reform",
id: 0,
},
};
const resUK = await fetch("https://api.policyengine.org/us/metadata");
const metadataRawUK = await resUK.json();
metadataUK = metadataRawUK.result;
});

describe("Test getReproducibilityCodeBlock", () => {
test("Properly outputs array of values from functions it calls", () => {
Expand All @@ -66,71 +53,57 @@ describe("Test getHeaderCode", () => {
test("Properly format household with reform", () => {
const output = getHeaderCode("household", metadataUS, reformPolicyUS);
expect(output).toBeInstanceOf(Array);
expect(output.length).toBe(3);
expect(output.length).toBe(2);
});

test("Properly format standard policy", () => {
const output = getHeaderCode("policy", metadataUS, reformPolicyUS);
expect(output).toBeInstanceOf(Array);
expect(output.length).toBe(3);
expect(output.length).toBe(2);
});
});

describe("Test getBaselineCode", () => {
test("Output nothing for household type", () => {
const output = getBaselineCode("household", baselinePolicyUS, "us");
const output = getBaselineCode(baselinePolicyUS, metadataUS);
expect(output).toBeInstanceOf(Array);
expect(output.length).toBe(0);
});
test("Output nothing for non-US", () => {
const output = getBaselineCode("policy", baselinePolicyUK, "uk");
test("Output nothing for policies with current-law baseline", () => {
const output = getBaselineCode(baselinePolicyUK, metadataUK);
expect(output).toBeInstanceOf(Array);
expect(output.length).toBe(0);
});
test("Output baseline override for US policies", () => {
const output = getBaselineCode("policy", reformPolicyUS, "us");
test("Output baseline for policies with stated baseline", () => {
let testPolicy = JSON.parse(JSON.stringify(baselinePolicyUK));
testPolicy = {
...testPolicy,
baseline: {
data: {
"sample.reform.item": {
"2020.01.01": true,
"2022.01.01": true,
},
},
label: "dworkin",
id: 1,
},
};
const output = getBaselineCode(testPolicy, metadataUK);
expect(output).toBeInstanceOf(Array);
expect(output).toContain(
" parameters.simulation.reported_state_income_tax.update(",
);
expect(output.length).toBe(7);
});
});
describe("Test getReformCode", () => {
test("Output nothing if there's no reform", () => {
const output = getReformCode("household", baselinePolicyUS, "us");
const output = getReformCode(baselinePolicyUS, metadataUS);
expect(output).toBeInstanceOf(Array);
expect(output.length).toBe(0);
});
test("Ensure normal output for non-US policy", () => {
const output = getReformCode("policy", reformPolicyUK, "uk");
expect(output).toBeInstanceOf(Array);

const paramAccessor = Object.keys(reformPolicyUK.reform.data)[0];
expect(output).toContain(` parameters.${paramAccessor}.update(`);
expect(output).toContain(` value=True)`);
});
test("Ensure addition of use_reported_state_income_tax for US policy", () => {
const output = getReformCode("policy", reformPolicyUS, "us");
expect(output).toBeInstanceOf(Array);
expect(output).toContain(
" self.modify_parameters(use_reported_state_income_tax)",
);
const paramName = Object.keys(reformPolicyUS.reform.data)[0];
const paramAccessor = `parameters.${paramName}`;
expect(output).toContain(` ${paramAccessor}.update(`);
expect(output).toContain(` value=False)`);
});
test("Ensure proper formatting for policies with numbers", () => {
const output = getReformCode("policy", numberedPolicyUS, "us");
const output = getReformCode(reformPolicyUK, metadataUK);
expect(output).toBeInstanceOf(Array);
const paramName = Object.keys(numberedPolicyUS.reform.data)[0];
let nameParts = paramName.split(".");
let numPart = nameParts[nameParts.length - 1];
numPart = `children["${numPart}"]`;
nameParts[nameParts.length - 1] = numPart;
const sanitizedName = nameParts.join(".");

expect(output).toContain(` parameters.${sanitizedName}.update(`);
expect(output.length).toBe(7);
});
});
describe("Test getSituationCode", () => {
Expand Down Expand Up @@ -250,18 +223,35 @@ describe("Test getImplementationCode", () => {
expect(output).toBeInstanceOf(Array);
expect(output.length).toBe(0);
});
test("If not US, return lines without state tax overrides", () => {
const output = getImplementationCode("policy", "uk", 2024);
test("If policy, return lines", () => {
const output = getImplementationCode(
"policy",
"uk",
2024,
baselinePolicyUK,
);
expect(output).toBeInstanceOf(Array);
expect(output).not.toContain(
"baseline = Microsimulation(reform=baseline_reform)",
);
});
test("If US, return lines with state tax overrides", () => {
const output = getImplementationCode("policy", "us", 2024);
test("If set baseline, return lines with baseline", () => {
let testPolicy = JSON.parse(JSON.stringify(baselinePolicyUK));
testPolicy = {
...testPolicy,
baseline: {
data: {
"sample.reform.item": {
"2020.01.01": true,
"2022.01.01": true,
},
},
label: "dworkin",
id: 1,
},
};
const output = getImplementationCode("policy", "uk", 2024, testPolicy);
expect(output).toBeInstanceOf(Array);
expect(output).toContain(
"baseline = Microsimulation(reform=baseline_reform)",
);
expect(output).toContain("baseline = Microsimulation(reform=baseline)");
});
});
134 changes: 43 additions & 91 deletions src/data/reformDefinitionCode.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { optimiseHousehold } from "../api/variables";
import { defaultYear } from "./constants";

const US_REGIONS = ["us", "enhanced_us"];

export function getReproducibilityCodeBlock(
type,
metadata,
Expand All @@ -17,8 +15,8 @@ export function getReproducibilityCodeBlock(

return [
...getHeaderCode(type, metadata, policy),
...getBaselineCode(type, policy, region),
...getReformCode(type, policy, region),
...getBaselineCode(policy, metadata),
...getReformCode(policy, metadata),
...getSituationCode(
type,
metadata,
Expand All @@ -27,7 +25,7 @@ export function getReproducibilityCodeBlock(
householdInput,
earningVariation,
),
...getImplementationCode(type, region, year),
...getImplementationCode(type, region, year, policy),
];
}

Expand All @@ -43,89 +41,36 @@ export function getHeaderCode(type, metadata, policy) {

// If there is a reform, add the following Python imports
if (Object.keys(policy.reform.data).length > 0) {
lines.push(
"from policyengine_core.reforms import Reform",
"from policyengine_core.periods import instant",
);
lines.push("from policyengine_core.reforms import Reform");
}

return lines;
}

export function getBaselineCode(type, policy, region) {
// Disregard baseline code for household code
// or non-US locales
if (type === "household" || !US_REGIONS.includes(region)) {
export function getBaselineCode(policy, metadata) {
if (
!policy?.baseline?.data ||
Object.keys(policy.baseline.data).length === 0
) {
return [];
}

// Calculate the earliest start date and latest end date for
// the policies included in the simulation
const { earliestStart, latestEnd } = getStartEndDates(policy);

return [
"",
"",
`"""`,
"In US nationwide simulations,",
"use reported state income tax liabilities",
`"""`,
"def use_reported_state_income_tax(parameters):",
" parameters.simulation.reported_state_income_tax.update(",
` start=instant("${earliestStart}"), stop=instant("${latestEnd}"),`,
" value=True)",
" return parameters",
"",
"",
"class baseline_reform(Reform):",
" def apply(self):",
" self.modify_parameters(use_reported_state_income_tax)",
];
let json_str = JSON.stringify(policy.baseline.data, null, 2);
let lines = [""].concat(json_str.split("\n"));
lines[1] = "baseline = Reform.from_dict({" + lines[0];
lines[lines.length - 1] =
lines[lines.length - 1] + ', country_id="' + metadata.countryId + '")';
return lines;
}

export function getReformCode(type, policy, region) {
// Return no reform code for households or for policies
// without reform parameters
if (Object.keys(policy.reform.data).length <= 0) {
export function getReformCode(policy, metadata) {
if (!policy?.baseline?.data || Object.keys(policy.reform.data).length === 0) {
return [];
}

let lines = ["", "", "def reform_parameters(parameters):"];

for (let [parameterName, parameter] of Object.entries(policy.reform.data)) {
for (let [instant, value] of Object.entries(parameter)) {
const [start, end] = instant.split(".");
if (value === false) {
value = "False";
} else if (value === true) {
value = "True";
}
// If param name contains number, transform into valid Python
if (doesParamNameContainNumber(parameterName)) {
parameterName = transformNumberedParamName(parameterName);
}
lines.push(
` parameters.${parameterName}.update(`,
` start=instant("${start}"), stop=instant("${end}"),`,
` value=${value})`,
);
}
}
lines.push(" return parameters");

lines = lines.concat([
"",
"",
"class reform(Reform):",
" def apply(self):",
" self.modify_parameters(reform_parameters)",
]);

// For US reforms, when calculated society-wide, add reported state income tax
if (type === "policy" && US_REGIONS.includes(region)) {
lines.push(" self.modify_parameters(use_reported_state_income_tax)");
}

let json_str = JSON.stringify(policy.reform.data, null, 2);
let lines = [""].concat(json_str.split("\n"));
lines[1] = "reform = Reform.from_dict({" + lines[0];
lines[lines.length - 1] =
lines[lines.length - 1] + ', country_id="' + metadata.countryId + '")';
return lines;
}

Expand Down Expand Up @@ -200,32 +145,39 @@ export function getSituationCode(
return lines;
}

export function getImplementationCode(type, region, timePeriod) {
export function getImplementationCode(type, region, timePeriod, policy) {
if (type !== "policy") {
return [];
}

const isCountryUS = US_REGIONS.includes(region);
const hasBaseline = Object.keys(policy?.baseline?.data).length > 0;
const hasReform = Object.keys(policy?.reform?.data).length > 0;
const hasDatasetSpecified = region === "enhanced_us";
const dataset = hasDatasetSpecified ? 'dataset="enhanced_cps_2022"' : "";

return [
"",
"",
`baseline = Microsimulation(${
isCountryUS
? region === "enhanced_us"
? `reform=baseline_reform, dataset="enhanced_cps_2022"`
: `reform=baseline_reform`
: ""
hasDatasetSpecified && hasBaseline
? `reform=baseline, dataset=${dataset}`
: hasBaseline
? `reform=baseline`
: hasDatasetSpecified
? `dataset=${dataset}`
: ""
})`,
`reformed = Microsimulation(${
region === "enhanced_us"
? `reform=reform, dataset="enhanced_cps_2022"`
: `reform=reform`
hasDatasetSpecified && hasReform
? `reform=reform, dataset=${dataset}`
: hasReform
? `reform=reform`
: hasDatasetSpecified
? `dataset=${dataset}`
: ""
})`,
`baseline_person = baseline.calc("household_net_income",
period=${timePeriod || defaultYear}, map_to="person")`,
`reformed_person = reformed.calc("household_net_income",
period=${timePeriod || defaultYear}, map_to="person")`,
`baseline_person = baseline.calculate("household_net_income", period=${timePeriod || defaultYear}, map_to="person")`,
`reformed_person = reformed.calculate("household_net_income", period=${timePeriod || defaultYear}, map_to="person")`,
"difference_person = reformed_person - baseline_person",
];
}
Expand Down

0 comments on commit 4b234ca

Please sign in to comment.