-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ESM2 Finetuning refactor #574
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Ported over the changes from #546 |
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
/build-ci |
@@ -180,12 +180,12 @@ def get_loss_reduction_class(self) -> Type[RegressorLossReduction]: | |||
return RegressorLossReduction | |||
|
|||
|
|||
class InMemorySingleValueDataset(Dataset): | |||
class InMemorySingleValueDataset(InMemoryCSVDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this into dataset.py
or anywhere under data?
@@ -205,12 +190,12 @@ def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]: | |||
return ClassifierLossReduction | |||
|
|||
|
|||
class InMemoryPerTokenValueDataset(Dataset): | |||
class InMemoryPerTokenValueDataset(InMemoryCSVDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly should we move this under dataset.py
? Like esm2/data/finetune/dataset.py
or similar.
dataset_class_options: Dict[str, Type[InMemoryCSVDataset]] = SUPPORTED_DATASETS | ||
|
||
def dataset_class_type(desc: str) -> Type[InMemoryCSVDataset]: | ||
try: | ||
return dataset_class_options[desc] | ||
except KeyError: | ||
raise argparse.ArgumentTypeError( | ||
f"Do not recognize key {desc}, valid options are: {dataset_class_options.keys()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pstjohn had a similar approach by inheriting from both str
and enum
to streamline argument parsing.
https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py#L47
https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py#L586
config_class: Type[BioBertConfig] = ESM2FineTuneSeqConfig, | ||
metric_tracker: Callback | None = None, | ||
overlap_grad_reduce: bool = True, | ||
overlap_param_gather: bool = False, # TODO waiting for a NeMo fix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset_class=args.dataset_class, | ||
config_class=args.config_class, | ||
overlap_grad_reduce=not args.no_overlap_grad_reduce, | ||
overlap_param_gather=args.overlap_param_gather, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parser.add_argument( | ||
"--overlap-param-gather", | ||
action="store_true", | ||
default=False, | ||
) # TODO waiting for a NeMo fix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description
Type of changes
CI Pipeline Configuration
Configure CI behavior by checking relevant boxes below. This will automatically apply labels.
Note
By default, the notebooks validation tests are skipped unless explicitly enabled.
Usage
Pre-submit Checklist