---
title: Working with SpatialExperiment
package: "`r BiocStyle::pkg_ver('poem')`"
author: 
- name: Siyuan Luo
  email: roseluosy@gmail.com
  affiliation:
    - &IMLS Institute for Molecular Life Sciences, University of Zurich, Zurich, Switzerland
    - &HEST Department of Health Sciences and Technology, ETH Zurich, Zurich, Switzerland
- name: Pierre-Luc Germain
  email: 
  affiliation:
    - *IMLS
    - *HEST
date: "`r Sys.Date()`"
output:
  BiocStyle::html_document
vignette: |
  %\VignetteIndexEntry{3_SpatialExperiment}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---
<style type="text/css">
.smaller {
  font-size: 10px
}
</style>

```{r, echo=FALSE, message=FALSE, warning = FALSE}
knitr::opts_chunk$set(error=FALSE, message=FALSE, 
                      warning=FALSE, collapse = TRUE)
library(BiocStyle)
```

```{r, message=FALSE, warning = FALSE}
library(poem)
library(ggplot2)
library(cowplot)
library(SpatialExperiment)
library(STexampleData)
library(dplyr)
library(tidyr)
```

# Prepare dataset

```{r}
my_cols <-c("#D55E00", "#CC79A7","#E69F00","#0072B2","#009E73","#F0E442",
            "#56B4E9","#000000")
names(my_cols) <- as.character(seq(my_cols))
```

Our package `poem` can be easily integrated into a workflow with 
`r Biocpkg("SpatialExperiment")` objects. Here we use the `Visium_humanDLPFC` 
dataset from package `r Biocpkg("STexampleData")` for illustration. Load it:

```{r}
spe <- Visium_humanDLPFC()
spe <- spe[, !is.na(colData(spe)$reference)]
spe
```


From this `SpatialExperiment` object, we take the location information 
(accessible via `spatialCoords`) and the manual annotation in `colData` and 
store them as a dataframe:

```{r}
data <- data.frame(spatialCoords(spe))
data$reference <- colData(spe)$reference
data <- na.omit(data)
data$reference <- factor(data$reference, levels=c("WM", "Layer6", "Layer5", 
                                                  "Layer4", "Layer3", "Layer2", 
                                                  "Layer1"))
```

The manual annotation looks like this:

```{r, fig.height = 4, fig.width = 4, fig.small = TRUE}
p1 <- ggplot(data) +
  geom_point(aes(x = pxl_col_in_fullres, y = -pxl_row_in_fullres, 
                 color = reference), size=0.3) +
  labs(x = "", y = "", color="", title="Manual annotation") +
  theme_minimal() +
  scale_color_manual(values = unname(my_cols)) +
  theme(
  legend.box.background = element_rect(fill = "grey90", color = "black", 
                                       size = 0.1),
  legend.box.margin = margin(-1, -1, -1, -1),
  axis.title.x=element_blank(),
  legend.position = "bottom",
  legend.box.spacing = margin(0),
  axis.text.x=element_blank(),
  axis.ticks.x=element_blank(),
  axis.text.y=element_blank(),
  axis.ticks.y=element_blank(),
  panel.spacing.x = unit(-0.5, "cm"),
  panel.grid.major = element_blank(), 
  panel.grid.minor = element_blank(),
  plot.title = element_text(hjust = 0.5, size=12, 
                            margin = margin(b = 5, t = 15))) +
  guides(color = guide_legend(keywidth = 1, keyheight = 0.8, 
                              override.aes = list(size = 3)))

p1
```

We then generate some hypothetical domain detection predictions by randomly 
permuting the manual annotation.

```{r}
set.seed(123) # For reproducibility

# Given a factor vector representing clustering results, simulate clustering 
# variations including merging two clusters and adding random noise.
simulate_clustering_variation <- function(clusters, split_cluster = NULL, 
                                          merge_clusters = NULL, 
                                          noise_level = 0.1) {
  # Convert to numeric for easier manipulation
  merge_clusters <- which(levels(clusters) %in% merge_clusters)
  clusters <- as.numeric(clusters)

  # 1. Merging two clusters
  if (!is.null(merge_clusters)) {
    clusters[clusters %in% merge_clusters] <- merge_clusters[1] 
    # Rename both to the same label
  }
  
  # 2. Adding random noise
  n <- length(clusters)
  n_noise <- round(n * noise_level) # Number of elements to replace
  if (n_noise > 0) {
    noise_indices <- sample(seq_len(n), n_noise) # Random indices to replace
    existing_levels <- unique(clusters)
    clusters[noise_indices] <- sample(existing_levels, n_noise, replace = TRUE) 
    # Replace with random levels
  }
  
  # Convert back to factor and return
  factor(clusters)
}
```

