From a9f524bf635a4699b345da0cff589efe5c01c7e5 Mon Sep 17 00:00:00 2001 From: Larisa Markeeva Date: Thu, 18 Jul 2024 10:16:03 -0700 Subject: [PATCH] Add `huggingface_generators` to `__init__.py`. PiperOrigin-RevId: 653667715 --- .github/workflows/pypi-publish.yml | 15 ++++++++------- clrs/__init__.py | 2 ++ clrs/_src/clrs_text/__init__.py | 1 + clrs/_src/clrs_text/huggingface_generators.py | 4 ++-- .../_src/clrs_text/huggingface_generators_test.py | 9 ++++++--- 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index 79f27c61..43eeff8b 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -2,21 +2,18 @@ name: pypi on: release: - types: [created, published] + types: [published] branches: [main, master] + jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v1 with: python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install setuptools wheel twine - name: Check consistency between the package version and release tag run: | RELEASE_VER=${GITHUB_REF#refs/*/} @@ -25,6 +22,10 @@ jobs: then echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 fi + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} diff --git a/clrs/__init__.py b/clrs/__init__.py index 088eadd3..0da217a9 100644 --- a/clrs/__init__.py +++ b/clrs/__init__.py @@ -21,6 +21,7 @@ from clrs._src import clrs_text from clrs._src import decoders from clrs._src import processors +from clrs._src import specs from clrs._src.dataset import chunkify from clrs._src.dataset import CLRSDataset @@ -77,6 +78,7 @@ "process_permutations", "process_pred_as_input", "process_random_pos", + "specs", "evaluate", "evaluate_hints", "Features", diff --git a/clrs/_src/clrs_text/__init__.py b/clrs/_src/clrs_text/__init__.py index 6ab5403f..fe0b8819 100644 --- a/clrs/_src/clrs_text/__init__.py +++ b/clrs/_src/clrs_text/__init__.py @@ -15,3 +15,4 @@ """The CLRS Text Algorithmic Reasoning Benchmark.""" from clrs._src.clrs_text import clrs_utils +from clrs._src.clrs_text import huggingface_generators diff --git a/clrs/_src/clrs_text/huggingface_generators.py b/clrs/_src/clrs_text/huggingface_generators.py index fb41f178..14937cf3 100644 --- a/clrs/_src/clrs_text/huggingface_generators.py +++ b/clrs/_src/clrs_text/huggingface_generators.py @@ -17,7 +17,7 @@ import random from typing import Dict, List, Optional -import clrs +from clrs._src import samplers from clrs._src.clrs_text import clrs_utils @@ -86,7 +86,7 @@ def clrs_generator( # make all of the possible generators. for algo_name, lengths in algos_and_lengths.items(): for length in lengths: - sampler, _ = clrs.build_sampler( + sampler, _ = samplers.build_sampler( algo_name, seed=seed, num_samples=-1, diff --git a/clrs/_src/clrs_text/huggingface_generators_test.py b/clrs/_src/clrs_text/huggingface_generators_test.py index da2d901c..e7c547ff 100644 --- a/clrs/_src/clrs_text/huggingface_generators_test.py +++ b/clrs/_src/clrs_text/huggingface_generators_test.py @@ -18,9 +18,12 @@ from absl.testing import absltest from absl.testing import parameterized -import clrs + from clrs._src.clrs_text import clrs_utils from clrs._src.clrs_text import huggingface_generators + +import clrs._src.specs as clrs_spec + import datasets @@ -28,7 +31,7 @@ class TestCLRSGenerator(parameterized.TestCase): """Check that the generator output matches the expected format.""" @parameterized.product( - algo_name=list(clrs.CLRS_30_ALGS_SETTINGS.keys()), + algo_name=list(clrs_spec.CLRS_30_ALGS_SETTINGS.keys()), lengths=[[4, 8]], use_hints=[True, False], dataset_from_generator_and_num_samples=[ @@ -89,7 +92,7 @@ def test_generator_output_format( def test_auxiliary_fields(self, lengths, use_hints): """Test that the auxiliary fields are set correctly.""" algos_and_lengths = { - algo_name: lengths for algo_name in clrs.CLRS_30_ALGS_SETTINGS + algo_name: lengths for algo_name in clrs_spec.CLRS_30_ALGS_SETTINGS } clrs_ds = datasets.Dataset.from_generator( huggingface_generators.clrs_generator,