Skip to content

Commit

Permalink
Add huggingface_generators to __init__.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653667715
  • Loading branch information
RerRayne authored and copybara-github committed Jul 18, 2024
1 parent 9489aef commit a9f524b
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 12 deletions.
15 changes: 8 additions & 7 deletions .github/workflows/pypi-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@ name: pypi

on:
release:
types: [created, published]
types: [published]
branches: [main, master]

jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v1
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Check consistency between the package version and release tag
run: |
RELEASE_VER=${GITHUB_REF#refs/*/}
Expand All @@ -25,6 +22,10 @@ jobs:
then
echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1
fi
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
Expand Down
2 changes: 2 additions & 0 deletions clrs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from clrs._src import clrs_text
from clrs._src import decoders
from clrs._src import processors
from clrs._src import specs

from clrs._src.dataset import chunkify
from clrs._src.dataset import CLRSDataset
Expand Down Expand Up @@ -77,6 +78,7 @@
"process_permutations",
"process_pred_as_input",
"process_random_pos",
"specs",
"evaluate",
"evaluate_hints",
"Features",
Expand Down
1 change: 1 addition & 0 deletions clrs/_src/clrs_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
"""The CLRS Text Algorithmic Reasoning Benchmark."""

from clrs._src.clrs_text import clrs_utils
from clrs._src.clrs_text import huggingface_generators
4 changes: 2 additions & 2 deletions clrs/_src/clrs_text/huggingface_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import random
from typing import Dict, List, Optional

import clrs
from clrs._src import samplers
from clrs._src.clrs_text import clrs_utils


Expand Down Expand Up @@ -86,7 +86,7 @@ def clrs_generator(
# make all of the possible generators.
for algo_name, lengths in algos_and_lengths.items():
for length in lengths:
sampler, _ = clrs.build_sampler(
sampler, _ = samplers.build_sampler(
algo_name,
seed=seed,
num_samples=-1,
Expand Down
9 changes: 6 additions & 3 deletions clrs/_src/clrs_text/huggingface_generators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@

from absl.testing import absltest
from absl.testing import parameterized
import clrs

from clrs._src.clrs_text import clrs_utils
from clrs._src.clrs_text import huggingface_generators

import clrs._src.specs as clrs_spec

import datasets


class TestCLRSGenerator(parameterized.TestCase):
"""Check that the generator output matches the expected format."""

@parameterized.product(
algo_name=list(clrs.CLRS_30_ALGS_SETTINGS.keys()),
algo_name=list(clrs_spec.CLRS_30_ALGS_SETTINGS.keys()),
lengths=[[4, 8]],
use_hints=[True, False],
dataset_from_generator_and_num_samples=[
Expand Down Expand Up @@ -89,7 +92,7 @@ def test_generator_output_format(
def test_auxiliary_fields(self, lengths, use_hints):
"""Test that the auxiliary fields are set correctly."""
algos_and_lengths = {
algo_name: lengths for algo_name in clrs.CLRS_30_ALGS_SETTINGS
algo_name: lengths for algo_name in clrs_spec.CLRS_30_ALGS_SETTINGS
}
clrs_ds = datasets.Dataset.from_generator(
huggingface_generators.clrs_generator,
Expand Down

0 comments on commit a9f524b

Please sign in to comment.