From 561051882c02f90d87fc459937d4cc7ff344642d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jan 2026 16:58:35 +0100 Subject: [PATCH 1/9] Refactor aftifacts detection. --- src/spikeinterface/preprocessing/__init__.py | 1 + .../preprocessing/detect_artifacts.py | 202 ++++++++++++ .../preprocessing/preprocessing_classes.py | 4 +- .../preprocessing/silence_artifacts.py | 298 +++++------------- .../tests/test_detect_artifacts.py | 13 + .../tests/test_silence_artifacts.py | 16 +- 6 files changed, 309 insertions(+), 225 deletions(-) create mode 100644 src/spikeinterface/preprocessing/detect_artifacts.py create mode 100644 src/spikeinterface/preprocessing/tests/test_detect_artifacts.py diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index de25944bd2..d2d8674168 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -20,6 +20,7 @@ PreprocessingPipeline, ) +from .detect_artifacts import detect_artifact_periods, detect_period_artifacts_by_envelope # for snippets from .align_snippets import AlignSnippets from warnings import warn diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py new file mode 100644 index 0000000000..6cb22ac49f --- /dev/null +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import numpy as np + +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording +from spikeinterface.preprocessing.rectify import RectifyRecording +from spikeinterface.preprocessing.common_reference import CommonReferenceRecording +from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype +import numpy as np + + +artifact_dtype = [ + ("start_index", "int64"), + ("stop_index", "int64"), + ("segment_index", "int64"), +] + +extended_artifact_dtype = artifact_dtype + [ + # TODO +] + + +_internal_dtype = [ + ("sample_index", "int64"), + ("segment_index", "int64"), + ("front", "bool") +] + + +def detect_artifact_periods( + recording, + method="envelope", + method_kwargs=None, + job_kwargs=None, +): + """ + + """ + + if method_kwargs is None: + method_kwargs = dict() + + if method == "envelope": + artifacts, envelope = detect_period_artifacts_by_envelope(recording, **method_kwargs, job_kwargs=job_kwargs) + elif method == "saturation": + raise NotImplementedError("Soon") + + else: + raise ValueError("") + + return artifacts + + + +## detect_period_artifacts_saturation Zone + + + + +## detect_period_artifacts_by_envelope Zone + +class DetectThresholdCrossing(PeakDetector): + + name = "threshold_crossings" + preferred_mp_context = None + + def __init__( + self, + recording, + detect_threshold=5, + noise_levels=None, + seed=None, + noise_levels_kwargs=dict(), + ): + PeakDetector.__init__(self, recording, return_output=True) + if noise_levels is None: + random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + self.abs_thresholds = noise_levels * detect_threshold + self._dtype = np.dtype(_internal_dtype) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return self._dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + z = np.median(traces / self.abs_thresholds, 1) + threshold_mask = np.diff((z > 1) != 0, axis=0) + indices = np.flatnonzero(threshold_mask) + threshold_crossings = np.zeros(indices.size, dtype=self._dtype) + threshold_crossings["sample_index"] = indices + threshold_crossings["segment_index"] = segment_index + threshold_crossings["front"][::2] = True + threshold_crossings["front"][1::2] = False + return (threshold_crossings,) + + +def detect_period_artifacts_by_envelope( + recording, + detect_threshold=5, + # min_duration_ms=50, + freq_max=20.0, + seed=None, + job_kwargs=None, + random_slices_kwargs=None, +): + """ + Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of + a global envelope of the channels. + + Parameters + ---------- + recording : RecordingExtractor + The recording extractor to detect putative artifacts + detect_threshold : float, default: 5 + The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` + freq_max : float, default: 20 + The maximum frequency for the low pass filter used + seed : int | None, default: None + Random seed for `get_noise_levels`. + If none, `get_noise_levels` uses `seed=0`. + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function + + """ + + envelope = RectifyRecording(recording) + envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) + envelope = CommonReferenceRecording(envelope) + + from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ) + + # _, job_kwargs = split_job_kwargs(noise_levels_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + if random_slices_kwargs is None: + random_slices_kwargs = {} + else: + random_slices_kwargs = random_slices_kwargs.copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + + node0 = DetectThresholdCrossing( + recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + ) + + threshold_crossings = run_node_pipeline( + envelope, + [node0], + job_kwargs, + job_name="detect threshold crossings", + ) + + order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) + threshold_crossings = threshold_crossings[order] + + artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) + + + return artifacts, envelope + + +# tools + +def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): + + num_seg = recording.get_num_segments() + + final_artifacts = [] + for seg_index in range(num_seg): + mask = artifacts["segment_index"] == seg_index + sub_thr = artifacts[mask] + if len(sub_thr) > 0: + if not sub_thr["front"][0]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = 0 + local_thr["front"] = True + sub_thr = np.hstack((local_thr, sub_thr)) + if sub_thr["front"][-1]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = recording.get_num_samples(seg_index) + local_thr["front"] = False + sub_thr = np.hstack((sub_thr, local_thr)) + + local_artifact = np.zeros(sub_thr.size/2, dtype=artifact_dtype) + local_artifact["start_index"] = sub_thr["sample_index"][::2] + local_artifact["stop_index"] = sub_thr["sample_index"][1::2] + local_artifact["segment_index"] = seg_index + final_artifacts.append(local_artifact) + + if len(final_artifacts) > 0: + final_artifacts = np.concatenate(final_artifacts) + else: + final_artifacts = np.zeros(0, dtype=artifact_dtype) + return final_artifacts \ No newline at end of file diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index fe9d95c506..47839db7a0 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -50,7 +50,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed -from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts +# from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts _all_preprocesser_dict = { # filter stuff @@ -90,7 +90,7 @@ DirectionalDerivativeRecording: directional_derivative, AstypeRecording: astype, UnsignedToSignedRecording: unsigned_to_signed, - SilencedArtifactsRecording: silence_artifacts, + # SilencedArtifactsRecording: silence_artifacts, } # we control import in the preprocessing init by setting an __all__ diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index b1ae00b64c..8006342847 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -4,221 +4,89 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording -from spikeinterface.preprocessing.rectify import RectifyRecording -from spikeinterface.preprocessing.common_reference import CommonReferenceRecording -from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype import numpy as np -class DetectThresholdCrossing(PeakDetector): - - name = "threshold_crossings" - preferred_mp_context = None - - def __init__( - self, - recording, - detect_threshold=5, - noise_levels=None, - seed=None, - noise_levels_kwargs=dict(), - ): - PeakDetector.__init__(self, recording, return_output=True) - if noise_levels is None: - random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("front", "bool")]) - - def get_trace_margin(self): - return 0 - - def get_dtype(self): - return self._dtype - - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = np.median(traces / self.abs_thresholds, 1) - threshold_mask = np.diff((z > 1) != 0, axis=0) - indices = np.flatnonzero(threshold_mask) - threshold_crossings = np.zeros(indices.size, dtype=self._dtype) - threshold_crossings["sample_index"] = indices - threshold_crossings["front"][::2] = True - threshold_crossings["front"][1::2] = False - return (threshold_crossings,) - - -def detect_period_artifacts_by_envelope( - recording, - detect_threshold=5, - min_duration_ms=50, - freq_max=20.0, - seed=None, - noise_levels=None, - **noise_levels_kwargs, -): - """ - Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of - a global envelope of the channels. - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to detect putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels`. - If none, `get_noise_levels` uses `seed=0`. - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - """ - - envelope = RectifyRecording(recording) - envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) - envelope = CommonReferenceRecording(envelope) - - from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - ) - - _, job_kwargs = split_job_kwargs(noise_levels_kwargs) - job_kwargs = fix_job_kwargs(job_kwargs) - - node0 = DetectThresholdCrossing( - recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - threshold_crossings = run_node_pipeline( - recording, - [node0], - job_kwargs, - job_name="detect threshold crossings", - ) - - order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) - threshold_crossings = threshold_crossings[order] - - periods = [] - fs = recording.sampling_frequency - max_duration_samples = int(min_duration_ms * fs / 1000) - num_seg = recording.get_num_segments() - - for seg_index in range(num_seg): - sub_periods = [] - mask = threshold_crossings["segment_index"] == seg_index - sub_thr = threshold_crossings[mask] - if len(sub_thr) > 0: - local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) - if not sub_thr["front"][0]: - local_thr["sample_index"] = 0 - local_thr["front"] = True - sub_thr = np.hstack((local_thr, sub_thr)) - if sub_thr["front"][-1]: - local_thr["sample_index"] = recording.get_num_samples(seg_index) - local_thr["front"] = False - sub_thr = np.hstack((sub_thr, local_thr)) - - indices = np.flatnonzero(np.diff(sub_thr["front"])) - for i, j in zip(indices[:-1], indices[1:]): - if sub_thr["front"][i]: - start = sub_thr["sample_index"][i] - end = sub_thr["sample_index"][j] - if end - start > max_duration_samples: - sub_periods.append((start, end)) - - periods.append(sub_periods) - - return periods, envelope - - -class SilencedArtifactsRecording(SilencedPeriodsRecording): - """ - Silence user-defined periods from recording extractor traces. The code will construct - an enveloppe of the recording (as a low pass filtered version of the traces) and detect - threshold crossings to identify the periods to silence. The periods are then silenced either - on a per channel basis or across all channels by replacing the values by zeros or by - adding gaussian noise with the same variance as the one in the recordings - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to silence putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise", default: "zeros" - Determines what periods are replaced by. Can be one of the following: - - - "zeros": Artifacts are replaced by zeros. - - - "noise": The periods are filled with a gaussion noise that has the - same variance that the one in the recordings, on a per channel - basis - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - Returns - ------- - silenced_recording : SilencedArtifactsRecording - The recording extractor after silencing detected artifacts - """ - - _precomputable_kwarg_names = ["list_periods"] - - def __init__( - self, - recording, - detect_threshold=5, - verbose=False, - freq_max=20.0, - min_duration_ms=50, - mode="zeros", - noise_levels=None, - seed=None, - list_periods=None, - **noise_levels_kwargs, - ): - - if list_periods is None: - list_periods, _ = detect_period_artifacts_by_envelope( - recording, - detect_threshold=detect_threshold, - min_duration_ms=min_duration_ms, - freq_max=freq_max, - seed=seed, - noise_levels=noise_levels, - **noise_levels_kwargs, - ) - - if verbose: - for i, periods in enumerate(list_periods): - total_time = np.sum([end - start for start, end in periods]) - percentage = 100 * total_time / recording.get_num_samples(i) - print(f"{percentage}% of segment {i} has been flagged as artifactual") - - SilencedPeriodsRecording.__init__( - self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - -# function for API -silence_artifacts = define_function_handling_dict_from_class( - source_class=SilencedArtifactsRecording, name="silence_artifacts" -) +# class SilencedArtifactsRecording(SilencedPeriodsRecording): +# """ +# Silence user-defined periods from recording extractor traces. The code will construct +# an enveloppe of the recording (as a low pass filtered version of the traces) and detect +# threshold crossings to identify the periods to silence. The periods are then silenced either +# on a per channel basis or across all channels by replacing the values by zeros or by +# adding gaussian noise with the same variance as the one in the recordings + +# Parameters +# ---------- +# recording : RecordingExtractor +# The recording extractor to silence putative artifacts +# artifacts : np.array, None +# The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` +# freq_max : float, default: 20 +# The maximum frequency for the low pass filter used +# min_duration_ms : float, default: 50 +# The minimum duration for a threshold crossing to be considered as an artefact. +# noise_levels : array +# Noise levels if already computed +# seed : int | None, default: None +# Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. +# If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. +# mode : "zeros" | "noise", default: "zeros" +# Determines what periods are replaced by. Can be one of the following: + +# - "zeros": Artifacts are replaced by zeros. + +# - "noise": The periods are filled with a gaussion noise that has the +# same variance that the one in the recordings, on a per channel +# basis +# **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function + +# Returns +# ------- +# silenced_recording : SilencedArtifactsRecording +# The recording extractor after silencing detected artifacts +# """ + +# _precomputable_kwarg_names = ["artifacts"] + +# def __init__( +# self, +# recording, +# artifacts=None, +# detect_threshold=5, +# verbose=False, +# freq_max=20.0, +# min_duration_ms=50, +# mode="zeros", +# noise_levels=None, +# seed=None, +# list_periods=None, +# **noise_levels_kwargs, +# ): + +# if artifacts is None: +# from spikeinterface.preprocessing import detect_artifacts +# artifacts = detect_artifact_periods( +# recording, +# detect_threshold=detect_threshold, +# min_duration_ms=min_duration_ms, +# freq_max=freq_max, +# seed=seed, +# noise_levels=noise_levels, +# **noise_levels_kwargs, +# ) + +# if verbose: +# for i, periods in enumerate(artifacts): +# total_time = np.sum([end - start for start, end in periods]) +# percentage = 100 * total_time / recording.get_num_samples(i) +# print(f"{percentage}% of segment {i} has been flagged as artifactual") + +# SilencedPeriodsRecording.__init__( +# self, recording, artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs +# ) + + +# # function for API +# silence_artifacts = define_function_handling_dict_from_class( +# source_class=SilencedArtifactsRecording, name="silence_artifacts" +# ) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py new file mode 100644 index 0000000000..52e8d927f9 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -0,0 +1,13 @@ +from spikeinterface.core import generate_recording +from spikeinterface.preprocessing import detect_artifact_periods + + +def test_detect_artifact_periods(): + # one segment only + rec = generate_recording(durations=[10.0, 10]) + artifacts = detect_artifact_periods(rec, method="envelope", + method_kwargs=dict(detect_threshold=5, freq_max=5.0), + ) + +if __name__ == "__main__": + test_detect_artifact_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py index 2baa4bf1b3..ad70540f40 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py @@ -2,15 +2,15 @@ import numpy as np -from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import silence_artifacts +# from spikeinterface.core import generate_recording +# from spikeinterface.preprocessing import silence_artifacts -def test_silence_artifacts(): - # one segment only - rec = generate_recording(durations=[10.0, 10]) - new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) +# def test_silence_artifacts(): +# # one segment only +# rec = generate_recording(durations=[10.0, 10]) +# new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) -if __name__ == "__main__": - test_silence_artifacts() +# if __name__ == "__main__": +# test_silence_artifacts() From a71dade8b1e3a852fc8ca35df97960125dd7080c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Jan 2026 14:11:49 +0100 Subject: [PATCH 2/9] Progagate new periods dtype to SilencedPeriodsRecording with backward compatibility --- src/spikeinterface/core/node_pipeline.py | 4 +- .../preprocessing/detect_artifacts.py | 79 ++++----- .../preprocessing/silence_artifacts.py | 96 ----------- .../preprocessing/silence_periods.py | 151 ++++++++++++++---- .../tests/test_detect_artifacts.py | 3 +- .../tests/test_silence_artifacts.py | 16 -- ...est_silence.py => test_silence_periods.py} | 20 ++- 7 files changed, 166 insertions(+), 203 deletions(-) delete mode 100644 src/spikeinterface/preprocessing/silence_artifacts.py delete mode 100644 src/spikeinterface/preprocessing/tests/test_silence_artifacts.py rename src/spikeinterface/preprocessing/tests/{test_silence.py => test_silence_periods.py} (76%) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index f1efe7a035..10e4885606 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,7 +489,9 @@ def check_graph(nodes): Check that node list is orderd in a good (parents are before children) """ - node0 = nodes[0] + # Do not remove this, this is to remenber that in previous version the first node needed to be + # a detectot but not anymore + # node0 = nodes[0] # if not isinstance(node0, PeakSource): # raise ValueError( # "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2ff3f8a78f..2a1ea069f0 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -14,22 +14,17 @@ import numpy as np -# artifact_dtype = [ -# ("start_index", "int64"), -# ("stop_index", "int64"), -# ("segment_index", "int64"), -# ] artifact_dtype = base_period_dtype +# this will be extend with channel boundaries if needed # extended_artifact_dtype = artifact_dtype + [ # # TODO # ] - def detect_artifact_periods( recording, method="envelope", @@ -37,7 +32,11 @@ def detect_artifact_periods( job_kwargs=None, ): """ - + Detect artifacts with several possible methods: + * 'saturation' using detect_artifact_periods_by_envelope() + * 'envelope' using detect_saturation_periods() + + See sub methods for more information on parameters. """ if method_kwargs is None: @@ -48,7 +47,7 @@ def detect_artifact_periods( elif method == "saturation": artifact_periods = detect_saturation_periods(recording, **method_kwargs, job_kwargs=job_kwargs) else: - raise ValueError("") + raise ValueError(f"detect_artifact_periods() method='{method}' is not valid") return artifact_periods @@ -56,13 +55,10 @@ def detect_artifact_periods( ## detect_period_artifacts_saturation Zone - def _collapse_events(events): """ If events are detected at a chunk edge, they will be split in two. - This detects such cases and collapses them in a single record instead - :param events: - :return: + This detects such cases and collapses them in a single record instead. """ order = np.lexsort((events["start_sample_index"], events["segment_index"])) events = events[order] @@ -87,21 +83,24 @@ class _DetectSaturation(PipelineNode): def __init__( self, recording, - saturation_threshold_uV, # 1200 uV - voltage_per_sec_threshold, # 1e-8 V.s-1 + saturation_threshold_uV, + voltage_per_sec_threshold, proportion, - mute_window_samples, ): PipelineNode.__init__(self, recording, return_output=True) - self.gains = recording.get_channel_gains() - self.offsets = recording.get_channel_offsets() + gains = recording.get_channel_gains() + offsets = recording.get_channel_offsets() + num_chans = recording.get_num_channels() self.voltage_per_sec_threshold = voltage_per_sec_threshold - self.saturation_threshold_uV = saturation_threshold_uV + thresh = np.full((num_chans, ), saturation_threshold_uV) + # 0.98 is empirically determined as the true saturating point is + # slightly lower than the documented saturation point of the probe + self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 + self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion - self.mute_window_samples = mute_window_samples self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() @@ -114,16 +113,7 @@ def get_dtype(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # @olivier @joe we can avoid this by making - traces = traces * self.gains[np.newaxis, :] + self.offsets[np.newaxis, :] - - - # first computes the saturated samples - max_voltage = np.atleast_1d(self.saturation_threshold_uV)[:, np.newaxis] - - # 0.98 is empirically determined as the true saturating point is - # slightly lower than the documented saturation point of the probe - saturation = np.mean(np.abs(traces) > max_voltage * 0.98, axis=1) + saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) if self.voltage_per_sec_threshold is not None: fs = self.sampling_frequency @@ -138,7 +128,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): else: saturation = saturation > self.proportion - intervals = np.where(np.diff(saturation, prepend=False, append=False))[0] + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) n_events = len(intervals) // 2 # Number of saturation periods events = np.zeros(n_events, dtype=artifact_dtype) @@ -146,7 +136,6 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): events[i]["start_sample_index"] = start + start_frame events[i]["end_sample_index"] = stop + start_frame events[i]["segment_index"] = segment_index - # events[i]["method_id"] = "saturation_detection" return (events, ) @@ -154,9 +143,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV - voltage_per_sec_threshold, # 1e-8 V.s-1 + voltage_per_sec_threshold=None, # 1e-8 V.s-1 proportion=0.5, - mute_window_samples=7, job_kwargs=None, ): """ @@ -174,7 +162,7 @@ def detect_saturation_periods( The recording on which to detect the saturation events. saturation_threshold_uV : float The voltage saturation threshold in volts. This will depend on the recording - probe and amplifier gain settings. For NP1 the value of 1200 * 1e-6 is recommended (IBL). + probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). Note that NP2 probes are more difficult to saturate than NP1. voltage_per_sec_threshold : None | float The first-derivative threshold in volts per second. Periods of the data over which the change @@ -207,10 +195,9 @@ def detect_saturation_periods( saturation_threshold_uV=saturation_threshold_uV, voltage_per_sec_threshold=voltage_per_sec_threshold, proportion=proportion, - mute_window_samples=mute_window_samples, ) - saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation events") + saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts") return _collapse_events(saturation_periods) @@ -218,13 +205,7 @@ def detect_saturation_periods( ## detect_artifact_periods_by_envelope Zone -# _internal_dtype = [ -# ("sample_index", "int64"), -# ("segment_index", "int64"), -# ("front", "bool") -# ] - -class DetectThresholdCrossing(PeakDetector): +class _DetectThresholdCrossing(PeakDetector): name = "threshold_crossings" preferred_mp_context = None @@ -243,6 +224,7 @@ def __init__( random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) self.abs_thresholds = noise_levels * detect_threshold + # internal dtype self._dtype = np.dtype([ ("sample_index", "int64"), ("segment_index", "int64"), @@ -278,7 +260,7 @@ def detect_artifact_periods_by_envelope( random_slices_kwargs=None, ): """ - Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of + Function to detect putative artifact periods as threshold crossings of a global envelope of the channels. Parameters @@ -300,8 +282,6 @@ def detect_artifact_periods_by_envelope( envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) envelope = CommonReferenceRecording(envelope) - - # _, job_kwargs = split_job_kwargs(noise_levels_kwargs) job_kwargs = fix_job_kwargs(job_kwargs) if random_slices_kwargs is None: random_slices_kwargs = {} @@ -310,7 +290,7 @@ def detect_artifact_periods_by_envelope( random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - node0 = DetectThresholdCrossing( + node0 = _DetectThresholdCrossing( recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, ) @@ -318,7 +298,7 @@ def detect_artifact_periods_by_envelope( envelope, [node0], job_kwargs, - job_name="detect threshold crossings", + job_name="detect artifact on envelope", ) order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) @@ -326,12 +306,9 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - return artifacts, envelope -# tools - def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): num_seg = recording.get_num_segments() diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py deleted file mode 100644 index 241fe0f915..0000000000 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from spikeinterface.core.base import base_peak_dtype -from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector -from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording -import numpy as np - - -# class SilencedArtifactsRecording(SilencedPeriodsRecording): -# """ -# Silence user-defined periods from recording extractor traces. The code will construct -# an enveloppe of the recording (as a low pass filtered version of the traces) and detect -# threshold crossings to identify the periods to silence. The periods are then silenced either -# on a per channel basis or across all channels by replacing the values by zeros or by -# adding gaussian noise with the same variance as the one in the recordings - -# Parameters -# ---------- -# recording : RecordingExtractor -# The recording extractor to silence putative artifacts -# artifacts : np.array, None -# The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` -# freq_max : float, default: 20 -# The maximum frequency for the low pass filter used -# min_duration_ms : float, default: 50 -# The minimum duration for a threshold crossing to be considered as an artefact. -# noise_levels : array -# Noise levels if already computed -# seed : int | None, default: None -# Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. -# If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. -# mode : "zeros" | "noise", default: "zeros" -# Determines what periods are replaced by. Can be one of the following: - -# - "zeros": Artifacts are replaced by zeros. - -# - "noise": The periods are filled with a gaussion noise that has the -# same variance that the one in the recordings, on a per channel -# basis -# **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - -# Returns -# ------- -# silenced_recording : SilencedArtifactsRecording -# The recording extractor after silencing detected artifacts -# """ - -# _precomputable_kwarg_names = ["artifacts"] - -# def __init__( -# self, -# recording, -# artifacts=None, -# detect_threshold=5, -# verbose=False, -# freq_max=20.0, -# min_duration_ms=50, -# mode="zeros", -# noise_levels=None, -# seed=None, -# list_periods=None, -# **noise_levels_kwargs, -# ): - -# if artifacts is None: -# from spikeinterface.preprocessing import detect_artifacts -# artifacts = detect_artifact_periods( -# recording, -# detect_threshold=detect_threshold, -# min_duration_ms=min_duration_ms, -# freq_max=freq_max, -# seed=seed, -# noise_levels=noise_levels, -# **noise_levels_kwargs, -# ) - -# if verbose: -# for i, periods in enumerate(artifacts): -# total_time = np.sum([end - start for start, end in periods]) -# percentage = 100 * total_time / recording.get_num_samples(i) -# print(f"{percentage}% of segment {i} has been flagged as artifactual") - -# SilencedPeriodsRecording.__init__( -# self, recording, artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs -# ) - - -# # function for API -# silence_artifacts = define_function_handling_dict_from_class( -# source_class=SilencedArtifactsRecording, name="silence_artifacts" -# ) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c9b6e2abe4..040e1275be 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -8,6 +8,8 @@ from spikeinterface.core import get_noise_levels from spikeinterface.core.generate import NoiseGeneratorRecording from spikeinterface.core.job_tools import split_job_kwargs +from spikeinterface.core.base import base_period_dtype + class SilencedPeriodsRecording(BasePreprocessor): @@ -48,7 +50,9 @@ class SilencedPeriodsRecording(BasePreprocessor): def __init__( self, recording, - list_periods, + periods=None, + # this is keep for backward compatibility + list_periods=None, mode="zeros", noise_levels=None, seed=None, @@ -56,25 +60,27 @@ def __init__( ): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() - if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: - # when unique segment accept list instead of list of list/arrays - list_periods = [list_periods] - # some checks - assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" - assert isinstance(list_periods, list), "'list_periods' must be a list (one per segment)" - assert len(list_periods) == num_seg, "'list_periods' must have the same length as the number of segments" - assert all( - isinstance(list_periods[i], (list, np.ndarray)) for i in range(num_seg) - ), "Each element of 'list_periods' must be array-like" + # handle backward compatibility with previous version + if list_periods is not None: + assert periods is None + periods = _all_period_list_to_periods_vec(list_periods, num_seg) + else: + assert list_periods is None + if not isinstance(periods, np.ndarray): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + if periods.dtype.fields is None: + # this is the old format : list[list[int]] + periods = _all_period_list_to_periods_vec(periods, num_seg) - for periods in list_periods: - if len(periods) > 0: - assert np.all(np.diff(np.array(periods), axis=1) > 0), "t_stops should be larger than t_starts" - assert np.all( - periods[i][1] < periods[i + 1][0] for i in np.arange(len(periods) - 1) - ), "Intervals should not overlap" + # force order + order = np.lexsort((periods["start_sample_index"], periods["segment_index"])) + periods = periods[order] + _check_periods(periods, num_seg) + + # some checks + assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" if mode in ["noise"]: if noise_levels is None: @@ -98,18 +104,57 @@ def __init__( noise_generator = None BasePreprocessor.__init__(self, recording) + + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for seg_index, parent_segment in enumerate(recording._recording_segments): - periods = list_periods[seg_index] - periods = np.asarray(periods, dtype="int64") - periods = np.sort(periods, axis=0) - rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) + i0 = seg_limits[seg_index] + i1 = seg_limits[seg_index+1] + periods_in_seg = periods[i0:i1] + rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods_in_seg, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) self._kwargs = dict( - recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels + recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels ) +def _all_period_list_to_periods_vec(list_periods, num_seg): + if num_seg == 1: + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: + # when unique segment accept list instead of list of list/arrays + list_periods = [list_periods] + size = sum(len(p) for p in list_periods) + periods = np.zeros(size, dtype=base_period_dtype) + start = 0 + for i in range(num_seg): + periods_in_seg = list_periods[i] + stop = start + periods_in_seg.shape[0] + periods[start:stop]["segment_index"] = i + periods[start:stop]["start_sample_index"] = periods_in_seg[:, 0] + periods[start:stop]["end_sample_index"] = periods_in_seg[:, 1] + start = stop + return periods + +def _check_periods(periods, num_seg): + # check dtype + if any(col not in np.dtype(base_period_dtype).fields for col in periods.dtype.fields): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + # check non overlap and non negative + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) + for i in range(num_seg): + i0 = seg_limits[i] + i1 = seg_limits[i+1] + periods_in_seg = periods[i0:i1] + if periods_in_seg.size == 0: + continue + if len(periods) > 0: + if np.any(periods_in_seg["start_sample_index"] > periods_in_seg["end_sample_index"]): + raise ValueError("end_sample_index should be larger than start_sample_index") + if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): + raise ValueError("Intervals should not overlap") + + class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -120,18 +165,20 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) - traces = traces.copy() + if self.periods.size > 0: new_interval = np.array([start_frame, end_frame]) - lower_index = np.searchsorted(self.periods[:, 1], new_interval[0]) - upper_index = np.searchsorted(self.periods[:, 0], new_interval[1]) + + lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) + upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) if upper_index > lower_index: - periods_in_interval = self.periods[lower_index:upper_index] + traces = traces.copy() + periods_in_interval = self.periods[lower_index:upper_index] for period in periods_in_interval: - onset = max(0, period[0] - start_frame) - offset = min(period[1] - start_frame, end_frame) + onset = max(0, period["start_sample_index"] - start_frame) + offset = min(period["end_sample_index"] - start_frame, end_frame) if self.mode == "zeros": traces[onset:offset, :] = 0 @@ -143,8 +190,52 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces - # function for API silence_periods = define_function_handling_dict_from_class( source_class=SilencedPeriodsRecording, name="silence_periods" ) + + + +class DetectArtifactAndSilentPeriodsRecording(SilencedPeriodsRecording): + """ + Class doing artifact detection and lient at the same time. + + See SilencedPeriodsRecording and detect_artifact_periods for details. + """ + + _precomputable_kwarg_names = ["artifacts"] + + def __init__( + self, + recording, + detect_artifact_method="envelope", + detect_artifact_kwargs=dict(), + periods=None, + mode="zeros", + noise_levels=None, + seed=None, + **noise_levels_kwargs, + ): + + if artifacts is None: + from spikeinterface.preprocessing import detect_artifact_periods + artifacts = detect_artifact_periods( + recording, + method=detect_artifact_method, + method_kwargs=detect_artifact_kwargs, + job_kwargs=None, + ) + + SilencedPeriodsRecording.__init__( + self, recording, periods=artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs + ) + # note self._kwargs["periods"] is done by SilencedPeriodsRecording and so the computaion is done once + + + +# function for API +detect_artifacts_and_silent_periods = define_function_handling_dict_from_class( + source_class=DetectArtifactAndSilentPeriodsRecording, name="silence_artifacts" +) + diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index cfb32254f1..50003487d0 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -26,7 +26,8 @@ def test_detect_saturation_periods(): # cross a chunk boundary. Do not change without changing the below. sat_value = 1200 - data = np.random.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 * 1e-6 + rng = np.random.default_rng() + data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 * 1e-6 # Design the Butterworth filter sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py deleted file mode 100644 index ad70540f40..0000000000 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -import numpy as np - -# from spikeinterface.core import generate_recording -# from spikeinterface.preprocessing import silence_artifacts - - -# def test_silence_artifacts(): -# # one segment only -# rec = generate_recording(durations=[10.0, 10]) -# new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) - - -# if __name__ == "__main__": -# test_silence_artifacts() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py similarity index 76% rename from src/spikeinterface/preprocessing/tests/test_silence.py rename to src/spikeinterface/preprocessing/tests/test_silence_periods.py index e7aee1a84d..ffba9059a0 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -1,11 +1,12 @@ import pytest from spikeinterface.core import generate_recording - +from spikeinterface.core import get_noise_levels +from spikeinterface.core.base import base_period_dtype from spikeinterface.preprocessing import silence_periods -from spikeinterface.core import get_noise_levels + import numpy as np @@ -18,17 +19,20 @@ def test_silence(create_cache_folder): rec = generate_recording() - rec0 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="zeros", seed=2308) - rec0.save(verbose=False) + periods = np.array([(0, 0, 1000), (0, 5000, 6000)], dtype=base_period_dtype) + rec0 = silence_periods(rec, periods=periods, mode="zeros", seed=2308) + rec0.save(format="memory", verbose=False) traces_in0 = rec0.get_traces(segment_index=0, start_frame=0, end_frame=1000) - traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) - traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert np.all(traces_in0 == 0) + traces_half0 = rec0.get_traces(segment_index=0, start_frame=900, end_frame=1100) + assert np.all(traces_half0[:100] == 0) + traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) assert np.all(traces_in1 == 0) + traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert not np.all(traces_out0 == 0) - rec1 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="noise", seed=2308) - rec1 = rec1.save(folder=cache_folder / "rec_w_noise", verbose=False, overwrite=True) + rec1 = silence_periods(rec, periods=periods, mode="noise", seed=2308) + rec1 = rec1.save(format="memory", verbose=False, overwrite=True) noise_levels = get_noise_levels(rec, return_in_uV=False) traces_in0 = rec1.get_traces(segment_index=0, start_frame=0, end_frame=1000) traces_in1 = rec1.get_traces(segment_index=0, start_frame=5000, end_frame=6000) From 0d709e69b09b4b9e02315e5e55ee79623d8df14d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Jan 2026 14:16:15 +0100 Subject: [PATCH 3/9] oups --- .../preprocessing/tests/test_detect_artifacts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 50003487d0..b5d9a18a9b 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -1,7 +1,7 @@ from spikeinterface.core import generate_recording, NumpyRecording from spikeinterface.preprocessing import detect_artifact_periods, detect_saturation_periods import numpy as np -import scipy.signal + def test_detect_artifact_periods(): # one segment only @@ -13,6 +13,9 @@ def test_detect_artifact_periods(): def test_detect_saturation_periods(): + + import scipy.signal + """ TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity we have an extra sample after due to taking the diff on the final saturation mask From 7ab75c033e67710c85147d13112240aab0001f98 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Jan 2026 19:21:42 +0100 Subject: [PATCH 4/9] oups --- src/spikeinterface/preprocessing/detect_artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2a1ea069f0..3e42facdc5 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -291,7 +291,7 @@ def detect_artifact_periods_by_envelope( noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) node0 = _DetectThresholdCrossing( - recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + envelope, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, ) threshold_crossings = run_node_pipeline( From 7c7446ea076dd227f6888136ebb9f5ab3bd5af8f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 29 Jan 2026 15:40:42 +0000 Subject: [PATCH 5/9] Applying extras from other PR, adding voltage_per_sec_threshold. --- .../preprocessing/detect_artifacts.py | 34 +++++++++----- .../tests/test_detect_artifacts.py | 47 ++++++++++--------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 3e42facdc5..8720df03d4 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -75,7 +75,12 @@ def _collapse_events(events): class _DetectSaturation(PipelineNode): + """ + A recording node for parallelising saturation detection. + Run with `run_node_pipeline`, this computes saturation events + for a given chunk. See `detect_saturation()` for details. + """ name = "detect_saturation" preferred_mp_context = None _compute_has_extended_signature = True @@ -98,7 +103,8 @@ def __init__( # 0.98 is empirically determined as the true saturating point is # slightly lower than the documented saturation point of the probe self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 - + self.voltage_per_sec_threshold = (voltage_per_sec_threshold - offsets) / gains + self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion self._dtype = np.dtype(artifact_dtype) @@ -112,7 +118,10 @@ def get_dtype(self): return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - + """ + Compute saturation events for a given chunk of data. + See `detect_saturation()` for details. + """ saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) if self.voltage_per_sec_threshold is not None: @@ -144,7 +153,7 @@ def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV voltage_per_sec_threshold=None, # 1e-8 V.s-1 - proportion=0.5, + proportion=0.2, job_kwargs=None, ): """ @@ -170,24 +179,25 @@ def detect_saturation_periods( skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. - proportion : - mute_window_samples : - job_kwargs : + proportion : float + 0 < proportion <1 of channels above threshold to consider the sample as saturated + mute_window_samples : int + TODO: should we scale this based on the fs? + job_kwargs: dict + The classical job_kwargs most useful for NP1 can use ratio as a intuition for the value but dont do it in code Returns - ------- - +------- + collapsed_events : np.recarray + A numpy recarray holding information on each saturation event. Has the fields: + "start_sample_index", "stop_sample_index", "segment_index", "method_id" """ if job_kwargs: job_kwargs = {} - # if saturation_threshold_uV < 0.1: - # raise ValueError(f"The `saturation_threshold_uV` should be in microvolts. " - # f"Your value: {saturation_threshold_uV} is almost certainly in volts.") - job_kwargs = fix_job_kwargs(job_kwargs) node0 = _DetectSaturation( diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index b5d9a18a9b..9c923a5a7b 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -13,24 +13,32 @@ def test_detect_artifact_periods(): def test_detect_saturation_periods(): - - import scipy.signal - """ - TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity - we have an extra sample after due to taking the diff on the final saturation mask - this means we always take one sample before and one sample after the diff period, which is fine. + This tests the saturation detection method. First a mock recording is created with + saturation events. Events may be single-sample or a multi-sample period. We create a multi-segment + recording with the stop-sample of each event offset by one, so the segments are distinguishable. + + Saturation detection is performed on chunked data (we set to 30k sample chunks) and so injected + events are hard-coded in order to cross a chunk boundary to test this case. + + The saturation detection function tests both a) saturation threshold exceeded + and b) first derivative (velocity) threshold exceeded. Because the forward + derivative is taken, the sample before the first saturated sample is also flagged. + Also, because of the way the mask is computed in the function, the sample after the + last saturated sample is flagged. """ - # num_chans = 384 + import scipy.signal + num_chans = 32 sample_frequency = 30000 chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below job_kwargs = {"chunk_size": chunk_size} - # cross a chunk boundary. Do not change without changing the below. + # Generate some data in uV sat_value = 1200 + voltage_per_sec_threshold = 12 / sample_frequency rng = np.random.default_rng() - data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 * 1e-6 + data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 # Design the Butterworth filter sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") @@ -57,7 +65,7 @@ def test_detect_saturation_periods(): # this center the int16 around 0 and saturate on positive max_ = np.max(np.r_[data_seg_1.flatten(), data_seg_2.flatten()]) gain = max_ / 2**15 - offset = 0 + offset = 50 seg_1_int16 = np.clip( np.rint((data_seg_1 - offset) / gain), @@ -68,17 +76,12 @@ def test_detect_saturation_periods(): -32768, 32767 ).astype(np.int16) - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.plot(seg_1_int16[:, 0]) - # plt.show() - recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) recording.set_channel_gains(gain) recording.set_channel_offsets([offset] * num_chans) periods = detect_saturation_periods( - recording, saturation_threshold_uV=sat_value * 0.98, voltage_per_sec_threshold=1e-8, job_kwargs=job_kwargs + recording, saturation_threshold_uV=sat_value * 0.98, voltage_per_sec_threshold=voltage_per_sec_threshold, job_kwargs=job_kwargs ) seg_1_periods = periods[np.where(periods["segment_index"] == 0)] @@ -102,23 +105,23 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=1e-8, + voltage_per_sec_threshold=voltage_per_sec_threshold, job_kwargs=job_kwargs, ) assert periods["start_sample_index"][0] == 1000 assert periods["end_sample_index"][0] == 1001 - periods = detect_artifact_periods( + periods_entry_function = detect_artifact_periods( recording, method="saturation", method_kwargs=dict( saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=1e-8, + voltage_per_sec_threshold=voltage_per_sec_threshold, ), - job_kwargs=job_kwargs, - ) - + job_kwargs=job_kwargs, + ) + assert np.array_equal(periods, periods_entry_function) if __name__ == "__main__": test_detect_artifact_periods() From 86e6924cd37c44327895bd16c8a9232937cc918b Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 29 Jan 2026 15:41:16 +0000 Subject: [PATCH 6/9] Rename uV_per_sec_threshold. --- .../preprocessing/detect_artifacts.py | 16 ++++++++-------- .../preprocessing/tests/test_detect_artifacts.py | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 8720df03d4..6acad37901 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -89,7 +89,7 @@ def __init__( self, recording, saturation_threshold_uV, - voltage_per_sec_threshold, + uV_per_sec_threshold, proportion, ): PipelineNode.__init__(self, recording, return_output=True) @@ -98,12 +98,12 @@ def __init__( offsets = recording.get_channel_offsets() num_chans = recording.get_num_channels() - self.voltage_per_sec_threshold = voltage_per_sec_threshold + self.uV_per_sec_threshold = uV_per_sec_threshold thresh = np.full((num_chans, ), saturation_threshold_uV) # 0.98 is empirically determined as the true saturating point is # slightly lower than the documented saturation point of the probe self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 - self.voltage_per_sec_threshold = (voltage_per_sec_threshold - offsets) / gains + self.uV_per_sec_threshold = (uV_per_sec_threshold - offsets) / gains self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion @@ -124,10 +124,10 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): """ saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - if self.voltage_per_sec_threshold is not None: + if self.uV_per_sec_threshold is not None: fs = self.sampling_frequency # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.voltage_per_sec_threshold, axis=1) + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.uV_per_sec_threshold, axis=1) # Note this means the velocity is not checked for the last sample in the # check because we are taking the forward derivative n_diff_saturated = np.r_[n_diff_saturated, 0] @@ -152,7 +152,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV - voltage_per_sec_threshold=None, # 1e-8 V.s-1 + uV_per_sec_threshold=None, # 1e-8 V.s-1 proportion=0.2, job_kwargs=None, ): @@ -173,7 +173,7 @@ def detect_saturation_periods( The voltage saturation threshold in volts. This will depend on the recording probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). Note that NP2 probes are more difficult to saturate than NP1. - voltage_per_sec_threshold : None | float + uV_per_sec_threshold : None | float The first-derivative threshold in volts per second. Periods of the data over which the change in velocity is greater than this threshold will be detected as saturation events. Use `None` to skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be @@ -203,7 +203,7 @@ def detect_saturation_periods( node0 = _DetectSaturation( recording, saturation_threshold_uV=saturation_threshold_uV, - voltage_per_sec_threshold=voltage_per_sec_threshold, + uV_per_sec_threshold=uV_per_sec_threshold, proportion=proportion, ) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 9c923a5a7b..d968382421 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -36,7 +36,7 @@ def test_detect_saturation_periods(): # Generate some data in uV sat_value = 1200 - voltage_per_sec_threshold = 12 / sample_frequency + uV_per_sec_threshold = 12 / sample_frequency rng = np.random.default_rng() data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 @@ -81,7 +81,7 @@ def test_detect_saturation_periods(): recording.set_channel_offsets([offset] * num_chans) periods = detect_saturation_periods( - recording, saturation_threshold_uV=sat_value * 0.98, voltage_per_sec_threshold=voltage_per_sec_threshold, job_kwargs=job_kwargs + recording, saturation_threshold_uV=sat_value * 0.98, uV_per_sec_threshold=uV_per_sec_threshold, job_kwargs=job_kwargs ) seg_1_periods = periods[np.where(periods["segment_index"] == 0)] @@ -105,7 +105,7 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=voltage_per_sec_threshold, + uV_per_sec_threshold=uV_per_sec_threshold, job_kwargs=job_kwargs, ) assert periods["start_sample_index"][0] == 1000 @@ -116,7 +116,7 @@ def test_detect_saturation_periods(): method="saturation", method_kwargs=dict( saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=voltage_per_sec_threshold, + uV_per_sec_threshold=uV_per_sec_threshold, ), job_kwargs=job_kwargs, ) From 8b17890113c86336edba0b1d9ab175b0a5b4c204 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:42:17 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/preprocessing/__init__.py | 6 +- .../preprocessing/detect_artifacts.py | 123 +++++++++--------- .../preprocessing/preprocessing_classes.py | 1 + .../preprocessing/silence_periods.py | 27 ++-- .../tests/test_detect_artifacts.py | 39 +++--- .../tests/test_silence_periods.py | 2 - 7 files changed, 97 insertions(+), 103 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 10e4885606..a78082bc74 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,7 +489,7 @@ def check_graph(nodes): Check that node list is orderd in a good (parents are before children) """ - # Do not remove this, this is to remenber that in previous version the first node needed to be + # Do not remove this, this is to remenber that in previous version the first node needed to be # a detectot but not anymore # node0 = nodes[0] # if not isinstance(node0, PeakSource): diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index ab1adb6942..fd8d8fd787 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -20,11 +20,7 @@ PreprocessingPipeline, ) -from .detect_artifacts import ( - detect_artifact_periods, - detect_artifact_periods_by_envelope, - detect_saturation_periods -) +from .detect_artifacts import detect_artifact_periods, detect_artifact_periods_by_envelope, detect_saturation_periods # for snippets from .align_snippets import AlignSnippets diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 6acad37901..323a73f734 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -3,6 +3,7 @@ import numpy as np from spikeinterface.core.base import base_period_dtype + # from spikeinterface.core.core_tools import define_function_handling_dict_from_class # from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording from spikeinterface.preprocessing.rectify import RectifyRecording @@ -13,8 +14,6 @@ from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode import numpy as np - - artifact_dtype = base_period_dtype @@ -24,7 +23,6 @@ # ] - def detect_artifact_periods( recording, method="envelope", @@ -35,7 +33,7 @@ def detect_artifact_periods( Detect artifacts with several possible methods: * 'saturation' using detect_artifact_periods_by_envelope() * 'envelope' using detect_saturation_periods() - + See sub methods for more information on parameters. """ @@ -43,18 +41,20 @@ def detect_artifact_periods( method_kwargs = dict() if method == "envelope": - artifact_periods, envelope = detect_artifact_periods_by_envelope(recording, **method_kwargs, job_kwargs=job_kwargs) + artifact_periods, envelope = detect_artifact_periods_by_envelope( + recording, **method_kwargs, job_kwargs=job_kwargs + ) elif method == "saturation": artifact_periods = detect_saturation_periods(recording, **method_kwargs, job_kwargs=job_kwargs) else: raise ValueError(f"detect_artifact_periods() method='{method}' is not valid") - - return artifact_periods + return artifact_periods ## detect_period_artifacts_saturation Zone + def _collapse_events(events): """ If events are detected at a chunk edge, they will be split in two. @@ -81,6 +81,7 @@ class _DetectSaturation(PipelineNode): Run with `run_node_pipeline`, this computes saturation events for a given chunk. See `detect_saturation()` for details. """ + name = "detect_saturation" preferred_mp_context = None _compute_has_extended_signature = True @@ -99,16 +100,16 @@ def __init__( num_chans = recording.get_num_channels() self.uV_per_sec_threshold = uV_per_sec_threshold - thresh = np.full((num_chans, ), saturation_threshold_uV) + thresh = np.full((num_chans,), saturation_threshold_uV) # 0.98 is empirically determined as the true saturating point is - # slightly lower than the documented saturation point of the probe + # slightly lower than the documented saturation point of the probe self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 self.uV_per_sec_threshold = (uV_per_sec_threshold - offsets) / gains self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion self._dtype = np.dtype(artifact_dtype) - self.gain = recording.get_channel_gains() + self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() def get_trace_margin(self): @@ -146,7 +147,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): events[i]["end_sample_index"] = stop + start_frame events[i]["segment_index"] = segment_index - return (events, ) + return (events,) def detect_saturation_periods( @@ -157,43 +158,43 @@ def detect_saturation_periods( job_kwargs=None, ): """ - Detect amplifier saturation events (either single sample or multi-sample periods) in the data. - Saturation detection with this function should be applied to the raw data, before preprocessing. - However, saturation periods detected should be zeroed out after preprocessing has been performed. - - Saturation is detected by a voltage threshold, and optionally a derivative threshold that - flags periods of high velocity changes in the voltage. See _DetectSaturation.compute() - for details on the algorithm. - - Parameters - ---------- - recording : BaseRecording - The recording on which to detect the saturation events. - saturation_threshold_uV : float - The voltage saturation threshold in volts. This will depend on the recording - probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). - Note that NP2 probes are more difficult to saturate than NP1. - uV_per_sec_threshold : None | float - The first-derivative threshold in volts per second. Periods of the data over which the change - in velocity is greater than this threshold will be detected as saturation events. Use `None` to - skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be - empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. - - proportion : float - 0 < proportion <1 of channels above threshold to consider the sample as saturated - mute_window_samples : int - TODO: should we scale this based on the fs? - job_kwargs: dict - The classical job_kwargs - - most useful for NP1 - can use ratio as a intuition for the value but dont do it in code - - Returns -------- - collapsed_events : np.recarray - A numpy recarray holding information on each saturation event. Has the fields: - "start_sample_index", "stop_sample_index", "segment_index", "method_id" + Detect amplifier saturation events (either single sample or multi-sample periods) in the data. + Saturation detection with this function should be applied to the raw data, before preprocessing. + However, saturation periods detected should be zeroed out after preprocessing has been performed. + + Saturation is detected by a voltage threshold, and optionally a derivative threshold that + flags periods of high velocity changes in the voltage. See _DetectSaturation.compute() + for details on the algorithm. + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect the saturation events. + saturation_threshold_uV : float + The voltage saturation threshold in volts. This will depend on the recording + probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). + Note that NP2 probes are more difficult to saturate than NP1. + uV_per_sec_threshold : None | float + The first-derivative threshold in volts per second. Periods of the data over which the change + in velocity is greater than this threshold will be detected as saturation events. Use `None` to + skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be + empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. + + proportion : float + 0 < proportion <1 of channels above threshold to consider the sample as saturated + mute_window_samples : int + TODO: should we scale this based on the fs? + job_kwargs: dict + The classical job_kwargs + + most useful for NP1 + can use ratio as a intuition for the value but dont do it in code + + Returns + ------- + collapsed_events : np.recarray + A numpy recarray holding information on each saturation event. Has the fields: + "start_sample_index", "stop_sample_index", "segment_index", "method_id" """ if job_kwargs: job_kwargs = {} @@ -207,14 +208,16 @@ def detect_saturation_periods( proportion=proportion, ) - saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts") + saturation_periods = run_node_pipeline( + recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts" + ) return _collapse_events(saturation_periods) - ## detect_artifact_periods_by_envelope Zone + class _DetectThresholdCrossing(PeakDetector): name = "threshold_crossings" @@ -235,12 +238,7 @@ def __init__( noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) self.abs_thresholds = noise_levels * detect_threshold # internal dtype - self._dtype = np.dtype([ - ("sample_index", "int64"), - ("segment_index", "int64"), - ("front", "bool") - ] - ) + self._dtype = np.dtype([("sample_index", "int64"), ("segment_index", "int64"), ("front", "bool")]) def get_trace_margin(self): return 0 @@ -301,7 +299,10 @@ def detect_artifact_periods_by_envelope( noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) node0 = _DetectThresholdCrossing( - envelope, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + envelope, + detect_threshold=detect_threshold, + noise_levels=noise_levels, + seed=seed, ) threshold_crossings = run_node_pipeline( @@ -338,15 +339,15 @@ def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): local_thr["sample_index"] = recording.get_num_samples(seg_index) local_thr["front"] = False sub_thr = np.hstack((sub_thr, local_thr)) - - local_artifact = np.zeros(sub_thr.size/2, dtype=artifact_dtype) + + local_artifact = np.zeros(sub_thr.size / 2, dtype=artifact_dtype) local_artifact["start_index"] = sub_thr["sample_index"][::2] local_artifact["stop_index"] = sub_thr["sample_index"][1::2] local_artifact["segment_index"] = seg_index final_artifacts.append(local_artifact) - + if len(final_artifacts) > 0: final_artifacts = np.concatenate(final_artifacts) else: final_artifacts = np.zeros(0, dtype=artifact_dtype) - return final_artifacts \ No newline at end of file + return final_artifacts diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index 47839db7a0..ff07b5b3c6 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -50,6 +50,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed + # from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts _all_preprocesser_dict = { diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 040e1275be..a9a1ac06d3 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -11,7 +11,6 @@ from spikeinterface.core.base import base_period_dtype - class SilencedPeriodsRecording(BasePreprocessor): """ Silence user-defined periods from recording extractor traces. By default, @@ -104,18 +103,18 @@ def __init__( noise_generator = None BasePreprocessor.__init__(self, recording) - + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for seg_index, parent_segment in enumerate(recording._recording_segments): i0 = seg_limits[seg_index] - i1 = seg_limits[seg_index+1] + i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] - rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods_in_seg, mode, noise_generator, seg_index) + rec_segment = SilencedPeriodsRecordingSegment( + parent_segment, periods_in_seg, mode, noise_generator, seg_index + ) self.add_recording_segment(rec_segment) - self._kwargs = dict( - recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels - ) + self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -135,6 +134,7 @@ def _all_period_list_to_periods_vec(list_periods, num_seg): start = stop return periods + def _check_periods(periods, num_seg): # check dtype if any(col not in np.dtype(base_period_dtype).fields for col in periods.dtype.fields): @@ -144,14 +144,14 @@ def _check_periods(periods, num_seg): seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for i in range(num_seg): i0 = seg_limits[i] - i1 = seg_limits[i+1] + i1 = seg_limits[i + 1] periods_in_seg = periods[i0:i1] if periods_in_seg.size == 0: continue if len(periods) > 0: if np.any(periods_in_seg["start_sample_index"] > periods_in_seg["end_sample_index"]): raise ValueError("end_sample_index should be larger than start_sample_index") - if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): + if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): raise ValueError("Intervals should not overlap") @@ -165,10 +165,10 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) - + if self.periods.size > 0: new_interval = np.array([start_frame, end_frame]) - + lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) @@ -190,13 +190,13 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces + # function for API silence_periods = define_function_handling_dict_from_class( source_class=SilencedPeriodsRecording, name="silence_periods" ) - class DetectArtifactAndSilentPeriodsRecording(SilencedPeriodsRecording): """ Class doing artifact detection and lient at the same time. @@ -220,6 +220,7 @@ def __init__( if artifacts is None: from spikeinterface.preprocessing import detect_artifact_periods + artifacts = detect_artifact_periods( recording, method=detect_artifact_method, @@ -233,9 +234,7 @@ def __init__( # note self._kwargs["periods"] is done by SilencedPeriodsRecording and so the computaion is done once - # function for API detect_artifacts_and_silent_periods = define_function_handling_dict_from_class( source_class=DetectArtifactAndSilentPeriodsRecording, name="silence_artifacts" ) - diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index d968382421..c812f6a83c 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -6,10 +6,11 @@ def test_detect_artifact_periods(): # one segment only rec = generate_recording(durations=[10.0, 10]) - artifacts = detect_artifact_periods(rec, method="envelope", - method_kwargs=dict(detect_threshold=5, freq_max=5.0), - ) - + artifacts = detect_artifact_periods( + rec, + method="envelope", + method_kwargs=dict(detect_threshold=5, freq_max=5.0), + ) def test_detect_saturation_periods(): @@ -53,35 +54,32 @@ def test_detect_saturation_periods(): # exactly on the border, as it makes testing complex # This was checked manually and any future breaking change # on this function would be extremely unlikely only to break this case. - all_starts = np.array([0, 29950, 45123, 90005, 149500]) - all_stops = np.array([1001, 30011, 45126, 90006, 149999]) + all_starts = np.array([0, 29950, 45123, 90005, 149500]) + all_stops = np.array([1001, 30011, 45126, 90006, 149999]) second_seg_offset = 1 for start, stop in zip(all_starts, all_stops): - data_seg_1[start : stop, :] = sat_value + data_seg_1[start:stop, :] = sat_value # differentiate the second segment for testing purposes data_seg_2[start : stop + second_seg_offset, :] = sat_value # this center the int16 around 0 and saturate on positive max_ = np.max(np.r_[data_seg_1.flatten(), data_seg_2.flatten()]) - gain = max_ / 2**15 + gain = max_ / 2**15 offset = 50 - seg_1_int16 = np.clip( - np.rint((data_seg_1 - offset) / gain), - -32768, 32767 - ).astype(np.int16) - seg_2_int16 = np.clip( - np.rint((data_seg_2 - offset) / gain), - -32768, 32767 - ).astype(np.int16) + seg_1_int16 = np.clip(np.rint((data_seg_1 - offset) / gain), -32768, 32767).astype(np.int16) + seg_2_int16 = np.clip(np.rint((data_seg_2 - offset) / gain), -32768, 32767).astype(np.int16) recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) recording.set_channel_gains(gain) recording.set_channel_offsets([offset] * num_chans) periods = detect_saturation_periods( - recording, saturation_threshold_uV=sat_value * 0.98, uV_per_sec_threshold=uV_per_sec_threshold, job_kwargs=job_kwargs + recording, + saturation_threshold_uV=sat_value * 0.98, + uV_per_sec_threshold=uV_per_sec_threshold, + job_kwargs=job_kwargs, ) seg_1_periods = periods[np.where(periods["segment_index"] == 0)] @@ -115,14 +113,15 @@ def test_detect_saturation_periods(): recording, method="saturation", method_kwargs=dict( - saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_sec_threshold=uV_per_sec_threshold, - ), + saturation_threshold_uV=sat_value * (1 / 0.98), + uV_per_sec_threshold=uV_per_sec_threshold, + ), job_kwargs=job_kwargs, ) assert np.array_equal(periods, periods_entry_function) + if __name__ == "__main__": test_detect_artifact_periods() test_detect_saturation_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index ffba9059a0..44bd205f1b 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -6,8 +6,6 @@ from spikeinterface.preprocessing import silence_periods - - import numpy as np from pathlib import Path From 4d507adef33859a7fd2d8d06003764f0f1121821 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 6 Feb 2026 16:01:52 +0000 Subject: [PATCH 8/9] Updates testing converting to int. --- .../preprocessing/detect_artifacts.py | 26 +++++++------ .../tests/test_detect_artifacts.py | 37 +++++++++++++++---- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 323a73f734..695932d27c 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -90,28 +90,28 @@ def __init__( self, recording, saturation_threshold_uV, - uV_per_sec_threshold, + uV_per_ms_threshold, proportion, ): PipelineNode.__init__(self, recording, return_output=True) - gains = recording.get_channel_gains() - offsets = recording.get_channel_offsets() num_chans = recording.get_num_channels() - self.uV_per_sec_threshold = uV_per_sec_threshold + self.uV_per_ms_threshold = uV_per_ms_threshold thresh = np.full((num_chans,), saturation_threshold_uV) # 0.98 is empirically determined as the true saturating point is # slightly lower than the documented saturation point of the probe - self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 - self.uV_per_sec_threshold = (uV_per_sec_threshold - offsets) / gains - self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() + self.saturation_threshold_unscaled = (thresh - self.offset) / self.gain * 0.98 + + # do not apply offset when dealing with the derivative + self.uV_per_ms_threshold = (uV_per_ms_threshold * self.sampling_frequency / 1e3) / self.gain + def get_trace_margin(self): return 0 @@ -125,10 +125,12 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): """ saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - if self.uV_per_sec_threshold is not None: + if self.uV_per_ms_threshold is not None: fs = self.sampling_frequency # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.uV_per_sec_threshold, axis=1) + + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.uV_per_ms_threshold, axis=1) + # Note this means the velocity is not checked for the last sample in the # check because we are taking the forward derivative n_diff_saturated = np.r_[n_diff_saturated, 0] @@ -153,7 +155,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV - uV_per_sec_threshold=None, # 1e-8 V.s-1 + uV_per_ms_threshold=None, # 1e-8 V.s-1 proportion=0.2, job_kwargs=None, ): @@ -174,7 +176,7 @@ def detect_saturation_periods( The voltage saturation threshold in volts. This will depend on the recording probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). Note that NP2 probes are more difficult to saturate than NP1. - uV_per_sec_threshold : None | float + uV_per_ms_threshold : None | float The first-derivative threshold in volts per second. Periods of the data over which the change in velocity is greater than this threshold will be detected as saturation events. Use `None` to skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be @@ -204,7 +206,7 @@ def detect_saturation_periods( node0 = _DetectSaturation( recording, saturation_threshold_uV=saturation_threshold_uV, - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, proportion=proportion, ) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index c812f6a83c..aa4e8876e9 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -36,8 +36,8 @@ def test_detect_saturation_periods(): job_kwargs = {"chunk_size": chunk_size} # Generate some data in uV - sat_value = 1200 - uV_per_sec_threshold = 12 / sample_frequency + sat_value = 12 + uV_per_ms_threshold = 12 / sample_frequency / 1e3 rng = np.random.default_rng() data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 @@ -64,13 +64,34 @@ def test_detect_saturation_periods(): data_seg_2[start : stop + second_seg_offset, :] = sat_value # this center the int16 around 0 and saturate on positive - max_ = np.max(np.r_[data_seg_1.flatten(), data_seg_2.flatten()]) - gain = max_ / 2**15 - offset = 50 + combined = np.r_[data_seg_1.flatten(), data_seg_2.flatten()] + max_ = np.max(combined) + # min_ = np.min(combined) + gain = max_ / 2**15 # (max_ - min_) / 65535 + offset = 0 # min_ + 32768 * gain + + PLOT = True + if PLOT: + import matplotlib + import matplotlib.pyplot as plt + plt.plot(data_seg_1) + plt.title("data float") + plt.show() + plt.plot(np.diff(data_seg_1, axis=0)) + plt.title("diff float") + plt.show() seg_1_int16 = np.clip(np.rint((data_seg_1 - offset) / gain), -32768, 32767).astype(np.int16) seg_2_int16 = np.clip(np.rint((data_seg_2 - offset) / gain), -32768, 32767).astype(np.int16) + if PLOT: + plt.plot(seg_1_int16) + plt.title("data int") + plt.show() + plt.plot(np.diff(seg_1_int16, axis=0)) + plt.title("diff int") + plt.show() + recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) recording.set_channel_gains(gain) recording.set_channel_offsets([offset] * num_chans) @@ -78,7 +99,7 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * 0.98, - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, job_kwargs=job_kwargs, ) @@ -103,7 +124,7 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, job_kwargs=job_kwargs, ) assert periods["start_sample_index"][0] == 1000 @@ -114,7 +135,7 @@ def test_detect_saturation_periods(): method="saturation", method_kwargs=dict( saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, ), job_kwargs=job_kwargs, ) From 654616661b00dc69db96f5d305aa8289145df278 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 16:02:26 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/tests/test_detect_artifacts.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index aa4e8876e9..72350096b1 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -37,7 +37,7 @@ def test_detect_saturation_periods(): # Generate some data in uV sat_value = 12 - uV_per_ms_threshold = 12 / sample_frequency / 1e3 + uV_per_ms_threshold = 12 / sample_frequency / 1e3 rng = np.random.default_rng() data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 @@ -67,13 +67,14 @@ def test_detect_saturation_periods(): combined = np.r_[data_seg_1.flatten(), data_seg_2.flatten()] max_ = np.max(combined) # min_ = np.min(combined) - gain = max_ / 2**15 # (max_ - min_) / 65535 - offset = 0 # min_ + 32768 * gain + gain = max_ / 2**15 # (max_ - min_) / 65535 + offset = 0 # min_ + 32768 * gain PLOT = True if PLOT: import matplotlib import matplotlib.pyplot as plt + plt.plot(data_seg_1) plt.title("data float") plt.show()