Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 33 additions & 64 deletions src/spikeinterface/metrics/quality/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,57 +102,27 @@ def _nn_one_unit(args):
return unit_id, nn_hit_rate, nn_miss_rate


def _nearest_neighbor_metric_function(sorting_analyzer, unit_ids, tmp_data, job_kwargs, **metric_params):
def _nearest_neighbor_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
nn_result = namedtuple("NearestNeighborResult", ["nn_hit_rate", "nn_miss_rate"])

# Use pre-computed PCA data
pca_data_per_unit = tmp_data["pca_data_per_unit"]

# Extract job parameters
n_jobs = job_kwargs.get("n_jobs", 1)
mp_context = job_kwargs.get("mp_context", None)

nn_hit_rate_dict = {}
nn_miss_rate_dict = {}

if n_jobs == 1:
# Sequential processing
units_loop = unit_ids

for unit_id in units_loop:
pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"]
labels = pca_data_per_unit[unit_id]["labels"]

try:
nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params)
except:
nn_hit_rate = np.nan
nn_miss_rate = np.nan

nn_hit_rate_dict[unit_id] = nn_hit_rate
nn_miss_rate_dict[unit_id] = nn_miss_rate
else:
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')

# Prepare arguments - only pass pickle-able data
args_list = []
for unit_id in unit_ids:
pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"]
labels = pca_data_per_unit[unit_id]["labels"]
args_list.append((unit_id, pcs_flat, labels, metric_params))
for unit_id in unit_ids:
pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"]
labels = pca_data_per_unit[unit_id]["labels"]

with ProcessPoolExecutor(
max_workers=n_jobs,
mp_context=mp.get_context(mp_context) if mp_context else None,
) as executor:
results = executor.map(_nn_one_unit, args_list)
try:
nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params)
except:
nn_hit_rate = np.nan
nn_miss_rate = np.nan

for unit_id, nn_hit_rate, nn_miss_rate in results:
nn_hit_rate_dict[unit_id] = nn_hit_rate
nn_miss_rate_dict[unit_id] = nn_miss_rate
nn_hit_rate_dict[unit_id] = nn_hit_rate
nn_miss_rate_dict[unit_id] = nn_miss_rate

return nn_result(nn_hit_rate=nn_hit_rate_dict, nn_miss_rate=nn_miss_rate_dict)

Expand All @@ -168,7 +138,6 @@ class NearestNeighbor(BaseMetric):
}
depend_on = ["principal_components"]
needs_tmp_data = True
needs_job_kwargs = True


def _nn_advanced_one_unit(args):
Expand Down Expand Up @@ -392,10 +361,10 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id):
import scipy.stats
import scipy.spatial.distance

pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :]
pcs_for_other_units = all_pcs[all_labels != this_unit_id, :]
pcs_for_this_unit = all_pcs[all_labels == this_unit_id]
pcs_for_other_units = all_pcs[all_labels != this_unit_id]

mean_value = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0)
mean_value = np.mean(pcs_for_this_unit, 0, keepdims=True)

try:
VI = np.linalg.inv(np.cov(pcs_for_this_unit.T))
Expand All @@ -405,14 +374,14 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id):

mahalanobis_other = np.sort(scipy.spatial.distance.cdist(mean_value, pcs_for_other_units, "mahalanobis", VI=VI)[0])

mahalanobis_self = np.sort(scipy.spatial.distance.cdist(mean_value, pcs_for_this_unit, "mahalanobis", VI=VI)[0])

# number of spikes
n = np.min([pcs_for_this_unit.shape[0], pcs_for_other_units.shape[0]])
num_spikes_self = pcs_for_this_unit.shape[0]
num_spikes_other = pcs_for_other_units.shape[0]
n = min(num_spikes_self, num_spikes_other)

if n >= 2:
dof = pcs_for_this_unit.shape[1] # number of features
l_ratio = np.sum(1 - scipy.stats.chi2.cdf(pow(mahalanobis_other, 2), dof)) / mahalanobis_self.shape[0]
l_ratio = np.sum(1 - scipy.stats.chi2.cdf(pow(mahalanobis_other, 2), dof)) / num_spikes_self
isolation_distance = pow(mahalanobis_other[n - 1], 2)
# if math.isnan(l_ratio):
# print("NaN detected", mahalanobis_other, VI)
Expand Down Expand Up @@ -449,18 +418,17 @@ def d_prime_metric(all_pcs, all_labels, this_unit_id) -> float:

