Skip to content

Visualizations (Viz)

valency-anndata methods

valency_anndata.viz.schematic_diagram

schematic_diagram(
    adata: Optional[AnnData] = None,
    *,
    diff_from: Optional[AnnData] | Literal[False] = False,
    filename: Optional[str] = None,
)

Render a schematic diagram of an AnnData object, optionally highlighting structural differences relative to a snapshot.

This function supports two usage modes: render mode and context-manager mode.

1. Render mode

Render a diagram of adata immediately.

Examples

val.viz.schematic_diagram(adata)

val.viz.schematic_diagram(adata, diff_from=None)

adata_snapshot = adata.copy()
val.tools.some_transformation(adata, inplace=True)

val.viz.schematic_diagram(adata, diff_from=adata_snapshot)

Behavior

  • Visualizes adata structure (X, obs, var, layers, obsm).
  • If diff_from is provided:
    • Highlights additions and removals relative to diff_from.
  • If diff_from is None:
    • Highlights all entries as additions (diff from empty AnnData).
  • If diff_from is False:
    • No diff highlighting is applied.
  • The diagram is displayed inline (notebooks) or in a browser (script).
2. Context-manager mode

Capture a snapshot on entering a with block, rendering a diff on exit.

Examples

with val.viz.schematic_diagram(diff_from=adata):
    val.tools.some_transformation(adata, inplace=True)

Behavior

  • diff_from must be provided; adata must be omitted.
  • On entry, a snapshot of diff_from is recorded.
  • On exit, a diff diagram between the snapshot and current adata is rendered.
  • Exceptions inside the with block prevent rendering.

Parameters:

Name Type Description Default
adata Optional[AnnData]

The AnnData object to visualize (required in render mode, must be omitted in context-manager mode).

None
diff_from Optional[AnnData] | Literal[False]

Determines the snapshot to diff against: (must be AnnData in context-manager mode) - AnnData instance: highlights differences from the snapshot. - None: highlights all entries as additions (diff from empty). - False: disables diff highlighting.

False
filename Optional[str]

Optional filename hint used in cell output metadata. This is used when during generation of documentation website, when nbconvert extracts images from notebooks.

None

Returns:

Type Description
None

In render mode, the diagram is displayed; nothing is returned.

_SchematicDiagramContext

In context-manager mode, a context manager for automatic diff rendering.

Notes
  • Explicit diff rendering always takes precedence over context-manager snapshots.
  • Snapshots are stored internally to allow nested diff scopes.
  • This function does not mutate adata.
Source code in src/valency_anndata/viz/schematic_diagram/__init__.py
def schematic_diagram(
    adata: Optional[AnnData] = None,
    *,
    diff_from: Optional[AnnData] | Literal[False] = False,
    filename: Optional[str] = None,
):
    """
    Render a schematic diagram of an AnnData object, optionally highlighting
    structural differences relative to a snapshot.

    This function supports two usage modes: **render mode** and **context-manager mode**.

    1\\. Render mode
    --------------
    Render a diagram of `adata` immediately.

    **Examples**

    ```py
    val.viz.schematic_diagram(adata)
    ```
    <img src="../../notebooks-autogenerated/notebook-assets/viz--schematic-diagrams-diff-simple.svg" width="50%">

    ```py
    val.viz.schematic_diagram(adata, diff_from=None)
    ```
    <img src="../../notebooks-autogenerated/notebook-assets/viz--schematic-diagrams-diff-new.svg" width="50%">

    ```py
    adata_snapshot = adata.copy()
    val.tools.some_transformation(adata, inplace=True)

    val.viz.schematic_diagram(adata, diff_from=adata_snapshot)
    ```
    <img src="../../notebooks-autogenerated/notebook-assets/viz--schematic-diagrams-diff-from.svg" width="50%">

    **Behavior**

    - Visualizes `adata` structure (`X`, `obs`, `var`, `layers`, `obsm`).
    - If `diff_from` is provided:
        - Highlights additions and removals relative to `diff_from`.
    - If `diff_from` is `None`:
        - Highlights all entries as additions (diff from empty AnnData).
    - If `diff_from` is `False`:
        - No diff highlighting is applied.
    - The diagram is displayed inline (notebooks) or in a browser (script).

    2\\. Context-manager mode
    -----------------------
    Capture a snapshot on entering a `with` block, rendering a diff on exit.

    **Examples**

    ```py
    with val.viz.schematic_diagram(diff_from=adata):
        val.tools.some_transformation(adata, inplace=True)
    ```
    <img src="../../notebooks-autogenerated/notebook-assets/viz--schematic-diagrams-diff-context.svg" width="50%">

    **Behavior**

    - `diff_from` must be provided; `adata` must be omitted.
    - On entry, a snapshot of `diff_from` is recorded.
    - On exit, a diff diagram between the snapshot and current `adata` is rendered.
    - Exceptions inside the `with` block prevent rendering.

    Parameters
    ----------
    adata :
        The AnnData object to visualize (required in render mode, must be omitted in
        context-manager mode).
    diff_from :
        Determines the snapshot to diff against: (must be AnnData in context-manager mode)
        - `AnnData` instance: highlights differences from the snapshot.
        - `None`: highlights all entries as additions (diff from empty).
        - `False`: disables diff highlighting.
    filename :
        Optional filename hint used in cell output metadata. This is used when
        during generation of documentation website, when nbconvert extracts
        images from notebooks.

    Returns
    -------
    None
        In render mode, the diagram is displayed; nothing is returned.
    _SchematicDiagramContext
        In context-manager mode, a context manager for automatic diff rendering.

    Notes
    -----
    - Explicit diff rendering always takes precedence over context-manager snapshots.
    - Snapshots are stored internally to allow nested diff scopes.
    - This function does not mutate `adata`.
    """
    if adata is None:
        # ------------------
        # Context mode
        # ------------------
        if isinstance(diff_from, AnnData):
            return _SchematicDiagramContext(diff_from, filename=filename)
    else:
        # ------------------
        # Render mode (explicit or implicit diff)
        # ------------------
        if diff_from is False:
            base = None
        elif diff_from is None:
            base = AnnData()
        else:
            base = diff_from

        dwg = adata_structure_svg(adata, diff_from=base)
        _show_svg(dwg, filename=filename)
        return None

    raise TypeError("Invalid schematic_diagram() call")

valency_anndata.viz.voter_vignette_browser

voter_vignette_browser(adata: AnnData) -> None

Interactive browser for quickly surveying many voting timelines of random participants alongside statements they authored.

