diff --git a/.github/workflows/rust-publish.yml b/.github/workflows/rust-publish.yml index 8e1c5b1a..e876c45d 100644 --- a/.github/workflows/rust-publish.yml +++ b/.github/workflows/rust-publish.yml @@ -157,8 +157,7 @@ jobs: secrets: inherit publish-wasm: - if: ${{ always() && inputs.wasm != false && needs.publish-all-crates.result == 'success' }} - needs: publish-all-crates + if: ${{ always() && inputs.wasm != false }} runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.gitignore b/.gitignore index 674ed035..ab517cf3 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,8 @@ bin/ tests/data/out/region_scoring_count.csv.gz /gtars-refget/tests/store_test/rgstore.json /gtars-refget/tests/store_test/sequences.rgsi +libgtars.dylib.dSYM/ +gtars-bm25/tests/demo_cache/ # Large benchmark data and validation files tests/data/interval_ranges_benchmark/ diff --git a/Cargo.toml b/Cargo.toml index 43f08259..68b2c1c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ "gtars-python", "gtars-genomicdist", "gtars-wasm", - "gtars-r/src/rust", + # "gtars-r/src/rust", "gtars-core", "gtars-refget", "gtars-uniwig", @@ -17,7 +17,8 @@ members = [ "gtars-lola", "gtars-fragsplit", "gtars-scoring", - "gtars", + "gtars-bm25", + "gtars", ] [workspace.dependencies] diff --git a/gtars-bm25/Cargo.toml b/gtars-bm25/Cargo.toml new file mode 100644 index 00000000..df14df47 --- /dev/null +++ b/gtars-bm25/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "gtars-bm25" +version = "0.1.0" +edition = "2024" +description = "BM25 sparse embedding implementation for genomic intervals and information retrieval" + +[dependencies] +gtars-core = { path = "../gtars-core", version="0.5.2" } +gtars-tokenizers = { path = "../gtars-tokenizers" } + +[dev-dependencies] +rstest = "0.26.1" \ No newline at end of file diff --git a/gtars-bm25/README.md b/gtars-bm25/README.md new file mode 100644 index 00000000..15f4726f --- /dev/null +++ b/gtars-bm25/README.md @@ -0,0 +1,120 @@ +# gtars-bm25 +This crate implements a BM25 sparse embedding for genomic intervals, motivated by Qdrants own [bm25 implementation](https://github.com/qdrant/fastembed/blob/main/fastembed/sparse/bm25.py) within fastembed. The implementation is actually BM25-_like_, in that it assumes a constant, prior-known average document length. This enables us to compute the BM25 scores for a query interval without needing to know teh distribution of document lengths in the corpus. Moreover, sparse BM25 embeddings for documents need not be recomputed as the corpus grows. + +This method is designed to be used in conjunction with one of our dense embedding models, such as Atacformer or Region2Vec to enable hybrid search. The sparse BM25 embedding can be used perform "key-word" search (look for specific regions), while the dense embedding can be used to perform "semantic search" (look for similar biology). By combining the two with a fusion strategy, we can achieve better recall and precision than either method alone. + +## Example usage +Here is an example usage of the BM25 embedding: + +```python +from gtars.bm25 import Bm25 +from gtars.models import RegionSet + +model = Bm25( + tokenizer="/path/to/vocab.bed", + k=1.5, + b=0.75, + avg_doc_length=1_000 +) + +query = RegionSet("path/to/query.bed") +embedding = model.embed(query) + +print(embedding.indices) # [1, 5, 10] +print(embedding.values) # [0.5, 1.0, 0.75] +``` + +## Use with Atacformer and Qdrant +BM25 can be used with dense embedding models like Atacformer to perform hybrid search in Qdrant. + + +First, we need to create a Qdrant collection with both dense and sparse vector configurations: +```python +from geniml.atacformer import AtacformerForCellClustering +from gtars.bm25 import Bm25 +from gtars.models import RegionSet +from gtars.tokenizers import Tokenizer + +from qdrant_client import models as qdrant_models +from qdrant_client import QdrantClient + +# instantiate the qdrant collection +client = QdrantClient("http://localhost:6333") +client.recreate_collection( + collection_name="bedbase", + # atacformer embeddings + vectors_config={ + "dense": qdrant_models.VectorParams( + size=384, + distance=qdrant_models.Distance.COSINE + ), + }, + # bm25 sparse embeddings + sparse_vectors_config={ + "sparse": qdrant_models.SparseVectorsConfig( + modifier=qdrant_models.Modifier.IDF + ) + } +) +``` + +Then we can instantiate our Atacformer and BM25 models, and insert some data into the collection: +```python +# instantiate the models +tokenizer = Tokenizer.from_pretrained("databio/atacformer-ctft-hg38") +atacformer = AtacformerForCellClustering.from_pretrained("databio/atacformer-ctft-hg38") +bm25 = Bm25( + tokenizer=tokenizer, + k=1.5, + b=0.75, + avg_doc_length=1_000 # bed files are usually very large +) + +documents = [ + RegionSet("path/to/document1.bed"), + RegionSet("path/to/document2.bed"), + RegionSet("path/to/document3.bed"), + RegionSet("path/to/document4.bed"), + RegionSet("path/to/document5.bed"), +] + +for i, document in enumerate(documents): + dense_embedding = atacformer.embed(document) + sparse_embedding = bm25.embed(document) + + client.upsert( + collection_name="bedbase", + points=[ + qdrant_models.PointStruct( + id=i, + vector=dense_embedding, + sparse_vector=sparse_embedding + ) + ] + ) +``` + +Finally, we can perform a hybrid search using both the dense and sparse embeddings: +```python +query = RegionSet("path/to/query.bed") +dense_query_embedding = atacformer.embed(query) +sparse_query_embedding = bm25.embed(query) + +response = client.query_points( + collection_name="bedbase", + prefetch=[ + qdrant_models.Prefetch( + query=sparse_query_embedding, + using="sparse", + limit=3, + ), + qdrant_models.Prefetch( + query=dense_query_embedding, + using="dense", + limit=3, + ) + ], + query=qdrant_models.FusionQuery(fusion=qdrant_models.Fusion.RRF), + limit=3, +) +``` \ No newline at end of file diff --git a/gtars-bm25/src/bm25.rs b/gtars-bm25/src/bm25.rs new file mode 100644 index 00000000..9bde4aac --- /dev/null +++ b/gtars-bm25/src/bm25.rs @@ -0,0 +1,332 @@ +use std::path::Path; +use std::collections::HashMap; + +use gtars_core::models::region::Region; +use gtars_tokenizers::Tokenizer; + +use crate::sparse_vector::SparseVector; + +pub struct Bm25 { + avg_doc_length: f32, + b: f32, + k: f32, + tokenizer: Tokenizer, +} + +pub struct Bm25Builder { + b: f32, + k: f32, + avg_doc_length: f32, + tokenizer: Option, +} + +impl Bm25Builder { + /// Build the BM25 model from the builder. + /// + /// # Panics + /// Panics if no tokenizer/vocabulary has been provided. + pub fn build(self) -> Bm25 { + let tokenizer = self + .tokenizer + .expect("A tokenizer or vocabulary must be provided via with_tokenizer() or with_vocab()"); + + Bm25 { + avg_doc_length: self.avg_doc_length, + b: self.b, + k: self.k, + tokenizer, + } + } + + pub fn with_k(mut self, k: f32) -> Self { + self.k = k; + self + } + + pub fn with_b(mut self, b: f32) -> Self { + self.b = b; + self + } + + pub fn with_avg_doc_length(mut self, avg_doc_length: f32) -> Self { + self.avg_doc_length = avg_doc_length; + self + } + + /// Set the tokenizer directly from an existing `Tokenizer` instance. + pub fn with_tokenizer(mut self, tokenizer: Tokenizer) -> Self { + self.tokenizer = Some(tokenizer); + self + } + + /// Load a vocabulary from a file path. + /// + /// The path can be: + /// 1. A BED file on disk (optionally gzipped) + /// 2. A TOML config file pointing to a universe + /// + /// The file type is auto-detected. + pub fn with_vocab>(mut self, vocab: P) -> Self { + let tokenizer = Tokenizer::from_auto(vocab) + .expect("Failed to load vocabulary. Ensure the path points to a valid .bed, .bed.gz, or .toml file."); + self.tokenizer = Some(tokenizer); + self + } +} + +impl Default for Bm25Builder { + fn default() -> Self { + Self { + b: 0.75, + k: 1.0, + avg_doc_length: 1_000.0, + tokenizer: None, + } + } +} + +impl Bm25 { + /// Create a new Bm25Builder. + pub fn builder() -> Bm25Builder { + Bm25Builder::default() + } + + /// Tokenize a set of regions into token IDs. + /// + /// This performs the overlap query against the vocabulary and returns + /// the token IDs of all overlapping vocabulary regions. + pub fn tokenize(&self, regions: &[Region]) -> Vec { + let unk_id = self.tokenizer.get_unk_token_id(); + self.tokenizer + .encode(regions) + .unwrap_or_default() + // filter OOV tokens (those that don't overlap any vocab region) by removing the unk_id + .into_iter().filter(|&id| id != unk_id) + .collect() + } + + /// Returns a reference to the internal tokenizer. + pub fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } + + /// Returns the vocabulary size (number of regions in the vocabulary). + pub fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size() + } + + /// Returns the `k` parameter (term frequency saturation). + pub fn k(&self) -> f32 { + self.k + } + + /// Returns the `b` parameter (document length normalization). + pub fn b(&self) -> f32 { + self.b + } + + /// Returns the assumed average document length. + pub fn avg_doc_length(&self) -> f32 { + self.avg_doc_length + } + + /// Encode a set of regions into a BM25 sparse vector. + /// + /// This tokenizes the regions and computes BM25-like term frequency + /// scores for each token. The indices are token IDs and the values + /// are the BM25 term-frequency component scores. + /// Formula from: https://en.wikipedia.org/wiki/Okapi_BM25 + pub fn embed(&self, regions: &[Region]) -> SparseVector { + // tokenize to get token IDs + let token_ids = self.tokenize(regions); + + if token_ids.is_empty() { + return SparseVector::empty(); + } + + // count term frequencies + let mut tf_map: HashMap = HashMap::new(); + for id in &token_ids { + *tf_map.entry(*id).or_insert(0) += 1; + } + + let doc_length = token_ids.len() as f32; + + // compute bm25 term-frequency scores + let mut indices: Vec = Vec::with_capacity(tf_map.len()); + let mut values: Vec = Vec::with_capacity(tf_map.len()); + + for (token_id, raw_tf) in tf_map { + let tf = raw_tf as f32; + let tf_score = (tf * (self.k + 1.0)) + / (tf + self.k * (1.0 - self.b + self.b * (doc_length / self.avg_doc_length))); + + indices.push(token_id); + values.push(tf_score); + } + + SparseVector::new(indices, values) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::*; + + #[fixture] + fn peaks_path() -> String { + "../tests/data/tokenizers/peaks.bed".to_string() + } + + #[fixture] + fn bm25(peaks_path: String) -> Bm25 { + Bm25::builder() + .with_vocab(&peaks_path) + .with_k(1.5) + .with_b(0.75) + .with_avg_doc_length(1_000.0) + .build() + } + + #[rstest] + fn test_builder_defaults() { + let builder = Bm25Builder::default(); + assert_eq!(builder.k, 1.0); + assert_eq!(builder.b, 0.75); + assert_eq!(builder.avg_doc_length, 1_000.0); + } + + #[rstest] + fn test_builder_custom_params(peaks_path: String) { + let model = Bm25::builder() + .with_vocab(&peaks_path) + .with_k(2.0) + .with_b(0.5) + .with_avg_doc_length(500.0) + .build(); + + assert_eq!(model.k(), 2.0); + assert_eq!(model.b(), 0.5); + assert_eq!(model.avg_doc_length(), 500.0); + } + + #[rstest] + fn test_vocab_size(bm25: Bm25) { + // peaks.bed has 25 regions + 7 special tokens + assert!(bm25.vocab_size() > 25); + } + + #[rstest] + fn test_tokenize_overlapping_regions(bm25: Bm25) { + // query a region that overlaps the first entry in peaks.bed: chr17 7915738 7915777 + let regions = vec![Region { + chr: "chr17".to_string(), + start: 7915700, + end: 7915800, + rest: None, + }]; + + let token_ids = bm25.tokenize(®ions); + assert!(!token_ids.is_empty(), "should find overlapping tokens"); + } + + #[rstest] + fn test_tokenize_no_overlap(bm25: Bm25) { + // query a region on a chromosome with no vocab entries + let regions = vec![Region { + chr: "chrZ".to_string(), + start: 0, + end: 100, + rest: None, + }]; + + let token_ids = bm25.tokenize(®ions); + assert!(token_ids.is_empty(), "should find no tokens for non-existent chrom"); + } + + #[rstest] + fn test_embed_produces_sparse_vector(bm25: Bm25) { + let regions = vec![Region { + chr: "chr17".to_string(), + start: 7915700, + end: 7915800, + rest: None, + }]; + + let sv = bm25.embed(®ions); + assert!(!sv.is_empty()); + assert_eq!(sv.indices.len(), sv.values.len()); + } + + #[rstest] + fn test_embed_empty_input(bm25: Bm25) { + let regions: Vec = vec![]; + let sv = bm25.embed(®ions); + assert!(sv.is_empty()); + } + + #[rstest] + fn test_embed_no_overlap_returns_empty(bm25: Bm25) { + let regions = vec![Region { + chr: "chrZ".to_string(), + start: 0, + end: 100, + rest: None, + }]; + + let sv = bm25.embed(®ions); + assert!(sv.is_empty()); + } + + #[rstest] + fn test_embed_values_are_positive(bm25: Bm25) { + // query multiple overlapping regions + let regions = vec![ + Region { chr: "chr17".to_string(), start: 7915700, end: 7915800, rest: None }, + Region { chr: "chr6".to_string(), start: 157381091, end: 157381200, rest: None }, + ]; + + let sv = bm25.embed(®ions); + for val in &sv.values { + assert!(*val > 0.0, "Bm25 TF scores should be positive"); + } + } + + #[rstest] + fn test_embed_repeated_regions_increase_tf(bm25: Bm25) { + let region = Region { + chr: "chr17".to_string(), + start: 7915700, + end: 7915800, + rest: None, + }; + + let sv_single = bm25.embed(&[region.clone()]); + let sv_repeated = bm25.embed(&[region.clone(), region.clone(), region]); + + // with repeated regions, the tf score should be higher (but sublinear due to saturation) + assert!(!sv_single.is_empty()); + assert!(!sv_repeated.is_empty()); + + let val_single = sv_single.values[0]; + let val_repeated = sv_repeated.values[0]; + assert!(val_repeated > val_single, "repeated terms should have higher TF score"); + } + + #[rstest] + fn test_with_tokenizer(peaks_path: String) { + let tokenizer = Tokenizer::from_bed(&peaks_path).unwrap(); + let model = Bm25::builder() + .with_tokenizer(tokenizer) + .build(); + + assert!(model.vocab_size() > 0); + } + + #[rstest] + #[should_panic(expected = "A tokenizer or vocabulary must be provided")] + fn test_builder_panics_without_tokenizer() { + Bm25Builder::default().build(); + } +} \ No newline at end of file diff --git a/gtars-bm25/src/lib.rs b/gtars-bm25/src/lib.rs new file mode 100644 index 00000000..7af0851f --- /dev/null +++ b/gtars-bm25/src/lib.rs @@ -0,0 +1,10 @@ +//! +//! BM25 sparse embedding implementation for genomic intervals and information retrieval. +//! +//! This crate enables powerful hybrid search when paired with a dense embedding model like Atacformer. +//! +pub mod bm25; +pub mod sparse_vector; + +pub use bm25::{Bm25, Bm25Builder}; +pub use sparse_vector::SparseVector; \ No newline at end of file diff --git a/gtars-bm25/src/sparse_vector.rs b/gtars-bm25/src/sparse_vector.rs new file mode 100644 index 00000000..431ab1e1 --- /dev/null +++ b/gtars-bm25/src/sparse_vector.rs @@ -0,0 +1,41 @@ +/// A sparse vector representation for BM25 embeddings. +/// +/// This is designed to be compatible with Qdrant's sparse vector format. +/// Each non-zero dimension is represented by an index-value pair. +pub struct SparseVector { + pub indices: Vec, + pub values: Vec, +} + +impl SparseVector { + /// Create a new sparse vector from indices and values. + /// + /// # Panics + /// Panics if `indices` and `values` have different lengths. + pub fn new(indices: Vec, values: Vec) -> Self { + assert_eq!( + indices.len(), + values.len(), + "indices and values must have the same length" + ); + SparseVector { indices, values } + } + + /// Create an empty sparse vector. + pub fn empty() -> Self { + SparseVector { + indices: Vec::new(), + values: Vec::new(), + } + } + + /// Returns the number of non-zero entries. + pub fn len(&self) -> usize { + self.indices.len() + } + + /// Returns true if the vector has no entries. + pub fn is_empty(&self) -> bool { + self.indices.is_empty() + } +} \ No newline at end of file diff --git a/gtars-bm25/tests/demo_bm25_enrichment.py b/gtars-bm25/tests/demo_bm25_enrichment.py new file mode 100644 index 00000000..e22b2cf2 --- /dev/null +++ b/gtars-bm25/tests/demo_bm25_enrichment.py @@ -0,0 +1,435 @@ +""" +BM25 vs LOLA Enrichment Demo + +Compares enrichment analysis using: + 1. BM25 sparse embeddings (new approach) + 2. Traditional LOLA Fisher's exact test (gtars-lola) + +Uses the lola_hg38_ucsc_features bedset from BEDbase as the reference database. + +Usage: + python demo_bm25_enrichment.py + python demo_bm25_enrichment.py # uses a sample BED from BEDbase +""" + +import sys +import gzip +import json +import time +import urllib.request +from pathlib import Path +from collections import defaultdict + +from gtars.bm25 import Bm25 +from gtars.lola import RegionDB, run_lola + +BEDSET_ID = "lola_hg38_ucsc_features" +API_BASE = "https://api.bedbase.org/v1" +CACHE_DIR = Path(__file__).parent / "demo_cache" +RESULTS_PATH = Path(__file__).parent / "demo_cache" / "results.json" + + +# ── Helpers ────────────────────────────────────────────────────────────── + +def _request(url: str) -> urllib.request.Request: + return urllib.request.Request(url, headers={"User-Agent": "gtars-bm25-demo/0.1"}) + + +def fetch_json(url: str) -> dict: + with urllib.request.urlopen(_request(url)) as r: + return json.loads(r.read()) + + +def download_file(url: str, dest: Path): + if dest.exists(): + return + print(f" {dest.name}") + with urllib.request.urlopen(_request(url)) as r, open(dest, "wb") as f: + f.write(r.read()) + + +def fetch_bedset_files() -> list[dict]: + """Fetch metadata for all BED files in the LOLA UCSC bedset.""" + data = fetch_json(f"{API_BASE}/bedset/{BEDSET_ID}/bedfiles") + return data["results"] + + +def download_bed(bed_id: str, name: str, cache: Path) -> Path: + """Download a BED file from BEDbase, return path to decompressed file.""" + gz_path = cache / f"{name}.bed.gz" + bed_path = cache / f"{name}.bed" + + if bed_path.exists(): + return bed_path + + files_meta = fetch_json(f"{API_BASE}/bed/{bed_id}/metadata/files") + url = files_meta["bed_file"]["access_methods"][0]["access_url"]["url"] + download_file(url, gz_path) + + with gzip.open(gz_path, "rb") as f_in, open(bed_path, "wb") as f_out: + f_out.write(f_in.read()) + + return bed_path + + +def read_bed_tuples(path: Path) -> list[tuple[str, int, int]]: + """Read a BED file into a list of (chr, start, end) tuples.""" + regions = [] + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or line.startswith("track"): + continue + parts = line.split("\t") + if len(parts) < 3: + continue + regions.append((parts[0], int(parts[1]), int(parts[2]))) + return regions + + +def get_sample_query(cache: Path) -> str: + """Download a sample BED file to use as query.""" + sample_path = cache / "sample_query.bed" + if sample_path.exists(): + return str(sample_path) + + print(" No query provided — fetching a sample from BEDbase...") + beds = fetch_json(f"{API_BASE}/bed/list?genome=hg38&limit=1&offset=50") + bed = beds["results"][0] + bed_id = bed["id"] + name = bed.get("name", bed_id) + print(f" Sample: {name} ({bed_id})") + + files_meta = fetch_json(f"{API_BASE}/bed/{bed_id}/metadata/files") + url = files_meta["bed_file"]["access_methods"][0]["access_url"]["url"] + download_file(url, cache / "sample_query.bed.gz") + + with gzip.open(cache / "sample_query.bed.gz", "rb") as f_in: + with open(sample_path, "wb") as f_out: + f_out.write(f_in.read()) + + return str(sample_path) + + +# ── Data setup ─────────────────────────────────────────────────────────── + +def download_reference_beds(bed_files: list[dict], cache: Path) -> dict[str, Path]: + """Download all reference BED files. Returns {name: path}.""" + ref_beds = {} + for meta in bed_files: + path = download_bed(meta["id"], meta["name"], cache) + ref_beds[meta["name"]] = path + return ref_beds + + +def build_vocab_and_mapping( + bed_files: list[dict], ref_beds: dict[str, Path], cache: Path +) -> tuple[Path, dict[int, str], dict[str, str]]: + """Concatenate reference BEDs into one vocab, tracking annotation per line.""" + vocab_path = cache / "vocab.bed" + token_to_anno: dict[int, str] = {} + anno_descriptions: dict[str, str] = {} + anno_region_counts: dict[str, int] = {} + line_num = 0 + + with open(vocab_path, "w") as vocab_f: + for meta in bed_files: + name = meta["name"] + anno_descriptions[name] = meta["description"] + count = 0 + + for chr, start, end in read_bed_tuples(ref_beds[name]): + vocab_f.write(f"{chr}\t{start}\t{end}\n") + token_to_anno[line_num] = name + line_num += 1 + count += 1 + + anno_region_counts[name] = count + + return vocab_path, token_to_anno, anno_descriptions + + +# ── IDF computation ────────────────────────────────────────────────────── + +def compute_idf( + bm25: Bm25, + ref_beds: dict[str, Path], +) -> dict[int, float]: + """Compute IDF for each token across the reference corpus. + + IDF(token) = ln(1 + (N - df + 0.5) / (df + 0.5)) + where N = number of documents, df = documents containing token. + """ + import math + + # Embed each reference file, collect which tokens appear in each + doc_freq: dict[int, int] = defaultdict(int) + n_docs = len(ref_beds) + + for name, path in ref_beds.items(): + sv = bm25.embed(str(path)) + seen_tokens = set(sv.indices) + for token_id in seen_tokens: + doc_freq[token_id] += 1 + + # Compute IDF + idf = {} + for token_id, df in doc_freq.items(): + idf[token_id] = math.log(1 + (n_docs - df + 0.5) / (df + 0.5)) + + return idf + + +# ── BM25 enrichment ───────────────────────────────────────────────────── + +def run_bm25_enrichment( + query_path: str, + bm25: Bm25, + token_to_anno: dict[int, str], + idf: dict[int, float], +) -> tuple[dict[str, dict], float]: + """Run BM25 enrichment with IDF. Returns ({name: {score, hits}}, elapsed_seconds).""" + t0 = time.perf_counter() + sv = bm25.embed(query_path) + elapsed = time.perf_counter() - t0 + + if len(sv) == 0: + return {}, elapsed + + anno_scores: dict[str, float] = defaultdict(float) + anno_hits: dict[str, int] = defaultdict(int) + + for idx, val in zip(sv.indices, sv.values): + anno = token_to_anno.get(idx) + if anno is None: + for offset in range(1, 10): + anno = token_to_anno.get(idx - offset) + if anno: + break + if anno is None: + continue + # Apply IDF: full BM25 = IDF * TF_score + token_idf = idf.get(idx, 0.0) + anno_scores[anno] += val * token_idf + anno_hits[anno] += 1 + + results = {} + for name in anno_scores: + results[name] = { + "score": round(anno_scores[name], 4), + "hits": anno_hits[name], + } + + return results, elapsed + + +# ── LOLA enrichment ────────────────────────────────────────────────────── + +def run_lola_enrichment( + query_path: str, + ref_beds: dict[str, Path], +) -> tuple[dict[str, dict], float]: + """Run traditional LOLA. Returns ({name: {support, fraction, ...}}, elapsed).""" + bed_paths = list(ref_beds.values()) + bed_names = list(ref_beds.keys()) + + query_tuples = read_bed_tuples(Path(query_path)) + n_query = len(query_tuples) + + # Universe = union of query + all reference regions + universe_tuples = list(query_tuples) + for path in bed_paths: + universe_tuples.extend(read_bed_tuples(path)) + + t0 = time.perf_counter() + region_db = RegionDB.from_bed_files( + [str(p) for p in bed_paths], + filenames=bed_names, + ) + + raw = run_lola( + user_sets=[query_tuples], + universe=universe_tuples, + region_db=region_db, + min_overlap=1, + direction="enrichment", + ) + elapsed = time.perf_counter() - t0 + + results = {} + n = len(raw["filename"]) + for i in range(n): + name = raw["filename"][i] + support = raw["support"][i] + results[name] = { + "support": support, + "fraction": round(support / n_query, 4) if n_query > 0 else 0, + } + + return results, elapsed + + +# ── Display ────────────────────────────────────────────────────────────── + +def print_results( + bm25_results: dict[str, dict], + bm25_time: float, + lola_results: dict[str, dict], + lola_time: float, + anno_descriptions: dict[str, str], +): + """Pretty-print side-by-side comparison.""" + + # Rank by BM25 score and by LOLA support + bm25_ranked = sorted(bm25_results.keys(), key=lambda n: bm25_results[n]["score"], reverse=True) + sup_ranked = sorted(lola_results.keys(), key=lambda n: lola_results[n]["support"], reverse=True) + + bm25_rank = {name: i + 1 for i, name in enumerate(bm25_ranked)} + sup_rank = {name: i + 1 for i, name in enumerate(sup_ranked)} + + all_annos = list(dict.fromkeys(bm25_ranked + sup_ranked)) + + # Header + print() + print("┌────────────────────────────────────────────────────────────────────────────────────────────────┐") + print("│ BM25 vs Overlap Enrichment Results │") + print("├──────────────────────┬────────┬──────────┬────────┬──────────┬──────────┬─────────────────────┤") + print("│ Annotation │ BM25 # │ BM25 Scr │ Sup # │ Support │ Fraction │ Description │") + print("├──────────────────────┼────────┼──────────┼────────┼──────────┼──────────┼─────────────────────┤") + + for anno in sorted(all_annos, key=lambda a: bm25_rank.get(a, 99)): + br = str(bm25_rank.get(anno, "-")) + bs = bm25_results.get(anno, {}).get("score", 0) + sr = str(sup_rank.get(anno, "-")) + sup = lola_results.get(anno, {}).get("support", 0) + frac = lola_results.get(anno, {}).get("fraction", 0) + desc = anno_descriptions.get(anno, "")[:19] + + print( + f"│ {anno:<20} │ {br:>6} │ {bs:>8.2f} │ {sr:>6} │ {sup:>8} │ {frac:>8.1%} │ {desc:<19} │" + ) + + print("└──────────────────────┴────────┴──────────┴────────┴──────────┴──────────┴─────────────────────┘") + + # Timing + print(f"\n BM25 time: {bm25_time:.3f}s") + print(f" LOLA time: {lola_time:.3f}s") + + # Rank correlation (BM25 vs support rank) + shared = [a for a in all_annos if a in bm25_rank and a in sup_rank] + if len(shared) >= 2: + bm25_r = [bm25_rank[a] for a in shared] + sup_r = [sup_rank[a] for a in shared] + n = len(shared) + d_sq = sum((b - s) ** 2 for b, s in zip(bm25_r, sup_r)) + rho = 1 - (6 * d_sq) / (n * (n ** 2 - 1)) + print(f" Rank corr: {rho:.3f} (Spearman, BM25 vs support, n={n})") + + +def save_json( + query_path: str, + bm25_results: dict[str, dict], + bm25_time: float, + lola_results: dict[str, dict], + lola_time: float, + anno_descriptions: dict[str, str], + output_path: Path, +): + """Save comprehensive JSON output.""" + # Add ranks + bm25_ranked = sorted(bm25_results.keys(), key=lambda n: bm25_results[n]["score"], reverse=True) + lola_ranked = sorted(lola_results.keys(), key=lambda n: lola_results[n]["support"], reverse=True) + + combined = [] + all_annos = list(dict.fromkeys(bm25_ranked + lola_ranked)) + + for anno in all_annos: + entry = { + "annotation": anno, + "description": anno_descriptions.get(anno, ""), + "bm25": { + "rank": bm25_ranked.index(anno) + 1 if anno in bm25_ranked else None, + **(bm25_results.get(anno, {})), + }, + "lola": { + "rank": lola_ranked.index(anno) + 1 if anno in lola_ranked else None, + **(lola_results.get(anno, {})), + }, + } + combined.append(entry) + + output = { + "query": str(query_path), + "reference_bedset": BEDSET_ID, + "num_annotations": len(all_annos), + "timing": { + "bm25_seconds": round(bm25_time, 4), + "lola_seconds": round(lola_time, 4), + }, + "results": combined, + } + + output_path.parent.mkdir(exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2) + + print(f"\n JSON saved: {output_path}") + + +# ── Main ───────────────────────────────────────────────────────────────── + +def main(): + cache = CACHE_DIR + cache.mkdir(exist_ok=True) + + print() + print(" BM25 vs LOLA Enrichment Comparison") + print(f" Reference: {BEDSET_ID}") + print() + + # Fetch and download + print(" Fetching reference files...") + bed_files = fetch_bedset_files() + print(f" {len(bed_files)} annotations found\n") + + print(" Downloading BED files...") + ref_beds = download_reference_beds(bed_files, cache) + + # Build BM25 vocab + print("\n Building BM25 vocabulary...") + vocab_path, token_to_anno, anno_descriptions = build_vocab_and_mapping( + bed_files, ref_beds, cache + ) + bm25 = Bm25(tokenizer=str(vocab_path), k=1.5, b=0.75, avg_doc_length=1000.0) + print(f" Vocab: {bm25.vocab_size} tokens") + + # Compute IDF across reference corpus + print("\n Computing IDF across reference files...") + idf = compute_idf(bm25, ref_beds) + print(f" {len(idf)} tokens with IDF scores") + + # Query + if len(sys.argv) > 1: + query_path = sys.argv[1] + else: + query_path = get_sample_query(cache) + + query_regions = read_bed_tuples(Path(query_path)) + print(f"\n Query: {Path(query_path).name} ({len(query_regions)} regions)") + + # Run both + print("\n Running BM25 (TF * IDF)...") + bm25_results, bm25_time = run_bm25_enrichment(query_path, bm25, token_to_anno, idf) + print(f" done ({bm25_time:.3f}s)") + + print(" Running LOLA...") + lola_results, lola_time = run_lola_enrichment(query_path, ref_beds) + print(f" done ({lola_time:.3f}s)") + + # Output + print_results(bm25_results, bm25_time, lola_results, lola_time, anno_descriptions) + save_json(query_path, bm25_results, bm25_time, lola_results, lola_time, anno_descriptions, RESULTS_PATH) + + +if __name__ == "__main__": + main() diff --git a/gtars-bm25/tests/demo_qdrant_enrichment.py b/gtars-bm25/tests/demo_qdrant_enrichment.py new file mode 100644 index 00000000..974817f1 --- /dev/null +++ b/gtars-bm25/tests/demo_qdrant_enrichment.py @@ -0,0 +1,460 @@ +""" +BM25 Enrichment via Qdrant Sparse Vector Search + +Indexes a LOLA reference database into Qdrant (in-memory) as sparse vectors, +then searches with a query BED file. Qdrant applies IDF at search time. + +This demonstrates the production-ready approach: embed once, search many times, +IDF computed automatically from the corpus. + +Usage: + python demo_qdrant_enrichment.py [bedset_id] [query.bed] + python demo_qdrant_enrichment.py # UCSC features, sample query + python demo_qdrant_enrichment.py lola_hg38_encode_tfbs # ENCODE TFBS, sample query + python demo_qdrant_enrichment.py lola_hg38_encode_tfbs query.bed # ENCODE TFBS, custom query +""" + +import sys +import gzip +import json +import time +import urllib.request +from pathlib import Path +from collections import defaultdict + +from gtars.bm25 import Bm25 +from gtars.lola import RegionDB, run_lola +from qdrant_client import QdrantClient, models + +API_BASE = "https://api.bedbase.org/v1" +CACHE_DIR = Path(__file__).parent / "demo_cache" + + +# ── Helpers ────────────────────────────────────────────────────────────── + +def _request(url: str) -> urllib.request.Request: + return urllib.request.Request(url, headers={"User-Agent": "gtars-bm25-demo/0.1"}) + + +def fetch_json(url: str) -> dict: + with urllib.request.urlopen(_request(url)) as r: + return json.loads(r.read()) + + +def download_file(url: str, dest: Path): + if dest.exists(): + return + with urllib.request.urlopen(_request(url)) as r, open(dest, "wb") as f: + f.write(r.read()) + + +def fetch_bedset_files(bedset_id: str) -> list[dict]: + """Fetch all BED file metadata from a bedset, handling pagination.""" + results = [] + data = fetch_json(f"{API_BASE}/bedset/{bedset_id}/bedfiles") + results.extend(data["results"]) + return results + + +def download_bed(bed_id: str, name: str, cache: Path) -> Path: + """Download and decompress a BED file from BEDbase.""" + bed_path = cache / f"{name}.bed" + if bed_path.exists(): + return bed_path + + gz_path = cache / f"{name}.bed.gz" + files_meta = fetch_json(f"{API_BASE}/bed/{bed_id}/metadata/files") + url = files_meta["bed_file"]["access_methods"][0]["access_url"]["url"] + download_file(url, gz_path) + + with gzip.open(gz_path, "rb") as f_in, open(bed_path, "wb") as f_out: + f_out.write(f_in.read()) + + return bed_path + + +def read_bed_tuples(path: Path) -> list[tuple[str, int, int]]: + regions = [] + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or line.startswith("track"): + continue + parts = line.split("\t") + if len(parts) < 3: + continue + regions.append((parts[0], int(parts[1]), int(parts[2]))) + return regions + + +def get_sample_query(cache: Path) -> str: + sample_path = cache / "sample_query.bed" + if sample_path.exists(): + return str(sample_path) + + print(" No query provided — fetching a sample from BEDbase...") + beds = fetch_json(f"{API_BASE}/bed/list?genome=hg38&limit=1&offset=50") + bed = beds["results"][0] + bed_id = bed["id"] + name = bed.get("name", bed_id) + print(f" Sample: {name} ({bed_id})") + + files_meta = fetch_json(f"{API_BASE}/bed/{bed_id}/metadata/files") + url = files_meta["bed_file"]["access_methods"][0]["access_url"]["url"] + download_file(url, cache / "sample_query.bed.gz") + + with gzip.open(cache / "sample_query.bed.gz", "rb") as f_in: + with open(sample_path, "wb") as f_out: + f_out.write(f_in.read()) + + return str(sample_path) + + +# ── Qdrant setup ───────────────────────────────────────────────────────── + +def create_collection(client: QdrantClient, name: str): + """Create a Qdrant collection with sparse vectors + IDF modifier.""" + client.create_collection( + collection_name=name, + vectors_config={}, + sparse_vectors_config={ + "bm25": models.SparseVectorParams( + modifier=models.Modifier.IDF, + ) + }, + ) + + +def index_reference_files( + client: QdrantClient, + collection: str, + bm25: Bm25, + bed_files: list[dict], + cache: Path, +) -> dict[str, str]: + """Download, embed, and index all reference BED files into Qdrant. + + Returns {name: description} mapping. + """ + anno_descriptions = {} + points = [] + skipped = 0 + + for i, meta in enumerate(bed_files): + name = meta["name"] + desc = meta.get("description", "") + bed_id = meta["id"] + anno_descriptions[name] = desc + + try: + bed_path = download_bed(bed_id, name, cache) + except Exception as e: + print(f" SKIP {name}: {e}") + skipped += 1 + continue + + sv = bm25.embed(str(bed_path)) + + if len(sv) == 0: + skipped += 1 + continue + + points.append( + models.PointStruct( + id=i, + vector={ + "bm25": models.SparseVector( + indices=list(sv.indices), + values=list(sv.values), + ) + }, + payload={ + "name": name, + "description": desc, + "bed_id": bed_id, + "annotation": meta.get("annotation", {}), + }, + ) + ) + + if len(points) % 50 == 0: + # Batch upsert + client.upsert(collection_name=collection, points=points) + print(f" indexed {i + 1}/{len(bed_files)}...") + points = [] + + # Final batch + if points: + client.upsert(collection_name=collection, points=points) + + print(f" indexed {len(bed_files) - skipped}/{len(bed_files)} files ({skipped} skipped)") + return anno_descriptions + + +def search_enrichment( + client: QdrantClient, + collection: str, + bm25: Bm25, + query_path: str, + limit: int = 20, +) -> tuple[list[dict], float]: + """Embed query and search Qdrant. Returns (results, elapsed).""" + t0 = time.perf_counter() + + sv = bm25.embed(query_path) + if len(sv) == 0: + return [], time.perf_counter() - t0 + + results = client.query_points( + collection_name=collection, + query=models.SparseVector( + indices=list(sv.indices), + values=list(sv.values), + ), + using="bm25", + limit=limit, + with_payload=True, + ) + + elapsed = time.perf_counter() - t0 + + ranked = [] + for point in results.points: + ranked.append({ + "name": point.payload["name"], + "description": point.payload["description"], + "bed_id": point.payload["bed_id"], + "score": round(point.score, 4), + "annotation": point.payload.get("annotation", {}), + }) + + return ranked, elapsed + + +# ── LOLA support ───────────────────────────────────────────────────────── + +def compute_lola_support( + query_path: str, + ref_beds: dict[str, Path], +) -> tuple[dict[str, dict], float]: + """Compute LOLA support/fraction for comparison. Returns ({name: {support, fraction}}, elapsed).""" + bed_paths = list(ref_beds.values()) + bed_names = list(ref_beds.keys()) + + query_tuples = read_bed_tuples(Path(query_path)) + n_query = len(query_tuples) + + # Universe = union of query + all reference regions + universe_tuples = list(query_tuples) + for path in bed_paths: + universe_tuples.extend(read_bed_tuples(path)) + + t0 = time.perf_counter() + region_db = RegionDB.from_bed_files( + [str(p) for p in bed_paths], + filenames=bed_names, + ) + + raw = run_lola( + user_sets=[query_tuples], + universe=universe_tuples, + region_db=region_db, + min_overlap=1, + direction="enrichment", + ) + elapsed = time.perf_counter() - t0 + + results = {} + for i in range(len(raw["filename"])): + name = raw["filename"][i] + support = raw["support"][i] + results[name] = { + "support": support, + "fraction": round(support / n_query, 4) if n_query > 0 else 0, + } + + return results, elapsed + + +# ── Display ────────────────────────────────────────────────────────────── + +def print_results( + qdrant_results: list[dict], + search_time: float, + lola_results: dict[str, dict], + lola_time: float, + bedset_id: str, +): + # Build LOLA support rank lookup + sup_ranked = sorted(lola_results.keys(), key=lambda n: lola_results[n]["support"], reverse=True) + sup_rank = {name: i + 1 for i, name in enumerate(sup_ranked)} + + print() + print(f"┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────┐") + print(f"│ Qdrant BM25+IDF vs LOLA Support Bedset: {bedset_id:<40} │") + print(f"├──────┬──────────────────────────────┬──────────┬────────┬──────────┬──────────┬────────────────────────────┤") + print(f"│ BM25 │ Name │ BM25 Scr │ Sup # │ Support │ Fraction │ Description │") + print(f"├──────┼──────────────────────────────┼──────────┼────────┼──────────┼──────────┼────────────────────────────┤") + + for i, r in enumerate(qdrant_results, 1): + name = r["name"] + score = r["score"] + desc = r["description"][:26] + sr = str(sup_rank.get(name, "-")) + sup = lola_results.get(name, {}).get("support", 0) + frac = lola_results.get(name, {}).get("fraction", 0) + + print( + f"│ {i:>4} │ {name[:28]:<28} │ {score:>8.2f} │ {sr:>6} │ {sup:>8} │ {frac:>8.1%} │ {desc:<26} │" + ) + + print(f"└──────┴──────────────────────────────┴──────────┴────────┴──────────┴──────────┴────────────────────────────┘") + + print(f"\n BM25 search: {search_time:.3f}s (embed + Qdrant)") + print(f" LOLA time: {lola_time:.3f}s") + + # Rank correlation + shared = [r["name"] for r in qdrant_results if r["name"] in sup_rank] + if len(shared) >= 2: + bm25_r = [i + 1 for i, r in enumerate(qdrant_results) if r["name"] in sup_rank] + sup_r = [sup_rank[r["name"]] for r in qdrant_results if r["name"] in sup_rank] + n = len(shared) + d_sq = sum((b - s) ** 2 for b, s in zip(bm25_r, sup_r)) + rho = 1 - (6 * d_sq) / (n * (n ** 2 - 1)) + print(f" Rank corr: {rho:.3f} (Spearman, BM25 vs support, n={n})") + + +def save_json( + query_path: str, + bedset_id: str, + qdrant_results: list[dict], + lola_results: dict[str, dict], + search_time: float, + lola_time: float, + index_time: float, + n_indexed: int, + output_path: Path, +): + # Merge BM25 and LOLA results + combined = [] + for i, r in enumerate(qdrant_results, 1): + name = r["name"] + lola = lola_results.get(name, {}) + combined.append({ + "bm25_rank": i, + "name": name, + "description": r["description"], + "bm25_score": r["score"], + "lola_support": lola.get("support"), + "lola_fraction": lola.get("fraction"), + }) + + output = { + "query": str(query_path), + "reference_bedset": bedset_id, + "num_indexed": n_indexed, + "timing": { + "index_seconds": round(index_time, 4), + "bm25_search_seconds": round(search_time, 4), + "lola_seconds": round(lola_time, 4), + }, + "results": combined, + } + + output_path.parent.mkdir(exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2) + + print(f" JSON saved: {output_path}") + + +# ── Main ───────────────────────────────────────────────────────────────── + +def main(): + # Parse args + bedset_id = sys.argv[1] if len(sys.argv) > 1 else "lola_hg38_ucsc_features" + query_arg = sys.argv[2] if len(sys.argv) > 2 else None + + cache = CACHE_DIR / bedset_id + cache.mkdir(parents=True, exist_ok=True) + collection = f"lola_{bedset_id}" + results_path = CACHE_DIR / f"qdrant_results_{bedset_id}.json" + + print() + print(" Qdrant Sparse Vector Enrichment Demo") + print(f" Bedset: {bedset_id}") + print() + + # Fetch bedset metadata + print(" Fetching bedset metadata...") + bed_files = fetch_bedset_files(bedset_id) + print(f" {len(bed_files)} files in bedset") + + # Build vocab from all reference files + print("\n Downloading reference files & building vocabulary...") + t0 = time.perf_counter() + + # Download all first + ref_beds = {} + for meta in bed_files: + try: + path = download_bed(meta["id"], meta["name"], cache) + ref_beds[meta["name"]] = path + except Exception as e: + print(f" SKIP {meta['name']}: {e}") + + # Build concatenated vocab + vocab_path = cache / "vocab.bed" + with open(vocab_path, "w") as vocab_f: + for name, path in ref_beds.items(): + for chr, start, end in read_bed_tuples(path): + vocab_f.write(f"{chr}\t{start}\t{end}\n") + + download_time = time.perf_counter() - t0 + print(f" Downloaded {len(ref_beds)} files in {download_time:.1f}s") + + # Build BM25 model + print("\n Building BM25 model...") + bm25 = Bm25(tokenizer=str(vocab_path), k=1.5, b=0.75, avg_doc_length=1000.0) + print(f" Vocab: {bm25.vocab_size} tokens") + + # Create Qdrant collection and index + print("\n Indexing into Qdrant (in-memory)...") + client = QdrantClient(":memory:") + create_collection(client, collection) + + t0 = time.perf_counter() + anno_descriptions = index_reference_files(client, collection, bm25, bed_files, cache) + index_time = time.perf_counter() - t0 + print(f" Index time: {index_time:.1f}s") + + # Get collection info + info = client.get_collection(collection) + print(f" Points in collection: {info.points_count}") + + # Query + if query_arg: + query_path = query_arg + else: + query_path = get_sample_query(CACHE_DIR) + + query_regions = read_bed_tuples(Path(query_path)) + print(f"\n Query: {Path(query_path).name} ({len(query_regions)} regions)") + + # BM25 search + print("\n Searching Qdrant (BM25+IDF)...") + qdrant_results, search_time = search_enrichment(client, collection, bm25, query_path, limit=20) + + # LOLA support + print(" Computing LOLA support...") + lola_results, lola_time = compute_lola_support(query_path, ref_beds) + print(f" done ({lola_time:.1f}s)") + + # Output + n_indexed = info.points_count + print_results(qdrant_results, search_time, lola_results, lola_time, bedset_id) + save_json(query_path, bedset_id, qdrant_results, lola_results, search_time, lola_time, index_time, n_indexed, results_path) + + +if __name__ == "__main__": + main() diff --git a/gtars-cli/Cargo.toml b/gtars-cli/Cargo.toml index 0bb4b8e2..ccbde2b9 100644 --- a/gtars-cli/Cargo.toml +++ b/gtars-cli/Cargo.toml @@ -25,7 +25,7 @@ gtars-igd = { path = "../gtars-igd", optional=true, version="0.5.1" } gtars-uniwig = { path = "../gtars-uniwig", optional=true, version="0.8.0" } gtars-overlaprs = { path = "../gtars-overlaprs", optional = true, version="0.5.1" } gtars-bbcache = { path = "../gtars-bbcache", optional=true, version="0.5.3" } -gtars-genomicdist = { path = "../gtars-genomicdist", optional=true, version="0.6.0" } +gtars-genomicdist = { path = "../gtars-genomicdist", optional=true, version="0.7.0" } gtars-core = { path = "../gtars-core", version="0.5.5", features=["bigbed", "http"] } # serialization diff --git a/gtars-genomicdist/Cargo.toml b/gtars-genomicdist/Cargo.toml index 93e7f383..2c62ead4 100644 --- a/gtars-genomicdist/Cargo.toml +++ b/gtars-genomicdist/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gtars-genomicdist" -version = "0.6.0" +version = "0.7.0" edition = "2024" description = "Rust port of GenomicDistributions: tools for computing statistics for genomic interval sets" license = "MIT" diff --git a/gtars-genomicdist/src/interval_ranges.rs b/gtars-genomicdist/src/interval_ranges.rs index aa2b952b..343f71be 100644 --- a/gtars-genomicdist/src/interval_ranges.rs +++ b/gtars-genomicdist/src/interval_ranges.rs @@ -1140,6 +1140,140 @@ pub fn pairwise_jaccard(sets: &[RegionSet]) -> Vec { matrix } +// --- Indexed operations on RegionSetList --- +// +// These let callers operate on pairs by index without cloning full RegionSets +// across an FFI boundary (wasm, Python, R). + +use gtars_core::models::RegionSetList; + +/// Indexed pair operations on a RegionSetList. +pub trait RegionSetListOps { + fn pintersect_at(&self, i: usize, j: usize) -> Option; + fn pintersect_count(&self, i: usize, j: usize) -> Option; + fn jaccard_at(&self, i: usize, j: usize) -> Option; + fn union_at(&self, i: usize, j: usize) -> Option; + fn setdiff_at(&self, i: usize, j: usize) -> Option; + fn region_count(&self, i: usize) -> Option; + fn union_except(&self, skip: usize) -> Option; + /// Compute union-of-all and all N union-except results in O(n) unions + /// using prefix/suffix arrays. Returns (full_union, vec_of_union_except). + fn bulk_union_except(&self) -> Option<(RegionSet, Vec)>; + /// Fold all sets into a single union. + fn union_all(&self) -> Option; + /// Fold all sets into a single intersection. + fn intersect_all(&self) -> Option; +} + +impl RegionSetListOps for RegionSetList { + fn pintersect_at(&self, i: usize, j: usize) -> Option { + let a = self.get(i)?; + let b = self.get(j)?; + Some(a.pintersect(b)) + } + + fn pintersect_count(&self, i: usize, j: usize) -> Option { + self.pintersect_at(i, j).map(|rs| rs.len() as u32) + } + + fn jaccard_at(&self, i: usize, j: usize) -> Option { + let a = self.get(i)?; + let b = self.get(j)?; + Some(a.jaccard(b)) + } + + fn union_at(&self, i: usize, j: usize) -> Option { + let a = self.get(i)?; + let b = self.get(j)?; + Some(a.union(b)) + } + + fn setdiff_at(&self, i: usize, j: usize) -> Option { + let a = self.get(i)?; + let b = self.get(j)?; + Some(a.setdiff(b)) + } + + fn region_count(&self, i: usize) -> Option { + self.get(i).map(|rs| rs.len() as u32) + } + + fn union_except(&self, skip: usize) -> Option { + let n = self.len(); + if n < 2 || skip >= n { return None; } + let first = if skip == 0 { 1 } else { 0 }; + let mut acc = self.get(first)?.clone(); + for k in (first + 1)..n { + if k == skip { continue; } + if let Some(other) = self.get(k) { + acc = acc.union(other); + } + } + Some(acc) + } + + fn bulk_union_except(&self) -> Option<(RegionSet, Vec)> { + let n = self.len(); + if n < 2 { return None; } + + // prefix[i] = union(set[0]..=set[i]) + let mut prefix = Vec::with_capacity(n); + prefix.push(self.get(0)?.clone()); + for i in 1..n { + let prev = &prefix[i - 1]; + prefix.push(prev.union(self.get(i)?)); + } + + // suffix[i] = union(set[i]..=set[n-1]), built incrementally from right + let mut suffix = vec![None; n]; + suffix[n - 1] = Some(self.get(n - 1)?.clone()); + for i in (0..n - 1).rev() { + suffix[i] = Some(self.get(i)?.union(suffix[i + 1].as_ref().unwrap())); + } + let suffix: Vec = suffix.into_iter().map(|s| s.unwrap()).collect(); + + let full_union = prefix[n - 1].clone(); + + // union_except[i] = union(prefix[i-1], suffix[i+1]) + let mut results = Vec::with_capacity(n); + for i in 0..n { + let except = match (i > 0, i < n - 1) { + (false, true) => suffix[1].clone(), + (true, false) => prefix[i - 1].clone(), + (true, true) => prefix[i - 1].union(&suffix[i + 1]), + (false, false) => unreachable!(), // n >= 2 + }; + results.push(except); + } + + Some((full_union, results)) + } + + fn union_all(&self) -> Option { + let n = self.len(); + if n == 0 { return None; } + let mut acc = self.get(0)?.clone(); + for i in 1..n { + if let Some(other) = self.get(i) { + acc = acc.union(other); + } + } + Some(acc) + } + + fn intersect_all(&self) -> Option { + let n = self.len(); + if n == 0 { return None; } + let mut acc = self.get(0)?.clone(); + for i in 1..n { + if let Some(other) = self.get(i) { + acc = acc.intersect(other); + } + } + Some(acc) + } +} + #[cfg(test)] mod tests { use super::*; @@ -2420,4 +2554,283 @@ mod tests { // overlap with [0,10): [5,10)=5bp, overlap with [20,30): [20,25)=5bp assert_eq!(super::merge_intersection_bp(&a, &b), 10); } + + // ── RegionSetListOps tests ───────────────────────────────────────── + + fn make_rsl(sets: Vec) -> RegionSetList { + RegionSetList::from(sets) + } + + #[rstest] + fn test_rsl_pintersect_count() { + let a = make_regionset(vec![("chr1", 0, 100), ("chr1", 200, 300)]); + let b = make_regionset(vec![("chr1", 50, 150), ("chr1", 250, 350)]); + let rsl = make_rsl(vec![a, b]); + // a has 2 regions, b has 2 regions, both overlap each other + let count = rsl.pintersect_count(0, 1).unwrap(); + assert_eq!(count, 2); // both pairs overlap + } + + #[rstest] + fn test_rsl_pintersect_count_no_overlap() { + let a = make_regionset(vec![("chr1", 0, 10)]); + let b = make_regionset(vec![("chr1", 100, 200)]); + let rsl = make_rsl(vec![a, b]); + // paired by index: chr1:0-10 vs chr1:100-200 → no genomic overlap, + // but pintersect produces a zero-width region (start=end) per pair + let count = rsl.pintersect_count(0, 1).unwrap(); + assert_eq!(count, 1); // zero-width region still counted + } + + #[rstest] + fn test_rsl_pintersect_count_out_of_bounds() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let rsl = make_rsl(vec![a]); + assert!(rsl.pintersect_count(0, 5).is_none()); + } + + #[rstest] + fn test_rsl_jaccard_at() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 0, 100)]); + let rsl = make_rsl(vec![a, b]); + let j = rsl.jaccard_at(0, 1).unwrap(); + assert!((j - 1.0).abs() < 1e-9, "identical sets should have jaccard=1.0"); + } + + #[rstest] + fn test_rsl_jaccard_at_disjoint() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 200, 300)]); + let rsl = make_rsl(vec![a, b]); + let j = rsl.jaccard_at(0, 1).unwrap(); + assert!((j - 0.0).abs() < 1e-9, "disjoint sets should have jaccard=0.0"); + } + + #[rstest] + fn test_rsl_union_at() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 50, 150)]); + let rsl = make_rsl(vec![a, b]); + let u = rsl.union_at(0, 1).unwrap(); + assert_eq!(u.regions.len(), 1); + assert_eq!(u.regions[0].start, 0); + assert_eq!(u.regions[0].end, 150); + } + + #[rstest] + fn test_rsl_setdiff_at() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 50, 150)]); + let rsl = make_rsl(vec![a, b]); + let diff = rsl.setdiff_at(0, 1).unwrap(); + // a minus b: chr1:0-50 + assert_eq!(diff.regions.len(), 1); + assert_eq!(diff.regions[0].start, 0); + assert_eq!(diff.regions[0].end, 50); + } + + #[rstest] + fn test_rsl_region_count() { + let a = make_regionset(vec![("chr1", 0, 100), ("chr1", 200, 300)]); + let b = make_regionset(vec![("chr1", 50, 150)]); + let rsl = make_rsl(vec![a, b]); + assert_eq!(rsl.region_count(0).unwrap(), 2); + assert_eq!(rsl.region_count(1).unwrap(), 1); + assert!(rsl.region_count(5).is_none()); + } + + #[rstest] + fn test_rsl_union_all() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 50, 200)]); + let c = make_regionset(vec![("chr1", 150, 300)]); + let rsl = make_rsl(vec![a, b, c]); + let u = rsl.union_all().unwrap(); + assert_eq!(u.regions.len(), 1); + assert_eq!(u.regions[0].start, 0); + assert_eq!(u.regions[0].end, 300); + } + + #[rstest] + fn test_rsl_union_all_empty() { + let rsl = make_rsl(vec![]); + assert!(rsl.union_all().is_none()); + } + + #[rstest] + fn test_rsl_union_all_single() { + let a = make_regionset(vec![("chr1", 10, 50)]); + let rsl = make_rsl(vec![a]); + let u = rsl.union_all().unwrap(); + assert_eq!(u.regions.len(), 1); + assert_eq!(u.regions[0].start, 10); + assert_eq!(u.regions[0].end, 50); + } + + #[rstest] + fn test_rsl_intersect_all() { + // Three overlapping sets — intersection is the region shared by all three + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 30, 200)]); + let c = make_regionset(vec![("chr1", 60, 150)]); + let rsl = make_rsl(vec![a, b, c]); + let inter = rsl.intersect_all().unwrap(); + assert_eq!(inter.regions.len(), 1); + assert_eq!(inter.regions[0].start, 60); + assert_eq!(inter.regions[0].end, 100); + } + + #[rstest] + fn test_rsl_intersect_all_disjoint() { + let a = make_regionset(vec![("chr1", 0, 50)]); + let b = make_regionset(vec![("chr1", 100, 200)]); + let rsl = make_rsl(vec![a, b]); + let inter = rsl.intersect_all().unwrap(); + assert_eq!(inter.regions.len(), 0); + } + + #[rstest] + fn test_rsl_intersect_all_empty() { + let rsl = make_rsl(vec![]); + assert!(rsl.intersect_all().is_none()); + } + + #[rstest] + fn test_rsl_intersect_all_different_sizes() { + // This is the case where pintersect would give wrong results: + // sets have different numbers of regions, but share genomic coverage + let a = make_regionset(vec![("chr1", 0, 100), ("chr1", 200, 300)]); + let b = make_regionset(vec![("chr1", 50, 250)]); + let rsl = make_rsl(vec![a, b]); + let inter = rsl.intersect_all().unwrap(); + // Shared coverage: [50,100) and [200,250) + assert_eq!(inter.regions.len(), 2); + assert_eq!(inter.regions[0].start, 50); + assert_eq!(inter.regions[0].end, 100); + assert_eq!(inter.regions[1].start, 200); + assert_eq!(inter.regions[1].end, 250); + } + + #[rstest] + fn test_rsl_union_except() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 200, 300)]); + let c = make_regionset(vec![("chr1", 400, 500)]); + let rsl = make_rsl(vec![a, b, c]); + // union_except(1) = union of sets 0 and 2 (skip set 1) + let ue = rsl.union_except(1).unwrap(); + assert_eq!(ue.regions.len(), 2); + assert_eq!(ue.regions[0].start, 0); + assert_eq!(ue.regions[0].end, 100); + assert_eq!(ue.regions[1].start, 400); + assert_eq!(ue.regions[1].end, 500); + } + + #[rstest] + fn test_rsl_union_except_too_small() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let rsl = make_rsl(vec![a]); + assert!(rsl.union_except(0).is_none()); + } + + #[rstest] + fn test_rsl_bulk_union_except_n2() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 200, 300)]); + let rsl = make_rsl(vec![a, b]); + let (full_union, excepts) = rsl.bulk_union_except().unwrap(); + + // Full union covers both regions + assert_eq!(full_union.regions.len(), 2); + + // except[0] = union of everything except set 0 = set 1 + assert_eq!(excepts.len(), 2); + assert_eq!(excepts[0].regions.len(), 1); + assert_eq!(excepts[0].regions[0].start, 200); + assert_eq!(excepts[0].regions[0].end, 300); + + // except[1] = union of everything except set 1 = set 0 + assert_eq!(excepts[1].regions.len(), 1); + assert_eq!(excepts[1].regions[0].start, 0); + assert_eq!(excepts[1].regions[0].end, 100); + } + + #[rstest] + fn test_rsl_bulk_union_except_n3() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let b = make_regionset(vec![("chr1", 200, 300)]); + let c = make_regionset(vec![("chr1", 400, 500)]); + let rsl = make_rsl(vec![a, b, c]); + let (full_union, excepts) = rsl.bulk_union_except().unwrap(); + + assert_eq!(full_union.regions.len(), 3); + assert_eq!(excepts.len(), 3); + + // except[0] = union(b, c) + assert_eq!(excepts[0].regions.len(), 2); + assert_eq!(excepts[0].regions[0].start, 200); + assert_eq!(excepts[0].regions[1].start, 400); + + // except[1] = union(a, c) + assert_eq!(excepts[1].regions.len(), 2); + assert_eq!(excepts[1].regions[0].start, 0); + assert_eq!(excepts[1].regions[1].start, 400); + + // except[2] = union(a, b) + assert_eq!(excepts[2].regions.len(), 2); + assert_eq!(excepts[2].regions[0].start, 0); + assert_eq!(excepts[2].regions[1].start, 200); + } + + #[rstest] + fn test_rsl_bulk_union_except_too_small() { + let a = make_regionset(vec![("chr1", 0, 100)]); + let rsl = make_rsl(vec![a]); + assert!(rsl.bulk_union_except().is_none()); + + let rsl_empty = make_rsl(vec![]); + assert!(rsl_empty.bulk_union_except().is_none()); + } + + #[rstest] + fn test_rsl_bulk_union_except_matches_union_except() { + // Verify bulk algorithm produces same results as individual union_except calls + let a = make_regionset(vec![("chr1", 0, 100), ("chr2", 50, 200)]); + let b = make_regionset(vec![("chr1", 80, 180), ("chr2", 100, 300)]); + let c = make_regionset(vec![("chr1", 150, 250)]); + let d = make_regionset(vec![("chr2", 0, 150)]); + let rsl = make_rsl(vec![a, b, c, d]); + + let (_, bulk_excepts) = rsl.bulk_union_except().unwrap(); + + for i in 0..4 { + let individual = rsl.union_except(i).unwrap(); + assert_eq!( + bulk_excepts[i].regions.len(), + individual.regions.len(), + "region count mismatch at index {}", + i + ); + for (j, (bulk_r, indiv_r)) in bulk_excepts[i] + .regions + .iter() + .zip(individual.regions.iter()) + .enumerate() + { + assert_eq!( + bulk_r.chr, indiv_r.chr, + "chr mismatch at except[{}][{}]", i, j + ); + assert_eq!( + bulk_r.start, indiv_r.start, + "start mismatch at except[{}][{}]", i, j + ); + assert_eq!( + bulk_r.end, indiv_r.end, + "end mismatch at except[{}][{}]", i, j + ); + } + } + } } diff --git a/gtars-genomicdist/src/lib.rs b/gtars-genomicdist/src/lib.rs index f935cc26..3422204b 100644 --- a/gtars-genomicdist/src/lib.rs +++ b/gtars-genomicdist/src/lib.rs @@ -41,6 +41,7 @@ pub use gtars_core::models::CoordinateMode; pub use consensus::{ConsensusRegion, consensus}; pub use interval_ranges::IntervalRanges; pub use interval_ranges::pairwise_jaccard; +pub use interval_ranges::RegionSetListOps; pub use partitions::{ calc_expected_partitions, calc_partitions, genome_partition_list, ExpectedPartitionResult, ExpectedPartitionRow, GeneModel, PartitionList, PartitionResult, diff --git a/gtars-lola/Cargo.toml b/gtars-lola/Cargo.toml index 48cbcf70..bdf0ffef 100644 --- a/gtars-lola/Cargo.toml +++ b/gtars-lola/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gtars-lola" -version = "0.1.0" +version = "0.2.0" edition = "2024" description = "LOLA (Locus Overlap Analysis) for genomic region enrichment testing" license = "MIT" diff --git a/gtars-lola/src/database.rs b/gtars-lola/src/database.rs index d2d5737f..0b6ba5d7 100644 --- a/gtars-lola/src/database.rs +++ b/gtars-lola/src/database.rs @@ -24,14 +24,14 @@ pub struct CollectionAnno { #[derive(Debug, Clone, Default)] pub struct RegionSetAnno { pub filename: String, - pub cell_type: String, - pub description: String, - pub tissue: String, - pub data_source: String, - pub antibody: String, - pub treatment: String, + pub cell_type: Option, + pub description: Option, + pub tissue: Option, + pub data_source: Option, + pub antibody: Option, + pub treatment: Option, /// Which collection this file belongs to. - pub collection: String, + pub collection: Option, } /// A LOLA region database: IGD index + original region sets + annotations. @@ -145,12 +145,12 @@ impl RegionDB { // Fall back to collection description if file-level is empty. let mut anno = anno_map.get(fname).cloned().unwrap_or(RegionSetAnno { filename: fname.clone(), - collection: coll_name.clone(), + collection: Some(coll_name.clone()), ..Default::default() }); - if anno.description.is_empty() { + if anno.description.is_none() { // R LOLA uses the collection folder name as fallback - anno.description = coll_name.clone(); + anno.description = Some(coll_name.clone()); } all_region_anno.push(anno); files_loaded += 1; @@ -248,7 +248,7 @@ impl RegionDB { .iter() .filter(|a| { if let Some(filter) = collections { - filter.iter().any(|f| *f == a.collection) + a.collection.as_deref().map_or(false, |c| filter.contains(&c)) } else { true } @@ -269,7 +269,7 @@ impl RegionDB { .filter(|(_, a)| { let name_match = filenames.iter().any(|f| *f == a.filename); let coll_match = if let Some(filter) = collections { - filter.iter().any(|f| *f == a.collection) + a.collection.as_deref().map_or(false, |c| filter.contains(&c)) } else { true }; @@ -429,23 +429,30 @@ fn parse_index_txt(path: &Path, collection_name: &str) -> Vec { let fields: Vec<&str> = line.split(sep).collect(); - let get = |key: &str| -> String { + let get_required = |key: &str| -> String { col_map .get(key) .and_then(|&i| fields.get(i)) .map(|s| s.trim().to_string()) .unwrap_or_default() }; + let get_optional = |key: &str| -> Option { + col_map + .get(key) + .and_then(|&i| fields.get(i)) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + }; annos.push(RegionSetAnno { - filename: get("filename"), - cell_type: get("cellType"), - description: get("description"), - tissue: get("tissue"), - data_source: get("dataSource"), - antibody: get("antibody"), - treatment: get("treatment"), - collection: collection_name.to_string(), + filename: get_required("filename"), + cell_type: get_optional("cellType"), + description: get_optional("description"), + tissue: get_optional("tissue"), + data_source: get_optional("dataSource"), + antibody: get_optional("antibody"), + treatment: get_optional("treatment"), + collection: Some(collection_name.to_string()), }); } @@ -504,8 +511,8 @@ mod tests { assert_eq!(db.collection_anno.len(), 1); assert_eq!(db.collection_anno[0].collector, "John"); assert_eq!(db.region_anno.len(), 2); - assert_eq!(db.region_anno[0].cell_type, "K562"); - assert_eq!(db.region_anno[1].cell_type, "HeLa"); + assert_eq!(db.region_anno[0].cell_type.as_deref(), Some("K562")); + assert_eq!(db.region_anno[1].cell_type.as_deref(), Some("HeLa")); // Region sets should have correct counts assert_eq!(db.region_sets[0].regions.len(), 3); // file1: 3 regions @@ -599,7 +606,7 @@ mod tests { let anno = RegionSetAnno { filename: "test.bed".to_string(), - cell_type: "K562".to_string(), + cell_type: Some("K562".to_string()), ..Default::default() }; @@ -665,12 +672,12 @@ mod tests { assert_eq!(db.num_region_sets(), 1); // R LOLA uses the collection folder name as fallback, not collection.txt description assert_eq!( - db.region_anno[0].description, "fallback_coll", + db.region_anno[0].description.as_deref(), Some("fallback_coll"), "Description should fall back to collection folder name when index.txt description is empty" ); // Other fields from index.txt should still be present - assert_eq!(db.region_anno[0].cell_type, "K562"); - assert_eq!(db.region_anno[0].tissue, "blood"); + assert_eq!(db.region_anno[0].cell_type.as_deref(), Some("K562")); + assert_eq!(db.region_anno[0].tissue.as_deref(), Some("blood")); } #[test] @@ -728,7 +735,7 @@ mod tests { .region_anno .iter() .enumerate() - .filter(|(_, a)| a.collection == "coll1") + .filter(|(_, a)| a.collection.as_deref() == Some("coll1")) .map(|(i, _)| i) .collect(); diff --git a/gtars-lola/src/enrichment.rs b/gtars-lola/src/enrichment.rs index d3c60c7e..705c0a3d 100644 --- a/gtars-lola/src/enrichment.rs +++ b/gtars-lola/src/enrichment.rs @@ -264,13 +264,13 @@ pub fn run_lola( d, q_value: None, filename, - collection: String::new(), - description: String::new(), - cell_type: String::new(), - tissue: String::new(), - antibody: String::new(), - treatment: String::new(), - data_source: String::new(), + collection: None, + description: None, + cell_type: None, + tissue: None, + antibody: None, + treatment: None, + data_source: None, db_set_size: 0, }); } diff --git a/gtars-lola/src/models.rs b/gtars-lola/src/models.rs index 259c13b1..fc177aef 100644 --- a/gtars-lola/src/models.rs +++ b/gtars-lola/src/models.rs @@ -83,19 +83,19 @@ pub struct LolaResult { /// DB set filename (from IGD file_info). pub filename: String, /// Collection name this DB set belongs to. - pub collection: String, + pub collection: Option, /// Description from index.txt. - pub description: String, + pub description: Option, /// Cell type annotation. - pub cell_type: String, + pub cell_type: Option, /// Tissue annotation. - pub tissue: String, + pub tissue: Option, /// Antibody annotation. - pub antibody: String, + pub antibody: Option, /// Treatment annotation. - pub treatment: String, + pub treatment: Option, /// Data source annotation. - pub data_source: String, + pub data_source: Option, /// Number of regions in the DB set. pub db_set_size: u64, } diff --git a/gtars-lola/src/output.rs b/gtars-lola/src/output.rs index 64e34edb..aee22290 100644 --- a/gtars-lola/src/output.rs +++ b/gtars-lola/src/output.rs @@ -1,7 +1,6 @@ //! Output formatting, FDR correction, and annotation. use std::io::Write; -use std::path::Path; use crate::database::RegionDB; use crate::models::LolaResult; @@ -16,7 +15,7 @@ pub fn annotate_results(results: &mut [LolaResult], db: &RegionDB) { let anno = &db.region_anno[r.db_set]; r.collection = anno.collection.clone(); // Truncate description to 80 chars (matches R LOLA behavior) - r.description = anno.description.chars().take(80).collect(); + r.description = anno.description.as_ref().map(|d| d.chars().take(80).collect()); r.cell_type = anno.cell_type.clone(); r.tissue = anno.tissue.clone(); r.antibody = anno.antibody.clone(); @@ -105,6 +104,94 @@ pub fn apply_fdr_correction(results: &mut [LolaResult]) { } } +/// Column-oriented representation of LOLA results. +/// +/// Each field is a parallel Vec — row `i` across all fields describes one result. +/// Bindings should convert this to their native columnar type (JS object, PyDict, +/// R data.frame) rather than reimplementing the row→column pivot. +#[derive(Debug, Clone)] +pub struct LolaColumnar { + pub user_set: Vec, + pub db_set: Vec, + pub p_value_log: Vec, + pub odds_ratio: Vec, + pub support: Vec, + pub rnk_pv: Vec, + pub rnk_or: Vec, + pub rnk_sup: Vec, + pub max_rnk: Vec, + pub mean_rnk: Vec, + pub b: Vec, + pub c: Vec, + pub d: Vec, + pub q_value: Vec>, + pub filename: Vec, + pub collection: Vec>, + pub description: Vec>, + pub cell_type: Vec>, + pub tissue: Vec>, + pub antibody: Vec>, + pub treatment: Vec>, + pub data_source: Vec>, + pub db_set_size: Vec, +} + +/// Convert a slice of LolaResults into column-oriented vectors. +pub fn results_to_columns(results: &[LolaResult]) -> LolaColumnar { + let n = results.len(); + let mut c = LolaColumnar { + user_set: Vec::with_capacity(n), + db_set: Vec::with_capacity(n), + p_value_log: Vec::with_capacity(n), + odds_ratio: Vec::with_capacity(n), + support: Vec::with_capacity(n), + rnk_pv: Vec::with_capacity(n), + rnk_or: Vec::with_capacity(n), + rnk_sup: Vec::with_capacity(n), + max_rnk: Vec::with_capacity(n), + mean_rnk: Vec::with_capacity(n), + b: Vec::with_capacity(n), + c: Vec::with_capacity(n), + d: Vec::with_capacity(n), + q_value: Vec::with_capacity(n), + filename: Vec::with_capacity(n), + collection: Vec::with_capacity(n), + description: Vec::with_capacity(n), + cell_type: Vec::with_capacity(n), + tissue: Vec::with_capacity(n), + antibody: Vec::with_capacity(n), + treatment: Vec::with_capacity(n), + data_source: Vec::with_capacity(n), + db_set_size: Vec::with_capacity(n), + }; + for r in results { + c.user_set.push(r.user_set); + c.db_set.push(r.db_set); + c.p_value_log.push(r.p_value_log); + c.odds_ratio.push(r.odds_ratio); + c.support.push(r.support); + c.rnk_pv.push(r.rnk_pv); + c.rnk_or.push(r.rnk_or); + c.rnk_sup.push(r.rnk_sup); + c.max_rnk.push(r.max_rnk); + c.mean_rnk.push(r.mean_rnk); + c.b.push(r.b); + c.c.push(r.c); + c.d.push(r.d); + c.q_value.push(r.q_value); + c.filename.push(r.filename.clone()); + c.collection.push(r.collection.clone()); + c.description.push(r.description.clone()); + c.cell_type.push(r.cell_type.clone()); + c.tissue.push(r.tissue.clone()); + c.antibody.push(r.antibody.clone()); + c.treatment.push(r.treatment.clone()); + c.data_source.push(r.data_source.clone()); + c.db_set_size.push(r.db_set_size); + } + c +} + /// Write LOLA results as TSV matching R LOLA's `writeCombinedEnrichment` format. pub fn write_results_tsv( writer: &mut W, @@ -124,14 +211,13 @@ pub fn write_results_tsv( .q_value .map(|q| format!("{:.6e}", q)) .unwrap_or_else(|| "NA".to_string()); - writeln!( writer, "{}\t{}\t{}\t{:.4}\t{:.4}\t{}\t{}\t{}\t{}\t{}\t{:.2}\t{}\t{}\t{}\t\ {}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}", r.user_set + 1, // 1-based for R compatibility r.db_set + 1, - r.collection, + r.collection.as_deref().unwrap_or(""), r.p_value_log, r.odds_ratio, r.support, @@ -143,12 +229,12 @@ pub fn write_results_tsv( r.b, r.c, r.d, - r.description, - r.cell_type, - r.tissue, - r.antibody, - r.treatment, - r.data_source, + r.description.as_deref().unwrap_or(""), + r.cell_type.as_deref().unwrap_or(""), + r.tissue.as_deref().unwrap_or(""), + r.antibody.as_deref().unwrap_or(""), + r.treatment.as_deref().unwrap_or(""), + r.data_source.as_deref().unwrap_or(""), r.filename, qv, r.db_set_size, @@ -158,15 +244,6 @@ pub fn write_results_tsv( Ok(()) } -/// Write results to a TSV file on disk. -pub fn write_results_to_file( - path: &Path, - results: &[LolaResult], -) -> std::io::Result<()> { - let mut file = std::fs::File::create(path)?; - write_results_tsv(&mut file, results) -} - #[cfg(test)] mod tests { use super::*; @@ -189,13 +266,13 @@ mod tests { d: 100, q_value: None, filename: format!("file{}.bed", db_set), - collection: String::new(), - description: String::new(), - cell_type: String::new(), - tissue: String::new(), - antibody: String::new(), - treatment: String::new(), - data_source: String::new(), + collection: None, + description: None, + cell_type: None, + tissue: None, + antibody: None, + treatment: None, + data_source: None, db_set_size: 0, } } @@ -400,4 +477,55 @@ mod tests { let output = String::from_utf8(buf).unwrap(); assert!(output.contains("NA")); // q_value should be NA } + + #[test] + fn test_results_to_columns_basic() { + let results = vec![ + make_result(0, 0, 3.0), + make_result(1, 2, 5.0), + ]; + let c = results_to_columns(&results); + + assert_eq!(c.user_set.len(), 2); + assert_eq!(c.user_set, vec![0, 1]); + assert_eq!(c.db_set, vec![0, 2]); + assert_eq!(c.p_value_log, vec![3.0, 5.0]); + assert_eq!(c.odds_ratio, vec![1.0, 1.0]); + assert_eq!(c.support, vec![10, 10]); + assert_eq!(c.b, vec![5, 5]); + assert_eq!(c.c, vec![5, 5]); + assert_eq!(c.d, vec![100, 100]); + assert_eq!(c.filename, vec!["file0.bed", "file2.bed"]); + assert_eq!(c.db_set_size, vec![0, 0]); + // empty strings → None + assert_eq!(c.collection, vec![None, None]); + assert_eq!(c.description, vec![None, None]); + assert_eq!(c.cell_type, vec![None, None]); + assert_eq!(c.tissue, vec![None, None]); + assert_eq!(c.antibody, vec![None, None]); + assert_eq!(c.treatment, vec![None, None]); + assert_eq!(c.data_source, vec![None, None]); + } + + #[test] + fn test_results_to_columns_empty() { + let c = results_to_columns(&[]); + assert!(c.user_set.is_empty()); + assert!(c.filename.is_empty()); + } + + #[test] + fn test_results_to_columns_with_metadata() { + let mut r = make_result(0, 0, 1.0); + r.collection = Some("ENCODE".to_string()); + r.cell_type = Some("K562".to_string()); + r.tissue = None; // stays None + r.q_value = Some(0.05); + + let c = results_to_columns(&[r]); + assert_eq!(c.collection, vec![Some("ENCODE".to_string())]); + assert_eq!(c.cell_type, vec![Some("K562".to_string())]); + assert_eq!(c.tissue, vec![None]); + assert_eq!(c.q_value, vec![Some(0.05)]); + } } diff --git a/gtars-python/Cargo.toml b/gtars-python/Cargo.toml index b2bf21ad..66d9fded 100644 --- a/gtars-python/Cargo.toml +++ b/gtars-python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gtars-py" -version = "0.8.0" +version = "0.8.1" edition = "2024" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -9,13 +9,14 @@ name = "gtars" crate-type = ["cdylib"] [features] -default = ["refget", "tokenizers", "genomic_distributions", "models", "utils", "lola"] +default = ["refget", "tokenizers", "genomic_distributions", "models", "utils", "lola", "bm25"] refget = ["dep:gtars-refget"] tokenizers = ["dep:gtars-tokenizers", "dep:gtars-core"] genomic_distributions = ["dep:gtars-genomicdist"] models = ["dep:gtars-core", "dep:gtars-genomicdist", "dep:gtars-overlaprs"] utils = ["dep:gtars-core", "dep:gtars-io"] lola = ["dep:gtars-lola", "dep:gtars-igd", "dep:gtars-core"] +bm25 = ["dep:gtars-bm25"] [dependencies] anyhow = { workspace = true } @@ -24,6 +25,7 @@ pyo3 = { version = "0.27.1", features=["anyhow", "extension-module"] } openssl = { version = "0.10", features = ["vendored"] } # our code (optional, behind feature flags) +gtars-bm25 = { path = "../gtars-bm25", optional = true } gtars-core = { path = "../gtars-core", features=["bigbed", "http"], optional = true } gtars-genomicdist = { path = "../gtars-genomicdist", optional = true } gtars-refget = { path = "../gtars-refget", optional = true } diff --git a/gtars-python/py_src/gtars/bm25/__init__.py b/gtars-python/py_src/gtars/bm25/__init__.py new file mode 100644 index 00000000..d5a58c9c --- /dev/null +++ b/gtars-python/py_src/gtars/bm25/__init__.py @@ -0,0 +1 @@ +from .gtars.bm25 import * # noqa: F403 \ No newline at end of file diff --git a/gtars-python/py_src/gtars/bm25/__init__.pyi b/gtars-python/py_src/gtars/bm25/__init__.pyi new file mode 100644 index 00000000..959694bc --- /dev/null +++ b/gtars-python/py_src/gtars/bm25/__init__.pyi @@ -0,0 +1,96 @@ +from typing import Any, List, Union + +from gtars.tokenizers import Tokenizer + +class SparseVector: + """ + A sparse vector with indices and values, compatible with Qdrant's sparse vector format. + """ + + @property + def indices(self) -> List[int]: + """ + The non-zero dimension indices. + """ + + @property + def values(self) -> List[float]: + """ + The values at each non-zero dimension. + """ + + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + +class Bm25: + """ + BM25 sparse embedding model for genomic intervals. + + Computes BM25-like term frequency scores over a vocabulary of genomic regions. + Designed for hybrid search with dense models like Atacformer via Qdrant. + """ + + def __init__( + self, + tokenizer: Union[str, Tokenizer], + k: float = 1.0, + b: float = 0.75, + avg_doc_length: float = 1000.0, + ) -> None: + """ + Create a new BM25 model. + + Args: + tokenizer: A path to a BED/BED.GZ/TOML vocabulary file, or an existing Tokenizer object. + k: Term frequency saturation parameter. Higher values increase the impact of term frequency. + b: Document length normalization parameter (0 to 1). Higher values penalize longer documents more. + avg_doc_length: Assumed average document length (constant across the corpus). + """ + + def embed(self, regions: Any) -> SparseVector: + """ + Embed a set of regions into a BM25 sparse vector. + + Args: + regions: A file path to a BED file, or an iterable of objects with chr, start, end attributes. + + Returns: + SparseVector: Sparse vector with token ID indices and BM25 term-frequency scores as values. + """ + + def tokenize(self, regions: Any) -> List[int]: + """ + Tokenize regions into token IDs without computing BM25 scores. + + Args: + regions: A file path to a BED file, or an iterable of objects with chr, start, end attributes. + + Returns: + List[int]: Token IDs of overlapping vocabulary regions. + """ + + @property + def vocab_size(self) -> int: + """ + The number of regions in the vocabulary. + """ + + @property + def k(self) -> float: + """ + The term frequency saturation parameter. + """ + + @property + def b(self) -> float: + """ + The document length normalization parameter. + """ + + @property + def avg_doc_length(self) -> float: + """ + The assumed average document length. + """ + + def __repr__(self) -> str: ... \ No newline at end of file diff --git a/gtars-python/src/bm25/mod.rs b/gtars-python/src/bm25/mod.rs new file mode 100644 index 00000000..d86c82e8 --- /dev/null +++ b/gtars-python/src/bm25/mod.rs @@ -0,0 +1,128 @@ +use pyo3::prelude::*; + +use anyhow::Result; +use gtars_bm25::{Bm25Builder, SparseVector}; +use gtars_tokenizers::{Tokenizer, create_tokenize_core_from_universe, config::TokenizerType}; + +use crate::tokenizers::py_tokenizers::PyTokenizer; +use crate::utils::extract_regions_from_py_any; + +#[pyclass(name = "SparseVector", module = "gtars.bm25")] +pub struct PySparseVector { + inner: SparseVector, +} + +#[pymethods] +impl PySparseVector { + #[getter] + fn indices(&self) -> Vec { + self.inner.indices.clone() + } + + #[getter] + fn values(&self) -> Vec { + self.inner.values.clone() + } + + fn __len__(&self) -> usize { + self.inner.len() + } + + fn __repr__(&self) -> String { + format!( + "SparseVector(len={}, indices={:?}, values={:?})", + self.inner.len(), + self.inner.indices, + self.inner.values + ) + } +} + +#[pyclass(name = "Bm25", module = "gtars.bm25", subclass)] +pub struct PyBm25 { + inner: gtars_bm25::Bm25, +} + +#[pymethods] +impl PyBm25 { + #[new] + #[pyo3(signature = (tokenizer, k=1.0, b=0.75, avg_doc_length=1000.0))] + fn new(tokenizer: &Bound<'_, PyAny>, k: f32, b: f32, avg_doc_length: f32) -> Result { + let mut builder = Bm25Builder::default() + .with_k(k) + .with_b(b) + .with_avg_doc_length(avg_doc_length); + + if let Ok(path) = tokenizer.extract::() { + // Accept a string path to a BED/BED.GZ/TOML file + builder = builder.with_vocab(&path); + } else if let Ok(py_tok) = tokenizer.cast::() { + // Accept an existing Tokenizer object — rebuild from its universe + let borrowed = py_tok.borrow(); + let inner_tok = borrowed.inner(); + let universe = inner_tok.get_universe().clone(); + let special_tokens = inner_tok.get_special_tokens().clone(); + let core = create_tokenize_core_from_universe(&universe, TokenizerType::AIList); + let tok = Tokenizer::new(core, universe, special_tokens); + builder = builder.with_tokenizer(tok); + } else { + return Err(anyhow::anyhow!( + "tokenizer must be a string path (BED, BED.GZ, TOML) or a Tokenizer object" + )); + } + + Ok(PyBm25 { + inner: builder.build(), + }) + } + + /// Embed a set of regions into a BM25 sparse vector. + fn embed(&self, regions: &Bound<'_, PyAny>) -> Result { + let rs = extract_regions_from_py_any(regions)?; + let sv = self.inner.embed(&rs.regions); + Ok(PySparseVector { inner: sv }) + } + + /// Tokenize regions into token IDs without computing BM25 scores. + fn tokenize(&self, regions: &Bound<'_, PyAny>) -> Result> { + let rs = extract_regions_from_py_any(regions)?; + Ok(self.inner.tokenize(&rs.regions)) + } + + #[getter] + fn vocab_size(&self) -> usize { + self.inner.vocab_size() + } + + #[getter] + fn k(&self) -> f32 { + self.inner.k() + } + + #[getter] + fn b(&self) -> f32 { + self.inner.b() + } + + #[getter] + fn avg_doc_length(&self) -> f32 { + self.inner.avg_doc_length() + } + + fn __repr__(&self) -> String { + format!( + "Bm25(vocab_size={}, k={}, b={}, avg_doc_length={})", + self.inner.vocab_size(), + self.inner.k(), + self.inner.b(), + self.inner.avg_doc_length() + ) + } +} + +#[pymodule] +pub fn bm25(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} \ No newline at end of file diff --git a/gtars-python/src/lib.rs b/gtars-python/src/lib.rs index bc849a2a..8ab17b87 100644 --- a/gtars-python/src/lib.rs +++ b/gtars-python/src/lib.rs @@ -5,6 +5,8 @@ static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; use pyo3::prelude::*; use pyo3::types::PyDict; +#[cfg(feature = "bm25")] +mod bm25; #[cfg(feature = "genomic_distributions")] mod genomic_distributions; #[cfg(feature = "models")] @@ -26,6 +28,13 @@ fn gtars(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let binding = sys.getattr("modules")?; let sys_modules: &Bound<'_, PyDict> = binding.cast()?; + #[cfg(feature = "bm25")] + { + let bm25_module = pyo3::wrap_pymodule!(bm25::bm25); + m.add_wrapped(bm25_module)?; + sys_modules.set_item("gtars.bm25", m.getattr("bm25")?)?; + } + #[cfg(feature = "refget")] { let refget_module = pyo3::wrap_pymodule!(refget::refget); diff --git a/gtars-python/src/lola/mod.rs b/gtars-python/src/lola/mod.rs index 6d042360..63e4b9a0 100644 --- a/gtars-python/src/lola/mod.rs +++ b/gtars-python/src/lola/mod.rs @@ -125,13 +125,13 @@ impl PyRegionDB { for a in &self.inner.region_anno { let d = PyDict::new(py); d.set_item("filename", &a.filename)?; - d.set_item("cellType", &a.cell_type)?; - d.set_item("description", &a.description)?; - d.set_item("tissue", &a.tissue)?; - d.set_item("dataSource", &a.data_source)?; - d.set_item("antibody", &a.antibody)?; - d.set_item("treatment", &a.treatment)?; - d.set_item("collection", &a.collection)?; + d.set_item("cellType", a.cell_type.as_deref())?; + d.set_item("description", a.description.as_deref())?; + d.set_item("tissue", a.tissue.as_deref())?; + d.set_item("dataSource", a.data_source.as_deref())?; + d.set_item("antibody", a.antibody.as_deref())?; + d.set_item("treatment", a.treatment.as_deref())?; + d.set_item("collection", a.collection.as_deref())?; result.push(d); } Ok(result) @@ -242,82 +242,33 @@ fn results_to_dict<'py>( py: Python<'py>, results: &[LolaResult], ) -> PyResult> { - let n = results.len(); - let mut user_set = Vec::with_capacity(n); - let mut db_set = Vec::with_capacity(n); - let mut p_value_log = Vec::with_capacity(n); - let mut odds_ratio = Vec::with_capacity(n); - let mut support = Vec::with_capacity(n); - let mut rnk_pv = Vec::with_capacity(n); - let mut rnk_or = Vec::with_capacity(n); - let mut rnk_sup = Vec::with_capacity(n); - let mut max_rnk = Vec::with_capacity(n); - let mut mean_rnk = Vec::with_capacity(n); - let mut b_vec = Vec::with_capacity(n); - let mut c_vec = Vec::with_capacity(n); - let mut d_vec = Vec::with_capacity(n); - let mut collection: Vec> = Vec::with_capacity(n); - let mut description: Vec> = Vec::with_capacity(n); - let mut cell_type: Vec> = Vec::with_capacity(n); - let mut tissue: Vec> = Vec::with_capacity(n); - let mut antibody: Vec> = Vec::with_capacity(n); - let mut treatment: Vec> = Vec::with_capacity(n); - let mut data_source: Vec> = Vec::with_capacity(n); - let mut filename = Vec::with_capacity(n); - let mut db_set_size = Vec::with_capacity(n); - - for r in results { - user_set.push(r.user_set); - db_set.push(r.db_set); - p_value_log.push(r.p_value_log); - odds_ratio.push(r.odds_ratio); - support.push(r.support); - rnk_pv.push(r.rnk_pv); - rnk_or.push(r.rnk_or); - rnk_sup.push(r.rnk_sup); - max_rnk.push(r.max_rnk); - mean_rnk.push(r.mean_rnk); - b_vec.push(r.b); - c_vec.push(r.c); - d_vec.push(r.d); - collection.push(empty_to_none(&r.collection)); - description.push(empty_to_none(&r.description)); - cell_type.push(empty_to_none(&r.cell_type)); - tissue.push(empty_to_none(&r.tissue)); - antibody.push(empty_to_none(&r.antibody)); - treatment.push(empty_to_none(&r.treatment)); - data_source.push(empty_to_none(&r.data_source)); - filename.push(r.filename.clone()); - db_set_size.push(r.db_set_size); - } - - let q_value: Vec> = results.iter().map(|r| r.q_value).collect(); + use gtars_lola::output::results_to_columns; + let c = results_to_columns(results); let dict = PyDict::new(py); - dict.set_item("userSet", user_set)?; - dict.set_item("dbSet", db_set)?; - dict.set_item("collection", collection)?; - dict.set_item("pValueLog", p_value_log)?; - dict.set_item("oddsRatio", odds_ratio)?; - dict.set_item("support", support)?; - dict.set_item("rnkPV", rnk_pv)?; - dict.set_item("rnkOR", rnk_or)?; - dict.set_item("rnkSup", rnk_sup)?; - dict.set_item("maxRnk", max_rnk)?; - dict.set_item("meanRnk", mean_rnk)?; - dict.set_item("b", b_vec)?; - dict.set_item("c", c_vec)?; - dict.set_item("d", d_vec)?; - dict.set_item("description", description)?; - dict.set_item("cellType", cell_type)?; - dict.set_item("tissue", tissue)?; - dict.set_item("antibody", antibody)?; - dict.set_item("treatment", treatment)?; - dict.set_item("dataSource", data_source)?; - dict.set_item("filename", filename)?; - dict.set_item("qValue", q_value)?; - dict.set_item("size", db_set_size)?; - + dict.set_item("userSet", c.user_set)?; + dict.set_item("dbSet", c.db_set)?; + dict.set_item("collection", c.collection)?; + dict.set_item("pValueLog", c.p_value_log)?; + dict.set_item("oddsRatio", c.odds_ratio)?; + dict.set_item("support", c.support)?; + dict.set_item("rnkPV", c.rnk_pv)?; + dict.set_item("rnkOR", c.rnk_or)?; + dict.set_item("rnkSup", c.rnk_sup)?; + dict.set_item("maxRnk", c.max_rnk)?; + dict.set_item("meanRnk", c.mean_rnk)?; + dict.set_item("b", c.b)?; + dict.set_item("c", c.c)?; + dict.set_item("d", c.d)?; + dict.set_item("description", c.description)?; + dict.set_item("cellType", c.cell_type)?; + dict.set_item("tissue", c.tissue)?; + dict.set_item("antibody", c.antibody)?; + dict.set_item("treatment", c.treatment)?; + dict.set_item("dataSource", c.data_source)?; + dict.set_item("filename", c.filename)?; + dict.set_item("qValue", c.q_value)?; + dict.set_item("size", c.db_set_size)?; Ok(dict) } @@ -408,14 +359,6 @@ fn py_build_restricted_universe( // Helpers // ========================================================================= -fn empty_to_none(s: &str) -> Option { - if s.is_empty() { - None - } else { - Some(s.to_string()) - } -} - fn tuples_to_regionset(tuples: Vec<(String, u32, u32)>) -> RegionSet { RegionSet::from( tuples diff --git a/gtars-python/src/tokenizers/mod.rs b/gtars-python/src/tokenizers/mod.rs index 9c87881f..a464b040 100644 --- a/gtars-python/src/tokenizers/mod.rs +++ b/gtars-python/src/tokenizers/mod.rs @@ -1,5 +1,5 @@ mod encoding; -mod py_tokenizers; +pub(crate) mod py_tokenizers; mod universe; mod utils; diff --git a/gtars-r/DESCRIPTION b/gtars-r/DESCRIPTION index 76f270dc..7cde776d 100644 --- a/gtars-r/DESCRIPTION +++ b/gtars-r/DESCRIPTION @@ -1,6 +1,6 @@ Package: gtars Title: Performance critical genomic interval analysis using Rust, in R -Version: 0.8.0 +Version: 0.8.1 Authors@R: person("Nathan", "LeRoy", , "nleroy917@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-7354-7213")) diff --git a/gtars-r/src/rust/Cargo.toml b/gtars-r/src/rust/Cargo.toml index d148db6b..18330fc8 100644 --- a/gtars-r/src/rust/Cargo.toml +++ b/gtars-r/src/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gtars-r" -version = "0.8.0" +version = "0.8.1" edition = "2021" [lib] diff --git a/gtars-r/src/rust/src/lola.rs b/gtars-r/src/rust/src/lola.rs index a7c217d4..fa69cb00 100644 --- a/gtars-r/src/rust/src/lola.rs +++ b/gtars-r/src/rust/src/lola.rs @@ -55,79 +55,29 @@ fn extract_region_sets(user_sets: List) -> extendr_api::Result> { Ok(sets) } -/// Convert an empty string to None (becomes NA in R). -fn empty_to_na(s: &str) -> Option { - if s.is_empty() { - None - } else { - Some(s.to_string()) - } -} - /// Convert LOLA results to an R list (data.frame-like structure). fn results_to_list(results: &[gtars_lola::models::LolaResult]) -> List { - let n = results.len(); - let mut user_set = Vec::with_capacity(n); - let mut db_set = Vec::with_capacity(n); - let mut p_value_log = Vec::with_capacity(n); - let mut odds_ratio = Vec::with_capacity(n); - let mut support: Vec = Vec::with_capacity(n); - let mut rnk_pv: Vec = Vec::with_capacity(n); - let mut rnk_or: Vec = Vec::with_capacity(n); - let mut rnk_sup: Vec = Vec::with_capacity(n); - let mut max_rnk: Vec = Vec::with_capacity(n); - let mut mean_rnk = Vec::with_capacity(n); - let mut b_vec: Vec = Vec::with_capacity(n); - let mut c_vec: Vec = Vec::with_capacity(n); - let mut d_vec: Vec = Vec::with_capacity(n); - let mut collection: Vec> = Vec::with_capacity(n); - let mut description: Vec> = Vec::with_capacity(n); - let mut cell_type: Vec> = Vec::with_capacity(n); - let mut tissue: Vec> = Vec::with_capacity(n); - let mut antibody: Vec> = Vec::with_capacity(n); - let mut treatment: Vec> = Vec::with_capacity(n); - let mut data_source: Vec> = Vec::with_capacity(n); - let mut filename = Vec::with_capacity(n); - let mut db_set_size: Vec = Vec::with_capacity(n); - - let mut q_value: Vec> = Vec::with_capacity(n); - - for r in results { - user_set.push((r.user_set + 1) as i32); // 1-based for R - db_set.push((r.db_set + 1) as i32); - p_value_log.push(r.p_value_log); - odds_ratio.push(r.odds_ratio); - support.push(r.support as i32); - rnk_pv.push(r.rnk_pv as i32); - rnk_or.push(r.rnk_or as i32); - rnk_sup.push(r.rnk_sup as i32); - max_rnk.push(r.max_rnk as i32); - mean_rnk.push(r.mean_rnk); - b_vec.push(r.b as i32); - c_vec.push(r.c as i32); - d_vec.push(r.d as i32); - collection.push(empty_to_na(&r.collection)); - description.push(empty_to_na(&r.description)); - cell_type.push(empty_to_na(&r.cell_type)); - tissue.push(empty_to_na(&r.tissue)); - antibody.push(empty_to_na(&r.antibody)); - treatment.push(empty_to_na(&r.treatment)); - data_source.push(empty_to_na(&r.data_source)); - filename.push(r.filename.clone()); - db_set_size.push(r.db_set_size as i32); - q_value.push(r.q_value); - } - - // Convert Option to Rfloat (NA for None) - let q_value_r: Vec = q_value - .iter() - .map(|q| match q { - Some(v) => Rfloat::from(*v), - None => Rfloat::na(), - }) + use gtars_lola::output::results_to_columns; + + let c = results_to_columns(results); + + // R-specific conversions: 1-based indices, i32 casts, NA types + let user_set: Vec = c.user_set.iter().map(|&v| (v + 1) as i32).collect(); + let db_set: Vec = c.db_set.iter().map(|&v| (v + 1) as i32).collect(); + let support: Vec = c.support.iter().map(|&v| v as i32).collect(); + let rnk_pv: Vec = c.rnk_pv.iter().map(|&v| v as i32).collect(); + let rnk_or: Vec = c.rnk_or.iter().map(|&v| v as i32).collect(); + let rnk_sup: Vec = c.rnk_sup.iter().map(|&v| v as i32).collect(); + let max_rnk: Vec = c.max_rnk.iter().map(|&v| v as i32).collect(); + let b_vec: Vec = c.b.iter().map(|&v| v as i32).collect(); + let c_vec: Vec = c.c.iter().map(|&v| v as i32).collect(); + let d_vec: Vec = c.d.iter().map(|&v| v as i32).collect(); + let db_set_size: Vec = c.db_set_size.iter().map(|&v| v as i32).collect(); + + let q_value_r: Vec = c.q_value.iter() + .map(|q| match q { Some(v) => Rfloat::from(*v), None => Rfloat::na() }) .collect(); - // Convert Option to Rstr (NA for None) let to_rstr = |v: &[Option]| -> Vec { v.iter() .map(|s| match s { @@ -140,25 +90,25 @@ fn results_to_list(results: &[gtars_lola::models::LolaResult]) -> List { list!( userSet = user_set, dbSet = db_set, - collection = to_rstr(&collection), - pValueLog = p_value_log, - oddsRatio = odds_ratio, + collection = to_rstr(&c.collection), + pValueLog = c.p_value_log, + oddsRatio = c.odds_ratio, support = support, rnkPV = rnk_pv, rnkOR = rnk_or, rnkSup = rnk_sup, maxRnk = max_rnk, - meanRnk = mean_rnk, + meanRnk = c.mean_rnk, b = b_vec, c = c_vec, d = d_vec, - description = to_rstr(&description), - cellType = to_rstr(&cell_type), - tissue = to_rstr(&tissue), - antibody = to_rstr(&antibody), - treatment = to_rstr(&treatment), - dataSource = to_rstr(&data_source), - filename = filename, + description = to_rstr(&c.description), + cellType = to_rstr(&c.cell_type), + tissue = to_rstr(&c.tissue), + antibody = to_rstr(&c.antibody), + treatment = to_rstr(&c.treatment), + dataSource = to_rstr(&c.data_source), + filename = c.filename, qValue = q_value_r, size = db_set_size ) @@ -432,13 +382,13 @@ pub fn r_regiondb_anno(db: Robj) -> extendr_api::Result { for (i, a) in db_ref.region_anno.iter().enumerate() { filename.push(a.filename.clone()); - cell_type.push(empty_to_na(&a.cell_type)); - description.push(empty_to_na(&a.description)); - tissue.push(empty_to_na(&a.tissue)); - data_source.push(empty_to_na(&a.data_source)); - antibody.push(empty_to_na(&a.antibody)); - treatment.push(empty_to_na(&a.treatment)); - collection.push(empty_to_na(&a.collection)); + cell_type.push(a.cell_type.clone()); + description.push(a.description.clone()); + tissue.push(a.tissue.clone()); + data_source.push(a.data_source.clone()); + antibody.push(a.antibody.clone()); + treatment.push(a.treatment.clone()); + collection.push(a.collection.clone()); size.push(db_ref.region_sets[i].len() as i32); } diff --git a/gtars-refget/src/store/persistence.rs b/gtars-refget/src/store/persistence.rs index a5abd482..0a4115f7 100644 --- a/gtars-refget/src/store/persistence.rs +++ b/gtars-refget/src/store/persistence.rs @@ -188,7 +188,11 @@ impl ReadonlyRefgetStore { "#name\tlength\talphabet\tsha512t24u\tmd5\tdescription" )?; - for result_sr in self.sequence_store.values() { + // Sort by sha512t24u digest for deterministic output (sequence_store is a HashMap). + let mut entries: Vec<&SequenceRecord> = self.sequence_store.values().collect(); + entries.sort_by(|a, b| a.metadata().sha512t24u.cmp(&b.metadata().sha512t24u)); + + for result_sr in entries { let result = result_sr.metadata(); let description = result.description.as_deref().unwrap_or(""); writeln!( diff --git a/gtars-refget/src/store/readonly.rs b/gtars-refget/src/store/readonly.rs index 096ea74e..feb183fb 100644 --- a/gtars-refget/src/store/readonly.rs +++ b/gtars-refget/src/store/readonly.rs @@ -426,11 +426,17 @@ impl ReadonlyRefgetStore { .map(|name_map| { name_map .iter() - .filter_map(|(name, seq_key)| { - let record = self.sequence_store.get(seq_key)?; + .map(|(name, seq_key)| { + let record = self.sequence_store.get(seq_key).ok_or_else(|| { + anyhow!( + "Sequence {} not found in store for collection {}", + key_to_digest_string(seq_key), + collection_digest, + ) + })?; let mut meta = record.metadata().clone(); meta.name = name.clone(); - Some(match record.sequence() { + Ok(match record.sequence() { Some(seq) => SequenceRecord::Full { metadata: meta, sequence: seq.to_vec(), @@ -438,8 +444,9 @@ impl ReadonlyRefgetStore { None => SequenceRecord::Stub(meta), }) }) - .collect() + .collect::>>() }) + .transpose()? .unwrap_or_default(); Ok(crate::digest::SequenceCollection { @@ -539,11 +546,23 @@ impl ReadonlyRefgetStore { /// Import a single collection (with all its sequences, aliases, and FHR /// metadata) from another store into this store. /// + /// Both stores must be disk-backed with matching storage modes. /// The source store must have the collection loaded (call /// `load_collection()` or `load_all_collections()` first). pub fn import_collection(&mut self, source: &ReadonlyRefgetStore, digest: &str) -> Result<()> { - let collection = source.get_collection(digest)?; - self.add_sequence_collection(collection)?; + // Both stores must be disk-backed with same mode + if source.local_path.is_none() || self.local_path.is_none() || !self.persist_to_disk { + return Err(anyhow!("import_collection requires both stores to be disk-backed")); + } + if source.mode != self.mode { + return Err(anyhow!( + "import_collection requires matching storage modes (source={:?}, dest={:?})", + source.mode, + self.mode, + )); + } + + self.import_collection_file_copy(source, digest)?; // Copy sequence aliases for every sequence in the imported collection let coll_key = digest.to_key(); @@ -569,6 +588,126 @@ impl ReadonlyRefgetStore { Ok(()) } + /// File-copy based import: copies RGSI and .seq files directly from + /// source to destination, then registers the collection in memory. + fn import_collection_file_copy( + &mut self, + source: &ReadonlyRefgetStore, + digest: &str, + ) -> Result<()> { + use crate::collection::read_rgsi_file; + + // 1. Read the source RGSI file to get collection metadata + let src_rgsi_path = source + .collection_file_path(digest) + .ok_or_else(|| anyhow!("Source store has no local path for collection {}", digest))?; + let collection = read_rgsi_file(&src_rgsi_path) + .with_context(|| format!("Failed to read source RGSI file: {}", src_rgsi_path.display()))?; + + let mut metadata = collection.metadata.clone(); + + // 2. Handle ancillary digests and copy/write the RGSI file + let dst_rgsi_path = self + .collection_file_path(digest) + .ok_or_else(|| anyhow!("Dest store has no local path for collection {}", digest))?; + if let Some(parent) = dst_rgsi_path.parent() { + create_dir_all(parent)?; + } + + let needs_ancillary_rewrite = self.ancillary_digests + && metadata.name_length_pairs_digest.is_none(); + + if needs_ancillary_rewrite { + // Source lacks ancillary digests but destination wants them -- + // compute them and write a new RGSI file. + metadata.compute_ancillary_digests(&collection.sequences); + use crate::collection::SequenceCollectionRecordExt; + let record = SequenceCollectionRecord::Full { + metadata: metadata.clone(), + sequences: collection.sequences.iter() + .map(|s| SequenceRecord::Stub(s.metadata().clone())) + .collect(), + }; + record.write_collection_rgsi(&dst_rgsi_path)?; + } else { + // Byte-for-byte copy of the RGSI file + fs::copy(&src_rgsi_path, &dst_rgsi_path) + .with_context(|| format!( + "Failed to copy RGSI {} -> {}", + src_rgsi_path.display(), + dst_rgsi_path.display(), + ))?; + } + + // 3. Copy sequence data files + for seq_record in &collection.sequences { + let seq_meta = seq_record.metadata(); + let seq_digest = &seq_meta.sha512t24u; + + let src_seq_path = source + .sequence_file_path(seq_digest) + .ok_or_else(|| anyhow!("Source has no path for sequence {}", seq_digest))?; + let dst_seq_path = self + .sequence_file_path(seq_digest) + .ok_or_else(|| anyhow!("Dest has no path for sequence {}", seq_digest))?; + + // Skip if destination already has this sequence (dedup across collections) + if dst_seq_path.exists() { + // Still need to register in memory below + } else { + if let Some(parent) = dst_seq_path.parent() { + create_dir_all(parent)?; + } + fs::copy(&src_seq_path, &dst_seq_path).with_context(|| { + format!( + "Failed to copy sequence {} -> {}", + src_seq_path.display(), + dst_seq_path.display(), + ) + })?; + } + } + + // 4. Register in memory + let coll_key = digest.to_key(); + + // Build the collection record with stub sequences + let stub_sequences: Vec = collection + .sequences + .iter() + .map(|s| SequenceRecord::Stub(s.metadata().clone())) + .collect(); + + let record = SequenceCollectionRecord::Full { + metadata: metadata.clone(), + sequences: stub_sequences, + }; + self.collections.insert(coll_key, record); + + // Register sequences and populate name_lookup + let mut name_map = IndexMap::new(); + for seq_record in &collection.sequences { + let seq_meta = seq_record.metadata(); + let seq_key = seq_meta.sha512t24u.to_key(); + + name_map.insert(seq_meta.name.clone(), seq_key); + + // Insert stub into sequence_store (skip if already present -- dedup) + if !self.sequence_store.contains_key(&seq_key) { + self.sequence_store + .insert(seq_key, SequenceRecord::Stub(seq_meta.clone())); + self.md5_lookup + .insert(seq_meta.md5.to_key(), seq_key); + } + } + self.name_lookup.insert(coll_key, name_map); + + // 5. Update index files + self.write_index_files()?; + + Ok(()) + } + // ========================================================================= // Sequence API // ========================================================================= @@ -825,6 +964,23 @@ impl ReadonlyRefgetStore { PathBuf::from(path_str) } + /// Return the full filesystem path to a sequence `.seq` file for the given digest. + /// + /// Returns `None` if the store has no local path or no seqdata path template. + pub fn sequence_file_path(&self, seq_digest: &str) -> Option { + let local_path = self.local_path.as_ref()?; + let template = self.seqdata_path_template.as_ref()?; + Some(local_path.join(Self::expand_template(seq_digest, template))) + } + + /// Return the full filesystem path to a collection RGSI file for the given digest. + /// + /// Returns `None` if the store has no local path. + pub fn collection_file_path(&self, coll_digest: &str) -> Option { + let local_path = self.local_path.as_ref()?; + Some(local_path.join(format!("collections/{}.rgsi", coll_digest))) + } + /// Validate a relative path to prevent directory traversal attacks. pub(crate) fn sanitize_relative_path(path: &str) -> Result<()> { if path.starts_with('/') || path.starts_with('\\') { diff --git a/gtars-refget/src/store/tests.rs b/gtars-refget/src/store/tests.rs index 6015b96f..b9bb80f4 100644 --- a/gtars-refget/src/store/tests.rs +++ b/gtars-refget/src/store/tests.rs @@ -1526,6 +1526,92 @@ fn test_collection_order_preserved_after_roundtrip() { ); } +/// Test that multiple collections sharing sequences under different names and different orderings +/// all preserve their correct per-collection names and element orderings across a disk roundtrip. +/// +/// This covers the intersection of two previously-fixed bugs: +/// 1. HashMap ordering (fixed: inner map now IndexMap) +/// 2. Global name leakage (fixed: get_collection() overrides meta.name from name_lookup) +#[test] +fn test_shared_sequences_order_preserved_after_disk_roundtrip() { + // FASTA A: base ordering — chrX first, then chr1, then chr2 + let fasta_a = ">chrX\nTTGGGGAA\n>chr1\nGGAA\n>chr2\nGCGC\n"; + // FASTA B: different order — chr1 first, same sequences as A + let fasta_b = ">chr1\nGGAA\n>chr2\nGCGC\n>chrX\nTTGGGGAA\n"; + // FASTA C: name swap — chr2 has GGAA, chr1 has GCGC (opposite of A/B) + let fasta_c = ">chrX\nTTGGGGAA\n>chr2\nGGAA\n>chr1\nGCGC\n"; + + let dir = tempdir().unwrap(); + let fasta_a_path = dir.path().join("a.fa"); + let fasta_b_path = dir.path().join("b.fa"); + let fasta_c_path = dir.path().join("c.fa"); + fs::write(&fasta_a_path, fasta_a).unwrap(); + fs::write(&fasta_b_path, fasta_b).unwrap(); + fs::write(&fasta_c_path, fasta_c).unwrap(); + + let store_path = dir.path().join("store"); + let mut store = RefgetStore::on_disk(&store_path).unwrap(); + store.set_quiet(true); + + let (meta_a, _) = store.add_sequence_collection_from_fasta(&fasta_a_path, FastaImportOptions::new()).unwrap(); + let (meta_b, _) = store.add_sequence_collection_from_fasta(&fasta_b_path, FastaImportOptions::new()).unwrap(); + let (meta_c, _) = store.add_sequence_collection_from_fasta(&fasta_c_path, FastaImportOptions::new()).unwrap(); + + let digest_a = meta_a.digest.clone(); + let digest_b = meta_b.digest.clone(); + let digest_c = meta_c.digest.clone(); + + // Load collections before write and record level2 output + store.load_all_collections().unwrap(); + let pre_a = store.get_collection_level2(&digest_a).unwrap(); + let pre_b = store.get_collection_level2(&digest_b).unwrap(); + let pre_c = store.get_collection_level2(&digest_c).unwrap(); + + // Verify pre-write ordering for FASTA A: chrX, chr1, chr2 + assert_eq!(pre_a.names, vec!["chrX", "chr1", "chr2"], "A: names before roundtrip"); + // Verify pre-write ordering for FASTA B: chr1, chr2, chrX + assert_eq!(pre_b.names, vec!["chr1", "chr2", "chrX"], "B: names before roundtrip"); + // Verify pre-write ordering for FASTA C: chrX, chr2, chr1 (name swap) + assert_eq!(pre_c.names, vec!["chrX", "chr2", "chr1"], "C: names before roundtrip"); + + store.write().unwrap(); + + // Drop and reopen from disk + drop(store); + let mut reloaded = RefgetStore::open_local(&store_path).unwrap(); + reloaded.load_all_collections().unwrap(); + + let post_a = reloaded.get_collection_level2(&digest_a).unwrap(); + let post_b = reloaded.get_collection_level2(&digest_b).unwrap(); + let post_c = reloaded.get_collection_level2(&digest_c).unwrap(); + + // Names must match exactly (order-sensitive) after roundtrip + assert_eq!(post_a.names, pre_a.names, "A: names after roundtrip"); + assert_eq!(post_b.names, pre_b.names, "B: names after roundtrip"); + assert_eq!(post_c.names, pre_c.names, "C: names after roundtrip"); + + // Lengths must match exactly after roundtrip + assert_eq!(post_a.lengths, pre_a.lengths, "A: lengths after roundtrip"); + assert_eq!(post_b.lengths, pre_b.lengths, "B: lengths after roundtrip"); + assert_eq!(post_c.lengths, pre_c.lengths, "C: lengths after roundtrip"); + + // Sequence digests must match exactly after roundtrip + assert_eq!(post_a.sequences, pre_a.sequences, "A: sequences after roundtrip"); + assert_eq!(post_b.sequences, pre_b.sequences, "B: sequences after roundtrip"); + assert_eq!(post_c.sequences, pre_c.sequences, "C: sequences after roundtrip"); + + // Cross-check: FASTA C has chr2=GGAA and chr1=GCGC (opposite of A's chr1=GGAA, chr2=GCGC) + // The sequence digest for chr2 in C should equal chr1 in A + assert_eq!( + post_c.sequences[1], post_a.sequences[1], + "C.chr2 and A.chr1 share GGAA bytes, should have same sequence digest" + ); + assert_eq!( + post_c.sequences[2], post_a.sequences[2], + "C.chr1 and A.chr2 share GCGC bytes, should have same sequence digest" + ); +} + // ========================================================================= // Name source tests // ========================================================================= @@ -1591,10 +1677,30 @@ fn test_shared_sequence_different_names() { // Import collection tests // ========================================================================= +/// Helper: create a disk-backed store with one collection from a FASTA string. +fn disk_store_with_one_collection(fasta_content: &str) -> (RefgetStore, String, tempfile::TempDir, tempfile::TempDir) { + let store_dir = tempdir().unwrap(); + let fasta_dir = tempdir().unwrap(); + let fasta = fasta_dir.path().join("test.fa"); + fs::write(&fasta, fasta_content).unwrap(); + + let mut store = RefgetStore::on_disk(store_dir.path()).unwrap(); + store.set_quiet(true); + let (meta, _) = store + .add_sequence_collection_from_fasta(&fasta, FastaImportOptions::new()) + .unwrap(); + let digest = meta.digest.clone(); + // Load collection so name_lookup is populated + store.load_all_collections().unwrap(); + (store, digest, store_dir, fasta_dir) +} + #[test] fn test_import_collection_basic() { - let (mut source, digest) = store_with_one_collection(">chr1\nATGC\n>chr2\nGGGG\n"); - let mut target = RefgetStore::in_memory(); + let (mut source, digest, _src_dir, _fasta_dir) = + disk_store_with_one_collection(">chr1\nATGC\n>chr2\nGGGG\n"); + let target_dir = tempdir().unwrap(); + let mut target = RefgetStore::on_disk(target_dir.path()).unwrap(); target.import_collection(&mut source, &digest).unwrap(); @@ -1605,7 +1711,8 @@ fn test_import_collection_basic() { #[test] fn test_import_collection_copies_sequence_aliases() { - let (mut source, digest) = store_with_one_collection(">chr1\nATGC\n>chr2\nGGGG\n"); + let (mut source, digest, _src_dir, _fasta_dir) = + disk_store_with_one_collection(">chr1\nATGC\n>chr2\nGGGG\n"); // Add sequence aliases in source let coll = source.get_collection(&digest).unwrap(); @@ -1615,7 +1722,8 @@ fn test_import_collection_copies_sequence_aliases() { source.add_sequence_alias("ucsc", "chr1", &seq0_digest).unwrap(); source.add_sequence_alias("ncbi", "NC_000002.1", &seq1_digest).unwrap(); - let mut target = RefgetStore::in_memory(); + let target_dir = tempdir().unwrap(); + let mut target = RefgetStore::on_disk(target_dir.path()).unwrap(); target.import_collection(&mut source, &digest).unwrap(); // Target should have the sequence aliases @@ -1631,12 +1739,14 @@ fn test_import_collection_copies_sequence_aliases() { #[test] fn test_import_collection_copies_collection_aliases() { - let (mut source, digest) = store_with_one_collection(">chr1\nATGC\n>chr2\nGGGG\n"); + let (mut source, digest, _src_dir, _fasta_dir) = + disk_store_with_one_collection(">chr1\nATGC\n>chr2\nGGGG\n"); source.add_collection_alias("insdc", "GCA_000001.1", &digest).unwrap(); source.add_collection_alias("refseq", "GCF_000001.1", &digest).unwrap(); - let mut target = RefgetStore::in_memory(); + let target_dir = tempdir().unwrap(); + let mut target = RefgetStore::on_disk(target_dir.path()).unwrap(); target.import_collection(&mut source, &digest).unwrap(); let ns = target.list_collection_alias_namespaces(); @@ -1651,7 +1761,8 @@ fn test_import_collection_copies_collection_aliases() { fn test_import_collection_copies_fhr_metadata() { use super::fhr_metadata::FhrMetadata; - let (mut source, digest) = store_with_one_collection(">chr1\nATGC\n"); + let (mut source, digest, _src_dir, _fasta_dir) = + disk_store_with_one_collection(">chr1\nATGC\n"); let fhr = FhrMetadata { genome: Some("Homo sapiens".to_string()), @@ -1659,7 +1770,8 @@ fn test_import_collection_copies_fhr_metadata() { }; source.set_fhr_metadata(&digest, fhr).unwrap(); - let mut target = RefgetStore::in_memory(); + let target_dir = tempdir().unwrap(); + let mut target = RefgetStore::on_disk(target_dir.path()).unwrap(); target.import_collection(&mut source, &digest).unwrap(); let fhr = target.get_fhr_metadata(&digest); @@ -1682,6 +1794,7 @@ fn test_import_collection_disk_roundtrip_aliases() { let seq0_digest; { let mut source = RefgetStore::on_disk(source_dir.path()).unwrap(); + source.set_quiet(true); let (meta, _) = source .add_sequence_collection_from_fasta(&fasta, FastaImportOptions::new()) .unwrap(); @@ -1720,3 +1833,137 @@ fn test_import_collection_disk_roundtrip_aliases() { let coll_aliases = target.get_aliases_for_collection(&digest); assert_eq!(coll_aliases.len(), 1, "Expected 1 collection alias in target: {:?}", coll_aliases); } + +#[test] +fn test_import_collection_file_copy_roundtrip() { + // Verify RGSI and .seq files are byte-for-byte identical after import + // when ancillary digests match between source and dest. + let (mut source, digest, src_dir, _fasta_dir) = + disk_store_with_one_collection(">chr1\nATGC\n>chr2\nGGGG\n"); + let target_dir = tempdir().unwrap(); + let mut target = RefgetStore::on_disk(target_dir.path()).unwrap(); + + target.import_collection(&mut source, &digest).unwrap(); + + // Verify RGSI file is byte-for-byte identical + let src_rgsi = fs::read( + src_dir.path().join(format!("collections/{}.rgsi", digest)), + ).unwrap(); + let dst_rgsi = fs::read( + target_dir.path().join(format!("collections/{}.rgsi", digest)), + ).unwrap(); + assert_eq!(src_rgsi, dst_rgsi, "RGSI files should be byte-identical"); + + // Verify .seq files are byte-for-byte identical + let coll = target.get_collection(&digest).unwrap(); + for seq in &coll.sequences { + let seq_digest = &seq.metadata().sha512t24u; + let src_seq_path = source.sequence_file_path(seq_digest).unwrap(); + let dst_seq_path = target.sequence_file_path(seq_digest).unwrap(); + let src_data = fs::read(&src_seq_path).unwrap(); + let dst_data = fs::read(&dst_seq_path).unwrap(); + assert_eq!(src_data, dst_data, "Sequence file for {} should be byte-identical", seq_digest); + } + + // Verify in-memory metadata matches + let src_coll = source.get_collection(&digest).unwrap(); + assert_eq!(coll.metadata.digest, src_coll.metadata.digest); + assert_eq!(coll.sequences.len(), src_coll.sequences.len()); + for (src_seq, dst_seq) in src_coll.sequences.iter().zip(coll.sequences.iter()) { + assert_eq!(src_seq.metadata().sha512t24u, dst_seq.metadata().sha512t24u); + assert_eq!(src_seq.metadata().name, dst_seq.metadata().name); + assert_eq!(src_seq.metadata().length, dst_seq.metadata().length); + } +} + +#[test] +fn test_import_collection_ancillary_digest_enrichment() { + // Source store has ancillary_digests: false, destination has ancillary_digests: true. + // The destination RGSI should contain ancillary digest headers that the source lacks. + let source_dir = tempdir().unwrap(); + let fasta_dir = tempdir().unwrap(); + let fasta = fasta_dir.path().join("test.fa"); + fs::write(&fasta, ">chr1\nATGC\n>chr2\nGGGG\n").unwrap(); + + // Create source store with ancillary_digests: false + let mut source = RefgetStore::on_disk(source_dir.path()).unwrap(); + source.set_quiet(true); + source.disable_ancillary_digests(); + let (meta, _) = source + .add_sequence_collection_from_fasta(&fasta, FastaImportOptions::new()) + .unwrap(); + let digest = meta.digest.clone(); + source.load_all_collections().unwrap(); + + // Verify source RGSI lacks ancillary digests + let src_rgsi_content = fs::read_to_string( + source_dir.path().join(format!("collections/{}.rgsi", digest)), + ).unwrap(); + assert!( + !src_rgsi_content.contains("name_length_pairs_digest"), + "Source should NOT have ancillary digests", + ); + + // Create destination store with ancillary_digests: true (default) + let target_dir = tempdir().unwrap(); + let mut target = RefgetStore::on_disk(target_dir.path()).unwrap(); + target.enable_ancillary_digests(); + target.import_collection(&mut source, &digest).unwrap(); + + // Verify destination RGSI has ancillary digest headers + let dst_rgsi_content = fs::read_to_string( + target_dir.path().join(format!("collections/{}.rgsi", digest)), + ).unwrap(); + assert!( + dst_rgsi_content.contains("name_length_pairs_digest"), + "Destination should have ancillary digests. RGSI:\n{}", + dst_rgsi_content, + ); + assert!( + dst_rgsi_content.contains("sorted_name_length_pairs_digest"), + "Destination should have sorted_name_length_pairs_digest", + ); + assert!( + dst_rgsi_content.contains("sorted_sequences_digest"), + "Destination should have sorted_sequences_digest", + ); + + // Verify in-memory metadata has non-None ancillary fields + let coll_meta = target.get_collection_metadata(&digest).unwrap(); + assert!(coll_meta.name_length_pairs_digest.is_some(), "name_length_pairs_digest should be Some"); + assert!(coll_meta.sorted_name_length_pairs_digest.is_some(), "sorted_name_length_pairs_digest should be Some"); + assert!(coll_meta.sorted_sequences_digest.is_some(), "sorted_sequences_digest should be Some"); +} + +#[test] +fn test_import_collection_mode_mismatch_error() { + // Source with Raw mode, destination with Encoded mode should fail. + let source_dir = tempdir().unwrap(); + let fasta_dir = tempdir().unwrap(); + let fasta = fasta_dir.path().join("test.fa"); + fs::write(&fasta, ">chr1\nATGC\n").unwrap(); + + // Create source store in Raw mode + let mut source = RefgetStore::on_disk(source_dir.path()).unwrap(); + source.set_quiet(true); + source.disable_encoding(); + let (meta, _) = source + .add_sequence_collection_from_fasta(&fasta, FastaImportOptions::new()) + .unwrap(); + let digest = meta.digest.clone(); + source.load_all_collections().unwrap(); + + // Create destination store in Encoded mode (default) + let target_dir = tempdir().unwrap(); + let mut target = RefgetStore::on_disk(target_dir.path()).unwrap(); + // target uses Encoded mode by default + + let result = target.import_collection(&mut source, &digest); + assert!(result.is_err(), "Should fail with mode mismatch"); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("matching storage modes"), + "Error should mention storage modes: {}", + err_msg, + ); +} diff --git a/gtars-wasm/Cargo.toml b/gtars-wasm/Cargo.toml index 16375c40..004851de 100644 --- a/gtars-wasm/Cargo.toml +++ b/gtars-wasm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gtars-js" -version = "0.8.0" +version = "0.8.1" authors = ["Nathan LeRoy "] edition = "2024" @@ -15,6 +15,7 @@ wasm-bindgen = "0.2.93" serde = { version = "1.0", features = ["derive"] } serde-wasm-bindgen = "0.4" serde_json = "1.0" +flate2 = { workspace = true } getrandom = { version = "0.2.16", features = ["js"] } # our code diff --git a/gtars-wasm/src/bed_stream.rs b/gtars-wasm/src/bed_stream.rs new file mode 100644 index 00000000..d089f81f --- /dev/null +++ b/gtars-wasm/src/bed_stream.rs @@ -0,0 +1,287 @@ +use std::collections::HashMap; +use std::io::{Cursor, Read}; +use std::sync::Mutex; + +use flate2::read::GzDecoder; +use gtars_core::models::{Region, RegionSet}; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; + +use crate::models::BedEntries; + +// ============================================================================ +// BedStreamParser — streaming BED file parser with gzip auto-detection +// ============================================================================ + +struct BedStreamParser { + regions: Vec, + compressed_buf: Vec, + text_buf: String, + is_gzipped: Option, + bytes_processed: usize, +} + +impl BedStreamParser { + fn new() -> Self { + Self { + regions: Vec::new(), + compressed_buf: Vec::new(), + text_buf: String::new(), + is_gzipped: None, + bytes_processed: 0, + } + } + + fn update(&mut self, chunk: &[u8]) -> Result<(), String> { + if chunk.is_empty() { + return Ok(()); + } + + self.bytes_processed += chunk.len(); + + if self.is_gzipped.is_none() { + self.is_gzipped = Some(chunk.len() >= 2 && chunk[0] == 0x1f && chunk[1] == 0x8b); + } + + if self.is_gzipped == Some(true) { + self.compressed_buf.extend_from_slice(chunk); + } else { + let text = String::from_utf8_lossy(chunk); + self.text_buf.push_str(&text); + self.parse_complete_lines(); + } + + Ok(()) + } + + fn finish(mut self) -> Result { + if self.is_gzipped == Some(true) { + let cursor = Cursor::new(&self.compressed_buf); + let mut decoder = GzDecoder::new(cursor); + let mut decompressed = String::new(); + decoder + .read_to_string(&mut decompressed) + .map_err(|e| format!("Gzip decompression failed: {}", e))?; + self.text_buf = decompressed; + self.parse_complete_lines(); + } + + if !self.text_buf.is_empty() { + let remaining = std::mem::take(&mut self.text_buf); + self.parse_line(&remaining); + } + + let mut rs = RegionSet::from(self.regions); + rs.sort(); + Ok(rs) + } + + fn region_count(&self) -> usize { + self.regions.len() + } + + fn bytes_processed(&self) -> usize { + self.bytes_processed + } + + fn parse_complete_lines(&mut self) { + let last_newline = match self.text_buf.rfind('\n') { + Some(pos) => pos, + None => return, + }; + + let complete = self.text_buf[..last_newline].to_string(); + let remainder = self.text_buf[last_newline + 1..].to_string(); + self.text_buf = remainder; + + for line in complete.split('\n') { + self.parse_line(line); + } + } + + fn parse_line(&mut self, line: &str) { + let trimmed = line.trim(); + if trimmed.is_empty() + || trimmed.starts_with('#') + || trimmed.starts_with("browser") + || trimmed.starts_with("track") + { + return; + } + + let parts: Vec<&str> = trimmed.split('\t').collect(); + if parts.len() < 3 { + eprintln!( + "BED parse warning: expected at least 3 tab-separated fields, got {} in line: {:?}", + parts.len(), + trimmed + ); + return; + } + + let start = match parts[1].parse::() { + Ok(v) => v, + Err(e) => { + eprintln!( + "BED parse warning: invalid start coordinate {:?} in line {:?}: {}", + parts[1], + trimmed, + e + ); + return; + } + }; + let end = match parts[2].parse::() { + Ok(v) => v, + Err(e) => { + eprintln!( + "BED parse warning: invalid end coordinate {:?} in line {:?}: {}", + parts[2], + trimmed, + e + ); + return; + } + }; + + let rest = if parts.len() > 3 { + Some(parts[3..].join("\t")) + } else { + None + }; + + self.regions.push(Region { + chr: parts[0].to_string(), + start, + end, + rest: rest.filter(|s| !s.is_empty()), + }); + } +} + +// ============================================================================ +// Global storage for streaming BED parser instances +// ============================================================================ + +static BED_PARSER_STORAGE: Mutex> = Mutex::new(None); + +struct BedParserStorage { + parsers: HashMap, + next_id: u32, +} + +impl BedParserStorage { + fn new() -> Self { + Self { + parsers: HashMap::new(), + next_id: 1, + } + } + + fn insert(&mut self, parser: BedStreamParser) -> u32 { + let mut id = self.next_id; + while self.parsers.contains_key(&id) || id == 0 { + id = id.wrapping_add(1); + if id == 0 { + id = 1; + } + } + self.next_id = id.wrapping_add(1); + if self.next_id == 0 { + self.next_id = 1; + } + self.parsers.insert(id, parser); + id + } + + fn get_mut(&mut self, id: u32) -> Option<&mut BedStreamParser> { + self.parsers.get_mut(&id) + } + + fn remove(&mut self, id: u32) -> Option { + self.parsers.remove(&id) + } +} + +fn with_bed_storage(f: F) -> R +where + F: FnOnce(&mut BedParserStorage) -> R, +{ + let mut guard = BED_PARSER_STORAGE + .lock() + .expect("BED_PARSER_STORAGE mutex poisoned"); + if guard.is_none() { + *guard = Some(BedParserStorage::new()); + } + f(guard.as_mut().unwrap()) +} + +// ============================================================================ +// WASM bindings +// ============================================================================ + +#[wasm_bindgen(js_name = "bedParserNew")] +pub fn bed_parser_new() -> u32 { + with_bed_storage(|storage| storage.insert(BedStreamParser::new())) +} + +#[wasm_bindgen(js_name = "bedParserUpdate")] +pub fn bed_parser_update(handle: u32, chunk: &[u8]) -> Result<(), JsError> { + with_bed_storage(|storage| { + if let Some(parser) = storage.get_mut(handle) { + parser + .update(chunk) + .map_err(|e| JsError::new(&format!("Failed to process chunk: {}", e))) + } else { + Err(JsError::new("Invalid parser handle")) + } + }) +} + +/// Finalize parsing and return BedEntries (array of [chr, start, end, rest] tuples). +/// The result can be passed directly to `new RegionSet(entries)`. +#[wasm_bindgen(js_name = "bedParserFinish")] +pub fn bed_parser_finish(handle: u32) -> Result { + let parser = with_bed_storage(|storage| storage.remove(handle)) + .ok_or_else(|| JsError::new("Invalid parser handle"))?; + + let region_set = parser + .finish() + .map_err(|e| JsError::new(&format!("Failed to finalize parser: {}", e)))?; + + let entries: Vec<(String, u32, u32, String)> = region_set + .regions + .into_iter() + .map(|r| (r.chr, r.start, r.end, r.rest.unwrap_or_default())) + .collect(); + + serde_wasm_bindgen::to_value(&BedEntries(entries)) + .map_err(|e| JsError::new(&format!("Serialization error: {}", e))) +} + +#[wasm_bindgen(js_name = "bedParserFree")] +pub fn bed_parser_free(handle: u32) -> bool { + with_bed_storage(|storage| storage.remove(handle).is_some()) +} + +#[wasm_bindgen(js_name = "bedParserProgress")] +pub fn bed_parser_progress(handle: u32) -> Result { + let progress = with_bed_storage(|storage| { + storage.get_mut(handle).map(|parser| BedParserProgress { + region_count: parser.region_count(), + bytes_processed: parser.bytes_processed(), + }) + }); + + match progress { + Some(p) => serde_wasm_bindgen::to_value(&p) + .map_err(|e| JsError::new(&format!("Serialization error: {}", e))), + None => Err(JsError::new("Invalid parser handle")), + } +} + +#[derive(Serialize, Deserialize)] +pub struct BedParserProgress { + pub region_count: usize, + pub bytes_processed: usize, +} diff --git a/gtars-wasm/src/lib.rs b/gtars-wasm/src/lib.rs index 10ccca90..51e220c1 100644 --- a/gtars-wasm/src/lib.rs +++ b/gtars-wasm/src/lib.rs @@ -1,3 +1,4 @@ +mod bed_stream; mod asset; mod lola; mod models; @@ -12,15 +13,6 @@ mod utils; use wasm_bindgen::prelude::*; -// Re-export refget functions at the top level +// Re-export functions at the top level +pub use bed_stream::*; pub use refget::*; - -#[wasm_bindgen] -pub fn greet(name: &str) { - alert(&format!("Hello, {}!", name)); -} - -#[wasm_bindgen] -extern "C" { - fn alert(s: &str); -} diff --git a/gtars-wasm/src/lola.rs b/gtars-wasm/src/lola.rs index 0b1bbdf1..232ff94c 100644 --- a/gtars-wasm/src/lola.rs +++ b/gtars-wasm/src/lola.rs @@ -257,67 +257,63 @@ struct UniverseCheckResult { warnings: Vec, } -#[derive(serde::Serialize)] -#[serde(rename_all = "camelCase")] -struct LolaResults { - user_set: Vec, - db_set: Vec, - collection: Vec>, - p_value_log: Vec, - odds_ratio: Vec, - support: Vec, - rnk_pv: Vec, - rnk_or: Vec, - rnk_sup: Vec, - max_rnk: Vec, - mean_rnk: Vec, - b: Vec, - c: Vec, - d: Vec, - description: Vec>, - cell_type: Vec>, - tissue: Vec>, - antibody: Vec>, - treatment: Vec>, - data_source: Vec>, - filename: Vec, - q_value: Vec>, - size: Vec, -} - -fn empty_to_none(s: &str) -> Option { - if s.is_empty() { - None - } else { - Some(s.to_string()) +fn results_to_js(results: &[LolaResult]) -> Result { + use gtars_lola::output::results_to_columns; + + let c = results_to_columns(results); + + #[derive(serde::Serialize)] + #[serde(rename_all = "camelCase")] + struct Out { + user_set: Vec, + db_set: Vec, + collection: Vec>, + p_value_log: Vec, + odds_ratio: Vec, + support: Vec, + rnk_pv: Vec, + rnk_or: Vec, + rnk_sup: Vec, + max_rnk: Vec, + mean_rnk: Vec, + b: Vec, + c: Vec, + d: Vec, + description: Vec>, + cell_type: Vec>, + tissue: Vec>, + antibody: Vec>, + treatment: Vec>, + data_source: Vec>, + filename: Vec, + q_value: Vec>, + size: Vec, } -} -fn results_to_js(results: &[LolaResult]) -> Result { - let out = LolaResults { - user_set: results.iter().map(|r| r.user_set).collect(), - db_set: results.iter().map(|r| r.db_set).collect(), - collection: results.iter().map(|r| empty_to_none(&r.collection)).collect(), - p_value_log: results.iter().map(|r| r.p_value_log).collect(), - odds_ratio: results.iter().map(|r| r.odds_ratio).collect(), - support: results.iter().map(|r| r.support).collect(), - rnk_pv: results.iter().map(|r| r.rnk_pv).collect(), - rnk_or: results.iter().map(|r| r.rnk_or).collect(), - rnk_sup: results.iter().map(|r| r.rnk_sup).collect(), - max_rnk: results.iter().map(|r| r.max_rnk).collect(), - mean_rnk: results.iter().map(|r| r.mean_rnk).collect(), - b: results.iter().map(|r| r.b).collect(), - c: results.iter().map(|r| r.c).collect(), - d: results.iter().map(|r| r.d).collect(), - description: results.iter().map(|r| empty_to_none(&r.description)).collect(), - cell_type: results.iter().map(|r| empty_to_none(&r.cell_type)).collect(), - tissue: results.iter().map(|r| empty_to_none(&r.tissue)).collect(), - antibody: results.iter().map(|r| empty_to_none(&r.antibody)).collect(), - treatment: results.iter().map(|r| empty_to_none(&r.treatment)).collect(), - data_source: results.iter().map(|r| empty_to_none(&r.data_source)).collect(), - filename: results.iter().map(|r| r.filename.clone()).collect(), - q_value: results.iter().map(|r| r.q_value).collect(), - size: results.iter().map(|r| r.db_set_size).collect(), + let out = Out { + user_set: c.user_set, + db_set: c.db_set, + collection: c.collection, + p_value_log: c.p_value_log, + odds_ratio: c.odds_ratio, + support: c.support, + rnk_pv: c.rnk_pv, + rnk_or: c.rnk_or, + rnk_sup: c.rnk_sup, + max_rnk: c.max_rnk, + mean_rnk: c.mean_rnk, + b: c.b, + c: c.c, + d: c.d, + description: c.description, + cell_type: c.cell_type, + tissue: c.tissue, + antibody: c.antibody, + treatment: c.treatment, + data_source: c.data_source, + filename: c.filename, + q_value: c.q_value, + size: c.db_set_size, }; serde_wasm_bindgen::to_value(&out).map_err(|e| JsValue::from_str(&format!("{}", e))) diff --git a/gtars-wasm/src/regionset.rs b/gtars-wasm/src/regionset.rs index 32e3ab02..f11128bd 100644 --- a/gtars-wasm/src/regionset.rs +++ b/gtars-wasm/src/regionset.rs @@ -4,7 +4,7 @@ use crate::models::BedEntries; use gtars_core::models::{Region, RegionSet, RegionSetList}; use gtars_genomicdist::bed_classifier::classify_bed; use gtars_genomicdist::consensus; -use gtars_genomicdist::interval_ranges::IntervalRanges; +use gtars_genomicdist::interval_ranges::{IntervalRanges, RegionSetListOps}; use gtars_genomicdist::models::RegionBin; use gtars_genomicdist::statistics::GenomicIntervalSetStatistics; use wasm_bindgen::prelude::*; @@ -345,6 +345,25 @@ pub struct JsRegionSetList { #[wasm_bindgen(js_class = "RegionSetList")] impl JsRegionSetList { + /// Create an empty RegionSetList. Use `add()` to populate it. + #[wasm_bindgen(constructor)] + pub fn new() -> Self { + JsRegionSetList { + inner: RegionSetList::new(Vec::new()), + } + } + + /// Add a RegionSet to this list, with an optional name. + pub fn add(&mut self, set: &JsRegionSet, name: Option) { + let idx = self.inner.region_sets.len(); + self.inner.region_sets.push(set.region_set.clone()); + let name_val = name.unwrap_or_else(|| format!("set_{}", idx)); + self.inner + .names + .get_or_insert_with(Vec::new) + .push(name_val); + } + /// Number of region sets in this list. #[wasm_bindgen(getter)] pub fn length(&self) -> usize { @@ -383,6 +402,139 @@ impl JsRegionSetList { } } + /// Build a RegionSetList directly from arrays of BED entries. + /// + /// @param entries - Array of arrays: [[[chr, start, end, rest], ...], ...] + /// @param names - Optional array of names, one per set + #[wasm_bindgen(js_name = "fromEntries")] + pub fn from_entries(entries: &JsValue, names: &JsValue) -> Result { + let all_entries: Vec = + serde_wasm_bindgen::from_value(entries.clone())?; + let names_vec: Option> = if names.is_null() || names.is_undefined() { + None + } else { + Some(serde_wasm_bindgen::from_value(names.clone())?) + }; + + let mut sets = Vec::with_capacity(all_entries.len()); + for bed_entries in all_entries { + let regions: Vec = bed_entries + .0 + .into_iter() + .map(|be| Region { + chr: be.0, + start: be.1, + end: be.2, + rest: Some(be.3), + }) + .collect(); + let mut rs = RegionSet::from(regions); + rs.sort(); + sets.push(rs); + } + + Ok(JsRegionSetList { + inner: RegionSetList { + region_sets: sets, + names: names_vec, + path: None, + }, + }) + } + + /// Number of regions in the set at the given index. + #[wasm_bindgen(js_name = "regionCount")] + pub fn region_count(&self, index: usize) -> Result { + self.inner.region_count(index) + .ok_or_else(|| JsValue::from_str("Index out of range")) + } + + /// Number of overlapping regions between two sets by index. + #[wasm_bindgen(js_name = "pintersectCount")] + pub fn pintersect_count(&self, i: usize, j: usize) -> Result { + self.inner.pintersect_count(i, j) + .ok_or_else(|| JsValue::from_str("Index out of range")) + } + + /// Jaccard similarity between two sets by index. + #[wasm_bindgen(js_name = "jaccardAt")] + pub fn jaccard_at(&self, i: usize, j: usize) -> Result { + self.inner.jaccard_at(i, j) + .ok_or_else(|| JsValue::from_str("Index out of range")) + } + + /// Union of two sets by index. + #[wasm_bindgen(js_name = "unionAt")] + pub fn union_at(&self, i: usize, j: usize) -> Result { + self.inner.union_at(i, j) + .map(|rs| JsRegionSet { region_set: rs }) + .ok_or_else(|| JsValue::from_str("Index out of range")) + } + + /// Setdiff of two sets by index (set[i] minus set[j]). + #[wasm_bindgen(js_name = "setdiffAt")] + pub fn setdiff_at(&self, i: usize, j: usize) -> Result { + self.inner.setdiff_at(i, j) + .map(|rs| JsRegionSet { region_set: rs }) + .ok_or_else(|| JsValue::from_str("Index out of range")) + } + + /// Union of all sets except the one at the given index. + #[wasm_bindgen(js_name = "unionExcept")] + pub fn union_except(&self, skip: usize) -> Result { + self.inner.union_except(skip) + .map(|rs| JsRegionSet { region_set: rs }) + .ok_or_else(|| JsValue::from_str("Index out of range or list too small")) + } + + /// Compute all N union-except results in O(n) via prefix/suffix. + /// Returns { union: RegionSet, excepts: RegionSet[] }. + #[wasm_bindgen(js_name = "bulkUnionExcept")] + pub fn bulk_union_except(&self) -> Result { + let (full_union, excepts) = self.inner.bulk_union_except() + .ok_or_else(|| JsValue::from_str("Need at least 2 sets"))?; + + #[derive(serde::Serialize)] + struct BulkResult { + union_regions: u32, + union_nucleotides: u32, + except_unique: Vec, + } + + // For each file, compute setdiff(file_i, union_except_i).len() + let mut except_unique = Vec::with_capacity(excepts.len()); + for (i, ue) in excepts.iter().enumerate() { + if let Some(rs) = self.inner.get(i) { + except_unique.push(rs.setdiff(ue).len() as u32); + } else { + except_unique.push(0); + } + } + + let result = BulkResult { + union_regions: full_union.len() as u32, + union_nucleotides: full_union.nucleotides_length() as u32, + except_unique, + }; + serde_wasm_bindgen::to_value(&result).map_err(|e| e.into()) + } + + /// Union of all sets. + #[wasm_bindgen(js_name = "unionAll")] + pub fn union_all(&self) -> Result { + self.inner.union_all() + .map(|rs| JsRegionSet { region_set: rs }) + .ok_or_else(|| JsValue::from_str("Empty list")) + } + + /// Intersection of all sets. + #[wasm_bindgen(js_name = "intersectAll")] + pub fn intersect_all(&self) -> Result { + self.inner.intersect_all() + .map(|rs| JsRegionSet { region_set: rs }) + .ok_or_else(|| JsValue::from_str("Empty list")) + } + /// Compute pairwise Jaccard similarity for all pairs of region sets. /// /// Returns { matrix: number[][], names: string[] | null }.