Below we simulate some prediction results with random noise as well as merging 
or splitting of domains:

```{r}
# P1: add random noise
data$P1 <- simulate_clustering_variation(
  data$reference,
  noise_level = 0.1
)

# P2: split Layer 3 into 2 domains, add random noise
data$P2 <- as.numeric(data$reference)
data$P2[data$reference=="Layer3" & data$pxl_col_in_fullres < 8000] <- 8
data$P2 <- factor(as.numeric(factor(data$P2)))

data$P2 <- simulate_clustering_variation(
  data$P2,
  noise_level = 0.1
)

# P3: merge Layer 4 and Layer 5, add random noise
data$P3 <- simulate_clustering_variation(
  data$reference,
  merge_clusters = c("Layer4", "Layer5"),
  noise_level = 0.1
)

```

If we visualize them:

```{r, fig.height = 7, fig.width = 7}
p2 <- data %>% pivot_longer(cols=-c("pxl_col_in_fullres","pxl_row_in_fullres"), 
                      names_to="prediction", values_to="domain") %>% 
  dplyr::filter(prediction != "reference") %>%
  ggplot() +
  geom_point(aes(x = pxl_col_in_fullres, y = -pxl_row_in_fullres, 
                 color = domain), size=0.4) +
  facet_wrap(~prediction, nrow=2) +
  labs(x = "", y = "", color="", title="") +
  theme_minimal() +
  scale_color_manual(values = unname(my_cols)) +
  theme(
  legend.box.background = element_rect(fill = "grey90", 
                                       color = "black", size = 0.1),
  legend.box.margin = margin(-1, -1, -1, -1),
  axis.title.x=element_blank(),
  legend.position = "bottom",
  legend.justification=c(0, 0),
  legend.box.spacing = margin(0),
  axis.text.x=element_blank(),
  axis.ticks.x=element_blank(),
  axis.text.y=element_blank(),
  axis.ticks.y=element_blank(),
  panel.spacing.x = unit(-0.5, "cm"),
  panel.grid.major = element_blank(), 
  panel.grid.minor = element_blank(),
  plot.title = element_text(hjust = 0.5, size=10))  +
  guides(color = guide_legend(keywidth = 1, keyheight = 0.8, 
                              override.aes = list(size = 3)))

ggdraw() +
  draw_plot(p2 + theme(plot.margin = margin(0, 2, 2, 2))) +  # Main plot
  draw_plot(p1, x = 0.5, y = -0.01, width = 0.5, height = 0.56)  # Inset plot
```

# Calculate external spatial metrics

We can compare P1-P3 to the manual annotation using external spatial metrics. 
The `getSpatialExternalMetrics()` function can accept either a 
`SpatialExperiment` object directly as input, or separate inputs for `true`, 
`pred` and `location`. For example:

```{r}
colData(spe) <- cbind(colData(spe), data[, c("P1","P2","P3")])
getSpatialExternalMetrics(object=spe, true="reference", pred="P3", k=6)
```

is equivalent to:

```{r}
getSpatialExternalMetrics(true=colData(spe)$reference, pred=colData(spe)$P3,
                          location=spatialCoords(spe), k=6)
```

## Dataset level

Let's first calculate two dataset-level metrics, nsARI and spatialARI:

```{r, fig.height = 2.5, fig.width = 5}
res3 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P3", 
                                  level="dataset", k=6,
                                  metrics=c("nsARI","SpatialARI"))

res2 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P2",
                                  level="dataset", k=6,
                                  metrics=c("nsARI","SpatialARI"))

res1 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P1",
                                  level="dataset", k=6,
                                  metrics=c("nsARI","SpatialARI"))

cbind(bind_rows(list(res1, res2, res3), .id="P")) %>% 
  pivot_longer(cols=c("nsARI", "SpatialARI"), 
               names_to="metric", values_to="value") %>%
  ggplot(aes(x=P, y=value, group=metric)) +
  geom_point(size=3, aes(color=P)) +
  facet_wrap(~metric, scales = "free") +
  theme_bw() + labs(x="Prediction")
```