Parameters:

Name Type Description Default
adata AnnData

An AnnData object loaded from a Polis conversation.
(See Assumptions below)

required
Assumptions
  • Votes are stored in adata.uns["votes"] with columns:

    • voter-id
    • vote (-1, 0, 1)
    • timestamp (seconds since epoch)
  • Statements are stored in adata.var with columns:

    • participant_id_authored
    • created_date (milliseconds since epoch)
    • content
    • moderation_state (optional, -1/0/1)
Behavior
  • Renders a dropdown to select a user, with buttons for random voter or commenter.
  • Plots votes over time with colors (red/neutral/green).
  • Draws vertical bars for authored statements with moderation-state coloring.
  • Displays statements below the plot in submission order.
  • Warns if vote or statement timestamps appear out of expected ranges.

Examples:

adata = val.datasets.polis.load("https://pol.is/report/r29kkytnipymd3exbynkd", translate_to="en")

val.viz.voter_vignette_browser(adata)

Source code in src/valency_anndata/viz/_voter_vignette.py
def voter_vignette_browser(adata: AnnData) -> None:
    """
    Interactive browser for quickly surveying many voting timelines of random
    participants alongside statements they authored.

    Parameters
    ----------
    adata:
        An AnnData object loaded from a Polis conversation.<br/>
        (See Assumptions below)

    Assumptions
    -----------

    - Votes are stored in `adata.uns["votes"]` with columns:
        - `voter-id`
        - `vote` (-1, 0, 1)
        - `timestamp` (seconds since epoch)

    - Statements are stored in `adata.var` with columns:
        - `participant_id_authored`
        - `created_date` (milliseconds since epoch)
        - `content`
        - `moderation_state` (optional, -1/0/1)

    Behavior
    --------

    - Renders a dropdown to select a user, with buttons for random voter or commenter.
    - Plots votes over time with colors (red/neutral/green).
    - Draws vertical bars for authored statements with moderation-state coloring.
    - Displays statements below the plot in submission order.
    - Warns if vote or statement timestamps appear out of expected ranges.

    Examples
    --------

    ```py
    adata = val.datasets.polis.load("https://pol.is/report/r29kkytnipymd3exbynkd", translate_to="en")

    val.viz.voter_vignette_browser(adata)
    ```
    <img src="../../assets/documentation-examples/viz--voter-vignette-browser.png">
    """
    import random
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from ipywidgets import widgets
    from IPython.display import display, Markdown
    import warnings

    # -----------------------------
    # Prepare votes dataframe
    # -----------------------------
    votes_df = adata.uns["votes"].copy()
    votes_df["voter-id"] = votes_df["voter-id"].astype(str)

    # Heuristic check: votes should be seconds, median ~1e9–1e10
    votes_median = votes_df["timestamp"].median()
    if votes_median > 1e11:  # looks too large for seconds
        warnings.warn(
            f"Median timestamp in votes is {votes_median}, which seems too large. "
            "Expected seconds. If these are milliseconds, divide by 1000."
        )

    votes_df["timestamp"] = pd.to_datetime(votes_df["timestamp"], unit="s")

    # -----------------------------
    # Core plotting function
    # -----------------------------
    def plot_user_activity(user_id: str):
        user_id = str(user_id)

        # --- Votes ---
        user_votes = votes_df[votes_df["voter-id"] == user_id]
        n_votes = len(user_votes)

        if user_votes.empty:
            first_vote = last_vote = None
            delta = pd.Timedelta(0)
        else:
            first_vote = user_votes["timestamp"].min()
            last_vote = user_votes["timestamp"].max()
            delta = last_vote - first_vote

        # Adaptive duration string
        if delta < pd.Timedelta(hours=1):
            duration_str = f"{delta.total_seconds()/60:.1f} minutes"
        elif delta < pd.Timedelta(days=1):
            duration_str = f"{delta.total_seconds()/3600:.1f} hours"
        else:
            duration_str = f"{delta.days} days"

        plt.figure(figsize=(12, 4))

        if not user_votes.empty:
            vote_colors = {-1: "red", 0: "gold", 1: "green"}
            colors = user_votes["vote"].map(vote_colors)
            plt.scatter(user_votes["timestamp"], user_votes["vote"], c=colors, s=50)

        plt.yticks([-1, 0, 1], ["Disagree", "Pass", "Agree"])
        plt.xlabel("Time")
        plt.ylabel("Vote")

        # --- Statements ---
        statements = adata.var
        user_statements = statements[
            statements["participant_id_authored"].astype(str) == user_id
        ]
        n_statements = len(user_statements)

        if not user_statements.empty:
            created_ms = pd.to_numeric(
                user_statements["created_date"], errors="coerce"
            )

            # Heuristic check: statements in milliseconds
            statements_median = created_ms.median()
            if statements_median < 1e11 or statements_median > 1e14:
                warnings.warn(
                    f"Median created_date in statements is {statements_median}. "
                    "Expected milliseconds."
                )

            statement_times = pd.to_datetime(created_ms, unit="ms")

            # Map moderation_state for plotting vertical bars
            mod_colors = {1: "green", 0: "gray", -1: "red"}
            moderation_states = (
                user_statements.get("moderation_state", 0)
                .fillna(0)
                .astype(int)
            )

            for t, mod in zip(statement_times, moderation_states):
                plt.axvline(x=t, color=mod_colors.get(mod, "gray"), lw=2, alpha=0.7)

        # --- Legend proxies ---
        vote_proxy = plt.Line2D(
            [0], [0],
            marker="o",
            color="black",
            markersize=8,
            linestyle="None",
            label=f"Votes ({n_votes})"
        )
        statement_proxy = plt.Line2D(
            [0], [0],
            color="black",
            lw=2,
            label=f"Statements ({n_statements})"
        )
        plt.legend(handles=[vote_proxy, statement_proxy], loc="center left", bbox_to_anchor=(1, 0.5))

        # --- Title ---
        if not user_votes.empty:
            plt.title(
                f"User {user_id} activity | {first_vote} → {last_vote} ({duration_str})"
            )
        else:
            plt.title(f"User {user_id} activity | No votes")

        plt.ylim(-1.5, 1.5)
        plt.tight_layout()
        plt.show()

        # --- Statements text ---
        if not user_statements.empty:
            user_statements_sorted = (
                user_statements.assign(
                    created_dt=pd.to_datetime(
                        pd.to_numeric(user_statements["created_date"], errors="coerce"),
                        unit="ms"
                    )
                )
                .sort_values("created_dt")
            )

            md_text = f"**Statements by {user_id} in submission order:**\n\n"
            for t, s in zip(user_statements_sorted["created_dt"], user_statements_sorted["content"]):
                md_text += f"- {t}: {s}\n"
            display(Markdown(md_text))

    # -----------------------------
    # User selection widgets
    # -----------------------------
    all_voters = votes_df["voter-id"].unique()
    all_commenters = adata.var["participant_id_authored"].dropna().astype(str).unique()
    all_users = np.unique(np.concatenate([all_voters, all_commenters]))
    initial_user = random.choice(all_commenters.tolist())

    user_dropdown = widgets.Dropdown(
        options=sorted(all_users),
        value=initial_user,
        description="User ID:"
    )

    random_voter_btn = widgets.Button(description="Random voter")
    random_commenter_btn = widgets.Button(description="Random commenter")

    def pick_random_voter(_):
        user_dropdown.value = random.choice(all_voters)

    def pick_random_commenter(_):
        user_dropdown.value = random.choice(all_commenters)

    random_voter_btn.on_click(pick_random_voter)
    random_commenter_btn.on_click(pick_random_commenter)

    display(
        widgets.VBox([
            widgets.HBox([user_dropdown, random_voter_btn, random_commenter_btn]),
            widgets.interactive_output(plot_user_activity, {"user_id": user_dropdown})
        ])
    )

