diff --git a/app/database.py b/app/database.py new file mode 100644 index 0000000..74e3817 --- /dev/null +++ b/app/database.py @@ -0,0 +1,79 @@ +import sqlite3 +import os +from datetime import datetime +from pathlib import Path +from contextlib import contextmanager + +# Use an absolute path relative to the app directory or project root +# Using the directory of this file to place the DB in 'app/' folder or similar +BASE_DIR = Path(__file__).resolve().parent.parent +DB_FILE = BASE_DIR / "data" / "chord_fingerprints.db" + +# Ensure data directory exists +DB_FILE.parent.mkdir(parents=True, exist_ok=True) + +def get_db_connection(): + conn = sqlite3.connect(str(DB_FILE)) + conn.row_factory = sqlite3.Row + return conn + +@contextmanager +def db_cursor(): + conn = get_db_connection() + try: + yield conn, conn.cursor() + conn.commit() + finally: + conn.close() + +def init_db(): + with db_cursor() as (conn, c): + # Table for segment-level fingerprints + c.execute(''' + CREATE TABLE IF NOT EXISTS chord_fingerprints ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + phash TEXT NOT NULL, + chord_symbol TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # Table for file-level caching + c.execute(''' + CREATE TABLE IF NOT EXISTS file_cache ( + phash TEXT PRIMARY KEY, + progression_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + +def save_phash_chord_pair(phash: str, chord_symbol: str, cursor=None): + """ + Save pHash and chord symbol. + If cursor is provided, use it (for batch operations). + Otherwise, open a new connection. + """ + if cursor: + cursor.execute('INSERT INTO chord_fingerprints (phash, chord_symbol) VALUES (?, ?)', + (phash, chord_symbol)) + else: + with db_cursor() as (conn, c): + c.execute('INSERT INTO chord_fingerprints (phash, chord_symbol) VALUES (?, ?)', + (phash, chord_symbol)) + +def get_cached_progression(phash: str): + """Retrieve cached chord progression for a file pHash""" + with db_cursor() as (conn, c): + c.execute("SELECT progression_data FROM file_cache WHERE phash=?", (phash,)) + row = c.fetchone() + if row: + return row['progression_data'] + return None + +def save_cached_progression(phash: str, progression_data: str): + """Save chord progression to cache""" + with db_cursor() as (conn, c): + c.execute(''' + INSERT OR REPLACE INTO file_cache (phash, progression_data) + VALUES (?, ?) + ''', (phash, progression_data)) diff --git a/app/schemas.py b/app/schemas.py index cfcdac4..b4c0b07 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -86,8 +86,8 @@ class E2EBaseRequest(BaseModel): class E2EBaseResult(BaseModel): jobId: str - transcriptionUrl: str - separatedAudioUrl: str + transcriptionUrl: Optional[str] = None + separatedAudioUrl: Optional[str] = None chordProgressionUrl: str format: ChartFormat = ChartFormat.JSON diff --git a/app/tasks/chord_tasks.py b/app/tasks/chord_tasks.py index b7c004a..efc91f1 100644 --- a/app/tasks/chord_tasks.py +++ b/app/tasks/chord_tasks.py @@ -284,8 +284,16 @@ def e2e_base_ready_task(self, audio_file_path: str, instrument: str): from demucs.apply import apply_model import torch import torchaudio + import librosa + import numpy as np + from PIL import Image + import imagehash from halmoni import MIDIAnalyzer, ChordDetector, KeyDetector, ChordProgression import json + from app.database import ( + save_phash_chord_pair, db_cursor, init_db, + get_cached_progression, save_cached_progression + ) self.update_progress(0, 100, "Starting E2E pipeline") @@ -293,6 +301,39 @@ def e2e_base_ready_task(self, audio_file_path: str, instrument: str): output_dir = Path(f"./outputs/{job_id}") output_dir.mkdir(parents=True, exist_ok=True) + # Initialize DB once + init_db() + + # Check cache + self.update_progress(5, 100, "Checking cache") + try: + # Compute file pHash + y, sr = librosa.load(audio_file_path, sr=22050) # Use 22050Hz for pHash consistency + if len(y) > 0: + mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20) + mfcc_norm = (mfcc - mfcc.min()) / (mfcc.max() - mfcc.min() + 1e-6) * 255 + mfcc_img = Image.fromarray(mfcc_norm.astype(np.uint8)) + file_phash = str(imagehash.phash(mfcc_img)) + + cached_data = get_cached_progression(file_phash) + if cached_data: + self.update_progress(100, 100, "Found cached result") + # Save cached data to output file + chord_output_path = output_dir / "chord_progression.json" + with open(chord_output_path, 'w') as f: + f.write(cached_data) + + return { + 'jobId': job_id, + 'transcriptionUrl': None, + 'separatedAudioUrl': None, + 'chordProgressionUrl': f'/outputs/{job_id}/chord_progression.json', + 'format': 'json' + } + except Exception as e: + print(f"Warning: Cache check failed: {e}") + file_phash = None + # Step 1: Audio separation self.update_progress(10, 100, "Separating audio") @@ -347,12 +388,43 @@ def e2e_base_ready_task(self, audio_file_path: str, instrument: str): time_windows = analyzer.get_time_windows(all_notes_flat, window_size=1.0) chords = [] - for window_start, window_notes in time_windows: - notes = analyzer.group_simultaneous_notes(window_notes) - if notes: - chord = detector.detect_chord_from_midi_notes(notes[0]) - if chord: - chords.append(chord) + + # Load audio for MFCC extraction + y, sr = librosa.load(str(separated_audio_path), sr=sample_rate) + + # Use a single connection for all inserts + with db_cursor() as (conn, cursor): + for window_start, window_notes in time_windows: + notes = analyzer.group_simultaneous_notes(window_notes) + if notes: + chord = detector.detect_chord_from_midi_notes(notes[0]) + if chord: + chords.append(chord) + + # Extract MFCC and pHash + try: + # Extract audio segment (window_size=1.0) + start_sample = int(window_start * sr) + end_sample = int((window_start + 1.0) * sr) + + if start_sample < len(y): + segment = y[start_sample:min(end_sample, len(y))] + + if len(segment) > 0: + # Compute MFCC + mfcc = librosa.feature.mfcc(y=segment, sr=sr, n_mfcc=20) + + # Normalize MFCC to 0-255 for image conversion + mfcc_norm = (mfcc - mfcc.min()) / (mfcc.max() - mfcc.min() + 1e-6) * 255 + mfcc_img = Image.fromarray(mfcc_norm.astype(np.uint8)) + + # Compute pHash + phash = str(imagehash.phash(mfcc_img)) + + # Save to DB + save_phash_chord_pair(phash, str(chord), cursor=cursor) + except Exception as e: + print(f"Warning: Failed to generate pHash for chord {chord}: {e}") key_detector = KeyDetector() all_notes = all_notes_flat @@ -365,8 +437,16 @@ def e2e_base_ready_task(self, audio_file_path: str, instrument: str): 'key': str(key) if key else None, 'chords': [{'symbol': str(chord), 'duration': 1.0} for chord in chords] } + json_str = json.dumps(progression_data, indent=2) with open(chord_output_path, 'w') as f: - json.dump(progression_data, f, indent=2) + f.write(json_str) + + # Save to cache if we have a file pHash + if file_phash: + try: + save_cached_progression(file_phash, json_str) + except Exception as e: + print(f"Warning: Failed to save to cache: {e}") chord_progression_url = f'/outputs/{job_id}/chord_progression.json' diff --git a/requirements.txt b/requirements.txt index 1a62543..50e38f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,3 +42,5 @@ jams python-dotenv deprecated onnx>=1.19.0 +ImageHash +Pillow