X = all_pcs

y = np.zeros((X.shape[0],), dtype="bool")
y[all_labels == this_unit_id] = True
y = all_labels == this_unit_id

lda = LinearDiscriminantAnalysis(n_components=1)

X_flda = lda.fit_transform(X, y)

flda_this_cluster = X_flda[np.where(y)[0]]
flda_other_cluster = X_flda[np.where(np.invert(y))[0]]
flda_this_cluster = X_flda[y]
flda_other_cluster = X_flda[~y]

d_prime = (np.mean(flda_this_cluster) - np.mean(flda_other_cluster)) / np.sqrt(
0.5 * (np.std(flda_this_cluster) ** 2 + np.std(flda_other_cluster) ** 2)
(np.var(flda_this_cluster) + np.var(flda_other_cluster)) / 2
)

return d_prime
Expand Down Expand Up @@ -516,22 +484,23 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n
return 1.0, 0.0

this_unit = all_labels == this_unit_id
this_unit_pcs = all_pcs[this_unit, :]
other_units_pcs = all_pcs[np.invert(this_unit), :]
this_unit_pcs = all_pcs[this_unit]
other_units_pcs = all_pcs[~this_unit]
X = np.concatenate((this_unit_pcs, other_units_pcs), 0)

num_obs_this_unit = np.sum(this_unit)

if ratio < 1:
# Subsample spikes
inds = np.arange(0, X.shape[0] - 1, 1 / ratio).astype("int")
X = X[inds, :]
num_obs_this_unit = int(num_obs_this_unit * ratio)

nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="ball_tree").fit(X)
distances, indices = nbrs.kneighbors(X)
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(X)
indices = nbrs.kneighbors(return_distance=False) # don't feed X so it won't return itself as neighbor

this_cluster_nearest = indices[:num_obs_this_unit, 1:].flatten()
other_cluster_nearest = indices[num_obs_this_unit:, 1:].flatten()
this_cluster_nearest = indices[:num_obs_this_unit].flatten()
other_cluster_nearest = indices[num_obs_this_unit:].flatten()

hit_rate = np.mean(this_cluster_nearest < num_obs_this_unit)
miss_rate = np.mean(other_cluster_nearest < num_obs_this_unit)
Expand Down Expand Up @@ -968,17 +937,17 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id):
"""
import scipy.spatial.distance

pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :]
centroid_for_this_unit = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0)
pcs_for_this_unit = all_pcs[all_labels == this_unit_id]
centroid_for_this_unit = np.mean(pcs_for_this_unit, 0, keepdims=True)
distances_for_this_unit = scipy.spatial.distance.cdist(centroid_for_this_unit, pcs_for_this_unit)
distance = np.inf

# find centroid of other cluster and measure distances to that rather than pairwise
# if less than current minimum distance update
for label in np.unique(all_labels):
if label != this_unit_id:
pcs_for_other_cluster = all_pcs[all_labels == label, :]
centroid_for_other_cluster = np.expand_dims(np.mean(pcs_for_other_cluster, 0), 0)
pcs_for_other_cluster = all_pcs[all_labels == label]
centroid_for_other_cluster = np.mean(pcs_for_other_cluster, 0, keepdims=True)
distances_for_other_cluster = scipy.spatial.distance.cdist(centroid_for_other_cluster, pcs_for_this_unit)
mean_distance_for_other_cluster = np.mean(distances_for_other_cluster)
if mean_distance_for_other_cluster < distance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_compute_pc_metrics_multi_processing(small_sorting_analyzer, tmp_path):


if __name__ == "__main__":
from spikeinterface.metrics.tests.conftest import make_small_analyzer
from spikeinterface.metrics.conftest import make_small_analyzer

small_sorting_analyzer = make_small_analyzer()
test_compute_pc_metrics_multi_processing(small_sorting_analyzer)
Loading