valency_anndata.viz.jscatter

jscatter(
    adata: AnnData,
    use_reps: list[str] = [],
    color: str | Iterable[str] | None = None,
    height: int = 640,
    dark_mode: bool = True,
    nrows: Optional[int] = None,
    ncols: Optional[int] = None,
    return_objs: bool = False,
) -> list[Scatter] | None

Interactive Jupyter-Scatter view showing one or more embeddings. [Lekschas et al., 2024]

A button is created for each projected representation, and clicking will animate points into that projection.

Passing multiple color keys will display mulitple linked views.

Parameters:

Name Type Description Default
adata AnnData

An AnnData object with some projected representations stored in .obsm.

required
use_reps list[str]

One or more keys for projected representations of the data stored in .obsm.

[]
color str | Iterable[str] | None

One or more keys in .obs for coloring each participant. Categorical values will use a discrete color map (okabeito), and anything else will use a continuous gradient (viridis).

None
height int

Pixel height of the scatter widget in output cell.

640
dark_mode bool

Whether to set the plot background dark.

True
nrows Optional[int]

Number of rows to display the scatter plots in.

None
ncols Optional[int]

Number of columns to display the scatter plots in.

None
return_objs bool

Whether to return the Scatter object(s).

False

Returns:

Name Type Description
scatters list[Scatter] | None

A list of Scatter instances.

Examples:

Plotting multiple representations in one view, colored with discrete categorical values.

val.viz.jscatter(
    adata,
    use_reps=["X_pca_polis", "X_localmap"],
    color="kmeans_polis",
)

Plotting mulitple .obs keys across multiple views, colored with continuous values.

val.viz.jscatter(
    adata,
    use_reps=["X_pca_polis", "X_pacmap"],
    color=["n_votes", "pct_agree", "pct_pass", "pct_disagree"],
    height=320,
)

Source code in src/valency_anndata/viz/_jupyter_scatter.py
def jscatter(
    adata: AnnData,
    use_reps: list[str] = [],
    color: str | Iterable[str] | None = None,
    height: int = 640,
    dark_mode: bool = True,
    nrows: Optional[int] = None,
    ncols: Optional[int] = None,
    return_objs: bool = False,
) -> list[JScatter] | None:
    """
    Interactive Jupyter-Scatter view showing one or more embeddings. [[Lekschas _et al._, 2024](https://doi.org/10.21105/joss.07059)]

    A button is created for each projected representation, and clicking will
    animate points into that projection.

    Passing multiple color keys will display mulitple linked views.

    Parameters
    ----------

    adata :
        An AnnData object with some projected representations stored in
        [`.obsm`][anndata.AnnData.obsm].
    use_reps :
        One or more keys for projected representations of the data stored in
        [`.obsm`][anndata.AnnData.obsm].
    color :
        One or more keys in [`.obs`][anndata.AnnData.obs] for coloring each participant.
        Categorical values will use a discrete color map
        ([`okabeito`](https://cmap-docs.readthedocs.io/en/latest/catalog/qualitative/okabeito:okabeito/)),
        and anything else will use a continuous gradient
        ([`viridis`](https://cmap-docs.readthedocs.io/en/latest/catalog/sequential/bids:viridis/)).
    height :
        Pixel height of the scatter widget in output cell.
    dark_mode :
        Whether to set the plot background dark.
    nrows :
        Number of rows to display the scatter plots in.
    ncols :
        Number of columns to display the scatter plots in.
    return_objs :
        Whether to return the Scatter object(s).

    Returns
    -------

    scatters :
        A list of [`Scatter`](https://jupyter-scatter.dev/api#scatter) instances.

    Examples
    --------

    Plotting multiple representations in one view, colored with discrete categorical values.

    ```py
    val.viz.jscatter(
        adata,
        use_reps=["X_pca_polis", "X_localmap"],
        color="kmeans_polis",
    )
    ```

    <img src="../../assets/documentation-examples/viz--jscatter--single.png">

    Plotting mulitple `.obs` keys across multiple views, colored with continuous values.

    ```py
    val.viz.jscatter(
        adata,
        use_reps=["X_pca_polis", "X_pacmap"],
        color=["n_votes", "pct_agree", "pct_pass", "pct_disagree"],
        height=320,
    )
    ```

    <img src="../../assets/documentation-examples/viz--jscatter--multi.png">
    """
    background = "#1E1E20" if dark_mode else None

    # ---- prepare projections ----
    projections = [
        (key, key.removeprefix("X_").split("_")[0])
        for key in use_reps
    ]

    if color is None:
        colors = []
    elif isinstance(color, str):
        colors = [color]
    else:
        colors = list(color)

    obs_cols = colors if colors else None

    df = obsm_to_df(
        adata,
        projections=projections,
        obs_cols=obs_cols,
    )

    # ---- create scatter(s) ----
    default_prefix = projections[0][1]

    scatters = []

    for c in colors or [None]:
        scatter = JScatter(
            data=df,
            x=f"{default_prefix}1",
            y=f"{default_prefix}2",
            height=height,
            zoom_on_selection=True,
        )

        if c is not None:
            scatter.color(by=c)

        scatter.background(background)

        scatters.append(scatter)

    # ---- projection toggle ----
    toggle = widgets.ToggleButtons(
        options=[
            (prefix.upper(), prefix)
            for _, prefix in projections
        ],
        value=default_prefix,
        description="Projection:",
    )

    def on_toggle(change):
        prefix = change["new"]
        for s in scatters:
            s.xy(f"{prefix}1", f"{prefix}2")

    toggle.observe(on_toggle, names="value")

    grid = compose(
        list(zip(scatters, colors)),
        sync_view=True,
        sync_selection=True,
        sync_hover=True,
        cols=ncols,
        rows=nrows,
        row_height=height,
    )

    display(toggle)
    display(grid)

    return scatters if return_objs else None

