Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3835dc2
Update README.md
Nishanth-nishu Jun 17, 2025
bbc0e91
Update model.py
Nishanth-nishu Jul 24, 2025
5e22019
Update utils.py
Nishanth-nishu Jul 24, 2025
ebf6577
Update molecular_graph.py
Nishanth-nishu Jul 24, 2025
aa02b7b
Update main.py
Nishanth-nishu Jul 24, 2025
b9d46de
Update train.py
Nishanth-nishu Jul 24, 2025
5411125
Update model.py
Nishanth-nishu Jul 24, 2025
78f19cc
Update model.py
Nishanth-nishu Jul 24, 2025
921f74d
added evaluation function
Nishanth-nishu Jul 24, 2025
3427ac7
Update main.py
Nishanth-nishu Jul 24, 2025
92dd891
Update main.py
Nishanth-nishu Jul 24, 2025
8e8f609
changed from np to tensor
Nishanth-nishu Jul 24, 2025
836adba
Update train.py
Nishanth-nishu Jul 24, 2025
1097a3b
Update model.py
Nishanth-nishu Jul 24, 2025
0c19302
Update train.py
Nishanth-nishu Jul 24, 2025
57f4373
Update model.py
Nishanth-nishu Jul 24, 2025
2ee917b
Update train.py
Nishanth-nishu Jul 24, 2025
a72cba2
Update train.py
Nishanth-nishu Jul 24, 2025
32230e5
Update main.py
Nishanth-nishu Jul 24, 2025
6f1a596
logger added
Nishanth-nishu Jul 24, 2025
f9631ef
cahning it to tesing phase
Nishanth-nishu Jul 31, 2025
8c2f453
testing phase
Nishanth-nishu Jul 31, 2025
8f03e38
testing phase
Nishanth-nishu Jul 31, 2025
c108b50
Create run_model.py
Nishanth-nishu Jul 31, 2025
346bc00
Create train.py
Nishanth-nishu Jul 31, 2025
f17b147
Create main.py
Nishanth-nishu Jul 31, 2025
14734a1
Update main.py
Nishanth-nishu Jul 31, 2025
a6295a6
Update main.py
Nishanth-nishu Aug 1, 2025
0d89dda
Update train.py
Nishanth-nishu Aug 1, 2025
42853df
Update model.py
Nishanth-nishu Aug 1, 2025
8bbe6f9
Update molecular_graph.py
Nishanth-nishu Aug 1, 2025
b66416a
dgl api changed
Nishanth-nishu Aug 1, 2025
b5cd407
Update train.py
Nishanth-nishu Aug 1, 2025
703b4d2
Update main.py
Nishanth-nishu Aug 1, 2025
8b0dcad
Update main.py
Nishanth-nishu Aug 1, 2025
58bdd80
Update main.py
Nishanth-nishu Aug 1, 2025
d8b3190
Update molecular_graph.py
Nishanth-nishu Aug 1, 2025
f1406af
Update molecular_graph.py
Nishanth-nishu Aug 1, 2025
27f97b5
removed normalization
Nishanth-nishu Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CIGIN_V2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
`$ conda install -c rdkit rdkit==2019.03.1`
* Installing other dependencies:\
`$ conda install -c pytorch pytorch `\
`$ pip install dgl` (Please check [here](https://docs.dgl.ai/en/0.4.x/install/) for
`$ pip install dgl` (Please check [here](https://www.dgl.ai/pages/start.html) for
installing for different cuda builds)\
`$ pip install numpy`\
`$ pip install pandas`
Expand Down
39 changes: 27 additions & 12 deletions CIGIN_V2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import os
import argparse
from sklearn.model_selection import train_test_split

# rdkit imports
from rdkit import RDLogger
Expand All @@ -14,12 +15,12 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch

#dgl imports
# dgl imports
import dgl

# local imports
from model import CIGINModel
from train import train
from train import train, get_metrics
from molecular_graph import get_graph_from_smile
from utils import *

Expand Down Expand Up @@ -53,8 +54,8 @@ def collate(samples):
solute_graphs, solvent_graphs, labels = map(list, zip(*samples))
solute_graphs = dgl.batch(solute_graphs)
solvent_graphs = dgl.batch(solvent_graphs)
solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes)
solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes)
solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes().tolist())
solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes().tolist())
return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels


