Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions finetune_csv/configs/config_us_5min.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Kronos fine-tuning config for US equities, 5-minute bars
# Generated for use with finetune_tokenizer.py + finetune_base_model.py

data:
# Path to the CSV produced by fetch_us_data.py
# Override with --data_path or set here directly.
data_path: "data/us_5min.csv"

# How many past bars the model sees as context (max 512)
lookback_window: 200

# How many future bars are in each training sample
predict_window: 20

# Must match Kronos max_context (don't change)
max_context: 512

# Clip normalised values beyond this range (same as pretrain)
clip: 5.0

# Chronological split — no shuffling across time boundaries
train_ratio: 0.80
val_ratio: 0.10
test_ratio: 0.10

training:
# Tokenizer: fine-tune the VQ-VAE encoder/decoder on US price distributions
tokenizer_epochs: 20

# Predictor: fine-tune the autoregressive Transformer
basemodel_epochs: 15

# Reduce batch_size if you hit OOM on CPU/MPS
batch_size: 16

log_interval: 50
num_workers: 2
seed: 42

# Tokenizer LR — slightly lower than default to preserve pretrained codebook
tokenizer_learning_rate: 0.0001

# Predictor LR — keep small to avoid catastrophic forgetting
predictor_learning_rate: 0.000004

adam_beta1: 0.9
adam_beta2: 0.95
adam_weight_decay: 0.1

accumulation_steps: 2 # effective batch = batch_size * accumulation_steps

model_paths:
# HuggingFace Hub IDs (downloaded and cached on first run)
# or local absolute paths if you've already downloaded them
pretrained_tokenizer: "NeoQuasar/Kronos-Tokenizer-base"
pretrained_predictor: "NeoQuasar/Kronos-small"

exp_name: "us_5min_finetune"
base_path: "finetuned"

# Leave empty — auto-generated as {base_path}/{exp_name}/...
base_save_path: ""
finetuned_tokenizer: ""

tokenizer_save_name: "tokenizer"
basemodel_save_name: "basemodel"

experiment:
name: "kronos_us_5min"
description: "Fine-tune Kronos on US equity 5-min OHLCV data"
use_comet: false

# Set either to false to skip that stage (e.g. if tokenizer already trained)
train_tokenizer: true
train_basemodel: true
skip_existing: false

# Start from pretrained weights (recommended — keeps global knowledge)
pre_trained_tokenizer: true
pre_trained_predictor: true

device:
use_cuda: true
device_id: 0
87 changes: 87 additions & 0 deletions finetune_csv/fetch_us_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Fetch 5-minute OHLCV data for one or more US tickers via yfinance
and save in the format expected by CustomKlineDataset.

Required columns: timestamps, open, high, low, close, volume, amount

Usage:
python fetch_us_data.py --tickers AAPL MSFT TSLA --output data/us_5min.csv
python fetch_us_data.py --tickers AAPL --period 60d --output data/AAPL_5min.csv
"""

import argparse
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import yfinance as yf


def fetch_ticker(ticker: str, period: str) -> pd.DataFrame:
raw = yf.download(ticker, period=period, interval="5m", progress=False, auto_adjust=True)
if raw.empty:
print(f" WARNING: no data returned for {ticker}, skipping.")
return pd.DataFrame()

# Flatten MultiIndex columns if present
if isinstance(raw.columns, pd.MultiIndex):
raw.columns = raw.columns.get_level_values(0)

raw = raw.rename(columns={
"Open": "open", "High": "high", "Low": "low",
"Close": "close", "Volume": "volume",
})

# Keep only regular market hours (09:30–16:00 ET)
if raw.index.tz is not None:
raw.index = raw.index.tz_convert("America/New_York")
raw = raw.between_time("09:30", "16:00")

raw = raw[["open", "high", "low", "close", "volume"]].dropna()

# Derive 'amount' = volume * avg(ohlc) — a reasonable proxy
raw["amount"] = raw["volume"] * (raw["open"] + raw["high"] + raw["low"] + raw["close"]) / 4

raw = raw.reset_index().rename(columns={"Datetime": "timestamps", "index": "timestamps"})
raw["timestamps"] = pd.to_datetime(raw["timestamps"]).dt.tz_localize(None) # strip tz for CSV

raw = raw[["timestamps", "open", "high", "low", "close", "volume", "amount"]]
raw = raw.sort_values("timestamps").reset_index(drop=True)

print(f" {ticker}: {len(raw)} bars [{raw['timestamps'].iloc[0]} → {raw['timestamps'].iloc[-1]}]")
return raw


def main():
parser = argparse.ArgumentParser(description="Fetch US 5-min data for Kronos fine-tuning")
parser.add_argument("--tickers", nargs="+", required=True, help="One or more ticker symbols, e.g. AAPL MSFT")
parser.add_argument("--period", type=str, default="60d",
help="yfinance period string (max 60d for 5m). Default: 60d")
parser.add_argument("--output", type=str, default="data/us_5min.csv",
help="Output CSV path. Default: data/us_5min.csv")
args = parser.parse_args()

all_frames = []
print(f"Fetching {len(args.tickers)} ticker(s) ...")
for ticker in args.tickers:
df = fetch_ticker(ticker.upper(), args.period)
if not df.empty:
all_frames.append(df)

if not all_frames:
raise RuntimeError("No data fetched — check your ticker symbols and internet connection.")

combined = pd.concat(all_frames, ignore_index=True)
combined = combined.sort_values("timestamps").reset_index(drop=True)

import os
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
combined.to_csv(args.output, index=False)

print(f"\nSaved {len(combined):,} rows → {args.output}")
print(f"Date range : {combined['timestamps'].min()} → {combined['timestamps'].max()}")
print(f"Columns : {list(combined.columns)}")


if __name__ == "__main__":
main()
132 changes: 132 additions & 0 deletions finetune_csv/run_us_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
End-to-end fine-tuning runner for US equity 5-min data.

Steps:
1. Fetch data from yfinance (skipped if CSV already exists)
2. Fine-tune KronosTokenizer
3. Fine-tune Kronos predictor

Usage:
# Full pipeline — fetch AAPL + MSFT + TSLA then train
python run_us_finetune.py --tickers AAPL MSFT TSLA

# Use an existing CSV, custom config
python run_us_finetune.py --data_path data/us_5min.csv --config configs/config_us_5min.yaml

# Skip tokenizer training (already done), only train predictor
python run_us_finetune.py --tickers AAPL --skip_tokenizer
"""

