Skip to content

Commit

Permalink
feat(sampler): fix component value
Browse files Browse the repository at this point in the history
  • Loading branch information
breakthewall committed Jul 12, 2024
1 parent 827bd6c commit 59ba0e0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
34 changes: 29 additions & 5 deletions icfree/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import numpy as np
import random
from pyDOE2 import lhs
import ast

def generate_lhs_samples(input_file, num_samples, step, seed=None):
def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed=None):
"""
Generates Latin Hypercube Samples for components based on discrete ranges.
Parameters:
- input_file: Path to the input file containing components and their max values.
- num_samples: Number of samples to generate.
- step: Step size for creating discrete ranges.
- fixed_values: Dictionary of components with fixed values (optional).
- seed: Random seed for reproducibility.
Returns:
Expand All @@ -30,7 +32,12 @@ def generate_lhs_samples(input_file, num_samples, step, seed=None):

# Generate discrete ranges for each component
for index, row in components_df.iterrows():
component_range = np.arange(0, row['maxValue'] + step, step)
component_name = row['Component']
if fixed_values and component_name in fixed_values:
# If the component has a fixed value, use a single-element array
component_range = np.array([fixed_values[component_name]])
else:
component_range = np.arange(0, row['maxValue'] + step, step)
discrete_ranges.append(component_range)

# Determine the number of components
Expand All @@ -48,7 +55,7 @@ def generate_lhs_samples(input_file, num_samples, step, seed=None):
samples_df = pd.DataFrame(samples, columns=components_df['Component'])
return samples_df

def main(input_file, output_file, num_samples, step=2.5, seed=None):
def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed=None):
"""
Main function to generate LHS samples and save them to a CSV file.
Expand All @@ -57,10 +64,23 @@ def main(input_file, output_file, num_samples, step=2.5, seed=None):
- output_file: Path to the output CSV file where samples will be written.
- num_samples: Number of samples to generate.
- step: Step size for creating discrete ranges (default: 2.5).
- fixed_values: Dictionary of components with fixed values (optional).
- seed: Random seed for reproducibility (optional).
"""
# Read the input file
components_df = pd.read_csv(input_file, sep='\t')

# Get the list of components from the input file
component_names = components_df['Component'].tolist()

# Check for fixed values that are not in the list of components
if fixed_values:
for component in fixed_values.keys():
if component not in component_names:
print(f"Warning: Component '{component}' not found in the input file.")

# Generate LHS samples
samples_df = generate_lhs_samples(input_file, num_samples, step, seed)
samples_df = generate_lhs_samples(input_file, num_samples, step, fixed_values, seed)

# Write the samples to a CSV file
samples_df.to_csv(output_file, index=False)
Expand All @@ -73,10 +93,14 @@ def main(input_file, output_file, num_samples, step=2.5, seed=None):
parser.add_argument('output_file', type=str, help='Output CSV file path for the samples.')
parser.add_argument('num_samples', type=int, help='Number of samples to generate.')
parser.add_argument('--step', type=float, default=2.5, help='Step size for creating discrete ranges (default: 2.5).')
parser.add_argument('--fixed_values', type=str, default=None, help='Fixed values for components as a dictionary (e.g., \'{"Component1": 10, "Component2": 20}\')')
parser.add_argument('--seed', type=int, default=None, help='Seed for random number generation for reproducibility (optional).')

# Parse arguments
args = parser.parse_args()

# Convert fixed_values argument from string to dictionary if provided
fixed_values = ast.literal_eval(args.fixed_values) if args.fixed_values else None

# Run the main function with the parsed arguments
main(args.input_file, args.output_file, args.num_samples, args.step, args.seed)
main(args.input_file, args.output_file, args.num_samples, args.step, fixed_values, args.seed)
19 changes: 14 additions & 5 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
def test_generate_lhs_samples_normal(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, self.seed)
result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, self.seed)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertListEqual(list(result.columns), ['A', 'B', 'C'])
Expand All @@ -28,7 +28,7 @@ def test_generate_lhs_samples_normal(self, mock_read_csv):
def test_generate_lhs_samples_no_seed(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None)
result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, None)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertListEqual(list(result.columns), ['A', 'B', 'C'])
Expand All @@ -39,7 +39,7 @@ def test_generate_lhs_samples_edge_case_zero_maxValue(self, mock_read_csv):
edge_case_df.loc[0, 'maxValue'] = 0 # Set maxValue of component 'A' to 0
mock_read_csv.return_value = edge_case_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, self.seed)
result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, self.seed)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertTrue((result['A'] == 0).all()) # All values in column 'A' should be zero
Expand All @@ -49,14 +49,23 @@ def test_generate_lhs_samples_invalid_step(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

with self.assertRaises(IndexError):
generate_lhs_samples("fake_path.csv", self.num_samples, -2.5, self.seed) # Negative step size should raise an error
generate_lhs_samples("fake_path.csv", self.num_samples, -2.5, None, self.seed) # Negative step size should raise an error

@patch("icfree.sampler.pd.read_csv")
def test_generate_lhs_samples_invalid_input_file(self, mock_read_csv):
mock_read_csv.side_effect = FileNotFoundError

with self.assertRaises(FileNotFoundError):
generate_lhs_samples("invalid_path.csv", self.num_samples, self.step, self.seed)
generate_lhs_samples("invalid_path.csv", self.num_samples, self.step, None, self.seed)

@patch("icfree.sampler.pd.read_csv")
def test_generate_lhs_samples_fix_component_value(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, {'A': 5}, self.seed)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertTrue((result['A'] == 5).all())

if __name__ == "__main__":
unittest.main()

0 comments on commit 59ba0e0

Please sign in to comment.