valency_anndata.viz.langevitour

langevitour(
    adata: AnnData,
    *,
    use_reps: Optional[Sequence[str]] = None,
    color: Optional[str] = None,
    scale: Optional[str] = None,
    initial_axes: Optional[list[str]] = None,
    point_size: int = 2,
    **kwargs,
)

Interactive Langevitour visualization over one or more representations. [Harrison, 2022]

Parameters:

Name Type Description Default
adata AnnData

AnnData object.

required
use_reps Optional[Sequence[str]]

Representations to include, X_foo for all, and X_bar[:10] for subset (the first 10).

e.g. ["X_pca[:10]", "X_umap"].

None
color Optional[str]

obs column for grouping / coloring.

None
scale Optional[str]

obs column for point scaling.

None
initial_axes Optional[list[str]]

Set up to 3 axes, initially locked in place along XYZ axes (these can be moved). Each must be specified with an exact index, not ranges.

e.g. ["X_umap[0]", "X_umap[1]"] or ["X_pca[0]", "X_pca[1]", "X_pca[2]"]

None
point_size int

Base point size.

2
**kwargs

Passed through to Langevitour. See R docs: https://logarithmic.net/langevitour/reference/langevitour.html

{}

Examples:

val.viz.langevitour(
    adata,
    use_reps=["X_umap", "X_pca[:10]"],
    color="leiden",
    initial_axes=["X_umap[0]", "X_umap[1]"],
)

Source code in src/valency_anndata/viz/_langevitour.py
def langevitour(
    adata: AnnData,
    *,
    use_reps: Optional[Sequence[str]] = None,
    color: Optional[str] = None,
    scale: Optional[str] = None,
    initial_axes: Optional[list[str]] = None,
    point_size: int = 2,
    **kwargs,
):
    """
    Interactive Langevitour visualization over one or more representations. [[Harrison, 2022](https://doi.org/10.32614/RJ-2023-046)]

    Parameters
    ----------
    adata
        AnnData object.
    use_reps
        Representations to include, `X_foo` for all, and `X_bar[:10]` for subset (the first 10).

        e.g. `["X_pca[:10]", "X_umap"]`.
    color
        obs column for grouping / coloring.
    scale
        obs column for point scaling.
    initial_axes
        Set up to 3 axes, initially locked in place along XYZ axes (these can be moved). Each must be specified with an exact index, not ranges.

        e.g. `["X_umap[0]", "X_umap[1]"]` or `["X_pca[0]", "X_pca[1]", "X_pca[2]"]`
    point_size
        Base point size.
    **kwargs
        Passed through to `Langevitour`.
        See R docs: [https://logarithmic.net/langevitour/reference/langevitour.html](https://logarithmic.net/langevitour/reference/langevitour.html#arguments)

    Examples
    --------

    ```py
    val.viz.langevitour(
        adata,
        use_reps=["X_umap", "X_pca[:10]"],
        color="leiden",
        initial_axes=["X_umap[0]", "X_umap[1]"],
    )
    ```
    <img src="../../assets/documentation-examples/viz--langevitour--axis-gradient.png">
    """
    import warnings

    with warnings.catch_warnings():
        # Prevent setuptools from showing a warning about
        # Langevitour using `import pkg_resources`.
        warnings.filterwarnings(
            "ignore",
            message="pkg_resources is deprecated as an API",
            category=UserWarning,
        )
        from langevitour import Langevitour

    X_df = resolve_use_reps(adata, use_reps)

    group = adata.obs[color].tolist() if color is not None else None

    state = {}
    if initial_axes:
        # default positions in Y, X, Z plane
        default_positions = [
            [0.85, 0],  # pseudo-X axis
            [0, 0.95],  # pseudo-Y axis
            [0.6, 0.6], # pseudo-Z axis
        ]

        labelPos = {}
        for i, rep_str in enumerate(initial_axes):
            if i >= 3:
                break  # only support 3 initial axes
            key, dim, _ = parse_rep(rep_str)
            if dim is None:
                dim = 0  # default to first dimension if not specified
            col_name = format_rep_column(key, dim + 1)
            labelPos[col_name] = default_positions[i]

        state["labelPos"] = labelPos

    if scale is None:
        s = X_df.std() * 4
        scale_factors = [s] if isinstance(s, (float, int)) else s.tolist()
    else:
        scale_factors = scale

    return Langevitour(
        X_df,
        group=group,
        scale=scale_factors,
        point_size=point_size,
        state=state,
        **kwargs,
    )

valency_anndata.viz.highly_variable_statements

highly_variable_statements(
    adata: AnnData,
    *,
    key: str = "highly_variable",
    log: bool = False,
    show: bool | None = None,
    save: str | None = None,
) -> None

Plot normalized and raw dispersions for statements identified as highly variable.

Analogous to scanpy.pl.highly_variable_genes for single-cell data. Creates a two-panel scatter plot showing normalized dispersion (left) and raw dispersion (right) against the binning variable used in val.preprocessing.highly_variable_statements. Highly variable statements are highlighted in black, others in grey.

Parameters:

Name Type Description Default
adata AnnData

AnnData object that has been processed with val.preprocessing.highly_variable_statements.

required
key str

Key in adata.var and adata.uns where highly variable results are stored. Must match the key_added parameter used in preprocessing. Default is "highly_variable".

'highly_variable'
log bool

If True, use log scale for both axes. Default is False.

False
show bool | None

If True, display the plot. If None, defaults to Scanpy's settings.autoshow.

None
save str | None

File path to save the figure. If provided, figure is saved instead of shown.

None

Examples:

import valency_anndata as val
adata = val.datasets.aufstehen()
val.preprocessing.highly_variable_statements(adata, n_top_statements=50)
val.viz.highly_variable_statements(adata)

Use log scale for better visibility:

val.viz.highly_variable_statements(adata, log=True)

