Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Sep 27, 2024
1 parent 8d6d2df commit e78cff9
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions viscy/representation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def compute_displacement(

return displacement_per_tau


def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id):
filtered_data = embedding_dataset.where(
(embedding_dataset["fov_name"] == fov_name)
Expand All @@ -619,31 +620,35 @@ def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, tr

euclidean_distances = []
for i in range(len(time_points)):
distance = np.linalg.norm(first_time_point_embedding - normalized_features[i].reshape(1, -1))
distance = np.linalg.norm(
first_time_point_embedding - normalized_features[i].reshape(1, -1)
)
euclidean_distances.append(distance)

return time_points, euclidean_distances

def compute_displacement_mean_std_full(
embedding_dataset, max_tau=10
):

def compute_displacement_mean_std_full(embedding_dataset, max_tau=10):
fov_names = embedding_dataset["fov_name"].values
track_ids = embedding_dataset["track_id"].values
timepoints = embedding_dataset["t"].values
embeddings = embedding_dataset["features"].values

cell_identifiers = np.array(list(zip(fov_names, track_ids)), dtype=[('fov_name', 'O'), ('track_id', 'int64')])
cell_identifiers = np.array(
list(zip(fov_names, track_ids)),
dtype=[("fov_name", "O"), ("track_id", "int64")],
)

unique_cells = np.unique(cell_identifiers)

displacement_per_tau = defaultdict(list)

for cell in unique_cells:
fov_name = cell['fov_name']
track_id = cell['track_id']
fov_name = cell["fov_name"]
track_id = cell["track_id"]

indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0]

cell_timepoints = timepoints[indices]
cell_embeddings = embeddings[indices]

Expand All @@ -664,10 +669,12 @@ def compute_displacement_mean_std_full(

if len(future_index) >= 1:
future_embedding = cell_embeddings[future_index[0]]
future_embedding = future_embedding / np.linalg.norm(future_embedding)
future_embedding = future_embedding / np.linalg.norm(
future_embedding
)

distance = np.linalg.norm(current_embedding - future_embedding)

displacement_per_tau[tau].append(distance)

mean_displacement_per_tau = {
Expand All @@ -680,4 +687,3 @@ def compute_displacement_mean_std_full(
}

return mean_displacement_per_tau, std_displacement_per_tau

0 comments on commit e78cff9

Please sign in to comment.