Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flyvis/network/ensemble_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ def flash_response_index(
responses = self.flash_responses()
fris = flash_response_index(responses, radius=6)
if cell_types is not None:
requested_cell_types = cell_types
fris = fris.custom.where(cell_type=cell_types)
cell_types = fris.cell_type.values
kwargs.setdefault("sorted_type_list", requested_cell_types)
else:
cell_types = fris.cell_type.values
task_error = self.task_error()
Expand Down
54 changes: 54 additions & 0 deletions tests/test_ensemble_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from types import SimpleNamespace

import numpy as np

from flyvis.network import ensemble_view as ensemble_view_module
from flyvis.network.ensemble_view import EnsembleView


class _FakeFRI:
def __init__(self, values, cell_types, filtered_cell_types=None):
self.values = values
self.cell_type = SimpleNamespace(values=np.array(cell_types))
self.filtered_cell_types = filtered_cell_types
self.custom = SimpleNamespace(where=self._where)

def _where(self, cell_type):
if self.filtered_cell_types is not None:
return _FakeFRI(self.values, self.filtered_cell_types)
return _FakeFRI(self.values, list(cell_type))


class _DummyView:
def flash_responses(self):
return "responses"

def task_error(self):
return SimpleNamespace(values=np.array([0.5, 0.1, 0.2]))


def test_flash_response_index_aligns_labels_with_filtered_data(monkeypatch):
requested = ["Mi1", "Tm3", "CT1(M10)"]
filtered = ["CT1(M10)", "Mi1", "Tm3"]
fake_fris = _FakeFRI(np.ones((3, 2, 1)), requested, filtered_cell_types=filtered)

monkeypatch.setattr(
ensemble_view_module, "flash_response_index", lambda *_args, **_kwargs: fake_fris
)

captured = SimpleNamespace(fris=None, cell_types=None, sorted_type_list=None)

def _fake_plot_fris(fris, cell_types, **kwargs):
captured.fris = fris
captured.cell_types = list(cell_types)
captured.sorted_type_list = kwargs.get("sorted_type_list")
return "fig", "ax"

monkeypatch.setattr(ensemble_view_module, "plot_fris", _fake_plot_fris)

fig, ax = EnsembleView.flash_response_index(_DummyView(), cell_types=requested)

assert (fig, ax) == ("fig", "ax")
assert captured.fris.shape == (3, 2, 1)
assert captured.cell_types == filtered
assert captured.sorted_type_list == requested
Loading