Plot results from a custom key:

val.preprocessing.highly_variable_statements(
    adata,
    n_top_statements=100,
    key_added="highly_variable_top100"
)
val.viz.highly_variable_statements(adata, key="highly_variable_top100")
Source code in src/valency_anndata/viz/_highly_variable_statements.py
def highly_variable_statements(
    adata: AnnData,
    *,
    key: str = "highly_variable",
    log: bool = False,
    show: bool | None = None,
    save: str | None = None,
) -> None:
    """
    Plot normalized and raw dispersions for statements identified as highly variable.

    Analogous to [scanpy.pl.highly_variable_genes][] for single-cell data. Creates a
    two-panel scatter plot showing normalized dispersion (left) and raw dispersion (right)
    against the binning variable used in `val.preprocessing.highly_variable_statements`.
    Highly variable statements are highlighted in black, others in grey.

    Parameters
    ----------
    adata
        AnnData object that has been processed with
        `val.preprocessing.highly_variable_statements`.
    key
        Key in `adata.var` and `adata.uns` where highly variable results are stored.
        Must match the `key_added` parameter used in preprocessing. Default is "highly_variable".
    log
        If True, use log scale for both axes. Default is False.
    show
        If True, display the plot. If None, defaults to Scanpy's `settings.autoshow`.
    save
        File path to save the figure. If provided, figure is saved instead of shown.

    Examples
    --------
    ```py
    import valency_anndata as val
    adata = val.datasets.aufstehen()
    val.preprocessing.highly_variable_statements(adata, n_top_statements=50)
    val.viz.highly_variable_statements(adata)
    ```

    Use log scale for better visibility:

    ```py
    val.viz.highly_variable_statements(adata, log=True)
    ```

    Plot results from a custom key:

    ```py
    val.preprocessing.highly_variable_statements(
        adata,
        n_top_statements=100,
        key_added="highly_variable_top100"
    )
    val.viz.highly_variable_statements(adata, key="highly_variable_top100")
    ```
    """

    if key not in adata.uns:
        raise ValueError(
            f"No highly variable statement metadata found under key '{key}'. "
            f"Run `val.preprocessing.highly_variable_statements(adata, key_added='{key}')` first."
        )

    result = adata.var
    hv_meta = adata.uns[key]

    # Which statements are marked highly variable
    statement_subset = result[key].values

    # Means for x-axis (use the same column as `bin_by`)
    means = result[hv_meta.get("bin_by", "coverage")].values

    # dispersions (raw) and dispersions_norm (z-score within bins)
    dispersions = result["dispersions"].values
    dispersions_norm = result["dispersions_norm"].values

    # Setup figure
    size = rcParams["figure.figsize"]
    plt.figure(figsize=(2 * size[0], size[1]))
    plt.subplots_adjust(wspace=0.3)

    for idx, d in enumerate([dispersions_norm, dispersions]):
        plt.subplot(1, 2, idx + 1)
        for label, color, mask in zip(
            ["highly variable statements", "other statements"],
            ["black", "grey"],
            [statement_subset, ~statement_subset],
        ):
            x = means[mask]
            y = d[mask]
            plt.scatter(x, y, label=label, c=color, s=5)

        if log:
            plt.xscale("log")
            plt.yscale("log")
            y_min = np.nanmin(d)
            y_min = 0.95 * y_min if y_min > 0 else 1e-1
            plt.xlim(0.95 * np.nanmin(means), 1.05 * np.nanmax(means))
            plt.ylim(y_min, 1.05 * np.nanmax(d))

        if idx == 0:
            plt.legend()
        plt.xlabel(f"{hv_meta.get('bin_by', 'coverage')}")
        plt.ylabel(f"{'normalized ' if idx == 0 else ''}dispersion")

    # determine whether to show (default Scanpy behavior)
    show = settings.autoshow if show is None else show
    savefig_or_show("highly_variable_statements", show=show, save=save)

scanpy methods (inherited)

Note

These methods are simply quick convenience wrappers around methods in scanpy, a tool for single-cell gene expression. They will use terms like "cells", "genes" and "counts", but you can think of these as "participants", "statements" and "votes".

See scanpy.pl for more methods you can experiment with via the val.scanpy.pl namespace.

valency_anndata.viz.pca

pca(
    adata: AnnData,
    *,
    annotate_var_explained: bool = False,
    show: bool | None = None,
    return_fig: bool | None = None,
    save: bool | str | None = None,
    **kwargs,
) -> Figure | Axes | list[Axes] | None

Scatter plot in PCA coordinates.

Use the parameter annotate_var_explained to annotate the explained variance.

Parameters:

Name Type Description Default
annotate_var_explained bool
False

Returns:

Type Description
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.

Examples:

.. plot:: :context: close-figs

import scanpy as sc
adata = sc.datasets.pbmc3k_processed()
sc.pl.pca(adata)

Colour points by discrete variable (Louvain clusters).

.. plot:: :context: close-figs

sc.pl.pca(adata, color="louvain")

Colour points by gene expression.

.. plot:: :context: close-figs

sc.pl.pca(adata, color="CST3")

.. currentmodule:: scanpy

See Also

pp.pca

