Skip to content

Commit

Permalink
testing prior
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Aug 23, 2024
1 parent 45fc3d8 commit c156136
Showing 1 changed file with 67 additions and 11 deletions.
78 changes: 67 additions & 11 deletions gamma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,59 @@ def convert_picks_csv(picks, stations, config):
)


def hierarchical_dbscan_clustering(data, phase_loc, phase_type, phase_weight, vel, eps=15, min_samples=3):

def dbscan2(t, xy, w, ph, vel, eps, min_samples, ratio=1.1):

data = np.hstack([t * ratio, xy / vel["p"]]) # time, x, y
db_ = DBSCAN(eps=eps, min_samples=min_samples).fit(data, sample_weight=np.squeeze(w, axis=-1))

return db_.labels_

db = DBSCAN(eps=eps, min_samples=min_samples).fit(
np.hstack([data[:, 0:1], phase_loc[:, :2] / vel["p"]]), # time, x, y
sample_weight=np.squeeze(phase_weight, axis=-1),
)

labels = db.labels_
unique_labels = np.unique(labels)
current_label = labels.max() + 1
ratio = 1

for _ in range(20):

ratio *= 1.5
unique_labels = np.unique(labels)
current_label = unique_labels.max() + 1
keep_split = False

for label in unique_labels:

if label == -1:
continue

idx = labels == label
t = data[idx, 0:1]
if (t.max() - t.min()) < 800: # s
continue

xy = phase_loc[idx, :2]
w = phase_weight[idx]
ph = phase_type[idx]

labels_ = dbscan2(t, xy, w, ph, vel, eps=eps, min_samples=min_samples, ratio=ratio)
labels_ = np.where(labels_ == -1, -1, labels_ + current_label)
labels[idx] = labels_

current_label += labels_.max() + 1
keep_split = True

if not keep_split:
break

return labels


def association(picks, stations, config, event_idx0=0, method="BGMM", **kwargs):
data, locs, phase_type, phase_weight, pick_idx, pick_station_id, timestamp0 = convert_picks_csv(
picks, stations, config
Expand All @@ -84,17 +137,20 @@ def association(picks, stations, config, event_idx0=0, method="BGMM", **kwargs):
config["eikonal"] = initialize_eikonal(config["eikonal"])

if ("use_dbscan" in config) and config["use_dbscan"]:
# db = DBSCAN(eps=config["dbscan_eps"], min_samples=config["dbscan_min_samples"]).fit(data[:, 0:1])
db = DBSCAN(eps=config["dbscan_eps"], min_samples=config["dbscan_min_samples"]).fit(
np.hstack([data[:, 0:1], locs[:, :2] / np.average(vel["p"])]),
sample_weight=np.squeeze(phase_weight),
)
# db = DBSCAN(eps=config["dbscan_eps"], min_samples=config["dbscan_min_samples"]).fit(
# np.hstack([data[:, 0:1], locs[:, :2] / np.average(vel["p"]) / np.array([3.0, 1.0])]),
# np.hstack([data[:, 0:1], locs[:, :2] / np.average(vel["p"])]),
# sample_weight=np.squeeze(phase_weight),
# )

labels = db.labels_
# labels = db.labels_
labels = hierarchical_dbscan_clustering(
data,
locs,
phase_type,
phase_weight,
vel,
eps=config["dbscan_eps"],
min_samples=config["dbscan_min_samples"],
)
unique_labels = set(labels)
unique_labels = unique_labels.difference([-1])
else:
Expand Down Expand Up @@ -248,7 +304,7 @@ def associate(
## option 4
rstd = np.sqrt(x_std**2 + y_std**2)
# scaler = max(10.0, (rstd / 6.0) * (rstd / 60.0)) # 6.0 km/s, 60 km
scaler = max(1.0, (rstd / 6.0) * (rstd / 60.0)) # 6.0 km/s, 60 km
scaler = max(1.0, (rstd / 6.0) * (rstd / 30.0)) # 6.0 km/s, 30 km
if config["use_amplitude"]:
# covariance_prior_pre = [time_range * 10.0, amp_range * 10.0]
covariance_prior_pre = [scaler, scaler]
Expand Down Expand Up @@ -502,8 +558,8 @@ def init_centers(config, data_, locs_, time_range, max_num_event=1):

index = np.argsort(data_[:, 0])[:: max(len(data_) // num_t_init, 1)][:num_t_init]
t_init = data_[index, 0]
x_init = locs_[:, 0][index]
y_init = locs_[:, 1][index]
x_init = locs_[:, 0][index] # + np.random.uniform(low=-1, high=1, size=num_t_init) * np.std(locs_[:, 0])
y_init = locs_[:, 1][index] # + np.random.uniform(low=-1, high=1, size=num_t_init) * np.std(locs_[:, 1])
# x_init, y_init = np.mean(locs_[:, 0]), np.mean(locs_[:, 1])
# x_init = np.broadcast_to(x_init, (num_t_init)).reshape(-1)
# y_init = np.broadcast_to(y_init, (num_t_init)).reshape(-1)
Expand Down

0 comments on commit c156136

Please sign in to comment.