Expand All @@ -66,40 +67,54 @@ def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):

solute = self.dataset.loc[idx]['SoluteSMILES']
solute = self.dataset.iloc[idx]['SoluteSMILES']
mol = Chem.MolFromSmiles(solute)
mol = Chem.AddHs(mol)
solute = Chem.MolToSmiles(mol)
solute_graph = get_graph_from_smile(solute)

solvent = self.dataset.loc[idx]['SolventSMILES']
solvent = self.dataset.iloc[idx]['SolventSMILES']
mol = Chem.MolFromSmiles(solvent)
mol = Chem.AddHs(mol)
solvent = Chem.MolToSmiles(mol)

solvent_graph = get_graph_from_smile(solvent)
delta_g = self.dataset.loc[idx]['DeltaGsolv']

delta_g = self.dataset.iloc[idx]['delGsolv']
# Normalize delta_g
return [solute_graph, solvent_graph, [delta_g]]


def main():
train_df = pd.read_csv('data/train.csv', sep=";")
valid_df = pd.read_csv('data/valid.csv', sep=";")
# Load and split data
df = pd.read_csv('https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/refs/heads/master/CIGIN_V2/data/whole_data.csv')
df.columns = df.columns.str.strip()

train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
train_df, valid_df = train_test_split(train_df, test_size=0.111, random_state=42)

train_dataset = Dataclass(train_df)
valid_dataset = Dataclass(valid_df)
test_dataset = Dataclass(test_df)

train_loader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, collate_fn=collate, batch_size=128)
test_loader = DataLoader(test_dataset, collate_fn=collate, batch_size=128)

# Initialize model
model = CIGINModel(interaction=interaction)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True)

# Train model
train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name)

# Evaluate on test data
model.eval()
loss, mae_loss = get_metrics(model, test_loader)
print(f"Model performance on the testing data: Loss: {loss}, MAE_Loss: {mae_loss}")


if __name__ == '__main__':
main()
46 changes: 23 additions & 23 deletions CIGIN_V2/molecular_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
from dgl import DGLGraph
import dgl
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors as rdDesc
from utils import one_of_k_encoding_unk, one_of_k_encoding

import torch

