"""Differential expression plots."""
from __future__ import annotations
from pathlib import Path
from typing import Union
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
[docs]
def plot_volcano(
de_results: pd.DataFrame,
*,
highlight_genes: list[str] | None = None,
highlight_label: str = "Highlighted",
highlight_color: str = "gold",
padj_threshold: float = 0.05,
log2fc_threshold: float = 0.0,
label_top_n: int = 15,
output: Union[str, Path, None] = None,
title: str | None = None,
) -> plt.Figure:
"""Generate a volcano plot from DE results.
Significant genes (padj < threshold) are colored red (up) or blue (down).
An optional set of genes of interest can be highlighted with colored
markers on top, useful for visualizing pathway genes (e.g. IFN/JAK-STAT)
or custom gene lists.
Parameters
----------
de_results : pd.DataFrame
Must contain columns: ``gene``, ``log2FC``, ``pvalue``, ``padj``.
highlight_genes : list[str], optional
Genes to highlight with colored markers. Can be any gene list of
interest (pathway genes, custom markers, etc.).
highlight_label : str
Legend label for highlighted genes.
highlight_color : str
Color for highlighted gene markers.
padj_threshold : float
Significance threshold for coloring.
log2fc_threshold : float
Minimum absolute log2FC to color significant genes (default 0).
label_top_n : int
Number of top genes to label by p-value.
output : str or Path, optional
Save figure to this path.
title : str, optional
Figure title.
Returns
-------
matplotlib.figure.Figure
"""
fig, ax = plt.subplots(figsize=(10, 8))
neg_log10p = -np.log10(de_results["pvalue"].clip(lower=1e-10))
is_sig = de_results["padj"] < padj_threshold
is_up = de_results["log2FC"] > log2fc_threshold
is_down = de_results["log2FC"] < -log2fc_threshold
colors = np.where(
is_sig & is_up, "#d62728",
np.where(is_sig & is_down, "#1f77b4", "#888888"),
)
sizes = np.where(is_sig, 25, 10)
alphas = np.where(is_sig, 0.8, 0.4)
# Plot non-significant first, then significant on top
non_sig = ~is_sig
if non_sig.any():
ax.scatter(
de_results.loc[non_sig, "log2FC"], neg_log10p[non_sig],
c=colors[non_sig], s=sizes[non_sig], alpha=0.4, zorder=2,
)
if is_sig.any():
ax.scatter(
de_results.loc[is_sig, "log2FC"], neg_log10p[is_sig],
c=colors[is_sig], s=sizes[is_sig], alpha=0.8, zorder=3,
edgecolors="black", linewidth=0.3,
)
# Highlight gene set of interest
if highlight_genes:
mask = de_results["gene"].isin(highlight_genes)
hl = de_results[mask]
if len(hl) > 0:
hl_nlp = -np.log10(hl["pvalue"].clip(lower=1e-10))
ax.scatter(
hl["log2FC"], hl_nlp, c=highlight_color, s=45,
edgecolors="black", linewidth=0.5, zorder=5,
label=f"{highlight_label} ({len(hl)} genes)",
)
# Label highlighted genes that are near-significant or significant
for _, row in hl[hl["pvalue"] < 0.1].iterrows():
ax.annotate(
row["gene"],
(row["log2FC"], -np.log10(max(row["pvalue"], 1e-10))),
fontsize=7, alpha=0.9, color="darkgoldenrod",
fontweight="bold",
)
# Label top genes by p-value
labeled = set()
if highlight_genes:
labeled = set(de_results[de_results["gene"].isin(highlight_genes) & (de_results["pvalue"] < 0.1)]["gene"])
for _, row in de_results.head(label_top_n).iterrows():
if row["gene"] not in labeled:
ax.annotate(
row["gene"],
(row["log2FC"], -np.log10(max(row["pvalue"], 1e-10))),
fontsize=7, alpha=0.8,
)
ax.axhline(-np.log10(0.05), color="grey", linestyle="--", alpha=0.5,
label="p = 0.05")
ax.axvline(0, color="grey", linestyle="-", alpha=0.3)
ax.set_xlabel("log2 Fold Change")
ax.set_ylabel("-log10(p-value)")
if title:
ax.set_title(title)
else:
n_sig = is_sig.sum()
ax.set_title(f"Volcano plot ({n_sig} significant at padj < {padj_threshold})")
ax.legend(fontsize=9)
plt.tight_layout()
if output is not None:
fig.savefig(output)
return fig