Skip to content

Commit

Permalink
Commitin
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 4, 2023
1 parent b2d1b2c commit f3a809f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
52 changes: 52 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Benchmark Adaptive Localization
on:
push:
branches:
- main

permissions:
# deployments permission to deploy GitHub pages website
deployments: write
# contents permission to update benchmark contents in gh-pages branch
contents: write

jobs:
benchmark:
name: Run pytest-benchmark benchmark example
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
lfs: true

- uses: actions/setup-python@v4
id: setup_python
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: |
setup.py
pyproject.toml
- name: Get wheels
uses: actions/download-artifact@v3

- name: Install wheel
run: |
find . -name "*.whl" -exec pip install "{}[dev]" \;
- name: Run benchmark
run: |
pytest tests/unit_tests/analysis/test_adaptive_localization.py::test_benchmark --benchmark-json output.json
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
name: Python Benchmark with pytest-benchmark
tool: 'pytest'
output-file-path: output.json
github-token: ${{ secrets.GITHUB_TOKEN }}
auto-push: true
32 changes: 31 additions & 1 deletion tests/unit_tests/analysis/test_adaptive_localization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import uuid
from argparse import ArgumentParser
from functools import partial
from textwrap import dedent

import numpy as np
Expand All @@ -16,7 +18,7 @@
def run_cli_ES_with_case(poly_config):
config_name = poly_config.split(".")[0]
prior_sample_name = "prior_sample" + "_" + config_name
posterior_sample_name = "posterior_sample" + "_" + config_name
posterior_sample_name = str(uuid.uuid1())
parser = ArgumentParser(prog="test_main")
parsed = ert_parser(
parser,
Expand Down Expand Up @@ -101,6 +103,34 @@ def test_that_adaptive_localization_with_cutoff_0_equals_ESupdate(copy_poly_case
assert np.allclose(posterior_sample_loc0, posterior_sample_noloc)


def test_benchmark(copy_poly_case, benchmark):
# rng = np.random.default_rng(42)
# cutoff1 = rng.uniform(0, 1)
cutoff1 = 0.5

set_adaptive_localization_cutoff1 = dedent(
f"""
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD {cutoff1}
"""
)

with open("poly.ert", "r+", encoding="utf-8") as f:
lines = f.readlines()
for i, line in enumerate(lines):
if "NUM_REALIZATIONS 100" in line:
lines[i] = "NUM_REALIZATIONS 200\n"
break
lines.insert(2, random_seed_line)
lines.insert(9, set_adaptive_localization_cutoff1)

with open("poly_localization_cutoff1.ert", "w", encoding="utf-8") as f:
f.writelines(lines)

run_with_cutoff1 = partial(run_cli_ES_with_case, "poly_localization_cutoff1.ert")
benchmark(run_with_cutoff1)


@pytest.mark.integration_test
def test_that_posterior_generalized_variance_increases_in_cutoff(copy_poly_case):
rng = np.random.default_rng(42)
Expand Down

0 comments on commit f3a809f

Please sign in to comment.