Source code in .venv/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py
@_wraps_plot_scatter
@_doc_params(
    adata_color_etc=doc_adata_color_etc,
    scatter_bulk=doc_scatter_embedding,
    show_save_ax=doc_show_save_ax,
)
def pca(
    adata: AnnData,
    *,
    annotate_var_explained: bool = False,
    show: bool | None = None,
    return_fig: bool | None = None,
    save: bool | str | None = None,
    **kwargs,
) -> Figure | Axes | list[Axes] | None:
    """Scatter plot in PCA coordinates.

    Use the parameter `annotate_var_explained` to annotate the explained variance.

    Parameters
    ----------
    {adata_color_etc}
    annotate_var_explained
    {scatter_bulk}
    {show_save_ax}

    Returns
    -------
    If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.

    Examples
    --------

    .. plot::
        :context: close-figs

        import scanpy as sc
        adata = sc.datasets.pbmc3k_processed()
        sc.pl.pca(adata)

    Colour points by discrete variable (Louvain clusters).

    .. plot::
        :context: close-figs

        sc.pl.pca(adata, color="louvain")

    Colour points by gene expression.

    .. plot::
        :context: close-figs

        sc.pl.pca(adata, color="CST3")

    .. currentmodule:: scanpy

    See Also
    --------
    pp.pca

    """
    if not annotate_var_explained:
        return embedding(
            adata, "pca", show=show, return_fig=return_fig, save=save, **kwargs
        )
    if "pca" not in adata.obsm and "X_pca" not in adata.obsm:
        msg = (
            f"Could not find entry in `obsm` for 'pca'.\n"
            f"Available keys are: {list(adata.obsm.keys())}."
        )
        raise KeyError(msg)

    label_dict = {
        f"PC{i + 1}": f"PC{i + 1} ({round(v * 100, 2)}%)"
        for i, v in enumerate(adata.uns["pca"]["variance_ratio"])
    }

    if return_fig is True:
        # edit axis labels in returned figure
        fig = embedding(adata, "pca", return_fig=return_fig, **kwargs)
        for ax in fig.axes:
            if xlabel := label_dict.get(ax.xaxis.get_label().get_text()):
                ax.set_xlabel(xlabel)
            if ylabel := label_dict.get(ax.yaxis.get_label().get_text()):
                ax.set_ylabel(ylabel)
        return fig

    # get the axs, edit the labels and apply show and save from user
    axs = embedding(adata, "pca", show=False, save=False, **kwargs)
    if isinstance(axs, list):
        for ax in axs:
            ax.set_xlabel(label_dict[ax.xaxis.get_label().get_text()])
            ax.set_ylabel(label_dict[ax.yaxis.get_label().get_text()])
    else:
        axs.set_xlabel(label_dict[axs.xaxis.get_label().get_text()])
        axs.set_ylabel(label_dict[axs.yaxis.get_label().get_text()])
    _utils.savefig_or_show("pca", show=show, save=save)
    show = settings.autoshow if show is None else show
    if show:
        return None
    return axs

valency_anndata.viz.umap

umap(
    adata: AnnData, **kwargs
) -> Figure | Axes | list[Axes] | None

Scatter plot in UMAP basis.

Returns:

Type Description
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.

Examples:

.. plot:: :context: close-figs

import scanpy as sc
adata = sc.datasets.pbmc68k_reduced()
sc.pl.umap(adata)

Colour points by discrete variable (Louvain clusters).

.. plot:: :context: close-figs

sc.pl.umap(adata, color="louvain")

Colour points by gene expression.

.. plot:: :context: close-figs

sc.pl.umap(adata, color="HES4")

Plot muliple umaps for different gene expressions.

.. plot:: :context: close-figs

sc.pl.umap(adata, color=["HES4", "TNFRSF4"])

.. currentmodule:: scanpy

See Also

tl.umap

Source code in .venv/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py
@_wraps_plot_scatter
@_doc_params(
    adata_color_etc=doc_adata_color_etc,
    edges_arrows=doc_edges_arrows,
    scatter_bulk=doc_scatter_embedding,
    show_save_ax=doc_show_save_ax,
)
def umap(adata: AnnData, **kwargs) -> Figure | Axes | list[Axes] | None:
    """Scatter plot in UMAP basis.

    Parameters
    ----------
    {adata_color_etc}
    {edges_arrows}
    {scatter_bulk}
    {show_save_ax}

    Returns
    -------
    If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.

    Examples
    --------

    .. plot::
        :context: close-figs

        import scanpy as sc
        adata = sc.datasets.pbmc68k_reduced()
        sc.pl.umap(adata)

    Colour points by discrete variable (Louvain clusters).

    .. plot::
        :context: close-figs

        sc.pl.umap(adata, color="louvain")

    Colour points by gene expression.

    .. plot::
        :context: close-figs

        sc.pl.umap(adata, color="HES4")

    Plot muliple umaps for different gene expressions.

    .. plot::
        :context: close-figs

        sc.pl.umap(adata, color=["HES4", "TNFRSF4"])

    .. currentmodule:: scanpy

    See Also
    --------
    tl.umap

    """
    return embedding(adata, "umap", **kwargs)

valency_anndata.viz.embedding

embedding(
    adata: AnnData,
    basis: str,
    *,
    color: str | Sequence[str] | None = None,
    mask_obs: NDArray[bool_] | str | None = None,
    gene_symbols: str | None = None,
    use_raw: bool | None = None,
    sort_order: bool = True,
    edges: bool = False,
    edges_width: float = 0.1,
    edges_color: str
    | Sequence[float]
    | Sequence[str] = "grey",
    neighbors_key: str | None = None,
    arrows: bool = False,
    arrows_kwds: Mapping[str, Any] | None = None,
    groups: str | Sequence[str] | None = None,
    components: str | Sequence[str] | None = None,
    dimensions: tuple[int, int]
    | Sequence[tuple[int, int]]
    | None = None,
    layer: str | None = None,
    projection: Literal["2d", "3d"] = "2d",
    scale_factor: float | None = None,
    color_map: Colormap | str | None = None,
    cmap: Colormap | str | None = None,
    palette: str | Sequence[str] | Cycler | None = None,
    na_color: ColorLike = "lightgray",
    na_in_legend: bool = True,
    size: float | Sequence[float] | None = None,
    frameon: bool | None = None,
    legend_fontsize: float | _FontSize | None = None,
    legend_fontweight: int | _FontWeight = "bold",
    legend_loc: _LegendLoc | None = "right margin",
    legend_fontoutline: int | None = None,
    colorbar_loc: str | None = "right",
    vmax: VBound | Sequence[VBound] | None = None,
    vmin: VBound | Sequence[VBound] | None = None,
    vcenter: VBound | Sequence[VBound] | None = None,
    norm: Normalize | Sequence[Normalize] | None = None,
    add_outline: bool | None = False,
    outline_width: tuple[float, float] = (0.3, 0.05),
    outline_color: tuple[str, str] = ("black", "white"),
    ncols: int = 4,
    hspace: float = 0.25,
    wspace: float | None = None,
    title: str | Sequence[str] | None = None,
    show: bool | None = None,
    save: bool | str | None = None,
    ax: Axes | None = None,
    return_fig: bool | None = None,
    marker: str | Sequence[str] = ".",
    **kwargs,
) -> Figure | Axes | list[Axes] | None

Scatter plot for user specified embedding basis (e.g. umap, pca, etc).

Parameters:

Name Type Description Default
basis str

Name of the obsm basis to use.

required

Returns:

