diff --git a/flyvis/network/ensemble_view.py b/flyvis/network/ensemble_view.py index 80c1ef3..16cf0e3 100644 --- a/flyvis/network/ensemble_view.py +++ b/flyvis/network/ensemble_view.py @@ -237,7 +237,10 @@ def flash_response_index( responses = self.flash_responses() fris = flash_response_index(responses, radius=6) if cell_types is not None: + requested_cell_types = cell_types fris = fris.custom.where(cell_type=cell_types) + cell_types = fris.cell_type.values + kwargs.setdefault("sorted_type_list", requested_cell_types) else: cell_types = fris.cell_type.values task_error = self.task_error() diff --git a/tests/test_ensemble_view.py b/tests/test_ensemble_view.py new file mode 100644 index 0000000..08c3386 --- /dev/null +++ b/tests/test_ensemble_view.py @@ -0,0 +1,54 @@ +from types import SimpleNamespace + +import numpy as np + +from flyvis.network import ensemble_view as ensemble_view_module +from flyvis.network.ensemble_view import EnsembleView + + +class _FakeFRI: + def __init__(self, values, cell_types, filtered_cell_types=None): + self.values = values + self.cell_type = SimpleNamespace(values=np.array(cell_types)) + self.filtered_cell_types = filtered_cell_types + self.custom = SimpleNamespace(where=self._where) + + def _where(self, cell_type): + if self.filtered_cell_types is not None: + return _FakeFRI(self.values, self.filtered_cell_types) + return _FakeFRI(self.values, list(cell_type)) + + +class _DummyView: + def flash_responses(self): + return "responses" + + def task_error(self): + return SimpleNamespace(values=np.array([0.5, 0.1, 0.2])) + + +def test_flash_response_index_aligns_labels_with_filtered_data(monkeypatch): + requested = ["Mi1", "Tm3", "CT1(M10)"] + filtered = ["CT1(M10)", "Mi1", "Tm3"] + fake_fris = _FakeFRI(np.ones((3, 2, 1)), requested, filtered_cell_types=filtered) + + monkeypatch.setattr( + ensemble_view_module, "flash_response_index", lambda *_args, **_kwargs: fake_fris + ) + + captured = SimpleNamespace(fris=None, cell_types=None, sorted_type_list=None) + + def _fake_plot_fris(fris, cell_types, **kwargs): + captured.fris = fris + captured.cell_types = list(cell_types) + captured.sorted_type_list = kwargs.get("sorted_type_list") + return "fig", "ax" + + monkeypatch.setattr(ensemble_view_module, "plot_fris", _fake_plot_fris) + + fig, ax = EnsembleView.flash_response_index(_DummyView(), cell_types=requested) + + assert (fig, ax) == ("fig", "ax") + assert captured.fris.shape == (3, 2, 1) + assert captured.cell_types == filtered + assert captured.sorted_type_list == requested