Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions flyvis/analysis/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def umap_embedding(
metric: str = "correlation",
n_epochs: int = 1500,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray, UMAP]:
) -> Tuple[np.ndarray, np.ndarray, Optional[UMAP]]:
"""
Perform UMAP embedding on input data.

Expand All @@ -456,9 +456,14 @@ def umap_embedding(

Returns:
A tuple containing:
- embedding: The UMAP embedding.
- mask: Boolean mask for valid samples.
- reducer: The fitted UMAP object.
- embedding: The UMAP embedding (n_samples, n_components). May be NaN
if insufficient data.
- mask: Boolean mask (length n_samples). When reducer is not None,
True indicates rows with nonzero variance that were also connected
in the UMAP graph. When reducer is None (insufficient data), True
indicates only rows with nonzero variance.
- reducer: The fitted UMAP object or None if fewer than 2 rows had
nonzero variance.

Raises:
ValueError: If n_components is too large relative to sample size.
Expand All @@ -481,10 +486,16 @@ def umap_embedding(
X = X.reshape(X.shape[0], -1)
logging.info("reshaped X from %s to %s", shape, X.shape)

embedding = np.ones([X.shape[0], n_components]) * np.nan
# umap doesn't like contant rows
n_samples = X.shape[0]
embedding = np.ones([n_samples, n_components]) * np.nan
# umap doesn't like constant rows
mask = ~np.isclose(X.std(axis=1), 0)
X = X[mask]
X_nonconst = X[mask]

# If fewer than 2 rows remain, skip UMAP and return embedding of NaNs.
if X_nonconst.shape[0] < 2:
return embedding, mask, None

reducer = UMAP(
n_neighbors=n_neighbors,
min_dist=min_dist,
Expand All @@ -495,7 +506,7 @@ def umap_embedding(
n_epochs=n_epochs,
**kwargs,
)
_embedding = reducer.fit_transform(X)
_embedding = reducer.fit_transform(X_nonconst)

# gaussian mixture doesn't like nans through disconnected vertices in umap
connected_vertices_mask = ~disconnected_vertices(reducer)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from flyvis.analysis.clustering import umap_embedding


def test_umap_embedding_single_nonzero_variance_row():
"""Test that umap_embedding handles the edge case where only one row has
nonzero variance (all others are constant). UMAP should not be fitted and
the function should return NaN embedding with None reducer."""
rng = np.random.default_rng(0)
# One row with variance, four constant rows
X = np.zeros((5, 10))
X[2] = rng.random(10)

embedding, mask, reducer = umap_embedding(X)

assert reducer is None
assert np.all(np.isnan(embedding))
# Only the one non-constant row should be True in the mask
expected_mask = np.array([False, False, True, False, False])
np.testing.assert_array_equal(mask, expected_mask)


def test_umap_embedding_all_zero_variance_rows():
"""Test that umap_embedding handles all-constant rows gracefully."""
X = np.ones((5, 10))

embedding, mask, reducer = umap_embedding(X)

assert reducer is None
assert np.all(np.isnan(embedding))
assert not np.any(mask)


def test_umap_embedding_returns_none_reducer_when_insufficient_data():
"""Test that reducer is None when fewer than 2 rows have nonzero variance."""
X = np.zeros((4, 8))
# Only one non-constant row
X[0] = np.arange(8, dtype=float)

embedding, mask, reducer = umap_embedding(X)

assert reducer is None
assert embedding.shape == (4, 2)
assert np.all(np.isnan(embedding))