diff --git a/main/como/combine_distributions.py b/main/como/combine_distributions.py index 5bb39d24..24f8889a 100644 --- a/main/como/combine_distributions.py +++ b/main/como/combine_distributions.py @@ -18,8 +18,7 @@ _OutputCombinedSourceFilepath, _SourceWeights, ) -from como.pipelines.identifier import convert -from como.utils import LogLevel, get_missing_gene_data, log_and_raise_error, num_columns +from como.pipelines.identifier import get_remaining_identifiers from como.utils import num_columns @@ -287,7 +286,7 @@ async def _begin_combining_distributions( matrix_subset = matrix_subset.set_index(keys=[GeneIdentifier.entrez_gene_id.value], drop=True) matrix_subset = matrix_subset.drop(columns=["gene_symbol", "ensembl_gene_id"], errors="ignore") elif isinstance(matrix, sc.AnnData): - conversion = convert(ids=matrix.var_names.tolist(), taxon=taxon) + conversion = get_remaining_identifiers(ids=matrix.var_names.tolist(), taxon=taxon) conversion.reset_index(drop=False, inplace=True) matrix: pd.DataFrame = matrix.to_df().T matrix.reset_index(inplace=True, drop=False, names=["gene_symbol"]) diff --git a/main/como/merge_xomics.py b/main/como/merge_xomics.py index 2cfdd1a1..564c2025 100644 --- a/main/como/merge_xomics.py +++ b/main/como/merge_xomics.py @@ -22,7 +22,7 @@ _SourceWeights, ) from como.project import Config -from como.utils import get_missing_gene_data, log_and_raise_error, read_file, return_placeholder_data, set_up_logging +from como.utils import get_missing_gene_data, read_file, return_placeholder_data, set_up_logging class _MergedHeaderNames: diff --git a/main/como/pipelines/identifier.py b/main/como/pipelines/identifier.py index 59fe7444..e09bd9dc 100644 --- a/main/como/pipelines/identifier.py +++ b/main/como/pipelines/identifier.py @@ -1,37 +1,56 @@ -from collections.abc import Sequence -from typing import Literal +from collections.abc import Iterable, Iterator, Sequence +from typing import Any, Literal, overload import pandas as pd from bioservices.mygeneinfo import MyGeneInfo +from tqdm import tqdm __all__ = [ - "convert", + "build_gene_info", + "contains_identical_gene_types", "determine_gene_type", + "get_remaining_identifiers", ] +T_IDS = int | str | Iterable[int] | Iterable[str] | Iterable[int | str] T_MG_SCOPE = Literal["entrezgene", "ensembl.gene", "symbol"] T_MG_TRANSLATE = Literal["entrez_gene_id", "ensembl_gene_id", "gene_symbol"] T_MG_RETURN = list[dict[T_MG_TRANSLATE, str]] -def _get_conversion(info: MyGeneInfo, values: list[str], scope: T_MG_SCOPE, fields: str, taxon: str) -> T_MG_RETURN: - value_str = ",".join(map(str, values)) - results = info.get_queries(query=value_str, dotfield=True, scopes=scope, fields=fields, species=taxon) - if not isinstance(results, list): - raise TypeError(f"Expected results to be a list, but got {type(results)}") - if not isinstance(results[0], dict): - raise TypeError(f"Expected each result to be a dict, but got {type(results[0])}") - - data: T_MG_RETURN = [] - for result in results: - ensembl = result.get("query" if scope == "ensembl.gene" else "ensembl.gene") - entrez = result.get("query" if scope == "entrezgene" else "entrezgene") - symbol = result.get("query" if scope == "symbol" else "symbol") - data.append({"ensembl_gene_id": ensembl, "entrez_gene_id": entrez, "gene_symbol": symbol}) +def _get_conversion(info: MyGeneInfo, values: T_IDS, taxon: str | int) -> list[dict[str, Any]]: + value_list = sorted(map(str, [values] if isinstance(values, (int, str)) else values)) + data_type = determine_gene_type(value_list) + if not all(v == data_type[value_list[0]] for v in data_type.values()): + raise ValueError("All items in ids must be of the same type (Entrez, Ensembl, or symbols).") + + chunk_size = 1000 + taxon_str = str(taxon) + scope: T_MG_SCOPE = next(iter(data_type.values())) + data = [] + chunks = range(0, len(value_list), chunk_size) + + for i in tqdm(chunks, desc=f"Getting info for '{scope}'"): + result = info.get_queries( + query=",".join(map(str, value_list[i : i + chunk_size])), + dotfield=True, + scopes=scope, + fields="ensembl.gene,entrezgene,symbol,genomic_pos.start,genomic_pos.end,taxid,notfound", + species=taxon_str, + ) + if isinstance(result, int) and result == 414: + raise ValueError( + f"Query too long. Reduce the number of IDs in each query chunk (current chunk size: {chunk_size})." + ) + if not isinstance(result, list): + raise TypeError(f"Expected results to be a list, but got {type(result)}") + if not isinstance(result[0], dict): + raise TypeError(f"Expected each result to be a dict, but got {type(result[0])}") + data.extend(result) return data -def convert(ids: int | str | Sequence[int] | Sequence[str] | Sequence[int | str], taxon: int | str, cache: bool = True): +def get_remaining_identifiers(ids: T_IDS, taxon: int | str, cache: bool = True): """Convert between genomic identifiers. This function will convert between the following components: @@ -46,33 +65,111 @@ def convert(ids: int | str | Sequence[int] | Sequence[str] | Sequence[int | str] :return: DataFrame with columns "entrez_gene_id", "ensembl_gene_id", and "gene_symbol" """ my_geneinfo = MyGeneInfo(cache=cache) - chunk_size = 1000 - id_list = list(map(str, [ids] if isinstance(ids, (int, str)) else ids)) - chunks = list(range(0, len(id_list), chunk_size)) + gene_data = _get_conversion(info=my_geneinfo, values=ids, taxon=taxon) + df = ( + pd.json_normalize(gene_data) + .rename( + columns={ + "ensembl.gene": "ensembl_gene_id", + "entrezgene": "entrez_gene_id", + "symbol": "gene_symbol", + "taxid": "taxon_id", + } + ) + .drop( + columns=["query", "_id", "_score", "genomic_pos.end", "genomic_pos.start"], + errors="ignore", + ) + ) + df = df[df["taxon_id"] == taxon] + df["taxon_id"] = df["taxon_id"].astype(int, copy=True) + + # BUG: For an unknown reason, some Ensembl IDs are actually Entrez IDs + # To filter these, two approaches can be done: + # 1) Remove rows where Ensembl IDs are integers + # 2) Remove rows where Ensembl IDs equal Entrez IDs + # We are selecting option 1 because it goes for the root cause: Ensembl IDs are not pure integers + mask = df["ensembl_gene_id"].astype(str).str.fullmatch(r"\d+").fillna(False) + df = df[ + (df["ensembl_gene_id"].astype("string").notna()) # remove NA values + & (~df["ensembl_gene_id"].astype("string").str.fullmatch(r"\d+")) # remove Entrez IDs + ] + return df + + +def _to_scalar(val) -> int: + """Calculate the distance between end (e) and start (s).""" + if isinstance(val, list): + return int(sum(val) / len(val)) if val else 0 # `if val` checks that the list contains items + if pd.isna(val): + return 0 + return int(val) + + +def build_gene_info(ids: T_IDS, taxon: int | str, cache: bool = True): + """Get genomic information from a given set of IDs. + + The input should be of the same type, otherwise this function will fail. + Expected types are: + - Ensembl Gene ID + - Entrez Gene ID + - Gene Symbol - data_type = determine_gene_type(id_list) - if not all(v == data_type[id_list[0]] for v in data_type.values()): - raise ValueError("All items in ids must be of the same type (Entrez, Ensembl, or symbols).") + The returned data frame will have the following columns: + - ensembl_gene_id + - entrez_gene_id + - gene_symbol + - size (distance between genomic end and start) - scope = next(iter(data_type.values())) - fields = ",".join({"ensembl.gene", "entrezgene", "symbol"} - {scope}) - taxon_str = str(taxon) - return pd.DataFrame( - [ - row - for i in chunks - for row in _get_conversion( - info=my_geneinfo, - values=id_list[i : i + chunk_size], - scope=scope, - fields=fields, - taxon=taxon_str, - ) - ] + :param ids: IDs to be converted + :param taxon: Taxonomic identifier + :param cache: Should local caching be used for queries + :return: pandas.DataFrame + """ + my_geneinfo = MyGeneInfo(cache=cache) + gene_data = _get_conversion(info=my_geneinfo, values=ids, taxon=taxon) + df = pd.json_normalize(gene_data).rename(columns={"taxid": "taxon_id"}) + df = df[df["taxon_id"] == taxon] + df["taxon_id"] = df["taxon_id"].astype(int, copy=True) + + df["size"] = df["genomic_pos.end"].fillna(0).map(_to_scalar) - df["genomic_pos.start"].fillna(0).map(_to_scalar) + df = ( + df[~(df["size"] == 0)] + .drop( + columns=[ + "query", + "_id", + "_score", + "genomic_pos.start", + "genomic_pos.end", + "notfound", + ], + inplace=False, + errors="ignore", + ) + .rename( + columns={ + "ensembl.gene": "ensembl_gene_id", + "entrezgene": "entrez_gene_id", + "symbol": "gene_symbol", + } + ) + .explode(column=["ensembl_gene_id"]) + .sort_values(by="ensembl_gene_id", inplace=False) ) + return df + + +@overload +def determine_gene_type(items: str, /) -> T_MG_SCOPE: ... + + +@overload +def determine_gene_type(items: Sequence[str], /) -> dict[str, T_MG_SCOPE]: ... -def determine_gene_type(items: str | list[str], /) -> dict[str, T_MG_SCOPE]: + +def determine_gene_type(items: str | Sequence[str], /) -> str | dict[str, T_MG_SCOPE]: """Determine the genomic data type. :param items: A string or list of strings representing gene identifiers. @@ -85,16 +182,29 @@ def determine_gene_type(items: str | list[str], /) -> dict[str, T_MG_SCOPE]: followed by a specific format (length greater than 11 and the last 11 characters are digits). - "gene_symbol": If the item does not match the above criteria, it is assumed to be a gene symbol. """ - items = [items] if isinstance(items, str) else items - - determine: dict[str, Literal["entrezgene", "ensembl.gene", "symbol"]] = {} - for i in items: - i_str = str(i).split(".")[0] if isinstance(i, float) else str(i) - if i_str.isdigit(): - determine[i_str] = "entrezgene" - elif i_str.startswith("ENS") and (len(i_str) > 11 and all(i.isdigit() for i in i_str[-11:])): - determine[i_str] = "ensembl.gene" + values = (items,) if isinstance(items, str) else items + result: dict[str, Literal["entrezgene", "ensembl.gene", "symbol"]] = {} + + for i in values: + s = str(i).partition(".")[0] # remove any transcripts that may exist + + if s.startswith("ENS") and len(s) > 11 and s[-11:].isdigit(): + result[s] = "ensembl.gene" + elif s.isdigit(): + result[s] = "entrezgene" else: - determine[i_str] = "symbol" + result[s] = "symbol" + + if isinstance(items, str): + return result[items] + return result - return determine + +def contains_identical_gene_types(values: dict[str, T_MG_SCOPE] | Sequence[T_MG_SCOPE]) -> bool: + """Check if all values in the input are identical. + + :param values: A dictionary mapping gene identifiers to their types or a sequence of gene types. + :return: True if all values are identical, False otherwise. + """ + data = values if not isinstance(values, dict) else list(values.values()) + return all(v == data[0] for v in data) diff --git a/main/como/rnaseq_gen.py b/main/como/rnaseq_gen.py index 3f8c00db..5a5588ae 100644 --- a/main/como/rnaseq_gen.py +++ b/main/como/rnaseq_gen.py @@ -24,7 +24,7 @@ from como.data_types import FilteringTechnique, LogLevel, RNAType from como.migrations import gene_info_migrations -from como.pipelines.identifier import convert +from como.pipelines.identifier import contains_identical_gene_types, determine_gene_type from como.project import Config from como.utils import read_file, set_up_logging diff --git a/main/como/rnaseq_preprocess.py b/main/como/rnaseq_preprocess.py index 24d5e8b6..ec1a4f81 100644 --- a/main/como/rnaseq_preprocess.py +++ b/main/como/rnaseq_preprocess.py @@ -18,7 +18,7 @@ from loguru import logger from como.data_types import LogLevel, RNAType -from como.pipelines.identifier import convert +from como.pipelines.identifier import build_gene_info, get_remaining_identifiers from como.utils import read_file, set_up_logging @@ -473,7 +473,7 @@ async def read_ensembl_gene_ids(file: Path) -> list[str]: if isinstance(data_, pd.DataFrame): return data_["ensembl_gene_id"].tolist() try: - conversion = convert(ids=data_.var_names.tolist(), taxon=taxon) + conversion = get_remaining_identifiers(ids=data_.var_names.tolist(), taxon=taxon) except json.JSONDecodeError as e: raise ValueError(f"Got a JSON decode error for file '{counts_matrix_filepaths}' ({e})") @@ -486,90 +486,8 @@ async def read_ensembl_gene_ids(file: Path) -> list[str]: "depending on the number of genes and your internet connection" ) - ensembl_ids: set[str] = set( - chain.from_iterable(await asyncio.gather(*[read_ensembl_gene_ids(f) for f in counts_matrix_filepaths])) - ) - gene_data: list[dict[str, str | int | list[str] | list[int] | None]] = await MyGene(cache=cache).query( - items=list(ensembl_ids), - taxon=taxon, - scopes="ensemblgene", - ) - - n = len(gene_data) - all_gene_symbols: list[str] = ["-"] * n - all_entrez_ids: list[str | int] = ["-"] * n - all_ensembl_ids: list[str] = ["-"] * n - all_sizes: list[int] = [-1] * n - - def _avg_pos(value: int | list[int] | None) -> int: - if value is None: - return 0 - if isinstance(value, list): - return int(sum(value) / len(value)) if value else 0 - return int(value) - - for i, data in enumerate(gene_data): - data: dict[str, str | int | list[str] | list[int] | None] - if "genomic_pos.start" not in data: - log_and_raise_error( - message="Unexpectedly missing key 'genomic_pos.start'", error=KeyError, level=LogLevel.WARNING - ) - if "genomic_pos.end" not in data: - log_and_raise_error( - message="Unexpectedly missing key 'genomic_pos.end'", error=KeyError, level=LogLevel.WARNING - ) - if "ensembl.gene" not in data: - log_and_raise_error( - message="Unexpectedly missing key 'ensembl.gene'", error=KeyError, level=LogLevel.WARNING - ) - - start = data["genomic_pos.start"] - end = data["genomic_pos.end"] - ensembl_id = data["ensembl.gene"] - - if not isinstance(start, int): - log_and_raise_error( - message=f"Unexpected type for 'genomic_pos.start': expected int, got {type(start)}", - error=TypeError, - level=LogLevel.WARNING, - ) - if not isinstance(end, int): - log_and_raise_error( - message=f"Unexpected type for 'genomic_pos.end': expected int, got {type(start)}", - error=TypeError, - level=LogLevel.WARNING, - ) - if not isinstance(ensembl_id, str): - log_and_raise_error( - message=f"Unexpected type for 'ensembl.gene': expected str, got {type(ensembl_id)}", - error=ValueError, - level=LogLevel.WARNING, - ) - - size = end - start - all_ensembl_ids[i] = ",".join(map(str, ensembl_id)) if isinstance(ensembl_id, list) else ensembl_id - all_gene_symbols[i] = str(data.get("symbol", "-")) - all_entrez_ids[i] = str(data.get("entrezgene", "-")) - all_sizes[i] = max(size, -1) # use `size` otherwise -1 - - gene_info: pd.DataFrame = pd.DataFrame( - { - "ensembl_gene_id": all_ensembl_ids, - "gene_symbol": all_gene_symbols, - "entrez_gene_id": all_entrez_ids, - "size": all_sizes, - } - ) - - # remove rows where every gene size value is -1 (not available) - gene_info = gene_info[~(gene_info == -1).all(axis=1)] - - gene_info["ensembl_gene_id"] = gene_info["ensembl_gene_id"].str.split(",") # extend lists into multiple rows - gene_info = gene_info.explode(column=["ensembl_gene_id"]) - # we would set `entrez_gene_id` to int here as well, but not all ensembl ids are mapped to entrez ids, - # and as a result, there are still "-" values in the entrez id column that cannot be converted to an integer - - gene_info = gene_info.sort_values(by="ensembl_gene_id") + ensembl_ids: set[str] = set(chain.from_iterable(read_ensembl_gene_ids(f) for f in counts_matrix_filepaths)) + gene_info = build_gene_info(ids=ensembl_ids, taxon=taxon, cache=cache) output_filepath.parent.mkdir(parents=True, exist_ok=True) gene_info.to_csv(output_filepath, index=False) logger.success(f"Gene Info file written at '{output_filepath}'") diff --git a/main/como/utils.py b/main/como/utils.py index 98daca16..11a34f16 100644 --- a/main/como/utils.py +++ b/main/como/utils.py @@ -13,12 +13,11 @@ from loguru import logger from como.data_types import LOG_FORMAT, Algorithm, LogLevel -from como.pipelines.identifier import convert +from como.pipelines.identifier import get_remaining_identifiers T = TypeVar("T") __all__ = [ "get_missing_gene_data", - "log_and_raise_error", "num_columns", "num_rows", "read_file", @@ -135,7 +134,7 @@ def get_missing_gene_data(values: Sequence[str] | pd.DataFrame | sc.AnnData, tax # second isinstance required for static type check to be happy # if isinstance(values, list) and not isinstance(values, pd.DataFrame): if isinstance(values, list): - return convert(ids=values, taxon=taxon_id) + return get_remaining_identifiers(ids=values, taxon=taxon_id) elif isinstance(values, pd.DataFrame): # raise error if duplicate column names exist if any(values.columns.duplicated(keep=False)): diff --git a/pyproject.toml b/pyproject.toml index ede17cb1..7a277114 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ requires-python = ">=3.11,<3.14" dependencies = [ "aioftp>=0.23.1", "anndata>=0.12.0", - "bioservices>=1.12.1", + "bioservices>=1.13.0", "cobamp@git+https://github.com/JoshLoecker/cobamp@master", "cobra>=0.28.0", "joypy>=0.2.6", @@ -28,7 +28,7 @@ dependencies = [ "statsmodels>=0.13.0; python_version < '3.12'", "statsmodels>=0.14.0; python_version >= '3.12'", "troppo@git+https://github.com/JoshLoecker/troppo@master", - "zfpkm>=1.0.3", + "zfpkm>=1.1.0", ] [project.optional-dependencies] diff --git a/ruff.toml b/ruff.toml index 691022d2..edc19396 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,3 +1,4 @@ +src = ["main/como"] line-length = 120 extend-include = ["docs/**/*.py", "tests/**/*.py", "**/*.ipynb"] @@ -6,7 +7,11 @@ quote-style = "double" docstring-code-format = true [lint] -extend-fixable = ["I001"] +extend-fixable = [ + "I001", + "RUF022", + "W293", +] # Linting rules: https://docs.astral.sh/ruff/rules/ unfixable = [ "F401", # warn about, but do not remove, unused imports @@ -45,6 +50,9 @@ ignore = [ "TRY003", # allow exception messages outside the `Exception` class ] +[lint.isort] +known-first-party = ["como"] + [lint.pydocstyle] convention = "google" diff --git a/uv.lock b/uv.lock index 9139aea2..09d79c73 100644 --- a/uv.lock +++ b/uv.lock @@ -181,7 +181,7 @@ wheels = [ [[package]] name = "bioservices" -version = "1.12.1" +version = "1.13.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "appdirs" }, @@ -201,9 +201,9 @@ dependencies = [ { name = "wrapt" }, { name = "xmltodict" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5e/51/1a8dc87ffc6c27e0a896b982da3ff1b483f5ef85a47ae1ffafbc6a479bb4/bioservices-1.12.1.tar.gz", hash = "sha256:0f31782ae50930d4ab82b43f98d1ca2cc9befaa699f3a7a6cba512a8f9e2cab0", size = 218752, upload-time = "2025-02-27T22:38:58.628Z" } +sdist = { url = "https://files.pythonhosted.org/packages/99/8c/b9a20b6656446daaea5f3085d5465aad8f536c95fd14ffd6fcdcdf410031/bioservices-1.13.0.tar.gz", hash = "sha256:24d30be650d37c1377b4fd452abead9ec5ee891dbdaa96a21543ab508d401386", size = 220384, upload-time = "2026-03-02T14:49:11.002Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/c9/656854139abbdc0a09d544bacc6fa3132245ebf9847dd45da9a8e416c5a5/bioservices-1.12.1-py3-none-any.whl", hash = "sha256:898033fe0158a0e31ea5ad93ff02f5000aa67ff9d06e9dc6477583d3f60ace80", size = 258032, upload-time = "2025-02-27T22:38:56.406Z" }, + { url = "https://files.pythonhosted.org/packages/27/49/93a413b3b70a96e525f11212d0ca7992c4c7b35c47da78bf227a52282e85/bioservices-1.13.0-py3-none-any.whl", hash = "sha256:aaafb5ee246bccd9323a21de12ea6a35163e01c1ed4777e45acbda565804fc32", size = 259727, upload-time = "2026-03-02T14:49:09.879Z" }, ] [[package]] @@ -504,7 +504,7 @@ dev = [ requires-dist = [ { name = "aioftp", specifier = ">=0.23.1" }, { name = "anndata", specifier = ">=0.12.0" }, - { name = "bioservices", specifier = ">=1.12.1" }, + { name = "bioservices", specifier = ">=1.13.0" }, { name = "cobamp", git = "https://github.com/JoshLoecker/cobamp?rev=master" }, { name = "cobra", specifier = ">=0.28.0" }, { name = "gurobipy", marker = "extra == 'gurobi'", specifier = "<14" }, @@ -526,7 +526,7 @@ requires-dist = [ { name = "statsmodels", marker = "python_full_version < '3.12'", specifier = ">=0.13.0" }, { name = "statsmodels", marker = "python_full_version >= '3.12'", specifier = ">=0.14.0" }, { name = "troppo", git = "https://github.com/JoshLoecker/troppo?rev=master" }, - { name = "zfpkm", specifier = ">=1.0.3" }, + { name = "zfpkm", specifier = ">=1.1.0" }, ] provides-extras = ["gurobi", "interactive"] @@ -3697,7 +3697,7 @@ wheels = [ [[package]] name = "zfpkm" -version = "1.0.3" +version = "1.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "loguru" }, @@ -3705,9 +3705,9 @@ dependencies = [ { name = "numpy" }, { name = "pandas" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e3/7f/ff714f85601cd66439f2beed0d740772509b32b8be5a8b01a53652248714/zfpkm-1.0.3.tar.gz", hash = "sha256:58830ea61e6adc0c75f28d5304885bd03a33a6e9e56aa693856cbb37e30a5046", size = 15410, upload-time = "2025-11-10T16:47:45.614Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/bf/43471c26e47fc38b5f748723e3bfa104f77ee0352f5b6a242f8b01786ae4/zfpkm-1.1.0.tar.gz", hash = "sha256:c159f78703d9d853c5f8cbbc590fb9381f097329487947d4eaf95c72ab309870", size = 15623, upload-time = "2026-03-06T16:22:29.929Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/11/f8/ef2baeaf2d15682d5d663c3f5165b63abad114fc0c4cd90b67b1ed0a6456/zfpkm-1.0.3-py3-none-any.whl", hash = "sha256:085007f97e75e50d686677ee28e3fceba5fc19958b35e9fbad3756ca2302a219", size = 17841, upload-time = "2025-11-10T16:47:44.805Z" }, + { url = "https://files.pythonhosted.org/packages/85/a1/10cf3c5268131d37a4a968a1f5f935a07e81cb78dbb83e8bedc4a835c048/zfpkm-1.1.0-py3-none-any.whl", hash = "sha256:765d1785a22729adeb89732da01ba47abcbc9597c17c587e42176398a758b7a3", size = 18103, upload-time = "2026-03-06T16:22:29.104Z" }, ] [[package]]