This directory contains implementations of the Base-MLP, ResNet-18 and Logistic Regression base models used in evaluation in the paper. Model files containing the trained parameters for each base model used in the paper's evaluation are under the base_model_trained_files directory.
If you'd like to add a new base model for evaluation, you need to make to perform the following steps:
-
Implement the base model in PyTorch. Your base model must inherit from
torch.nn.Module
and must implement the__init__
andforward
methods. Theforward
method must take just one parameter, a batch of samples over which a forward pass is performed. -
Train your base model and save the base model's state dictionary to a file. This may be done using
torch.save(my_base_model.state_dict(), "my_file.t7")
-
Create a new configuration file that specifically changes the following parameters:
BaseModel
: Specify the classpath of your base model in the "class" field and any arguments required for the__init__
function in the "args" field.base_model_file
: Specify the path to the state dictionary saved in step (2).base_model_input_size
: Specify the input dimensions expected of inputs to theforward
method of your base model.
There are more details on other configuration parameters in the conf directory.