generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrn_val_tst.yaml
48 lines (39 loc) · 1.84 KB
/
trn_val_tst.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# @package _global_
# Global params
seed: 1337 # Random seed
target: "Age" # Target column name
# Cross-validation params
cv_is_split: True # Perform cross-validation?
cv_n_splits: 5 # Number of splits in cross-validation
cv_n_repeats: 1 # Number of repeats in cross-validation
# Data params
in_dim: 10 # Number of input features
out_dim: 1 # Output dimension
embed_dim: 16 # Default embedding dimension
# Optimization metrics params
optimized_metric: "mean_absolute_error" # All metrics listed in src.tasks.metrics
optimized_part: "val" # Optimized data partition. Options: ["val", "tst_ctrl"]
direction: "min" # Direction of metrics optimization. Options ["min", "max"]
# Run params
max_epochs: 1000 # Maximum number of epochs
patience: 50 # Number of early stopping epochs
feature_importance: none # Feature importance method. Options: [none, shap_deep, shap_kernel, shap_sampling, shap_tree, native]
# Info params
debug: False # Is Debug?
print_config: False # Print config?
print_model: False # Print model info?
ignore_warnings: True # Ignore warnings?
test_after_training: True # Test after training?
# Directories and files params
project_name: ${model.name}_trn_val_tst
base_dir: "${oc.env:PROJECT_ROOT}/data"
data_dir: "${base_dir}"
work_dir: "${base_dir}/models/${project_name}"
# SHAP values params
is_shap: False # Calculate SHAP values?
is_shap_save: True # Save SHAP values?
shap_explainer: "Sampling" # Type of explainer. Options: ["Tree", "Kernel", "Deep"]
shap_bkgrd: "trn" # Type of background data. Options: ["trn"]
# Plot params
num_top_features: 10 # Number of most important features to plot
num_examples: 10 # Number of samples to plot some SHAP figures