-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX[app]: csv save diatnce | DOC[app]: results | WIP
- Loading branch information
Showing
23 changed files
with
483 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
install: | ||
install-local: | ||
pip install . | ||
|
||
upgrade: | ||
upgrade-local: | ||
pip install --upgrade . | ||
|
||
uninstall: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
Utils Module | ||
============ | ||
|
||
.. autofunction:: safe_pfl_distance.utils.model_loader.load_model | ||
.. autofunction:: safe_pfl_distance.utils.model_loader.load_model | ||
.. autofunction:: safe_pfl_distance.utils.cosine.cosine_distance | ||
.. autofunction:: safe_pfl_distance.utils.euclidean.euclidean_distance | ||
.. autofunction:: safe_pfl_distance.utils.jensen_shannon.jensen_shannon_distance | ||
.. autofunction:: safe_pfl_distance.utils.wasserstein.wasserstein_distance |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import sys | ||
import torch | ||
import numpy as np | ||
import matplotlib | ||
|
||
from safe_pfl_distance.utils.results import generate_csv | ||
|
||
# Use a non-interactive backend | ||
matplotlib.use("Agg") | ||
|
||
from safe_pfl_distance.utils.model_loader import load_models | ||
from safe_pfl_distance.utils.coordinate_based import coordinate_based_distance | ||
from safe_pfl_distance.utils.jensen_shannon import jensen_shannon_distance | ||
from safe_pfl_distance.utils.wasserstein import wasserstein_distance | ||
from safe_pfl_distance.utils.euclidean import euclidean_distance | ||
from safe_pfl_distance.utils.cosine import cosine_distance | ||
|
||
|
||
class ModelDistancesCalculator: | ||
def __init__(self, model_type, sensitivity_parameter=0.01, model_path_prefix:str = "./models", precision=5): | ||
self.model_type = model_type.lower() | ||
self.model_path_prefix = model_path_prefix | ||
|
||
self.precision = precision if precision > 0 else 5 | ||
|
||
# Validate model_type | ||
if self.model_type not in ["cnn", "resnet", "google", "alexnet", "vgg"]: | ||
raise ValueError( | ||
"Invalid model_type. Please choose from 'cnn', 'resnet', 'google', 'alexnet', or 'vgg'." | ||
) | ||
else: | ||
print(f"Processing model_type: {self.model_type}") | ||
|
||
self.client_ids = list(range(10)) | ||
self.models = [] | ||
self.model_weights = [] | ||
self.model_top_weight_indices = [] | ||
self.P = sensitivity_parameter # Percentage for top weights (1%) | ||
|
||
self.models = load_models(self.client_ids, self.model_type, self.model_path_prefix) | ||
|
||
def extract_model_weights(self): | ||
for idx, model in enumerate(self.models): | ||
if isinstance(model, dict): | ||
state_dict = model | ||
elif isinstance(model, torch.nn.Module): | ||
state_dict = model.state_dict() | ||
print("Model state_dict extracted.") | ||
else: | ||
print(f"Unrecognized model format at index {idx}. Skipping...") | ||
continue | ||
|
||
weights = [] | ||
for key, value in state_dict.items(): | ||
# Include all parameters, including biases | ||
weights.append(value.cpu().numpy().flatten()) | ||
|
||
if weights: | ||
weights_vector = np.concatenate(weights) | ||
self.model_weights.append(weights_vector) | ||
else: | ||
print(f"No weights found for model at index {idx}. Skipping...") | ||
|
||
if len(self.model_weights) < 2: | ||
print("Not enough models to compute distances.") | ||
sys.exit(1) | ||
else: | ||
self.prepare_top_weight_indices() | ||
|
||
def prepare_top_weight_indices(self): | ||
N = len(self.model_weights[0]) # Total number of weights | ||
p = int(self.P * N) | ||
if p == 0: | ||
p = 1 # Ensure at least one weight is selected | ||
|
||
self.p = p | ||
self.model_top_weight_indices = [] | ||
for weights in self.model_weights: | ||
importance_scores = np.abs(weights) | ||
top_indices = np.argpartition(-importance_scores, self.p - 1)[: self.p] | ||
self.model_top_weight_indices.append(set(top_indices)) | ||
|
||
def compute_distance_matrix(self, distance_func, is_indices=False): | ||
num_models = len(self.model_weights) | ||
distance_matrix = np.zeros((num_models, num_models)) | ||
for i in range(num_models): | ||
for j in range(num_models): | ||
if is_indices: | ||
distance = distance_func( | ||
i, j, self.model_top_weight_indices, self.P | ||
) | ||
else: | ||
distance = distance_func( | ||
self.model_weights[i], self.model_weights[j] | ||
) | ||
distance_matrix[i, j] = distance | ||
return distance_matrix | ||
|
||
def compute_distance_matrices(self): | ||
distance_functions = { | ||
"Euclidean": euclidean_distance, | ||
"Cosine": cosine_distance, | ||
"coordinate-based": coordinate_based_distance, | ||
"Jensen-Shannon": jensen_shannon_distance, | ||
"Wasserstein": wasserstein_distance, | ||
} | ||
|
||
for metric_name, distance_func in distance_functions.items(): | ||
print(f"Computing {metric_name} distance matrix...") | ||
|
||
if metric_name == "coordinate-based": | ||
distance_matrix = self.compute_distance_matrix( | ||
distance_func, is_indices=True | ||
) | ||
else: | ||
distance_matrix = self.compute_distance_matrix( | ||
distance_func, is_indices=False | ||
) | ||
|
||
# Proceed to cluster and generate LaTeX code | ||
generate_csv(distance_matrix, metric_name, self.model_type, self.precision) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch.nn as nn | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self, num_classes=10): | ||
super(Net, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.Conv2d( | ||
3, 64, kernel_size=3, stride=1, padding=1 | ||
), # Input: 32x32 -> Output: 32x32 | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=2, stride=2), # Output: 16x16 | ||
nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1), # Output: 16x16 | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=2, stride=2), # Output: 8x8 | ||
nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1), # Output: 8x8 | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1), # Output: 8x8 | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), # Output: 8x8 | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=2, stride=2), # Output: 4x4 | ||
) | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(), | ||
nn.Linear(256 * 4 * 4, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Dropout(), | ||
nn.Linear(4096, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(4096, num_classes), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.classifier(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class Net( | ||
nn.Module, | ||
): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(x.size(0), 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch.nn as nn | ||
import torchvision.models as models | ||
|
||
|
||
class Net( | ||
nn.Module, | ||
): | ||
def __init__(self, num_classes=10): | ||
super(Net, self).__init__() | ||
self.googlenet = models.googlenet_v2(pretrained=False) | ||
self.googlenet.classifier[1] = nn.Linear( | ||
self.googlenet.last_channel, num_classes | ||
) | ||
|
||
def forward(self, x): | ||
out = self.googlenet(x) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch.nn as nn | ||
import torchvision.models as models | ||
|
||
|
||
class Net( | ||
nn.Module, | ||
): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
|
||
self.resnet18 = models.resnet18(pretrained=False) | ||
|
||
def forward(self, x): | ||
out = self.resnet18(x) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch.nn as nn | ||
from torchvision.models import vgg16 | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self, num_classes=10): | ||
super(Net, self).__init__() | ||
self.features = vgg16(pretrained=False).features | ||
self.classifier = nn.Sequential( | ||
nn.Linear(in_features=4096, out_features=4096), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(in_features=4096, out_features=num_classes), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(-1, 4096) | ||
x = self.classifier(x) | ||
return x |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
def coordinate_based_distance(u, v, model_weight_indices, sensitivity_parameter): | ||
# try: | ||
# if not isinstance(model_weight_indices, dict): | ||
# raise TypeError("model_weight_indices must be a dictionary.") | ||
# if not isinstance(sensitivity_parameter, (int, float)): | ||
# raise TypeError("sensitivity_parameter must be a number.") | ||
# if sensitivity_parameter <= 0: | ||
# raise ValueError("sensitivity_parameter must be positive.") | ||
|
||
# if u not in model_weight_indices or v not in model_weight_indices: | ||
# return 1.0 | ||
|
||
indices_u = model_weight_indices[u] | ||
indices_v = model_weight_indices[v] | ||
|
||
intersection = len(indices_u & indices_v) | ||
# if intersection >= sensitivity_parameter: | ||
# return 0.0 | ||
similarity = intersection / sensitivity_parameter | ||
distance = 1 - similarity | ||
return distance | ||
|
||
# except TypeError as e: | ||
# raise | ||
# except ValueError as e: | ||
# raise | ||
# except Exception as e: | ||
# raise Exception(f"An unexpected error occurred: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from scipy.spatial.distance import cosine | ||
|
||
|
||
def cosine_distance(u, v): | ||
""" | ||
Calculates the cosine distance between two vectors. | ||
Args: | ||
u (array-like): First input vector. | ||
v (array-like): Second input vector. | ||
Returns: | ||
str: The cosine distance between u and v, formatted to two decimal places. | ||
Returns "1.00" if either input vector is all zeros, as cosine distance is undefined in this case. | ||
Raises: | ||
TypeError: If inputs are not array-like. | ||
ValueError: If input vectors have different lengths. | ||
Exception: if any other exception occurs during cosine calculation. | ||
Example: | ||
>>> cosine_distance([1, 0, 0], [0, 1, 0]) | ||
'1.00' | ||
>>> cosine_distance([1, 0, 0], [1, 0, 0]) | ||
'0.00' | ||
>>> cosine_distance([1, 2, 3], [4, 5, 6]) | ||
'0.02' | ||
""" | ||
|
||
try: | ||
distance = cosine(u, v) | ||
if distance is not None: | ||
return distance | ||
else: | ||
return "1.00" | ||
except TypeError as e: | ||
raise TypeError(f"Input vectors must be array-like: {e}") | ||
except ValueError as e: | ||
raise ValueError(f"Input vectors must have the same length: {e}") | ||
except Exception as e: | ||
raise Exception(f"An error occurred during cosine distance calculation: {e}") |
Oops, something went wrong.