diff --git a/tfold/deploy/psp_featurizer.py b/tfold/deploy/psp_featurizer.py index 1809187..91f91e3 100644 --- a/tfold/deploy/psp_featurizer.py +++ b/tfold/deploy/psp_featurizer.py @@ -53,20 +53,70 @@ def restore(cls, path): logging.info('model weights restored from %s', path) return model + def _normalize_msa(self, msa): + """ + Normalizes the MSA so that all sequences have the same length as the query (first sequence). + If a sequence is shorter, it will be padded with gap characters ('-'). + If it is longer, it will be truncated. + """ + target_length = len(msa[0].strip()) + normalized = [] + for seq in msa: + s = seq.strip() + if len(s) < target_length: + s = s + "-" * (target_length - len(s)) + elif len(s) > target_length: + s = s[:target_length] + normalized.append(s) + return normalized + + def _normalize_deletion_matrix(self, deletion_matrix, msa): + """ + Normalizes the deletion matrix so that each row has the same length as the corresponding MSA sequence. + Each row is converted to a list of integers (using 0 for non-digit characters) + and padded or trimmed to match the length of the query sequence (assumed equal to len(msa[0])). + """ + normalized = [] + target_length = len(msa[0]) + for d, seq in zip(deletion_matrix, msa): + if isinstance(d, str): + # Convert each character: if digit, cast to int; otherwise 0. + row = [int(ch) if ch.isdigit() else 0 for ch in d] + elif isinstance(d, (list, tuple)): + try: + row = [int(x) for x in d] + except Exception: + row = [0] * len(seq) + else: + row = [0] * len(seq) + # Pad or truncate to target_length. + if len(row) < target_length: + row.extend([0] * (target_length - len(row))) + elif len(row) > target_length: + row = row[:target_length] + normalized.append(row) + return normalized + def _get_feature_dict(self, msa_path, idx_resd_beg=None, idx_resd_end=None): if os.path.exists(msa_path): with open(msa_path) as f: MSA, deletion_matrix = parse_a3m(f.read()) else: - raise ValueError(f' {msa_path} is not existed') + raise ValueError(f' {msa_path} does not exist') feature_dict = {} - query_seq = MSA[0] + # Assume the first sequence is the query. + query_seq = MSA[0].strip() + if idx_resd_beg is not None and idx_resd_end is not None: query_seq = query_seq[idx_resd_beg:idx_resd_end] MSA = [s[idx_resd_beg:idx_resd_end] for s in MSA] deletion_matrix = [d[idx_resd_beg:idx_resd_end] for d in deletion_matrix] + # Normalize both the MSA and the deletion matrix. + MSA = self._normalize_msa(MSA) + deletion_matrix = self._normalize_deletion_matrix(deletion_matrix, MSA) + feature_dict.update(data_pipeline.make_sequence_features(query_seq, 'test', len(query_seq))) feature_dict.update(data_pipeline.make_msa_features([MSA], [deletion_matrix])) feature_dict.update(template_feats_placeholder()) @@ -85,7 +135,8 @@ def forward(self, try: process_feature_dict = feature_processor.process_features(feature_dict, mode='predict') except IndexError as e: - logging.error(f'Fail to parse idx_resd_beg: {idx_resd_beg}, idx_resd_end: {idx_resd_end}...') + logging.error( + f'Failed to parse features with idx_resd_beg: {idx_resd_beg}, idx_resd_end: {idx_resd_end}...') raise e process_feature_dict = { k: torch.as_tensor(v, device=self.device) @@ -127,8 +178,9 @@ def _forward_impl(self, batch): # Enable grad iff we're training and it's the final recycling layer is_final_iter = cycle_no == (num_iters - 1) with torch.no_grad(): - # Run the next iteration of the model - outputs, m_1_prev, z_prev, x_prev = self.model.iteration(feats, prevs, _recycle=(num_iters > 1)) + outputs, m_1_prev, z_prev, x_prev = self.model.iteration( + feats, prevs, _recycle=(num_iters > 1) + ) feat_tns = [ outputs['msa'].clone().cpu(), outputs['pair'].clone().cpu(),