import argparse
import os
import subprocess
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))


def run(cmd: list, desc: str):
print(f"\n{'='*60}")
print(f" {desc}")
print(f"{'='*60}")
print(f" CMD: {' '.join(cmd)}\n")
result = subprocess.run(cmd, check=True)
return result


def patch_config_data_path(config_path: str, data_path: str, out_path: str):
"""Write a copy of config_path with data.data_path replaced."""
import yaml
with open(config_path) as f:
cfg = yaml.safe_load(f)
cfg["data"]["data_path"] = os.path.abspath(data_path)
with open(out_path, "w") as f:
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True, indent=2)


def main():
parser = argparse.ArgumentParser(description="Kronos US stock fine-tuning pipeline")
parser.add_argument("--tickers", nargs="+", default=[],
help="Ticker symbols to fetch (e.g. AAPL MSFT TSLA)")
parser.add_argument("--period", default="60d",
help="yfinance period for data fetch (max 60d for 5m). Default: 60d")
parser.add_argument("--data_path", default="data/us_5min.csv",
help="Path to existing or target CSV. Default: data/us_5min.csv")
parser.add_argument("--config", default="configs/config_us_5min.yaml",
help="YAML config path. Default: configs/config_us_5min.yaml")
parser.add_argument("--skip_tokenizer", action="store_true",
help="Skip tokenizer training (use if already fine-tuned)")
parser.add_argument("--skip_predictor", action="store_true",
help="Skip predictor training")
args = parser.parse_args()

script_dir = os.path.dirname(os.path.abspath(__file__))
data_path = args.data_path if os.path.isabs(args.data_path) \
else os.path.join(script_dir, args.data_path)

# ------------------------------------------------------------------ #
# 1. Fetch data
# ------------------------------------------------------------------ #
if args.tickers:
run(
[sys.executable, os.path.join(script_dir, "fetch_us_data.py"),
"--tickers"] + args.tickers + [
"--period", args.period,
"--output", data_path],
f"Fetching 5-min data for: {', '.join(args.tickers)}"
)
else:
if not os.path.exists(data_path):
raise FileNotFoundError(
f"No tickers provided and data file not found: {data_path}\n"
"Pass --tickers AAPL ... to fetch data first."
)
print(f"Using existing data file: {data_path}")

# ------------------------------------------------------------------ #
# 2. Patch config with the resolved data path
# ------------------------------------------------------------------ #
config_src = args.config if os.path.isabs(args.config) \
else os.path.join(script_dir, args.config)
config_run = os.path.join(script_dir, "configs", "_run_config.yaml")
os.makedirs(os.path.dirname(config_run), exist_ok=True)
patch_config_data_path(config_src, data_path, config_run)
print(f"Active config written to: {config_run}")

# ------------------------------------------------------------------ #
# 3. Fine-tune tokenizer
# ------------------------------------------------------------------ #
if not args.skip_tokenizer:
run(
[sys.executable, os.path.join(script_dir, "finetune_tokenizer.py"),
"--config", config_run],
"Stage 1 / 2 — Fine-tuning KronosTokenizer"
)
else:
print("\nSkipping tokenizer training (--skip_tokenizer set).")

# ------------------------------------------------------------------ #
# 4. Fine-tune predictor
# ------------------------------------------------------------------ #
if not args.skip_predictor:
run(
[sys.executable, os.path.join(script_dir, "finetune_base_model.py"),
"--config", config_run],
"Stage 2 / 2 — Fine-tuning Kronos predictor"
)
else:
print("\nSkipping predictor training (--skip_predictor set).")

print("\n" + "="*60)
print(" Fine-tuning complete!")
print(f" Fine-tuned models saved under: finetuned/us_5min_finetune/")
print(" To predict with the fine-tuned model, pass:")
print(" --tokenizer finetuned/us_5min_finetune/tokenizer/best_model")
print(" --model finetuned/us_5min_finetune/basemodel/best_model")
print(" to predict_us_stock.py")
print("="*60 + "\n")


if __name__ == "__main__":
main()
Loading