Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions tfold/deploy/psp_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,70 @@ def restore(cls, path):
logging.info('model weights restored from %s', path)
return model

def _normalize_msa(self, msa):
"""
Normalizes the MSA so that all sequences have the same length as the query (first sequence).
If a sequence is shorter, it will be padded with gap characters ('-').
If it is longer, it will be truncated.
"""
target_length = len(msa[0].strip())
normalized = []
for seq in msa:
s = seq.strip()
if len(s) < target_length:
s = s + "-" * (target_length - len(s))
elif len(s) > target_length:
s = s[:target_length]
normalized.append(s)
return normalized

def _normalize_deletion_matrix(self, deletion_matrix, msa):
"""
Normalizes the deletion matrix so that each row has the same length as the corresponding MSA sequence.
Each row is converted to a list of integers (using 0 for non-digit characters)
and padded or trimmed to match the length of the query sequence (assumed equal to len(msa[0])).
"""
normalized = []
target_length = len(msa[0])
for d, seq in zip(deletion_matrix, msa):
if isinstance(d, str):
# Convert each character: if digit, cast to int; otherwise 0.
row = [int(ch) if ch.isdigit() else 0 for ch in d]
elif isinstance(d, (list, tuple)):
try:
row = [int(x) for x in d]
except Exception:
row = [0] * len(seq)
else:
row = [0] * len(seq)
# Pad or truncate to target_length.
if len(row) < target_length:
row.extend([0] * (target_length - len(row)))
elif len(row) > target_length:
row = row[:target_length]
normalized.append(row)
return normalized

def _get_feature_dict(self, msa_path, idx_resd_beg=None, idx_resd_end=None):
if os.path.exists(msa_path):
with open(msa_path) as f:
MSA, deletion_matrix = parse_a3m(f.read())
else:
raise ValueError(f'<msa_path> {msa_path} is not existed')
raise ValueError(f'<msa_path> {msa_path} does not exist')

feature_dict = {}
query_seq = MSA[0]
# Assume the first sequence is the query.
query_seq = MSA[0].strip()

if idx_resd_beg is not None and idx_resd_end is not None:
query_seq = query_seq[idx_resd_beg:idx_resd_end]
MSA = [s[idx_resd_beg:idx_resd_end] for s in MSA]
deletion_matrix = [d[idx_resd_beg:idx_resd_end] for d in deletion_matrix]

# Normalize both the MSA and the deletion matrix.
MSA = self._normalize_msa(MSA)
deletion_matrix = self._normalize_deletion_matrix(deletion_matrix, MSA)

feature_dict.update(data_pipeline.make_sequence_features(query_seq, 'test', len(query_seq)))
feature_dict.update(data_pipeline.make_msa_features([MSA], [deletion_matrix]))
feature_dict.update(template_feats_placeholder())
Expand All @@ -85,7 +135,8 @@ def forward(self,
try:
process_feature_dict = feature_processor.process_features(feature_dict, mode='predict')
except IndexError as e:
logging.error(f'Fail to parse idx_resd_beg: {idx_resd_beg}, idx_resd_end: {idx_resd_end}...')
logging.error(
f'Failed to parse features with idx_resd_beg: {idx_resd_beg}, idx_resd_end: {idx_resd_end}...')
raise e
process_feature_dict = {
k: torch.as_tensor(v, device=self.device)
Expand Down Expand Up @@ -127,8 +178,9 @@ def _forward_impl(self, batch):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
with torch.no_grad():
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.model.iteration(feats, prevs, _recycle=(num_iters > 1))
outputs, m_1_prev, z_prev, x_prev = self.model.iteration(
feats, prevs, _recycle=(num_iters > 1)
)
feat_tns = [
outputs['msa'].clone().cpu(),
outputs['pair'].clone().cpu(),
Expand Down