def get_atom_features(atom, stereo, features, explicit_H=False):
"""
Expand All @@ -24,28 +24,21 @@ def get_atom_features(atom, stereo, features, explicit_H=False):
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D])
atom_features += [int(i) for i in list("{0:06b}".format(features))]

if not explicit_H:
atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])

try:
atom_features += one_of_k_encoding_unk(stereo, ['R', 'S'])
atom_features += [atom.HasProp('_ChiralityPossible')]
except Exception as e:

atom_features += [False, False
] + [atom.HasProp('_ChiralityPossible')]

atom_features += [False, False] + [atom.HasProp('_ChiralityPossible')]
return np.array(atom_features)


def get_bond_features(bond):
"""
Method that computes bond level features from rdkit bond object
:param bond: rdkit bond object
:return: bond features, 1d numpy array
"""

bond_type = bond.GetBondType()
bond_feats = [
bond_type == Chem.rdchem.BondType.SINGLE, bond_type == Chem.rdchem.BondType.DOUBLE,
Expand All @@ -54,43 +47,50 @@ def get_bond_features(bond):
bond.IsInRing()
]
bond_feats += one_of_k_encoding_unk(str(bond.GetStereo()), ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])

return np.array(bond_feats)


def get_graph_from_smile(molecule_smile):
"""
Method that constructs a molecular graph with nodes being the atoms
and bonds being the edges.
:param molecule_smile: SMILE sequence
:return: DGL graph object, Node features and Edge features
"""

G = DGLGraph()
molecule = Chem.MolFromSmiles(molecule_smile)
features = rdDesc.GetFeatureInvariants(molecule)

stereo = Chem.FindMolChiralCenters(molecule)
chiral_centers = [0] * molecule.GetNumAtoms()
for i in stereo:
chiral_centers[i[0]] = i[1]

G.add_nodes(molecule.GetNumAtoms())

# Create graph with modern DGL API
G = dgl.graph([], num_nodes=molecule.GetNumAtoms())
node_features = []
edge_features = []
edges_src = []
edges_dst = []

for i in range(molecule.GetNumAtoms()):

atom_i = molecule.GetAtomWithIdx(i)
atom_i_features = get_atom_features(atom_i, chiral_centers[i], features[i])
node_features.append(atom_i_features)

for j in range(molecule.GetNumAtoms()):
bond_ij = molecule.GetBondBetweenAtoms(i, j)
if bond_ij is not None:
G.add_edge(i, j)
edges_src.append(i)
edges_dst.append(j)
bond_features_ij = get_bond_features(bond_ij)
edge_features.append(bond_features_ij)

G.ndata['x'] = np.array(node_features)
G.edata['w'] = np.array(edge_features)

# Add edges to graph
if edges_src:
G.add_edges(edges_src, edges_dst)

# MINIMAL PERFORMANCE FIX: Convert to numpy array first, then to tensor
G.ndata['x'] = torch.tensor(np.array(node_features))
if edge_features: # Only if edges exist
G.edata['w'] = torch.tensor(np.array(edge_features))
else:
G.edata['w'] = torch.tensor([]) # Empty tensor for molecules with no bonds

return G
34 changes: 26 additions & 8 deletions CIGIN_V2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,45 @@

loss_fn = torch.nn.MSELoss()
mae_loss_fn = torch.nn.L1Loss()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

def evaluate_model(model, dataloader):
model.eval()
preds, targets = [], []
with torch.no_grad():
for samples in dataloader:
outputs, _ = model([samples[0].to(device), samples[1].to(device),
samples[2].to(device), samples[3].to(device)])
preds.extend(outputs.cpu().numpy())
targets.extend(samples[4])
preds, targets = np.array(preds), np.array(targets)
rmse = np.sqrt(np.mean((preds - targets) ** 2))
return rmse

def get_metrics(model, data_loader):
valid_outputs = []
valid_labels = []
valid_loss = []
valid_mae_loss = []
for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in tqdm(data_loader):
for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in data_loader:
outputs, i_map = model(
[solute_graphs.to(device), solvent_graphs.to(device), torch.tensor(solute_lens).to(device),
torch.tensor(solvent_lens).to(device)])
loss = loss_fn(outputs, torch.tensor(labels).to(device).float())
mae_loss = mae_loss_fn(outputs, torch.tensor(labels).to(device).float())

# MINIMAL FIX: Convert targets to proper tensor shape
targets_tensor = torch.tensor(labels).to(device).float().view(-1, 1)
loss = loss_fn(outputs, targets_tensor)
mae_loss = mae_loss_fn(outputs, targets_tensor)

valid_outputs += outputs.cpu().detach().numpy().tolist()
valid_loss.append(loss.cpu().detach().numpy())
valid_mae_loss.append(mae_loss.cpu().detach().numpy())
valid_labels += labels

loss = np.mean(np.array(valid_loss).flatten())
mae_loss = np.mean(np.array(valid_mae_loss).flatten())
return loss, mae_loss


def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name):
best_val_loss = 100
for epoch in range(max_epochs):
Expand All @@ -43,7 +56,11 @@ def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, p
[samples[0].to(device), samples[1].to(device), torch.tensor(samples[2]).to(device),
torch.tensor(samples[3]).to(device)])
l1_norm = torch.norm(interaction_map, p=2) * 1e-4
loss = loss_fn(outputs, torch.tensor(samples[4]).to(device).float()) + l1_norm

# MINIMAL FIX: Convert targets to proper tensor shape
targets_tensor = torch.tensor(samples[4]).to(device).float().view(-1, 1)
loss = loss_fn(outputs, targets_tensor) + l1_norm

loss.backward()
optimizer.step()
loss = loss - l1_norm
Expand All @@ -57,4 +74,5 @@ def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, p
val_loss) + " MAE Val_loss " + str(mae_loss))
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), "./runs/run-" + str(project_name) + "/models/best_model.tar")
# MINIMAL FIX: Fixed path format
torch.save(model.state_dict(), "./runs/run_" + str(project_name) + "/models/best_model.tar")
2 changes: 2 additions & 0 deletions CIGIN_V2/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np



def one_of_k_encoding(x, allowable_set):
if x not in allowable_set:
raise Exception("input {0} not in allowable set{1}:".format(
Expand Down
79 changes: 79 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pandas as pd
import numpy as np
from train import run_kfold_cv
import warnings
warnings.filterwarnings("ignore")

def load_and_preprocess_data(csv_path):
"""Load and preprocess the dataset"""
# Load the CSV file
df = pd.read_csv(csv_path)

# Strip whitespace from column names as requested
df.columns = df.columns.str.strip()

# Strip whitespace from string columns
for col in df.columns:
if df[col].dtype == 'object':
df[col] = df[col].str.strip()

# Remove any rows with missing values
df = df.dropna()

# Filter out problematic SMILES if any
df = df[df['SoluteSMILES'].str.len() > 0]
df = df[df['SolventSMILES'].str.len() > 0]

print(f"Dataset loaded: {len(df)} samples")
print(f"Unique solutes: {df['SoluteSMILES'].nunique()}")
print(f"Unique solvents: {df['SolventSMILES'].nunique()}")
print(f"Solvation free energy range: {df['delGsolv'].min():.2f} to {df['delGsolv'].max():.2f} kcal/mol")

return df

def main():
"""Main training function following CIGIN paper methodology"""
print("CIGIN Model Training")
print("=" * 50)

# Check GPU availability
import torch
if torch.cuda.is_available():
print(f"GPU Available: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("GPU Not Available - Using CPU")
print("=" * 50)

# Load dataset - replace with your actual CSV path
csv_path = "https://github.com/adithyamauryakr/CIGIN-DevaLab/raw/master/CIGIN_V2/data/whole_data.csv"

try:
# Try to load from URL first
df = pd.read_csv(csv_path)
except:
# If URL fails, try local file
print("Loading from URL failed, trying local file...")
df = pd.read_csv("whole_data.csv")

# Preprocess data
data = load_and_preprocess_data(csv_path if 'df' in locals() else "whole_data.csv")
df = data

# Run k-fold cross validation as described in the paper
# Paper mentions: "10-fold cross validation scheme was used to assess the model"
# "We made 5 such 10 cross validation splits and trained our model independently"
print("\nStarting 10-fold cross validation (5 independent runs)...")

mean_rmse, std_rmse = run_kfold_cv(df, k=10, n_runs=5)

print("\n" + "=" * 50)
print("FINAL RESULTS")
print("=" * 50)
print(f"CIGIN Model Performance:")
print(f"RMSE: {mean_rmse:.2f} ± {std_rmse:.2f} kcal/mol")
print("\nPaper reported RMSE: 0.57 ± 0.10 kcal/mol")
print("=" * 50)

if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions scripts/run_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
from models import Cigin
from molecular_graph import ConstructMolecularGraph

# Sample SMILES strings (you can change these)
solute = "CCO" # Ethanol
solvent = "O=C=O" # Carbon dioxide

# Load model
model = Cigin().to("cuda" if torch.cuda.is_available() else "cpu")

# Run model forward pass
with torch.no_grad():
prediction, interaction_map = model(solute, solvent)

print("Prediction (Solubility):", prediction.item())
print("Interaction Map Shape:", interaction_map.shape)
Loading