## Class/cluster level

We can further calculate the class/cluster-level metrics, nsAWH and 
nsAWC, to get more insights about the errors our predictions make:

```{r}
res3 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P3",
                                  level="class", k=6,
                                  metrics=c("nsAWH","nsAWC"))

res2 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P2",
                                  level="class", k=6,
                                  metrics=c("nsAWH","nsAWC"))

res1 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P1",
                                  level="class", k=6,
                                  metrics=c("nsAWH","nsAWC"))
res1
```

Note that the indices in columns "class" and "cluster" correspond to the levels 
of original factors passed to `true` and `pred`. We align them back to the 
previous factor values, and then plot them in heatmap:

```{r}
awh1 <- na.omit(res1[,c("nsAWH", "cluster")]) %>% 
  mutate(cluster = levels(data$P1)[cluster])
awh2 <- na.omit(res2[,c("nsAWH", "cluster")]) %>% 
  mutate(cluster = levels(data$P2)[cluster])
awh3 <- na.omit(res3[,c("nsAWH", "cluster")]) %>% 
  mutate(cluster = levels(data$P3)[cluster])

awh <- cbind(bind_rows(list(awh1, awh2, awh3), .id="P")) %>% 
  pivot_wider(names_from = cluster, values_from = nsAWH) %>%
  subset(select = -c(P))
awh <- as.matrix(awh)
rownames(awh) <- c("P1", "P2", "P3")
awh <- awh[,c("1", "2", "3", "4", "5", "6", "7", "8")]

awh <- data.frame(awh)
colnames(awh) <- seq_len(8)
awh$prediction <- rownames(awh)


p4 <- awh %>% pivot_longer(cols=-c("prediction"), names_to="cluster", 
                           values_to = "AWH") %>%
  mutate(prediction = factor(prediction), cluster=factor(cluster)) %>%
  ggplot(aes(cluster, prediction, fill=AWH)) +
  geom_tile() +
  scale_fill_distiller(palette = "RdPu") +
  labs(x="Predicted domain", y="")

```

```{r}
awc1 <- na.omit(res1[,c("nsAWC", "class")]) %>% 
  mutate(class = levels(data$reference)[class])
awc2 <- na.omit(res2[,c("nsAWC", "class")]) %>% 
  mutate(class = levels(data$reference)[class])
awc3 <- na.omit(res3[,c("nsAWC", "class")]) %>% 
  mutate(class = levels(data$reference)[class])

awc <- cbind(bind_rows(list(awc1, awc2, awc3), .id="P")) %>% 
  pivot_wider(names_from = class, values_from = nsAWC) %>%
  subset(select = -c(P))
awc <- as.matrix(awc)
rownames(awc) <- c("P1", "P2", "P3")


awc <- data.frame(awc)
awc$prediction <- rownames(awc)


p5 <- awc %>% pivot_longer(cols=-c("prediction"), names_to="class", 
                           values_to = "AWC") %>%
  mutate(prediction = factor(prediction), class=factor(class)) %>%
  ggplot(aes(class, prediction, fill=AWC)) +
  geom_tile() +
  scale_fill_distiller(palette = "RdPu") +
  labs(x="Annotated domain", y="")

```

```{r, fig.height = 2, fig.width = 9.5}
plot_grid(p4, p5, rel_widths=c(1,1), scale=c(1, 1))
```

The class-level AWC highlights that in P2, Layer3 has low completeness. This 
align with our simulation that Layer3 is splitted into 3 clusters in P2. 
Similarly, the cluster-level AWH highlights that in P3, cluster 3 has low 
homogeneity, consistent with the merging of layer 4 and 5.


## Element level

One can also calculate element-level metric, such as the neighborhood-smoothed 
spot-wise pair concordance (nsSPC), for visualization.

