Skip to content

Commit

Permalink
test: update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-nelson-uiuc committed Nov 28, 2024
1 parent c34252e commit f0f9605
Showing 1 changed file with 84 additions and 38 deletions.
122 changes: 84 additions & 38 deletions tests/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,70 +61,116 @@ def sample_data(spark_fixture):

class TestColumnSelector:
def test_string(self, sample_data):
selected_schema = sample_data.select(cs.string())
assert selected_schema.columns == ["name", "instrument"]
selected_columns = sample_data.select(cs.string()).columns
expected_columns = ["name", "instrument"]
assert set(selected_columns).difference(expected_columns) == set()

def test_numeric(self, sample_data):
selected_schema = sample_data.select(cs.numeric())
assert selected_schema.columns == ["seasons"]
selected_columns = sample_data.select(cs.numeric()).columns
expected_columns = ["seasons"]
assert set(selected_columns).difference(expected_columns) == set()

def test_temporal(self, sample_data):
selected_schema = sample_data.select(cs.temporal())
assert selected_schema.columns == ["birth_date", "original_air_date"]
selected_columns = sample_data.select(cs.temporal())
expected_columns = ["seasons"]
assert set(selected_columns).difference(expected_columns) == set()

def test_date(self, sample_data):
selected_schema = sample_data.select(cs.date())
assert selected_schema.columns == ["birth_date"]
selected_columns = sample_data.select(cs.date())
expected_columns = ["birth_date"]
assert set(selected_columns).difference(expected_columns) == set()

def test_time(self, sample_data):
selected_columns = sample_data.select(cs.time())
expected_columns = ["original_air_date"]
assert set(selected_columns).difference(expected_columns) == set()

def test_interval(self, sample_data):
selected_schema = sample_data.select(cs.interval())
assert selected_schema.columns == []
selected_columns = sample_data.select(cs.interval())
expected_columns = []
assert set(selected_columns).difference(expected_columns) == set()

def test_complex(self, sample_data):
selected_schema = sample_data.select(cs.complex())
assert selected_schema.columns == []
selected_columns = sample_data.select(cs.complex())
expected_columns = []
assert set(selected_columns).difference(expected_columns) == set()

def test_required(self, sample_data):
selected_schema = sample_data.select(cs.required())
assert selected_schema.columns == [
selected_columns = sample_data.select(cs.required())
expected_columns = [
"name",
"birth_date",
"original_air_date",
"seasons",
"seasons"
]
assert set(selected_columns).difference(expected_columns) == set()

def test_matches(self, sample_data):
selected_schema = sample_data.select(cs.matches("_"))
assert selected_schema.columns == ["birth_date", "original_air_date"]
selected_schema = sample_data.select(cs.matches("date$"))
assert selected_schema.columns == ["birth_date", "original_air_date"]
selected_schema = sample_data.select(cs.matches("*"))
assert selected_schema.columns == [
selected_columns = sample_data.select(cs.matches("_"))
expected_columns = [
"birth_date",
"original_air_date",
]
assert set(selected_columns).difference(expected_columns) == set()

selected_columns = sample_data.select(cs.matches("date$"))
expected_columns = [
"birth_date",
"original_air_date",
]
assert set(selected_columns).difference(expected_columns) == set()

selected_columns = sample_data.select(cs.matches("*"))
expected_columns = [
"name",
"birth_date",
"original_air_date",
"seasons",
"instrument",
"instrument"
]
assert set(selected_columns).difference(expected_columns) == set()

def test_contains(self, sample_data):
selected_schema = sample_data.select(cs.contains("_"))
assert selected_schema.columns == ["birth_date", "original_air_date"]
selected_schema = sample_data.select(cs.contains("me"))
assert selected_schema.columns == ["name", "instrument"]
selected_schema = sample_data.select(cs.contains("krusty"))
assert selected_schema.columns == []
selected_columns = sample_data.select(cs.contains("_"))
expected_columns = [
"birth_date",
"original_air_date",
]
assert set(selected_columns).difference(expected_columns) == set()

selected_columns = sample_data.select(cs.contains("me"))
expected_columns = [
"name",
"instrument"
]
assert set(selected_columns).difference(expected_columns) == set()

selected_columns = sample_data.select(cs.contains("krusty"))
expected_columns = []
assert set(selected_columns).difference(expected_columns) == set()

def test_starts_with(self, sample_data):
selected_schema = sample_data.select(cs.starts_with("o"))
assert selected_schema.columns == ["original_air_date"]
selected_schema = sample_data.select(cs.starts_with("z"))
assert selected_schema.columns == []
selected_columns = sample_data.select(cs.starts_with("o"))
expected_columns = ["original_air_date"]
assert set(selected_columns).difference(expected_columns) == set()

selected_columns = sample_data.select(cs.starts_with("z"))
expected_columns = []
assert set(selected_columns).difference(expected_columns) == set()

def test_ends_with(self, sample_data):
selected_schema = sample_data.select(cs.ends_with("e"))
assert selected_schema.columns == ["name", "birth_date", "original_air_date"]
selected_schema = sample_data.select(cs.ends_with("date"))
assert selected_schema.columns == ["birth_date", "original_air_date"]
selected_schema = sample_data.select(cs.ends_with("z"))
assert selected_schema.columns == []
selected_columns = sample_data.select(cs.ends_with("e"))
expected_columns = [
"name",
"birth_date",
"original_air_date",
]
assert set(selected_columns).difference(expected_columns) == set()

selected_columns = sample_data.select(cs.ends_with("date"))
expected_columns = ["birth_date", "original_air_date",]
assert set(selected_columns).difference(expected_columns) == set()

selected_columns = sample_data.select(cs.ends_with("z"))
expected_columns = []
assert set(selected_columns).difference(expected_columns) == set()

0 comments on commit f0f9605

Please sign in to comment.