diff --git a/src/spikeinterface/metrics/quality/pca_metrics.py b/src/spikeinterface/metrics/quality/pca_metrics.py index 79bb2e4196..e7e312aa4f 100644 --- a/src/spikeinterface/metrics/quality/pca_metrics.py +++ b/src/spikeinterface/metrics/quality/pca_metrics.py @@ -102,57 +102,27 @@ def _nn_one_unit(args): return unit_id, nn_hit_rate, nn_miss_rate -def _nearest_neighbor_metric_function(sorting_analyzer, unit_ids, tmp_data, job_kwargs, **metric_params): +def _nearest_neighbor_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): nn_result = namedtuple("NearestNeighborResult", ["nn_hit_rate", "nn_miss_rate"]) # Use pre-computed PCA data pca_data_per_unit = tmp_data["pca_data_per_unit"] - # Extract job parameters - n_jobs = job_kwargs.get("n_jobs", 1) - mp_context = job_kwargs.get("mp_context", None) - nn_hit_rate_dict = {} nn_miss_rate_dict = {} - if n_jobs == 1: - # Sequential processing - units_loop = unit_ids - - for unit_id in units_loop: - pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] - labels = pca_data_per_unit[unit_id]["labels"] - - try: - nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params) - except: - nn_hit_rate = np.nan - nn_miss_rate = np.nan - - nn_hit_rate_dict[unit_id] = nn_hit_rate - nn_miss_rate_dict[unit_id] = nn_miss_rate - else: - if mp_context is not None and platform.system() == "Windows": - assert mp_context != "fork", "'fork' mp_context not supported on Windows!" - elif mp_context == "fork" and platform.system() == "Darwin": - warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') - - # Prepare arguments - only pass pickle-able data - args_list = [] - for unit_id in unit_ids: - pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] - labels = pca_data_per_unit[unit_id]["labels"] - args_list.append((unit_id, pcs_flat, labels, metric_params)) + for unit_id in unit_ids: + pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] + labels = pca_data_per_unit[unit_id]["labels"] - with ProcessPoolExecutor( - max_workers=n_jobs, - mp_context=mp.get_context(mp_context) if mp_context else None, - ) as executor: - results = executor.map(_nn_one_unit, args_list) + try: + nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params) + except: + nn_hit_rate = np.nan + nn_miss_rate = np.nan - for unit_id, nn_hit_rate, nn_miss_rate in results: - nn_hit_rate_dict[unit_id] = nn_hit_rate - nn_miss_rate_dict[unit_id] = nn_miss_rate + nn_hit_rate_dict[unit_id] = nn_hit_rate + nn_miss_rate_dict[unit_id] = nn_miss_rate return nn_result(nn_hit_rate=nn_hit_rate_dict, nn_miss_rate=nn_miss_rate_dict) @@ -168,7 +138,6 @@ class NearestNeighbor(BaseMetric): } depend_on = ["principal_components"] needs_tmp_data = True - needs_job_kwargs = True def _nn_advanced_one_unit(args): @@ -392,10 +361,10 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id): import scipy.stats import scipy.spatial.distance - pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :] - pcs_for_other_units = all_pcs[all_labels != this_unit_id, :] + pcs_for_this_unit = all_pcs[all_labels == this_unit_id] + pcs_for_other_units = all_pcs[all_labels != this_unit_id] - mean_value = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0) + mean_value = np.mean(pcs_for_this_unit, 0, keepdims=True) try: VI = np.linalg.inv(np.cov(pcs_for_this_unit.T)) @@ -405,14 +374,14 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id): mahalanobis_other = np.sort(scipy.spatial.distance.cdist(mean_value, pcs_for_other_units, "mahalanobis", VI=VI)[0]) - mahalanobis_self = np.sort(scipy.spatial.distance.cdist(mean_value, pcs_for_this_unit, "mahalanobis", VI=VI)[0]) - # number of spikes - n = np.min([pcs_for_this_unit.shape[0], pcs_for_other_units.shape[0]]) + num_spikes_self = pcs_for_this_unit.shape[0] + num_spikes_other = pcs_for_other_units.shape[0] + n = min(num_spikes_self, num_spikes_other) if n >= 2: dof = pcs_for_this_unit.shape[1] # number of features - l_ratio = np.sum(1 - scipy.stats.chi2.cdf(pow(mahalanobis_other, 2), dof)) / mahalanobis_self.shape[0] + l_ratio = np.sum(1 - scipy.stats.chi2.cdf(pow(mahalanobis_other, 2), dof)) / num_spikes_self isolation_distance = pow(mahalanobis_other[n - 1], 2) # if math.isnan(l_ratio): # print("NaN detected", mahalanobis_other, VI) @@ -449,18 +418,17 @@ def d_prime_metric(all_pcs, all_labels, this_unit_id) -> float: X = all_pcs - y = np.zeros((X.shape[0],), dtype="bool") - y[all_labels == this_unit_id] = True + y = all_labels == this_unit_id lda = LinearDiscriminantAnalysis(n_components=1) X_flda = lda.fit_transform(X, y) - flda_this_cluster = X_flda[np.where(y)[0]] - flda_other_cluster = X_flda[np.where(np.invert(y))[0]] + flda_this_cluster = X_flda[y] + flda_other_cluster = X_flda[~y] d_prime = (np.mean(flda_this_cluster) - np.mean(flda_other_cluster)) / np.sqrt( - 0.5 * (np.std(flda_this_cluster) ** 2 + np.std(flda_other_cluster) ** 2) + (np.var(flda_this_cluster) + np.var(flda_other_cluster)) / 2 ) return d_prime @@ -516,22 +484,23 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n return 1.0, 0.0 this_unit = all_labels == this_unit_id - this_unit_pcs = all_pcs[this_unit, :] - other_units_pcs = all_pcs[np.invert(this_unit), :] + this_unit_pcs = all_pcs[this_unit] + other_units_pcs = all_pcs[~this_unit] X = np.concatenate((this_unit_pcs, other_units_pcs), 0) num_obs_this_unit = np.sum(this_unit) if ratio < 1: + # Subsample spikes inds = np.arange(0, X.shape[0] - 1, 1 / ratio).astype("int") X = X[inds, :] num_obs_this_unit = int(num_obs_this_unit * ratio) - nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="ball_tree").fit(X) - distances, indices = nbrs.kneighbors(X) + nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(X) + indices = nbrs.kneighbors(return_distance=False) # don't feed X so it won't return itself as neighbor - this_cluster_nearest = indices[:num_obs_this_unit, 1:].flatten() - other_cluster_nearest = indices[num_obs_this_unit:, 1:].flatten() + this_cluster_nearest = indices[:num_obs_this_unit].flatten() + other_cluster_nearest = indices[num_obs_this_unit:].flatten() hit_rate = np.mean(this_cluster_nearest < num_obs_this_unit) miss_rate = np.mean(other_cluster_nearest < num_obs_this_unit) @@ -968,8 +937,8 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): """ import scipy.spatial.distance - pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :] - centroid_for_this_unit = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0) + pcs_for_this_unit = all_pcs[all_labels == this_unit_id] + centroid_for_this_unit = np.mean(pcs_for_this_unit, 0, keepdims=True) distances_for_this_unit = scipy.spatial.distance.cdist(centroid_for_this_unit, pcs_for_this_unit) distance = np.inf @@ -977,8 +946,8 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): # if less than current minimum distance update for label in np.unique(all_labels): if label != this_unit_id: - pcs_for_other_cluster = all_pcs[all_labels == label, :] - centroid_for_other_cluster = np.expand_dims(np.mean(pcs_for_other_cluster, 0), 0) + pcs_for_other_cluster = all_pcs[all_labels == label] + centroid_for_other_cluster = np.mean(pcs_for_other_cluster, 0, keepdims=True) distances_for_other_cluster = scipy.spatial.distance.cdist(centroid_for_other_cluster, pcs_for_this_unit) mean_distance_for_other_cluster = np.mean(distances_for_other_cluster) if mean_distance_for_other_cluster < distance: diff --git a/src/spikeinterface/metrics/quality/tests/test_pca_metrics.py b/src/spikeinterface/metrics/quality/tests/test_pca_metrics.py index 8227ad5156..ffc77266fa 100644 --- a/src/spikeinterface/metrics/quality/tests/test_pca_metrics.py +++ b/src/spikeinterface/metrics/quality/tests/test_pca_metrics.py @@ -54,7 +54,7 @@ def test_compute_pc_metrics_multi_processing(small_sorting_analyzer, tmp_path): if __name__ == "__main__": - from spikeinterface.metrics.tests.conftest import make_small_analyzer + from spikeinterface.metrics.conftest import make_small_analyzer small_sorting_analyzer = make_small_analyzer() test_compute_pc_metrics_multi_processing(small_sorting_analyzer)