```{r, fig.height = 2.5, fig.width = 8}
res1 <- cbind(getSpatialExternalMetrics(object=spe, true="reference", pred="P1",
                                  level="element",
                                  metrics=c("nsSPC"), k=6,
                                  useNegatives = FALSE), 
              data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])

res2 <- cbind(getSpatialExternalMetrics(object=spe, true="reference", pred="P2",
                                  level="element",
                                  metrics=c("nsSPC"), k=6,
                                  useNegatives = FALSE),
               data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])

res3 <- cbind(getSpatialExternalMetrics(object=spe, true="reference", pred="P3",
                                  level="element",
                                  metrics=c("nsSPC"), k=6,
                                  useNegatives = FALSE),
               data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])


cbind(bind_rows(list(res1, res2, res3), .id="P")) %>% 
  pivot_longer(cols=c("nsSPC"), 
               names_to="metric", values_to="value") %>%
  ggplot(aes(x = pxl_col_in_fullres, y = - pxl_row_in_fullres, color = value)) +
  scale_colour_gradient(high="white", low ="deeppink4") +
  geom_point(size=0.3) +
  facet_wrap(~P, scales = "free") +
  theme_bw() + labs(x="Prediction", y="", color="nsSPC")
```

This clear highlights the low concordance regions in each prediction as 
expected.

# Calculate internal spatial metrics

When the manual annotation is not available, one can use internal metrics, 
CHAOS, ELSA and PAS, to understand the domain continuity and local homogeneity 
for a domain detection result. To illustrate this, we simulate P4 and P5 with 
20% and 30% random noise, respectively.

```{r}
# P4: add 20% random noise 
data$P4 <- simulate_clustering_variation(
  data$reference,
  noise_level = 0.2
)

# P5: add 30% random noise
data$P5 <- simulate_clustering_variation(
  data$reference,
  noise_level = 0.3
)
```

We calculate the internal spatial metrics for P1-P5. Similarly, the 
`getSpatialInternalMetrics()` function can accept either a 
`SpatialExperiment` object directly as input, or separate inputs for `labels` 
and `location`. This means that

```{r}
getSpatialInternalMetrics(object=spe, labels="reference", k=6)
```

is equivalent to:

```{r}
getSpatialInternalMetrics(labels=colData(spe)$reference, 
                          location = spatialCoords(spe), k=6)
```

To calculate the dataset-level metrics:

```{r}
colData(spe) <- cbind(colData(spe), data[, c("P4","P5")])
internal <-lapply(setNames(c("reference","P1","P2","P3","P4","P5"), 
                           c("reference","P1","P2","P3","P4","P5")),
function(x){getSpatialInternalMetrics(object=spe, labels=x,
                                      k=6, level="dataset", 
                                      metrics=c("PAS", "ELSA", "CHAOS"))})
internal <- bind_rows(internal,.id = "prediction")
```

```{r, fig.height = 2.5, fig.width = 5}
internal %>% 
  pivot_longer(cols=-c("prediction"), 
               names_to="metric", values_to="value") %>%
  filter(metric %in% c("ELSA", "PAS", "CHAOS")) %>%
  ggplot(aes(x=prediction, y=value, group=metric)) +
  geom_point(size=3, aes(color=prediction)) +
  facet_wrap(~metric, scales = "free") +
  theme_bw() + labs(x="", color="") +
  theme(legend.position="None",
        axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 1))
```

The lower the scores, the smoother the predictions. As expected, the smoothness 
decrease from P3 to P5 as the noise level increase.

The internal metrics can also be calculated at the element level. For example we 
can calculate the element-wise ELSA score, which is a score for local diversity 
and can be regarded as edge detector:

```{r, fig.height = 4, fig.width = 8}
internal <-lapply(setNames(c("reference","P1","P2","P3","P4","P5"), 
                           c("reference","P1","P2","P3","P4","P5")),
function(x){cbind(
  getSpatialInternalMetrics(object=spe, labels=x,
                k=6, level="element", metrics=c( "ELSA")), 
  data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])})
internal <- bind_rows(internal,.id = "prediction")

internal %>%
  ggplot(aes(x = pxl_col_in_fullres, y = - pxl_row_in_fullres, color = ELSA)) +
  scale_colour_gradient(low="white", high="deeppink4") +
  geom_point(size=0.4) +
  facet_wrap(~prediction, scales = "free") +
  theme_bw() + labs(x="", y="", color="ELSA")
```

# Session info
```{r}
sessionInfo()
```