Type Description
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
Source code in .venv/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
@_doc_params(
    adata_color_etc=doc_adata_color_etc,
    edges_arrows=doc_edges_arrows,
    scatter_bulk=doc_scatter_embedding,
    show_save_ax=doc_show_save_ax,
)
def embedding(  # noqa: PLR0912, PLR0913, PLR0915
    adata: AnnData,
    basis: str,
    *,
    color: str | Sequence[str] | None = None,
    mask_obs: NDArray[np.bool_] | str | None = None,
    gene_symbols: str | None = None,
    use_raw: bool | None = None,
    sort_order: bool = True,
    edges: bool = False,
    edges_width: float = 0.1,
    edges_color: str | Sequence[float] | Sequence[str] = "grey",
    neighbors_key: str | None = None,
    arrows: bool = False,
    arrows_kwds: Mapping[str, Any] | None = None,
    groups: str | Sequence[str] | None = None,
    components: str | Sequence[str] | None = None,
    dimensions: tuple[int, int] | Sequence[tuple[int, int]] | None = None,
    layer: str | None = None,
    projection: Literal["2d", "3d"] = "2d",
    scale_factor: float | None = None,
    color_map: Colormap | str | None = None,
    cmap: Colormap | str | None = None,
    palette: str | Sequence[str] | Cycler | None = None,
    na_color: ColorLike = "lightgray",
    na_in_legend: bool = True,
    size: float | Sequence[float] | None = None,
    frameon: bool | None = None,
    legend_fontsize: float | _FontSize | None = None,
    legend_fontweight: int | _FontWeight = "bold",
    legend_loc: _LegendLoc | None = "right margin",
    legend_fontoutline: int | None = None,
    colorbar_loc: str | None = "right",
    vmax: VBound | Sequence[VBound] | None = None,
    vmin: VBound | Sequence[VBound] | None = None,
    vcenter: VBound | Sequence[VBound] | None = None,
    norm: Normalize | Sequence[Normalize] | None = None,
    add_outline: bool | None = False,
    outline_width: tuple[float, float] = (0.3, 0.05),
    outline_color: tuple[str, str] = ("black", "white"),
    ncols: int = 4,
    hspace: float = 0.25,
    wspace: float | None = None,
    title: str | Sequence[str] | None = None,
    show: bool | None = None,
    save: bool | str | None = None,
    ax: Axes | None = None,
    return_fig: bool | None = None,
    marker: str | Sequence[str] = ".",
    **kwargs,
) -> Figure | Axes | list[Axes] | None:
    """Scatter plot for user specified embedding basis (e.g. umap, pca, etc).

    Parameters
    ----------
    basis
        Name of the `obsm` basis to use.
    {adata_color_etc}
    {edges_arrows}
    {scatter_bulk}
    {show_save_ax}

    Returns
    -------
    If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.

    """
    #####################
    # Argument handling #
    #####################

    check_projection(projection)
    sanitize_anndata(adata)

    basis_values = _get_basis(adata, basis)
    dimensions = _components_to_dimensions(
        components, dimensions, projection=projection, total_dims=basis_values.shape[1]
    )
    args_3d = dict(projection="3d") if projection == "3d" else {}

    # Checking the mask format and if used together with groups
    if groups is not None and mask_obs is not None:
        msg = "Groups and mask arguments are incompatible."
        raise ValueError(msg)
    mask_obs = _check_mask(adata, mask_obs, "obs")

    # Figure out if we're using raw
    if use_raw is None:
        # check if adata.raw is set
        use_raw = layer is None and adata.raw is not None
    if use_raw and layer is not None:
        msg = (
            "Cannot use both a layer and the raw representation. "
            f"Was passed: {use_raw=!r}, {layer=!r}."
        )
        raise ValueError(msg)
    if use_raw and adata.raw is None:
        msg = (
            "`use_raw` is set to True but AnnData object does not have raw. "
            "Please check."
        )
        raise ValueError(msg)

    if isinstance(groups, str):
        groups = [groups]

    # Color map
    if color_map is not None:
        if cmap is not None:
            msg = "Cannot specify both `color_map` and `cmap`."
            raise ValueError(msg)
        else:
            cmap = color_map
    cmap = copy(colormaps.get_cmap(cmap))
    cmap.set_bad(na_color)
    # Prevents warnings during legend creation
    na_color = colors.to_hex(na_color, keep_alpha=True)

    # by default turn off edge color. Otherwise, for
    # very small sizes the edge will not reduce its size
    # (https://github.com/scverse/scanpy/issues/293)
    kwargs.setdefault("edgecolor", "none")

    # Vectorized arguments

    # turn color into a python list
    color = [color] if isinstance(color, str) or color is None else list(color)

    # turn marker into a python list
    marker = [marker] if isinstance(marker, str) else list(marker)

    if title is not None:
        # turn title into a python list if not None
        title = [title] if isinstance(title, str) else list(title)

    # turn vmax and vmin into a sequence
    if isinstance(vmax, str) or not isinstance(vmax, Sequence):
        vmax = [vmax]
    if isinstance(vmin, str) or not isinstance(vmin, Sequence):
        vmin = [vmin]
    if isinstance(vcenter, str) or not isinstance(vcenter, Sequence):
        vcenter = [vcenter]
    if isinstance(norm, Normalize) or not isinstance(norm, Sequence):
        norm = [norm]

    # Size
    if "s" in kwargs and size is None:
        size = kwargs.pop("s")
    if size is not None:
        # check if size is any type of sequence, and if so
        # set as ndarray
        if (
            size is not None
            and isinstance(size, Sequence | pd.Series | np.ndarray)
            and len(size) == adata.shape[0]
        ):
            size = np.array(size, dtype=float)
    else:
        size = 120000 / adata.shape[0]

    ##########
    # Layout #
    ##########
    # Most of the code is for the case when multiple plots are required

    if wspace is None:
        #  try to set a wspace that is not too large or too small given the
        #  current figure size
        wspace = 0.75 / rcParams["figure.figsize"][0] + 0.02

    if components is not None:
        color, dimensions = list(zip(*product(color, dimensions), strict=True))

    color, dimensions, marker = _broadcast_args(color, dimensions, marker)

    # 'color' is a list of names that want to be plotted.
    # Eg. ['Gene1', 'louvain', 'Gene2'].
    # component_list is a list of components [[0,1], [1,2]]
    if (
        not isinstance(color, str) and isinstance(color, Sequence) and len(color) > 1
    ) or len(dimensions) > 1:
        if ax is not None:
            msg = (
                "Cannot specify `ax` when plotting multiple panels "
                "(each for a given value of 'color')."
            )
            raise ValueError(msg)

        # each plot needs to be its own panel
        fig, grid = _panel_grid(hspace, wspace, ncols, len(color))
    else:
        grid = None
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111, **args_3d)

    ############
    # Plotting #
    ############
    axs = []

    # use itertools.product to make a plot for each color and for each component
    # For example if color=[gene1, gene2] and components=['1,2, '2,3'].
    # The plots are: [
    #     color=gene1, components=[1,2], color=gene1, components=[2,3],
    #     color=gene2, components = [1, 2], color=gene2, components=[2,3],
    # ]
    for count, (value_to_plot, dims) in enumerate(zip(color, dimensions, strict=True)):
        kwargs_scatter = kwargs.copy()  # is potentially mutated for each plot
        color_source_vector = _get_color_source_vector(
            adata,
            value_to_plot,
            layer=layer,
            mask_obs=mask_obs,
            use_raw=use_raw,
            gene_symbols=gene_symbols,
            groups=groups,
        )
        color_vector, color_type = _color_vector(
            adata,
            value_to_plot,
            values=color_source_vector,
            palette=palette,
            na_color=na_color,
        )

        # Order points
        order = slice(None)
        if sort_order and value_to_plot is not None and color_type == "cont":
            # Higher values plotted on top, null values on bottom
            order = np.argsort(-color_vector, kind="stable")[::-1]
        elif sort_order and color_type == "cat":
            # Null points go on bottom
            order = np.argsort(~pd.isnull(color_source_vector), kind="stable")
        # Set orders
        if isinstance(size, np.ndarray):
            size = np.array(size)[order]
        color_source_vector = color_source_vector[order]
        color_vector = color_vector[order]
        coords = basis_values[:, dims][order, :]

        # if plotting multiple panels, get the ax from the grid spec
        # else use the ax value (either user given or created previously)
        if grid:
            ax = plt.subplot(grid[count], **args_3d)
            axs.append(ax)
        if not (settings._frameon if frameon is None else frameon):
            ax.axis("off")
        if title is None:
            if value_to_plot is not None:
                ax.set_title(value_to_plot)
            else:
                ax.set_title("")
        else:
            try:
                ax.set_title(title[count])
            except IndexError:
                logg.warning(
                    "The title list is shorter than the number of panels. "
                    "Using 'color' value instead for some plots."
                )
                ax.set_title(value_to_plot)

        if color_type == "cont":
            vmin_float, vmax_float, vcenter_float, norm_obj = _get_vboundnorm(
                vmin, vmax, vcenter, norm=norm, index=count, colors=color_vector
            )
            kwargs_scatter["norm"] = check_colornorm(
                vmin_float,
                vmax_float,
                vcenter_float,
                norm_obj,
            )
            kwargs_scatter["cmap"] = cmap

        # make the scatter plot
        if projection == "3d":
            cax = ax.scatter(
                coords[:, 0],
                coords[:, 1],
                coords[:, 2],
                c=color_vector,
                rasterized=settings._vector_friendly,
                marker=marker[count],
                **kwargs_scatter,
            )
        else:
            scatter = (
                partial(ax.scatter, s=size, plotnonfinite=True)
                if scale_factor is None
                else partial(
                    circles, s=size, ax=ax, scale_factor=scale_factor
                )  # size in circles is radius
            )

            if add_outline:
                # the default outline is a black edge followed by a
                # thin white edged added around connected clusters.
                # To add an outline
                # three overlapping scatter plots are drawn:
                # First black dots with slightly larger size,
                # then, white dots a bit smaller, but still larger
                # than the final dots. Then the final dots are drawn
                # with some transparency.

                bg_width, gap_width = outline_width
                point = np.sqrt(size)
                gap_size = (point + (point * gap_width) * 2) ** 2
                bg_size = (np.sqrt(gap_size) + (point * bg_width) * 2) ** 2
                # the default black and white colors can be changes using
                # the contour_config parameter
                bg_color, gap_color = outline_color

                # remove edge from kwargs if present
                # because edge needs to be set to None
                kwargs_scatter["edgecolor"] = "none"
                # For points, if user did not set alpha, set alpha to 0.7
                kwargs_scatter.setdefault("alpha", 0.7)

                # remove alpha and color mapping for outline
                kwargs_outline = {
                    k: v
                    for k, v in kwargs.items()
                    if k not in {"alpha", "cmap", "norm"}
                }

                for s, c in [(bg_size, bg_color), (gap_size, gap_color)]:
                    ax.scatter(
                        coords[:, 0],
                        coords[:, 1],
                        s=s,
                        c=c,
                        rasterized=settings._vector_friendly,
                        marker=marker[count],
                        **kwargs_outline,
                    )

            cax = scatter(
                coords[:, 0],
                coords[:, 1],
                c=color_vector,
                rasterized=settings._vector_friendly,
                marker=marker[count],
                **kwargs_scatter,
            )

        # remove y and x ticks
        ax.set_yticks([])
        ax.set_xticks([])
        if projection == "3d":
            ax.set_zticks([])

        # set default axis_labels
        name = _basis2name(basis)
        axis_labels = [name + str(d + 1) for d in dims]

        ax.set_xlabel(axis_labels[0])
        ax.set_ylabel(axis_labels[1])
        if projection == "3d":
            # shift the label closer to the axis
            ax.set_zlabel(axis_labels[2], labelpad=-7)
        ax.autoscale_view()

        if edges:
            _utils.plot_edges(
                ax, adata, basis, edges_width, edges_color, neighbors_key=neighbors_key
            )
        if arrows:
            _utils.plot_arrows(ax, adata, basis, arrows_kwds)

        if value_to_plot is None:
            # if only dots were plotted without an associated value
            # there is not need to plot a legend or a colorbar
            continue

        if legend_fontoutline is not None:
            path_effect = [
                patheffects.withStroke(linewidth=legend_fontoutline, foreground="w")
            ]
        else:
            path_effect = None

        # Adding legends
        if color_type == "cat":
            _add_categorical_legend(
                ax,
                color_source_vector,
                palette=_get_palette(adata, value_to_plot),
                scatter_array=coords,
                legend_loc=legend_loc,
                legend_fontweight=legend_fontweight,
                legend_fontsize=legend_fontsize,
                legend_fontoutline=path_effect,
                na_color=na_color,
                na_in_legend=na_in_legend,
                multi_panel=bool(grid),
            )
        elif colorbar_loc is not None:
            plt.colorbar(
                cax, ax=ax, pad=0.01, fraction=0.08, aspect=30, location=colorbar_loc
            )

    if return_fig is True:
        return fig
    axs = axs if grid else ax
    _utils.savefig_or_show(basis, show=show, save=save)
    show = settings.autoshow if show is None else show
    if show:
        return None
    return axs