From cc5bc08d31ca7bbd1731144f96e44647ace78f82 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 29 Aug 2024 21:07:45 +0200 Subject: [PATCH 01/46] move previous tests to integration dir --- tests/integration/__init__.py | 3 +++ tests/{ => integration}/testChebiData.py | 0 .../{ => integration}/testChebiDynamicDataSplits.py | 0 .../testCustomBalancedAccuracyMetric.py | 0 tests/{ => integration}/testCustomMacroF1Metric.py | 0 tests/{ => integration}/testPubChemData.py | 0 tests/{ => integration}/testTox21MolNetData.py | 0 .../test_data/ChEBIOver100_test/labels000.pt | Bin .../test_data/ChEBIOver100_test/labels001.pt | Bin .../test_data/ChEBIOver100_test/labels002.pt | Bin .../test_data/ChEBIOver100_test/labels003.pt | Bin .../test_data/ChEBIOver100_test/labels004.pt | Bin .../test_data/ChEBIOver100_test/labels005.pt | Bin .../test_data/ChEBIOver100_test/labels006.pt | Bin .../test_data/ChEBIOver100_test/labels007.pt | Bin .../test_data/ChEBIOver100_test/labels008.pt | Bin .../test_data/ChEBIOver100_test/labels009.pt | Bin .../test_data/ChEBIOver100_test/labels010.pt | Bin .../test_data/ChEBIOver100_test/labels011.pt | Bin .../test_data/ChEBIOver100_test/labels012.pt | Bin .../test_data/ChEBIOver100_test/labels013.pt | Bin .../test_data/ChEBIOver100_test/labels014.pt | Bin .../test_data/ChEBIOver100_test/labels015.pt | Bin .../test_data/ChEBIOver100_test/labels016.pt | Bin .../test_data/ChEBIOver100_test/labels017.pt | Bin .../test_data/ChEBIOver100_test/labels018.pt | Bin .../test_data/ChEBIOver100_test/labels019.pt | Bin .../test_data/ChEBIOver100_test/preds000.pt | Bin .../test_data/ChEBIOver100_test/preds001.pt | Bin .../test_data/ChEBIOver100_test/preds002.pt | Bin .../test_data/ChEBIOver100_test/preds003.pt | Bin .../test_data/ChEBIOver100_test/preds004.pt | Bin .../test_data/ChEBIOver100_test/preds005.pt | Bin .../test_data/ChEBIOver100_test/preds006.pt | Bin .../test_data/ChEBIOver100_test/preds007.pt | Bin .../test_data/ChEBIOver100_test/preds008.pt | Bin .../test_data/ChEBIOver100_test/preds009.pt | Bin .../test_data/ChEBIOver100_test/preds010.pt | Bin .../test_data/ChEBIOver100_test/preds011.pt | Bin .../test_data/ChEBIOver100_test/preds012.pt | Bin .../test_data/ChEBIOver100_test/preds013.pt | Bin .../test_data/ChEBIOver100_test/preds014.pt | Bin .../test_data/ChEBIOver100_test/preds015.pt | Bin .../test_data/ChEBIOver100_test/preds016.pt | Bin .../test_data/ChEBIOver100_test/preds017.pt | Bin .../test_data/ChEBIOver100_test/preds018.pt | Bin .../test_data/ChEBIOver100_test/preds019.pt | Bin 47 files changed, 3 insertions(+) create mode 100644 tests/integration/__init__.py rename tests/{ => integration}/testChebiData.py (100%) rename tests/{ => integration}/testChebiDynamicDataSplits.py (100%) rename tests/{ => integration}/testCustomBalancedAccuracyMetric.py (100%) rename tests/{ => integration}/testCustomMacroF1Metric.py (100%) rename tests/{ => integration}/testPubChemData.py (100%) rename tests/{ => integration}/testTox21MolNetData.py (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels000.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels001.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels002.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels003.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels004.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels005.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels006.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels007.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels008.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels009.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels010.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels011.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels012.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels013.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels014.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels015.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels016.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels017.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels018.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/labels019.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds000.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds001.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds002.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds003.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds004.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds005.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds006.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds007.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds008.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds009.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds010.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds011.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds012.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds013.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds014.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds015.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds016.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds017.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds018.pt (100%) rename tests/{ => integration}/test_data/ChEBIOver100_test/preds019.pt (100%) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..caa8759f --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,3 @@ +""" +This directory contains integration tests that cover the overall behavior of the data preprocessing tool. +""" diff --git a/tests/testChebiData.py b/tests/integration/testChebiData.py similarity index 100% rename from tests/testChebiData.py rename to tests/integration/testChebiData.py diff --git a/tests/testChebiDynamicDataSplits.py b/tests/integration/testChebiDynamicDataSplits.py similarity index 100% rename from tests/testChebiDynamicDataSplits.py rename to tests/integration/testChebiDynamicDataSplits.py diff --git a/tests/testCustomBalancedAccuracyMetric.py b/tests/integration/testCustomBalancedAccuracyMetric.py similarity index 100% rename from tests/testCustomBalancedAccuracyMetric.py rename to tests/integration/testCustomBalancedAccuracyMetric.py diff --git a/tests/testCustomMacroF1Metric.py b/tests/integration/testCustomMacroF1Metric.py similarity index 100% rename from tests/testCustomMacroF1Metric.py rename to tests/integration/testCustomMacroF1Metric.py diff --git a/tests/testPubChemData.py b/tests/integration/testPubChemData.py similarity index 100% rename from tests/testPubChemData.py rename to tests/integration/testPubChemData.py diff --git a/tests/testTox21MolNetData.py b/tests/integration/testTox21MolNetData.py similarity index 100% rename from tests/testTox21MolNetData.py rename to tests/integration/testTox21MolNetData.py diff --git a/tests/test_data/ChEBIOver100_test/labels000.pt b/tests/integration/test_data/ChEBIOver100_test/labels000.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels000.pt rename to tests/integration/test_data/ChEBIOver100_test/labels000.pt diff --git a/tests/test_data/ChEBIOver100_test/labels001.pt b/tests/integration/test_data/ChEBIOver100_test/labels001.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels001.pt rename to tests/integration/test_data/ChEBIOver100_test/labels001.pt diff --git a/tests/test_data/ChEBIOver100_test/labels002.pt b/tests/integration/test_data/ChEBIOver100_test/labels002.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels002.pt rename to tests/integration/test_data/ChEBIOver100_test/labels002.pt diff --git a/tests/test_data/ChEBIOver100_test/labels003.pt b/tests/integration/test_data/ChEBIOver100_test/labels003.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels003.pt rename to tests/integration/test_data/ChEBIOver100_test/labels003.pt diff --git a/tests/test_data/ChEBIOver100_test/labels004.pt b/tests/integration/test_data/ChEBIOver100_test/labels004.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels004.pt rename to tests/integration/test_data/ChEBIOver100_test/labels004.pt diff --git a/tests/test_data/ChEBIOver100_test/labels005.pt b/tests/integration/test_data/ChEBIOver100_test/labels005.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels005.pt rename to tests/integration/test_data/ChEBIOver100_test/labels005.pt diff --git a/tests/test_data/ChEBIOver100_test/labels006.pt b/tests/integration/test_data/ChEBIOver100_test/labels006.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels006.pt rename to tests/integration/test_data/ChEBIOver100_test/labels006.pt diff --git a/tests/test_data/ChEBIOver100_test/labels007.pt b/tests/integration/test_data/ChEBIOver100_test/labels007.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels007.pt rename to tests/integration/test_data/ChEBIOver100_test/labels007.pt diff --git a/tests/test_data/ChEBIOver100_test/labels008.pt b/tests/integration/test_data/ChEBIOver100_test/labels008.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels008.pt rename to tests/integration/test_data/ChEBIOver100_test/labels008.pt diff --git a/tests/test_data/ChEBIOver100_test/labels009.pt b/tests/integration/test_data/ChEBIOver100_test/labels009.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels009.pt rename to tests/integration/test_data/ChEBIOver100_test/labels009.pt diff --git a/tests/test_data/ChEBIOver100_test/labels010.pt b/tests/integration/test_data/ChEBIOver100_test/labels010.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels010.pt rename to tests/integration/test_data/ChEBIOver100_test/labels010.pt diff --git a/tests/test_data/ChEBIOver100_test/labels011.pt b/tests/integration/test_data/ChEBIOver100_test/labels011.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels011.pt rename to tests/integration/test_data/ChEBIOver100_test/labels011.pt diff --git a/tests/test_data/ChEBIOver100_test/labels012.pt b/tests/integration/test_data/ChEBIOver100_test/labels012.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels012.pt rename to tests/integration/test_data/ChEBIOver100_test/labels012.pt diff --git a/tests/test_data/ChEBIOver100_test/labels013.pt b/tests/integration/test_data/ChEBIOver100_test/labels013.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels013.pt rename to tests/integration/test_data/ChEBIOver100_test/labels013.pt diff --git a/tests/test_data/ChEBIOver100_test/labels014.pt b/tests/integration/test_data/ChEBIOver100_test/labels014.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels014.pt rename to tests/integration/test_data/ChEBIOver100_test/labels014.pt diff --git a/tests/test_data/ChEBIOver100_test/labels015.pt b/tests/integration/test_data/ChEBIOver100_test/labels015.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels015.pt rename to tests/integration/test_data/ChEBIOver100_test/labels015.pt diff --git a/tests/test_data/ChEBIOver100_test/labels016.pt b/tests/integration/test_data/ChEBIOver100_test/labels016.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels016.pt rename to tests/integration/test_data/ChEBIOver100_test/labels016.pt diff --git a/tests/test_data/ChEBIOver100_test/labels017.pt b/tests/integration/test_data/ChEBIOver100_test/labels017.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels017.pt rename to tests/integration/test_data/ChEBIOver100_test/labels017.pt diff --git a/tests/test_data/ChEBIOver100_test/labels018.pt b/tests/integration/test_data/ChEBIOver100_test/labels018.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels018.pt rename to tests/integration/test_data/ChEBIOver100_test/labels018.pt diff --git a/tests/test_data/ChEBIOver100_test/labels019.pt b/tests/integration/test_data/ChEBIOver100_test/labels019.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels019.pt rename to tests/integration/test_data/ChEBIOver100_test/labels019.pt diff --git a/tests/test_data/ChEBIOver100_test/preds000.pt b/tests/integration/test_data/ChEBIOver100_test/preds000.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds000.pt rename to tests/integration/test_data/ChEBIOver100_test/preds000.pt diff --git a/tests/test_data/ChEBIOver100_test/preds001.pt b/tests/integration/test_data/ChEBIOver100_test/preds001.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds001.pt rename to tests/integration/test_data/ChEBIOver100_test/preds001.pt diff --git a/tests/test_data/ChEBIOver100_test/preds002.pt b/tests/integration/test_data/ChEBIOver100_test/preds002.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds002.pt rename to tests/integration/test_data/ChEBIOver100_test/preds002.pt diff --git a/tests/test_data/ChEBIOver100_test/preds003.pt b/tests/integration/test_data/ChEBIOver100_test/preds003.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds003.pt rename to tests/integration/test_data/ChEBIOver100_test/preds003.pt diff --git a/tests/test_data/ChEBIOver100_test/preds004.pt b/tests/integration/test_data/ChEBIOver100_test/preds004.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds004.pt rename to tests/integration/test_data/ChEBIOver100_test/preds004.pt diff --git a/tests/test_data/ChEBIOver100_test/preds005.pt b/tests/integration/test_data/ChEBIOver100_test/preds005.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds005.pt rename to tests/integration/test_data/ChEBIOver100_test/preds005.pt diff --git a/tests/test_data/ChEBIOver100_test/preds006.pt b/tests/integration/test_data/ChEBIOver100_test/preds006.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds006.pt rename to tests/integration/test_data/ChEBIOver100_test/preds006.pt diff --git a/tests/test_data/ChEBIOver100_test/preds007.pt b/tests/integration/test_data/ChEBIOver100_test/preds007.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds007.pt rename to tests/integration/test_data/ChEBIOver100_test/preds007.pt diff --git a/tests/test_data/ChEBIOver100_test/preds008.pt b/tests/integration/test_data/ChEBIOver100_test/preds008.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds008.pt rename to tests/integration/test_data/ChEBIOver100_test/preds008.pt diff --git a/tests/test_data/ChEBIOver100_test/preds009.pt b/tests/integration/test_data/ChEBIOver100_test/preds009.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds009.pt rename to tests/integration/test_data/ChEBIOver100_test/preds009.pt diff --git a/tests/test_data/ChEBIOver100_test/preds010.pt b/tests/integration/test_data/ChEBIOver100_test/preds010.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds010.pt rename to tests/integration/test_data/ChEBIOver100_test/preds010.pt diff --git a/tests/test_data/ChEBIOver100_test/preds011.pt b/tests/integration/test_data/ChEBIOver100_test/preds011.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds011.pt rename to tests/integration/test_data/ChEBIOver100_test/preds011.pt diff --git a/tests/test_data/ChEBIOver100_test/preds012.pt b/tests/integration/test_data/ChEBIOver100_test/preds012.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds012.pt rename to tests/integration/test_data/ChEBIOver100_test/preds012.pt diff --git a/tests/test_data/ChEBIOver100_test/preds013.pt b/tests/integration/test_data/ChEBIOver100_test/preds013.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds013.pt rename to tests/integration/test_data/ChEBIOver100_test/preds013.pt diff --git a/tests/test_data/ChEBIOver100_test/preds014.pt b/tests/integration/test_data/ChEBIOver100_test/preds014.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds014.pt rename to tests/integration/test_data/ChEBIOver100_test/preds014.pt diff --git a/tests/test_data/ChEBIOver100_test/preds015.pt b/tests/integration/test_data/ChEBIOver100_test/preds015.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds015.pt rename to tests/integration/test_data/ChEBIOver100_test/preds015.pt diff --git a/tests/test_data/ChEBIOver100_test/preds016.pt b/tests/integration/test_data/ChEBIOver100_test/preds016.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds016.pt rename to tests/integration/test_data/ChEBIOver100_test/preds016.pt diff --git a/tests/test_data/ChEBIOver100_test/preds017.pt b/tests/integration/test_data/ChEBIOver100_test/preds017.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds017.pt rename to tests/integration/test_data/ChEBIOver100_test/preds017.pt diff --git a/tests/test_data/ChEBIOver100_test/preds018.pt b/tests/integration/test_data/ChEBIOver100_test/preds018.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds018.pt rename to tests/integration/test_data/ChEBIOver100_test/preds018.pt diff --git a/tests/test_data/ChEBIOver100_test/preds019.pt b/tests/integration/test_data/ChEBIOver100_test/preds019.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds019.pt rename to tests/integration/test_data/ChEBIOver100_test/preds019.pt From 5af03512863cb7b68193eb0698c899b762de721b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 29 Aug 2024 21:13:07 +0200 Subject: [PATCH 02/46] unit dir + test for ChemDataReader --- tests/unit/__init__.py | 4 ++ tests/unit/collators/__init__.py | 0 tests/unit/data_readers/__init__.py | 0 tests/unit/data_readers/testChemDataReader.py | 71 +++++++++++++++++++ tests/unit/dataset_classes/__init__.py | 0 5 files changed, 75 insertions(+) create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/collators/__init__.py create mode 100644 tests/unit/data_readers/__init__.py create mode 100644 tests/unit/data_readers/testChemDataReader.py create mode 100644 tests/unit/dataset_classes/__init__.py diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..6640a696 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,4 @@ +""" +This directory contains unit tests, which focus on individual functions and methods, ensuring they work as +expected in isolation. +""" diff --git a/tests/unit/collators/__init__.py b/tests/unit/collators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/data_readers/__init__.py b/tests/unit/data_readers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/data_readers/testChemDataReader.py b/tests/unit/data_readers/testChemDataReader.py new file mode 100644 index 00000000..bf3dea6e --- /dev/null +++ b/tests/unit/data_readers/testChemDataReader.py @@ -0,0 +1,71 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, ChemDataReader + + +class TestChemDataReader(unittest.TestCase): + """ + Unit tests for the ChemDataReader class. + """ + + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="C\nO\nN\n=\n1\n(", + ) + def setUp(self, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a ChemDataReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + self.reader = ChemDataReader(token_path="/mock/path") + # After initializing, self.reader.cache should now be set to ['C', 'O', 'N', '=', '1', '('] + self.assertEqual(self.reader.cache, ["C", "O", "N", "=", "1", "("]) + + def test_read_data(self) -> None: + """ + Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. + """ + raw_data = "CC(=O)NC1" + # Expected output as per the tokens already in the cache, and ")" getting added to it. + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 5, # = + EMBEDDING_OFFSET + 3, # O + EMBEDDING_OFFSET + 1, # N + EMBEDDING_OFFSET + 6, # ( + EMBEDDING_OFFSET + 2, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 4, # 1 + ] + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a SMILES string that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "[H-]" + + # Note: test methods within a TestCase class are not guaranteed to be executed in any specific order. + # Determine the index for the new token based on the current size of the cache. + index_for_last_token = len(self.reader.cache) + expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] + + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + # Verify that '[H-]' was added to the cache + self.assertIn("[H-]", self.reader.cache) + # Ensure it's at the correct index + self.assertEqual(self.reader.cache.index("[H-]"), index_for_last_token) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/__init__.py b/tests/unit/dataset_classes/__init__.py new file mode 100644 index 00000000..e69de29b From a0810a233dd319c7fcb18bb3684eacd3047796ef Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 29 Aug 2024 21:15:49 +0200 Subject: [PATCH 03/46] Test for DataReader --- tests/unit/data_readers/testDataReader.py | 51 +++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/unit/data_readers/testDataReader.py diff --git a/tests/unit/data_readers/testDataReader.py b/tests/unit/data_readers/testDataReader.py new file mode 100644 index 00000000..1a511b26 --- /dev/null +++ b/tests/unit/data_readers/testDataReader.py @@ -0,0 +1,51 @@ +import unittest +from typing import Any, Dict, List + +from chebai.preprocessing.reader import DataReader + + +class TestDataReader(unittest.TestCase): + """ + Unit tests for the DataReader class. + """ + + def setUp(self) -> None: + """ + Set up the test environment by initializing a DataReader instance. + """ + self.reader = DataReader() + + def test_to_data(self) -> None: + """ + Test the to_data method to ensure it correctly processes the input row + and formats it according to the expected output. + + This method tests the conversion of raw data into a processed format, + including extracting features, labels, ident, group, and additional + keyword arguments. + """ + features_list: List[int] = [10, 20, 30] + labels_list: List[bool] = [True, False, True] + ident_no: int = 123 + + row: Dict[str, Any] = { + "features": features_list, + "labels": labels_list, + "ident": ident_no, + "group": "group_data", + "additional_kwargs": {"extra_key": "extra_value"}, + } + + expected: Dict[str, Any] = { + "features": features_list, + "labels": labels_list, + "ident": ident_no, + "group": "group_data", + "extra_key": "extra_value", + } + + self.assertEqual(self.reader.to_data(row), expected) + + +if __name__ == "__main__": + unittest.main() From 1b3836d5c103a1455f41245b757a94acc0b3d5f5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 29 Aug 2024 23:09:53 +0200 Subject: [PATCH 04/46] tests for DeepChemReader --- .../{data_readers => readers}/__init__.py | 0 .../testChemDataReader.py | 8 +- .../testDataReader.py | 0 tests/unit/readers/testDeepChemDataReader.py | 80 +++++++++++++++++++ 4 files changed, 85 insertions(+), 3 deletions(-) rename tests/unit/{data_readers => readers}/__init__.py (100%) rename tests/unit/{data_readers => readers}/testChemDataReader.py (90%) rename tests/unit/{data_readers => readers}/testDataReader.py (100%) create mode 100644 tests/unit/readers/testDeepChemDataReader.py diff --git a/tests/unit/data_readers/__init__.py b/tests/unit/readers/__init__.py similarity index 100% rename from tests/unit/data_readers/__init__.py rename to tests/unit/readers/__init__.py diff --git a/tests/unit/data_readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py similarity index 90% rename from tests/unit/data_readers/testChemDataReader.py rename to tests/unit/readers/testChemDataReader.py index bf3dea6e..2bc525e1 100644 --- a/tests/unit/data_readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -8,6 +8,8 @@ class TestChemDataReader(unittest.TestCase): """ Unit tests for the ChemDataReader class. + + Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. """ @patch( @@ -30,7 +32,7 @@ def test_read_data(self) -> None: """ Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. """ - raw_data = "CC(=O)NC1" + raw_data = "CC(=O)NC1[Mg-2]" # Expected output as per the tokens already in the cache, and ")" getting added to it. expected_output: List[int] = [ EMBEDDING_OFFSET + 0, # C @@ -38,10 +40,11 @@ def test_read_data(self) -> None: EMBEDDING_OFFSET + 5, # = EMBEDDING_OFFSET + 3, # O EMBEDDING_OFFSET + 1, # N - EMBEDDING_OFFSET + 6, # ( + EMBEDDING_OFFSET + len(self.reader.cache), # ( EMBEDDING_OFFSET + 2, # C EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 4, # 1 + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] ] result = self.reader._read_data(raw_data) self.assertEqual(result, expected_output) @@ -53,7 +56,6 @@ def test_read_data_with_new_token(self) -> None: """ raw_data = "[H-]" - # Note: test methods within a TestCase class are not guaranteed to be executed in any specific order. # Determine the index for the new token based on the current size of the cache. index_for_last_token = len(self.reader.cache) expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] diff --git a/tests/unit/data_readers/testDataReader.py b/tests/unit/readers/testDataReader.py similarity index 100% rename from tests/unit/data_readers/testDataReader.py rename to tests/unit/readers/testDataReader.py diff --git a/tests/unit/readers/testDeepChemDataReader.py b/tests/unit/readers/testDeepChemDataReader.py new file mode 100644 index 00000000..c93e2592 --- /dev/null +++ b/tests/unit/readers/testDeepChemDataReader.py @@ -0,0 +1,80 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, DeepChemDataReader + + +class TestDeepChemDataReader(unittest.TestCase): + """ + Unit tests for the DeepChemDataReader class. + + Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. + """ + + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="C\nO\nc\n)", + ) + def setUp(self, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a DeepChemDataReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + self.reader = DeepChemDataReader(token_path="/mock/path") + # After initializing, self.reader.cache should now be set to ['C', 'O', 'c', ')'] + self.assertEqual(self.reader.cache, ["C", "O", "c", ")"]) + + def test_read_data(self) -> None: + """ + Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. + """ + raw_data = "c1ccccc1C(Br)(OC)I[Ni-2]" + + # Expected output as per the tokens already in the cache, and new tokens getting added to it. + expected_output: List[int] = [ + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + len(self.reader.cache), # 6 (new token) + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # Br (new token) + EMBEDDING_OFFSET + 3, # ) + EMBEDDING_OFFSET + 1, # O + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 3, # ) + EMBEDDING_OFFSET + 3, # ) + EMBEDDING_OFFSET + len(self.reader.cache) + 2, # I (new token) + EMBEDDING_OFFSET + len(self.reader.cache) + 3, # [Ni-2] (new token) + ] + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a SMILES string that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "[H-]" + + # Determine the index for the new token based on the current size of the cache. + index_for_last_token = len(self.reader.cache) + expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] + + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + # Verify that '[H-]' was added to the cache + self.assertIn("[H-]", self.reader.cache) + # Ensure it's at the correct index + self.assertEqual(self.reader.cache.index("[H-]"), index_for_last_token) + + +if __name__ == "__main__": + unittest.main() From aa467c6fde67a9545b23c79132c128d0a837b69e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 30 Aug 2024 00:06:39 +0200 Subject: [PATCH 05/46] Test for SelfiesReader --- tests/unit/readers/testDeepChemDataReader.py | 3 + tests/unit/readers/testSelfiesReader.py | 106 +++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 tests/unit/readers/testSelfiesReader.py diff --git a/tests/unit/readers/testDeepChemDataReader.py b/tests/unit/readers/testDeepChemDataReader.py index c93e2592..ac1a50b7 100644 --- a/tests/unit/readers/testDeepChemDataReader.py +++ b/tests/unit/readers/testDeepChemDataReader.py @@ -34,6 +34,9 @@ def test_read_data(self) -> None: """ raw_data = "c1ccccc1C(Br)(OC)I[Ni-2]" + # benzene is c1ccccc1 in SMILES but cccccc6 in DeepSMILES + # SMILES C(Br)(OC)I can be converted to the DeepSMILES CBr)OC))I. + # Resultant String: "cccccc6CBr)OC))I[Ni-2]" # Expected output as per the tokens already in the cache, and new tokens getting added to it. expected_output: List[int] = [ EMBEDDING_OFFSET + 2, # c diff --git a/tests/unit/readers/testSelfiesReader.py b/tests/unit/readers/testSelfiesReader.py new file mode 100644 index 00000000..41202757 --- /dev/null +++ b/tests/unit/readers/testSelfiesReader.py @@ -0,0 +1,106 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, SelfiesReader + + +class TestSelfiesReader(unittest.TestCase): + """ + Unit tests for the SelfiesReader class. + + Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. + """ + + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="[C]\n[O]\n[=C]", + ) + def setUp(self, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a SelfiesReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + self.reader = SelfiesReader(token_path="/mock/path") + # After initializing, self.reader.cache should now be set to ['[C]', '[O]', '[N]', '[=]', '[1]', '[('] + self.assertEqual( + self.reader.cache, + [ + "[C]", + "[O]", + "[=C]", + ], + ) + + def test_read_data(self) -> None: + """ + Test the _read_data method with a SELFIES string to ensure it correctly tokenizes the string. + """ + raw_data = "c1ccccc1C(Br)(OC)I[Ni-2]" + + # benzene is "c1ccccc1" in SMILES is translated to "[C][=C][C][=C][C][=C][Ring1][=Branch1]" in SELFIES + # SELFIES translation of SMILES "c1ccccc1C(Br)(OC)I[Ni-2]": + # "[C][=C][C][=C][C][=C][Ring1][=Branch1][C][Branch1][C][Br][Branch1][Ring1][O][C][I][Ni-2]" + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + 2, # [=C] (already in cache) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + 2, # [=C] (already in cache) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + 2, # [=C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache), # [Ring1] (new token) + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [=Branch1] (new token) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache) + 2, # [Branch1] (new token) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache) + 3, # [Br] (new token) + EMBEDDING_OFFSET + + len(self.reader.cache) + + 2, # [Branch1] (reused new token) + EMBEDDING_OFFSET + len(self.reader.cache), # [Ring1] (reused new token) + EMBEDDING_OFFSET + 1, # [O] (already in cache) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache) + 4, # [I] (new token) + EMBEDDING_OFFSET + len(self.reader.cache) + 5, # [Ni-2] (new token) + ] + + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a SELFIES string that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "[H-]" + + # Determine the index for the new token based on the current size of the cache. + index_for_last_token = len(self.reader.cache) + expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] + + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + # Verify that '[H-1]' was added to the cache, "[H-]" translated to "[H-1]" in SELFIES + self.assertIn("[H-1]", self.reader.cache) + # Ensure it's at the correct index + self.assertEqual(self.reader.cache.index("[H-1]"), index_for_last_token) + + def test_read_data_with_invalid_selfies(self) -> None: + """ + Test the _read_data method with an invalid SELFIES string to ensure error handling works. + """ + raw_data = "[C][O][INVALID][N]" + + result = self.reader._read_data(raw_data) + self.assertIsNone(result) + + # Verify that the error count was incremented + self.assertEqual(self.reader.error_count, 1) + + +if __name__ == "__main__": + unittest.main() From b6f5e5162d22359a67fa212c288b13715fd51356 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 30 Aug 2024 23:14:49 +0200 Subject: [PATCH 06/46] test for ProteinDataReader --- tests/unit/readers/testProteinDataReader.py | 105 ++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 tests/unit/readers/testProteinDataReader.py diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py new file mode 100644 index 00000000..5f828e75 --- /dev/null +++ b/tests/unit/readers/testProteinDataReader.py @@ -0,0 +1,105 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, ProteinDataReader + + +class TestProteinDataReader(unittest.TestCase): + """ + Unit tests for the ProteinDataReader class. + """ + + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="M\nK\nT\nF\nR\nN", + ) + def setUp(self, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a ProteinDataReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + self.reader = ProteinDataReader(token_path="/mock/path") + # After initializing, self.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N'] + self.assertEqual(self.reader.cache, ["M", "K", "T", "F", "R", "N"]) + + def test_read_data(self) -> None: + """ + Test the _read_data method with a protein sequence to ensure it correctly tokenizes the sequence. + """ + raw_data = "MKTFFRN" + + # Expected output based on the cached tokens + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # M + EMBEDDING_OFFSET + 1, # K + EMBEDDING_OFFSET + 2, # T + EMBEDDING_OFFSET + 3, # F + EMBEDDING_OFFSET + 3, # F (repeated token) + EMBEDDING_OFFSET + 4, # R + EMBEDDING_OFFSET + 5, # N + ] + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a protein sequence that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "MKTFY" + + # 'Y' is not in the initial cache and should be added. + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # M + EMBEDDING_OFFSET + 1, # K + EMBEDDING_OFFSET + 2, # T + EMBEDDING_OFFSET + 3, # F + EMBEDDING_OFFSET + len(self.reader.cache), # Y (new token) + ] + + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + # Verify that 'Y' was added to the cache + self.assertIn("Y", self.reader.cache) + # Ensure it's at the correct index + self.assertEqual(self.reader.cache.index("Y"), len(self.reader.cache) - 1) + + def test_read_data_with_invalid_token(self) -> None: + """ + Test the _read_data method with an invalid amino acid token to ensure it raises a KeyError. + """ + raw_data = "MKTFZ" # 'Z' is not a valid amino acid token + + with self.assertRaises(KeyError) as context: + self.reader._read_data(raw_data) + + self.assertIn("Invalid token 'Z' encountered", str(context.exception)) + + def test_read_data_with_empty_sequence(self) -> None: + """ + Test the _read_data method with an empty protein sequence to ensure it returns an empty list. + """ + raw_data = "" + + result = self.reader._read_data(raw_data) + self.assertEqual(result, []) + + def test_read_data_with_repeated_tokens(self) -> None: + """ + Test the _read_data method with repeated amino acid tokens to ensure it handles them correctly. + """ + raw_data = "MMMMM" + + expected_output: List[int] = [EMBEDDING_OFFSET + 0] * 5 # All tokens are 'M' + + result = self.reader._read_data(raw_data) + self.assertEqual(result, expected_output) + + +if __name__ == "__main__": + unittest.main() From 73f05c01f81c90107eccb61c638529755b05df15 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 31 Aug 2024 00:03:21 +0200 Subject: [PATCH 07/46] test for DefaultCollator --- tests/unit/collators/testDefaultCollator.py | 52 +++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/unit/collators/testDefaultCollator.py diff --git a/tests/unit/collators/testDefaultCollator.py b/tests/unit/collators/testDefaultCollator.py new file mode 100644 index 00000000..6362d0a6 --- /dev/null +++ b/tests/unit/collators/testDefaultCollator.py @@ -0,0 +1,52 @@ +import unittest +from typing import Dict, List + +from chebai.preprocessing.collate import DefaultCollator +from chebai.preprocessing.structures import XYData + + +class TestDefaultCollator(unittest.TestCase): + """ + Unit tests for the DefaultCollator class. + """ + + def setUp(self) -> None: + """ + Set up the test environment by initializing a DefaultCollator instance. + """ + self.collator = DefaultCollator() + + def test_call_with_valid_data(self) -> None: + """ + Test the __call__ method with valid data to ensure features and labels are correctly extracted. + """ + data: List[Dict] = [ + {"features": [1.0, 2.0], "labels": 0}, + {"features": [3.0, 4.0], "labels": 1}, + ] + + result: XYData = self.collator(data) + self.assertIsInstance(result, XYData) + + expected_x = ([1.0, 2.0], [3.0, 4.0]) + expected_y = (0, 1) + + self.assertEqual(result.x, expected_x) + self.assertEqual(result.y, expected_y) + + def test_call_with_empty_data(self) -> None: + """ + Test the __call__ method with an empty list to ensure it handles the edge case correctly. + """ + data: List[Dict] = [] + + with self.assertRaises(ValueError) as context: + self.collator(data) + + self.assertEqual( + str(context.exception), "not enough values to unpack (expected 2, got 0)" + ) + + +if __name__ == "__main__": + unittest.main() From 8007f37f7622168fa3db1837e5b7fafcb8307a5e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 31 Aug 2024 22:05:16 +0200 Subject: [PATCH 08/46] test for RaggedColllator --- tests/unit/collators/testRaggedCollator.py | 150 +++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 tests/unit/collators/testRaggedCollator.py diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py new file mode 100644 index 00000000..97e1c08f --- /dev/null +++ b/tests/unit/collators/testRaggedCollator.py @@ -0,0 +1,150 @@ +import unittest +from typing import Dict, List, Tuple + +import torch + +from chebai.preprocessing.collate import RaggedCollator +from chebai.preprocessing.structures import XYData + + +class TestRaggedCollator(unittest.TestCase): + """ + Unit tests for the RaggedCollator class. + """ + + def setUp(self) -> None: + """ + Set up the test environment by initializing a RaggedCollator instance. + """ + self.collator = RaggedCollator() + + def test_call_with_valid_data(self) -> None: + """ + Test the __call__ method with valid ragged data to ensure features, labels, and masks are correctly handled. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [1, 0], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [0, 1, 1], "ident": "sample2"}, + {"features": [6], "labels": [1], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor([[1, 0, 0], [0, 1, 1], [1, 0, 0]]) + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue(torch.equal(result.x, expected_x)) + self.assertTrue(torch.equal(result.y, expected_y)) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ) + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ) + ) + self.assertEqual( + result.additional_fields["idents"], ("sample1", "sample2", "sample3") + ) + + def test_call_with_missing_entire_labels(self) -> None: + """ + Test the __call__ method with data where some samples are missing labels. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [1, 0], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": None, "ident": "sample2"}, + {"features": [6], "labels": [1], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + expected_x = torch.tensor([[1, 2], [6, 0]]) + expected_y = torch.tensor([[1, 0], [1, 0]]) + expected_mask_for_x = torch.tensor([[True, True], [True, False]]) + expected_lens_for_x = torch.tensor([2, 1]) + + self.assertTrue(torch.equal(result.x, expected_x)) + self.assertTrue(torch.equal(result.y, expected_y)) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ) + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ) + ) + self.assertEqual( + result.additional_fields["loss_kwargs"]["non_null_labels"], [0, 2] + ) + self.assertEqual( + result.additional_fields["idents"], ("sample1", "sample2", "sample3") + ) + + def test_call_with_none_in_labels(self) -> None: + """ + Test the __call__ method with data where one of the elements in the labels is None. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [None, 1], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [1, 0], "ident": "sample2"}, + {"features": [6], "labels": [1], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor([[0, 1], [1, 0], [1, 0]]) # None is replaced by 0 + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue(torch.equal(result.x, expected_x)) + self.assertTrue(torch.equal(result.y, expected_y)) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ) + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ) + ) + self.assertEqual( + result.additional_fields["idents"], ("sample1", "sample2", "sample3") + ) + + def test_call_with_empty_data(self) -> None: + """ + Test the __call__ method with an empty list to ensure it raises an error. + """ + data: List[Dict] = [] + + with self.assertRaises(Exception): + self.collator(data) + + def test_process_label_rows(self) -> None: + """ + Test the process_label_rows method to ensure it pads label sequences correctly. + """ + labels: Tuple = ([1, 0], [0, 1, 1], [1]) + + result: torch.Tensor = self.collator.process_label_rows(labels) + + expected_output = torch.tensor([[1, 0, 0], [0, 1, 1], [1, 0, 0]]) + + self.assertTrue(torch.equal(result, expected_output)) + + +if __name__ == "__main__": + unittest.main() From 248eaa7034ac2aa204d578c2c249096ee07dbd83 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 31 Aug 2024 23:55:52 +0200 Subject: [PATCH 09/46] modify tests to use `setUpClass` class method instead of `setUp` instance method --- tests/unit/collators/testDefaultCollator.py | 5 +++-- tests/unit/collators/testRaggedCollator.py | 5 +++-- tests/unit/readers/testChemDataReader.py | 9 +++++---- tests/unit/readers/testDataReader.py | 5 +++-- tests/unit/readers/testDeepChemDataReader.py | 9 +++++---- tests/unit/readers/testProteinDataReader.py | 9 +++++---- tests/unit/readers/testSelfiesReader.py | 16 +++++----------- 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/unit/collators/testDefaultCollator.py b/tests/unit/collators/testDefaultCollator.py index 6362d0a6..287cadcd 100644 --- a/tests/unit/collators/testDefaultCollator.py +++ b/tests/unit/collators/testDefaultCollator.py @@ -10,11 +10,12 @@ class TestDefaultCollator(unittest.TestCase): Unit tests for the DefaultCollator class. """ - def setUp(self) -> None: + @classmethod + def setUpClass(cls) -> None: """ Set up the test environment by initializing a DefaultCollator instance. """ - self.collator = DefaultCollator() + cls.collator = DefaultCollator() def test_call_with_valid_data(self) -> None: """ diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py index 97e1c08f..a3126314 100644 --- a/tests/unit/collators/testRaggedCollator.py +++ b/tests/unit/collators/testRaggedCollator.py @@ -12,11 +12,12 @@ class TestRaggedCollator(unittest.TestCase): Unit tests for the RaggedCollator class. """ - def setUp(self) -> None: + @classmethod + def setUpClass(cls) -> None: """ Set up the test environment by initializing a RaggedCollator instance. """ - self.collator = RaggedCollator() + cls.collator = RaggedCollator() def test_call_with_valid_data(self) -> None: """ diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index 2bc525e1..3d7b5e6f 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -12,21 +12,22 @@ class TestChemDataReader(unittest.TestCase): Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. """ + @classmethod @patch( "chebai.preprocessing.reader.open", new_callable=mock_open, read_data="C\nO\nN\n=\n1\n(", ) - def setUp(self, mock_file: mock_open) -> None: + def setUpClass(cls, mock_file: mock_open) -> None: """ Set up the test environment by initializing a ChemDataReader instance with a mocked token file. Args: mock_file: Mock object for file operations. """ - self.reader = ChemDataReader(token_path="/mock/path") - # After initializing, self.reader.cache should now be set to ['C', 'O', 'N', '=', '1', '('] - self.assertEqual(self.reader.cache, ["C", "O", "N", "=", "1", "("]) + cls.reader = ChemDataReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['C', 'O', 'N', '=', '1', '('] + assert cls.reader.cache == ["C", "O", "N", "=", "1", "("] def test_read_data(self) -> None: """ diff --git a/tests/unit/readers/testDataReader.py b/tests/unit/readers/testDataReader.py index 1a511b26..8a8af053 100644 --- a/tests/unit/readers/testDataReader.py +++ b/tests/unit/readers/testDataReader.py @@ -9,11 +9,12 @@ class TestDataReader(unittest.TestCase): Unit tests for the DataReader class. """ - def setUp(self) -> None: + @classmethod + def setUpClass(cls) -> None: """ Set up the test environment by initializing a DataReader instance. """ - self.reader = DataReader() + cls.reader = DataReader() def test_to_data(self) -> None: """ diff --git a/tests/unit/readers/testDeepChemDataReader.py b/tests/unit/readers/testDeepChemDataReader.py index ac1a50b7..23ac35d5 100644 --- a/tests/unit/readers/testDeepChemDataReader.py +++ b/tests/unit/readers/testDeepChemDataReader.py @@ -12,21 +12,22 @@ class TestDeepChemDataReader(unittest.TestCase): Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. """ + @classmethod @patch( "chebai.preprocessing.reader.open", new_callable=mock_open, read_data="C\nO\nc\n)", ) - def setUp(self, mock_file: mock_open) -> None: + def setUpClass(cls, mock_file: mock_open) -> None: """ Set up the test environment by initializing a DeepChemDataReader instance with a mocked token file. Args: mock_file: Mock object for file operations. """ - self.reader = DeepChemDataReader(token_path="/mock/path") - # After initializing, self.reader.cache should now be set to ['C', 'O', 'c', ')'] - self.assertEqual(self.reader.cache, ["C", "O", "c", ")"]) + cls.reader = DeepChemDataReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['C', 'O', 'c', ')'] + assert cls.reader.cache == ["C", "O", "c", ")"] def test_read_data(self) -> None: """ diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py index 5f828e75..6e5f325c 100644 --- a/tests/unit/readers/testProteinDataReader.py +++ b/tests/unit/readers/testProteinDataReader.py @@ -10,21 +10,22 @@ class TestProteinDataReader(unittest.TestCase): Unit tests for the ProteinDataReader class. """ + @classmethod @patch( "chebai.preprocessing.reader.open", new_callable=mock_open, read_data="M\nK\nT\nF\nR\nN", ) - def setUp(self, mock_file: mock_open) -> None: + def setUpClass(cls, mock_file: mock_open) -> None: """ Set up the test environment by initializing a ProteinDataReader instance with a mocked token file. Args: mock_file: Mock object for file operations. """ - self.reader = ProteinDataReader(token_path="/mock/path") - # After initializing, self.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N'] - self.assertEqual(self.reader.cache, ["M", "K", "T", "F", "R", "N"]) + cls.reader = ProteinDataReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N'] + assert cls.reader.cache == ["M", "K", "T", "F", "R", "N"] def test_read_data(self) -> None: """ diff --git a/tests/unit/readers/testSelfiesReader.py b/tests/unit/readers/testSelfiesReader.py index 41202757..019a0f59 100644 --- a/tests/unit/readers/testSelfiesReader.py +++ b/tests/unit/readers/testSelfiesReader.py @@ -12,28 +12,22 @@ class TestSelfiesReader(unittest.TestCase): Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. """ + @classmethod @patch( "chebai.preprocessing.reader.open", new_callable=mock_open, read_data="[C]\n[O]\n[=C]", ) - def setUp(self, mock_file: mock_open) -> None: + def setUpClass(cls, mock_file: mock_open) -> None: """ Set up the test environment by initializing a SelfiesReader instance with a mocked token file. Args: mock_file: Mock object for file operations. """ - self.reader = SelfiesReader(token_path="/mock/path") - # After initializing, self.reader.cache should now be set to ['[C]', '[O]', '[N]', '[=]', '[1]', '[('] - self.assertEqual( - self.reader.cache, - [ - "[C]", - "[O]", - "[=C]", - ], - ) + cls.reader = SelfiesReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['[C]', '[O]', '[N]', '[=]', '[1]', '[('] + assert cls.reader.cache == ["[C]", "[O]", "[=C]"] def test_read_data(self) -> None: """ From 3e57d78420ec8b1076b5e5842c535b03b212da8a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 1 Sep 2024 13:33:18 +0200 Subject: [PATCH 10/46] bool labels instead of numeric, for realistic data --- tests/unit/collators/testDefaultCollator.py | 6 ++-- tests/unit/collators/testRaggedCollator.py | 34 +++++++++++++-------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/tests/unit/collators/testDefaultCollator.py b/tests/unit/collators/testDefaultCollator.py index 287cadcd..29b1cc91 100644 --- a/tests/unit/collators/testDefaultCollator.py +++ b/tests/unit/collators/testDefaultCollator.py @@ -22,15 +22,15 @@ def test_call_with_valid_data(self) -> None: Test the __call__ method with valid data to ensure features and labels are correctly extracted. """ data: List[Dict] = [ - {"features": [1.0, 2.0], "labels": 0}, - {"features": [3.0, 4.0], "labels": 1}, + {"features": [1.0, 2.0], "labels": [True, False, True]}, + {"features": [3.0, 4.0], "labels": [False, False, True]}, ] result: XYData = self.collator(data) self.assertIsInstance(result, XYData) expected_x = ([1.0, 2.0], [3.0, 4.0]) - expected_y = (0, 1) + expected_y = ([True, False, True], [False, False, True]) self.assertEqual(result.x, expected_x) self.assertEqual(result.y, expected_y) diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py index a3126314..81947b47 100644 --- a/tests/unit/collators/testRaggedCollator.py +++ b/tests/unit/collators/testRaggedCollator.py @@ -24,15 +24,17 @@ def test_call_with_valid_data(self) -> None: Test the __call__ method with valid ragged data to ensure features, labels, and masks are correctly handled. """ data: List[Dict] = [ - {"features": [1, 2], "labels": [1, 0], "ident": "sample1"}, - {"features": [3, 4, 5], "labels": [0, 1, 1], "ident": "sample2"}, - {"features": [6], "labels": [1], "ident": "sample3"}, + {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [False, True, True], "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, ] result: XYData = self.collator(data) expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) - expected_y = torch.tensor([[1, 0, 0], [0, 1, 1], [1, 0, 0]]) + expected_y = torch.tensor( + [[True, False, False], [False, True, True], [True, False, False]] + ) expected_mask_for_x = torch.tensor( [[True, True, False], [True, True, True], [True, False, False]] ) @@ -59,15 +61,17 @@ def test_call_with_missing_entire_labels(self) -> None: Test the __call__ method with data where some samples are missing labels. """ data: List[Dict] = [ - {"features": [1, 2], "labels": [1, 0], "ident": "sample1"}, + {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, {"features": [3, 4, 5], "labels": None, "ident": "sample2"}, - {"features": [6], "labels": [1], "ident": "sample3"}, + {"features": [6], "labels": [True], "ident": "sample3"}, ] result: XYData = self.collator(data) expected_x = torch.tensor([[1, 2], [6, 0]]) - expected_y = torch.tensor([[1, 0], [1, 0]]) + expected_y = torch.tensor( + [[True, False], [True, False]] + ) # True -> 1, False -> 0 expected_mask_for_x = torch.tensor([[True, True], [True, False]]) expected_lens_for_x = torch.tensor([2, 1]) @@ -95,15 +99,17 @@ def test_call_with_none_in_labels(self) -> None: Test the __call__ method with data where one of the elements in the labels is None. """ data: List[Dict] = [ - {"features": [1, 2], "labels": [None, 1], "ident": "sample1"}, - {"features": [3, 4, 5], "labels": [1, 0], "ident": "sample2"}, - {"features": [6], "labels": [1], "ident": "sample3"}, + {"features": [1, 2], "labels": [None, True], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [True, False], "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, ] result: XYData = self.collator(data) expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) - expected_y = torch.tensor([[0, 1], [1, 0], [1, 0]]) # None is replaced by 0 + expected_y = torch.tensor( + [[False, True], [True, False], [True, False]] + ) # None -> False expected_mask_for_x = torch.tensor( [[True, True, False], [True, True, True], [True, False, False]] ) @@ -138,11 +144,13 @@ def test_process_label_rows(self) -> None: """ Test the process_label_rows method to ensure it pads label sequences correctly. """ - labels: Tuple = ([1, 0], [0, 1, 1], [1]) + labels: Tuple = ([True, False], [False, True, True], [True]) result: torch.Tensor = self.collator.process_label_rows(labels) - expected_output = torch.tensor([[1, 0, 0], [0, 1, 1], [1, 0, 0]]) + expected_output = torch.tensor( + [[True, False, False], [False, True, True], [True, False, False]] + ) self.assertTrue(torch.equal(result, expected_output)) From f9ca653d76b9a8434b1a1a487ee57b796156b40a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 1 Sep 2024 13:33:51 +0200 Subject: [PATCH 11/46] test for XYBaseDataModule --- .../dataset_classes/testXYBaseDataModule.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/unit/dataset_classes/testXYBaseDataModule.py diff --git a/tests/unit/dataset_classes/testXYBaseDataModule.py b/tests/unit/dataset_classes/testXYBaseDataModule.py new file mode 100644 index 00000000..d8aabc67 --- /dev/null +++ b/tests/unit/dataset_classes/testXYBaseDataModule.py @@ -0,0 +1,76 @@ +import unittest +from unittest.mock import PropertyMock, patch + +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.reader import ProteinDataReader + + +class TestXYBaseDataModule(unittest.TestCase): + """ + Unit tests for the methods of the XYBaseDataModule class. + """ + + @classmethod + @patch.object(XYBaseDataModule, "_name", new_callable=PropertyMock) + def setUpClass(cls, mock_name_property) -> None: + """ + Set up a base instance of XYBaseDataModule for testing. + """ + + # Mock the _name property of XYBaseDataModule + mock_name_property.return_value = "MockedXYBaseDataModule" + + # Assign a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) + XYBaseDataModule.READER = ProteinDataReader + + # Initialize the module with a label_filter + cls.module = XYBaseDataModule( + label_filter=1, # Provide a label_filter + balance_after_filter=1.0, # Balance ratio + ) + + def test_filter_labels_valid_index(self) -> None: + """ + Test the _filter_labels method with a valid label_filter index. + """ + self.module.label_filter = 1 + row = { + "features": ["feature1", "feature2"], + "labels": [0, 3, 1, 2], # List of labels + } + filtered_row = self.module._filter_labels(row) + expected_labels = [3] # Only the label at index 1 should be kept + + self.assertEqual(filtered_row["labels"], expected_labels) + + row = { + "features": ["feature1", "feature2"], + "labels": [True, False, True, True], + } + self.assertEqual(self.module._filter_labels(row)["labels"], [False]) + + def test_filter_labels_no_filter(self) -> None: + """ + Test the _filter_labels method with no label_filter index. + """ + # Update the module to have no label filter + self.module.label_filter = None + row = {"features": ["feature1", "feature2"], "labels": [False, True]} + # Handle the case where the index is out of bounds + with self.assertRaises(TypeError): + self.module._filter_labels(row) + + def test_filter_labels_invalid_index(self) -> None: + """ + Test the _filter_labels method with an invalid label_filter index. + """ + # Set an invalid label filter index (e.g., greater than the number of labels) + self.module.label_filter = 10 + row = {"features": ["feature1", "feature2"], "labels": [False, True]} + # Handle the case where the index is out of bounds + with self.assertRaises(IndexError): + self.module._filter_labels(row) + + +if __name__ == "__main__": + unittest.main() From d8016aa6459548f8981c43473706a80c9748fca9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 2 Sep 2024 00:25:58 +0200 Subject: [PATCH 12/46] test for DynamicDataset --- .../dataset_classes/testDynamicDataset.py | 231 ++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tests/unit/dataset_classes/testDynamicDataset.py diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py new file mode 100644 index 00000000..ae69952a --- /dev/null +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -0,0 +1,231 @@ +import unittest +from typing import Tuple +from unittest.mock import PropertyMock, patch + +import pandas as pd + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.reader import ProteinDataReader + + +class TestDynamicDataset(unittest.TestCase): + """ + Test case for _DynamicDataset functionality, ensuring correct data splits and integrity + of train, validation, and test datasets. + """ + + @classmethod + @patch.multiple(_DynamicDataset, __abstractmethods__=frozenset()) + @patch.object(_DynamicDataset, "base_dir", new_callable=PropertyMock) + @patch.object(_DynamicDataset, "_name", new_callable=PropertyMock) + def setUpClass( + cls, mock_base_dir_property: PropertyMock, mock_name_property: PropertyMock + ) -> None: + """ + Set up a base instance of _DynamicDataset for testing with mocked properties. + """ + + # Mocking properties + mock_base_dir_property.return_value = "MockedBaseDirProperty" + mock_name_property.return_value = "MockedNameProperty" + + # Assigning a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) + _DynamicDataset.READER = ProteinDataReader + + # Creating an instance of the dataset + cls.dataset: _DynamicDataset = _DynamicDataset() + + # Dataset with a balanced distribution of labels + X = [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + [13, 14], + [15, 16], + [17, 18], + [19, 20], + [21, 22], + [23, 24], + [25, 26], + [27, 28], + [29, 30], + [31, 32], + ] + y = [ + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + ] + cls.df = pd.DataFrame( + {"ident": [f"id{i + 1}" for i in range(len(X))], "features": X, "labels": y} + ) + + def test_get_test_split_valid(self) -> None: + """ + Test splitting the dataset into train and test sets and verify balance and non-overlap. + """ + self.dataset.train_split = 0.5 + # Test size will be 0.25 * 16 = 4 + train_df, test_df = self.dataset.get_test_split(self.df, seed=0) + + # Assert the correct number of rows in train and test sets + self.assertEqual(len(train_df), 12, "Train set should contain 12 samples.") + self.assertEqual(len(test_df), 4, "Test set should contain 4 samples.") + + # Check positive and negative label counts in train and test sets + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( + test_df + ) + + # Ensure that the train and test sets have balanced positives and negatives + self.assertEqual( + train_pos_count, train_neg_count, "Train set labels should be balanced." + ) + self.assertEqual( + test_pos_count, test_neg_count, "Test set labels should be balanced." + ) + + # Assert there is no overlap between train and test sets + train_idents = set(train_df["ident"]) + test_idents = set(test_df["ident"]) + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + def test_get_test_split_missing_labels(self) -> None: + """ + Test the behavior when the 'labels' column is missing in the dataset. + """ + df_missing_labels = pd.DataFrame({"ident": ["id1", "id2"]}) + with self.assertRaises( + KeyError, msg="Expected KeyError when 'labels' column is missing." + ): + self.dataset.get_test_split(df_missing_labels) + + def test_get_test_split_seed_consistency(self) -> None: + """ + Test that splitting the dataset with the same seed produces consistent results. + """ + train_df1, test_df1 = self.dataset.get_test_split(self.df, seed=42) + train_df2, test_df2 = self.dataset.get_test_split(self.df, seed=42) + + pd.testing.assert_frame_equal( + train_df1, + train_df2, + obj="Train sets should be identical for the same seed.", + ) + pd.testing.assert_frame_equal( + test_df1, test_df2, obj="Test sets should be identical for the same seed." + ) + + def test_get_train_val_splits_given_test(self) -> None: + """ + Test splitting the dataset into train and validation sets and verify balance and non-overlap. + """ + self.dataset.use_inner_cross_validation = False + self.dataset.train_split = 0.5 + df_train_main, test_df = self.dataset.get_test_split(self.df, seed=0) + train_df, val_df = self.dataset.get_train_val_splits_given_test( + df_train_main, test_df, seed=42 + ) + + # Ensure there is no overlap between train and test sets + train_idents = set(train_df["ident"]) + test_idents = set(test_df["ident"]) + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + # Ensure there is no overlap between validation and test sets + val_idents = set(val_df["ident"]) + self.assertEqual( + len(val_idents.intersection(test_idents)), + 0, + "Validation and test sets should not overlap.", + ) + + # Ensure there is no overlap between train and validation sets + self.assertEqual( + len(train_idents.intersection(val_idents)), + 0, + "Train and validation sets should not overlap.", + ) + + # Check positive and negative label counts in train and validation sets + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) + + # Ensure that the train and validation sets have balanced positives and negatives + self.assertEqual( + train_pos_count, train_neg_count, "Train set labels should be balanced." + ) + self.assertEqual( + val_pos_count, val_neg_count, "Validation set labels should be balanced." + ) + + def test_get_train_val_splits_given_test_consistency(self) -> None: + """ + Test that splitting the dataset into train and validation sets with the same seed produces consistent results. + """ + test_df = self.df.iloc[12:] # Assume rows 12 onward are for testing + train_df1, val_df1 = self.dataset.get_train_val_splits_given_test( + self.df, test_df, seed=42 + ) + train_df2, val_df2 = self.dataset.get_train_val_splits_given_test( + self.df, test_df, seed=42 + ) + + pd.testing.assert_frame_equal( + train_df1, + train_df2, + obj="Train sets should be identical for the same seed.", + ) + pd.testing.assert_frame_equal( + val_df1, + val_df2, + obj="Validation sets should be identical for the same seed.", + ) + + @staticmethod + def get_positive_negative_labels_counts(df: pd.DataFrame) -> Tuple[int, int]: + """ + Count the number of True and False values within the labels column. + + Args: + df (pd.DataFrame): The DataFrame containing the 'labels' column. + + Returns: + Tuple[int, int]: A tuple containing the counts of True and False values, respectively. + """ + true_count = sum(sum(label) for label in df["labels"]) + false_count = sum(len(label) - sum(label) for label in df["labels"]) + return true_count, false_count + + +if __name__ == "__main__": + unittest.main() From 0c7c5b8fab7612bbcfc8c7feba8d07d7b797a3d9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 2 Sep 2024 00:53:59 +0200 Subject: [PATCH 13/46] add relevant msg to each assert statement --- tests/unit/collators/testDefaultCollator.py | 20 ++++- tests/unit/collators/testRaggedCollator.py | 73 ++++++++++++++----- .../dataset_classes/testDynamicDataset.py | 4 +- .../dataset_classes/testXYBaseDataModule.py | 25 +++++-- tests/unit/readers/testChemDataReader.py | 33 +++++++-- tests/unit/readers/testDataReader.py | 6 +- tests/unit/readers/testDeepChemDataReader.py | 31 ++++++-- tests/unit/readers/testProteinDataReader.py | 49 +++++++++++-- tests/unit/readers/testSelfiesReader.py | 43 +++++++++-- 9 files changed, 227 insertions(+), 57 deletions(-) diff --git a/tests/unit/collators/testDefaultCollator.py b/tests/unit/collators/testDefaultCollator.py index 29b1cc91..73f09c75 100644 --- a/tests/unit/collators/testDefaultCollator.py +++ b/tests/unit/collators/testDefaultCollator.py @@ -27,13 +27,23 @@ def test_call_with_valid_data(self) -> None: ] result: XYData = self.collator(data) - self.assertIsInstance(result, XYData) + self.assertIsInstance( + result, XYData, "The result should be an instance of XYData." + ) expected_x = ([1.0, 2.0], [3.0, 4.0]) expected_y = ([True, False, True], [False, False, True]) - self.assertEqual(result.x, expected_x) - self.assertEqual(result.y, expected_y) + self.assertEqual( + result.x, + expected_x, + "The feature data 'x' does not match the expected output.", + ) + self.assertEqual( + result.y, + expected_y, + "The label data 'y' does not match the expected output.", + ) def test_call_with_empty_data(self) -> None: """ @@ -45,7 +55,9 @@ def test_call_with_empty_data(self) -> None: self.collator(data) self.assertEqual( - str(context.exception), "not enough values to unpack (expected 2, got 0)" + str(context.exception), + "not enough values to unpack (expected 2, got 0)", + "The exception message for empty data is not as expected.", ) diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py index 81947b47..d31776a6 100644 --- a/tests/unit/collators/testRaggedCollator.py +++ b/tests/unit/collators/testRaggedCollator.py @@ -40,20 +40,30 @@ def test_call_with_valid_data(self) -> None: ) expected_lens_for_x = torch.tensor([2, 3, 1]) - self.assertTrue(torch.equal(result.x, expected_x)) - self.assertTrue(torch.equal(result.y, expected_y)) + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output.", + ) self.assertTrue( torch.equal( result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x - ) + ), + "The mask tensor does not match the expected output.", ) self.assertTrue( torch.equal( result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x - ) + ), + "The lens tensor does not match the expected output.", ) self.assertEqual( - result.additional_fields["idents"], ("sample1", "sample2", "sample3") + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output.", ) def test_call_with_missing_entire_labels(self) -> None: @@ -75,23 +85,35 @@ def test_call_with_missing_entire_labels(self) -> None: expected_mask_for_x = torch.tensor([[True, True], [True, False]]) expected_lens_for_x = torch.tensor([2, 1]) - self.assertTrue(torch.equal(result.x, expected_x)) - self.assertTrue(torch.equal(result.y, expected_y)) + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output when labels are missing.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output when labels are missing.", + ) self.assertTrue( torch.equal( result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x - ) + ), + "The mask tensor does not match the expected output when labels are missing.", ) self.assertTrue( torch.equal( result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x - ) + ), + "The lens tensor does not match the expected output when labels are missing.", ) self.assertEqual( - result.additional_fields["loss_kwargs"]["non_null_labels"], [0, 2] + result.additional_fields["loss_kwargs"]["non_null_labels"], + [0, 2], + "The non-null labels list does not match the expected output.", ) self.assertEqual( - result.additional_fields["idents"], ("sample1", "sample2", "sample3") + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output when labels are missing.", ) def test_call_with_none_in_labels(self) -> None: @@ -115,20 +137,30 @@ def test_call_with_none_in_labels(self) -> None: ) expected_lens_for_x = torch.tensor([2, 3, 1]) - self.assertTrue(torch.equal(result.x, expected_x)) - self.assertTrue(torch.equal(result.y, expected_y)) + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output when labels contain None.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output when labels contain None.", + ) self.assertTrue( torch.equal( result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x - ) + ), + "The mask tensor does not match the expected output when labels contain None.", ) self.assertTrue( torch.equal( result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x - ) + ), + "The lens tensor does not match the expected output when labels contain None.", ) self.assertEqual( - result.additional_fields["idents"], ("sample1", "sample2", "sample3") + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output when labels contain None.", ) def test_call_with_empty_data(self) -> None: @@ -137,7 +169,9 @@ def test_call_with_empty_data(self) -> None: """ data: List[Dict] = [] - with self.assertRaises(Exception): + with self.assertRaises( + Exception, msg="Expected an Error when no data is provided" + ): self.collator(data) def test_process_label_rows(self) -> None: @@ -152,7 +186,10 @@ def test_process_label_rows(self) -> None: [[True, False, False], [False, True, True], [True, False, False]] ) - self.assertTrue(torch.equal(result, expected_output)) + self.assertTrue( + torch.equal(result, expected_output), + "The processed label rows tensor does not match the expected output.", + ) if __name__ == "__main__": diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py index ae69952a..50b9287a 100644 --- a/tests/unit/dataset_classes/testDynamicDataset.py +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -26,8 +26,8 @@ def setUpClass( """ # Mocking properties - mock_base_dir_property.return_value = "MockedBaseDirProperty" - mock_name_property.return_value = "MockedNameProperty" + mock_base_dir_property.return_value = "MockedBaseDirPropertyDynamicDataset" + mock_name_property.return_value = "MockedNamePropertyDynamicDataset" # Assigning a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) _DynamicDataset.READER = ProteinDataReader diff --git a/tests/unit/dataset_classes/testXYBaseDataModule.py b/tests/unit/dataset_classes/testXYBaseDataModule.py index d8aabc67..4c2d21dc 100644 --- a/tests/unit/dataset_classes/testXYBaseDataModule.py +++ b/tests/unit/dataset_classes/testXYBaseDataModule.py @@ -12,13 +12,13 @@ class TestXYBaseDataModule(unittest.TestCase): @classmethod @patch.object(XYBaseDataModule, "_name", new_callable=PropertyMock) - def setUpClass(cls, mock_name_property) -> None: + def setUpClass(cls, mock_name_property: PropertyMock) -> None: """ Set up a base instance of XYBaseDataModule for testing. """ # Mock the _name property of XYBaseDataModule - mock_name_property.return_value = "MockedXYBaseDataModule" + mock_name_property.return_value = "MockedNamePropXYBaseDataModule" # Assign a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) XYBaseDataModule.READER = ProteinDataReader @@ -41,13 +41,21 @@ def test_filter_labels_valid_index(self) -> None: filtered_row = self.module._filter_labels(row) expected_labels = [3] # Only the label at index 1 should be kept - self.assertEqual(filtered_row["labels"], expected_labels) + self.assertEqual( + filtered_row["labels"], + expected_labels, + "The filtered labels do not match the expected labels.", + ) row = { "features": ["feature1", "feature2"], "labels": [True, False, True, True], } - self.assertEqual(self.module._filter_labels(row)["labels"], [False]) + self.assertEqual( + self.module._filter_labels(row)["labels"], + [False], + "The filtered labels for the boolean case do not match the expected labels.", + ) def test_filter_labels_no_filter(self) -> None: """ @@ -57,7 +65,9 @@ def test_filter_labels_no_filter(self) -> None: self.module.label_filter = None row = {"features": ["feature1", "feature2"], "labels": [False, True]} # Handle the case where the index is out of bounds - with self.assertRaises(TypeError): + with self.assertRaises( + TypeError, msg="Expected a TypeError when no label filter is provided." + ): self.module._filter_labels(row) def test_filter_labels_invalid_index(self) -> None: @@ -68,7 +78,10 @@ def test_filter_labels_invalid_index(self) -> None: self.module.label_filter = 10 row = {"features": ["feature1", "feature2"], "labels": [False, True]} # Handle the case where the index is out of bounds - with self.assertRaises(IndexError): + with self.assertRaises( + IndexError, + msg="Expected an IndexError when the label filter index is out of bounds.", + ): self.module._filter_labels(row) diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index 3d7b5e6f..fde8604f 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -27,7 +27,14 @@ def setUpClass(cls, mock_file: mock_open) -> None: """ cls.reader = ChemDataReader(token_path="/mock/path") # After initializing, cls.reader.cache should now be set to ['C', 'O', 'N', '=', '1', '('] - assert cls.reader.cache == ["C", "O", "N", "=", "1", "("] + assert cls.reader.cache == [ + "C", + "O", + "N", + "=", + "1", + "(", + ], "Initial cache does not match expected values." def test_read_data(self) -> None: """ @@ -48,7 +55,11 @@ def test_read_data(self) -> None: EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] ] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The output of _read_data does not match the expected tokenized values.", + ) def test_read_data_with_new_token(self) -> None: """ @@ -62,12 +73,24 @@ def test_read_data_with_new_token(self) -> None: expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The output for new token '[H-]' does not match the expected values.", + ) # Verify that '[H-]' was added to the cache - self.assertIn("[H-]", self.reader.cache) + self.assertIn( + "[H-]", + self.reader.cache, + "The new token '[H-]' was not added to the cache.", + ) # Ensure it's at the correct index - self.assertEqual(self.reader.cache.index("[H-]"), index_for_last_token) + self.assertEqual( + self.reader.cache.index("[H-]"), + index_for_last_token, + "The new token '[H-]' was not added at the correct index in the cache.", + ) if __name__ == "__main__": diff --git a/tests/unit/readers/testDataReader.py b/tests/unit/readers/testDataReader.py index 8a8af053..745c0ace 100644 --- a/tests/unit/readers/testDataReader.py +++ b/tests/unit/readers/testDataReader.py @@ -45,7 +45,11 @@ def test_to_data(self) -> None: "extra_key": "extra_value", } - self.assertEqual(self.reader.to_data(row), expected) + self.assertEqual( + self.reader.to_data(row), + expected, + "The to_data method did not process the input row as expected.", + ) if __name__ == "__main__": diff --git a/tests/unit/readers/testDeepChemDataReader.py b/tests/unit/readers/testDeepChemDataReader.py index 23ac35d5..31a63dd1 100644 --- a/tests/unit/readers/testDeepChemDataReader.py +++ b/tests/unit/readers/testDeepChemDataReader.py @@ -27,7 +27,12 @@ def setUpClass(cls, mock_file: mock_open) -> None: """ cls.reader = DeepChemDataReader(token_path="/mock/path") # After initializing, cls.reader.cache should now be set to ['C', 'O', 'c', ')'] - assert cls.reader.cache == ["C", "O", "c", ")"] + assert cls.reader.cache == [ + "C", + "O", + "c", + ")", + ], "Cache initialization did not match expected tokens." def test_read_data(self) -> None: """ @@ -58,7 +63,11 @@ def test_read_data(self) -> None: EMBEDDING_OFFSET + len(self.reader.cache) + 3, # [Ni-2] (new token) ] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected tokenized output for the SMILES string.", + ) def test_read_data_with_new_token(self) -> None: """ @@ -72,12 +81,24 @@ def test_read_data_with_new_token(self) -> None: expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected output for a SMILES string with a new token.", + ) # Verify that '[H-]' was added to the cache - self.assertIn("[H-]", self.reader.cache) + self.assertIn( + "[H-]", + self.reader.cache, + "The new token '[H-]' was not added to the cache as expected.", + ) # Ensure it's at the correct index - self.assertEqual(self.reader.cache.index("[H-]"), index_for_last_token) + self.assertEqual( + self.reader.cache.index("[H-]"), + index_for_last_token, + "The new token '[H-]' was not added to the correct index in the cache.", + ) if __name__ == "__main__": diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py index 6e5f325c..c5bc5e9a 100644 --- a/tests/unit/readers/testProteinDataReader.py +++ b/tests/unit/readers/testProteinDataReader.py @@ -25,7 +25,14 @@ def setUpClass(cls, mock_file: mock_open) -> None: """ cls.reader = ProteinDataReader(token_path="/mock/path") # After initializing, cls.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N'] - assert cls.reader.cache == ["M", "K", "T", "F", "R", "N"] + assert cls.reader.cache == [ + "M", + "K", + "T", + "F", + "R", + "N", + ], "Cache initialization did not match expected tokens." def test_read_data(self) -> None: """ @@ -44,7 +51,11 @@ def test_read_data(self) -> None: EMBEDDING_OFFSET + 5, # N ] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected tokenized output.", + ) def test_read_data_with_new_token(self) -> None: """ @@ -63,12 +74,22 @@ def test_read_data_with_new_token(self) -> None: ] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle a new token.", + ) # Verify that 'Y' was added to the cache - self.assertIn("Y", self.reader.cache) + self.assertIn( + "Y", self.reader.cache, "The new token 'Y' was not added to the cache." + ) # Ensure it's at the correct index - self.assertEqual(self.reader.cache.index("Y"), len(self.reader.cache) - 1) + self.assertEqual( + self.reader.cache.index("Y"), + len(self.reader.cache) - 1, + "The new token 'Y' was not added at the correct index in the cache.", + ) def test_read_data_with_invalid_token(self) -> None: """ @@ -79,7 +100,11 @@ def test_read_data_with_invalid_token(self) -> None: with self.assertRaises(KeyError) as context: self.reader._read_data(raw_data) - self.assertIn("Invalid token 'Z' encountered", str(context.exception)) + self.assertIn( + "Invalid token 'Z' encountered", + str(context.exception), + "The KeyError did not contain the expected message for an invalid token.", + ) def test_read_data_with_empty_sequence(self) -> None: """ @@ -88,7 +113,11 @@ def test_read_data_with_empty_sequence(self) -> None: raw_data = "" result = self.reader._read_data(raw_data) - self.assertEqual(result, []) + self.assertEqual( + result, + [], + "The _read_data method did not return an empty list for an empty input sequence.", + ) def test_read_data_with_repeated_tokens(self) -> None: """ @@ -99,7 +128,11 @@ def test_read_data_with_repeated_tokens(self) -> None: expected_output: List[int] = [EMBEDDING_OFFSET + 0] * 5 # All tokens are 'M' result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle repeated tokens.", + ) if __name__ == "__main__": diff --git a/tests/unit/readers/testSelfiesReader.py b/tests/unit/readers/testSelfiesReader.py index 019a0f59..411fc63b 100644 --- a/tests/unit/readers/testSelfiesReader.py +++ b/tests/unit/readers/testSelfiesReader.py @@ -26,8 +26,12 @@ def setUpClass(cls, mock_file: mock_open) -> None: mock_file: Mock object for file operations. """ cls.reader = SelfiesReader(token_path="/mock/path") - # After initializing, cls.reader.cache should now be set to ['[C]', '[O]', '[N]', '[=]', '[1]', '[('] - assert cls.reader.cache == ["[C]", "[O]", "[=C]"] + # After initializing, cls.reader.cache should now be set to ['[C]', '[O]', '[=C]'] + assert cls.reader.cache == [ + "[C]", + "[O]", + "[=C]", + ], "Cache initialization did not match expected tokens." def test_read_data(self) -> None: """ @@ -62,7 +66,11 @@ def test_read_data(self) -> None: ] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected tokenized output.", + ) def test_read_data_with_new_token(self) -> None: """ @@ -76,12 +84,24 @@ def test_read_data_with_new_token(self) -> None: expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] result = self.reader._read_data(raw_data) - self.assertEqual(result, expected_output) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle a new token.", + ) # Verify that '[H-1]' was added to the cache, "[H-]" translated to "[H-1]" in SELFIES - self.assertIn("[H-1]", self.reader.cache) + self.assertIn( + "[H-1]", + self.reader.cache, + "The new token '[H-1]' was not added to the cache.", + ) # Ensure it's at the correct index - self.assertEqual(self.reader.cache.index("[H-1]"), index_for_last_token) + self.assertEqual( + self.reader.cache.index("[H-1]"), + index_for_last_token, + "The new token '[H-1]' was not added at the correct index in the cache.", + ) def test_read_data_with_invalid_selfies(self) -> None: """ @@ -90,10 +110,17 @@ def test_read_data_with_invalid_selfies(self) -> None: raw_data = "[C][O][INVALID][N]" result = self.reader._read_data(raw_data) - self.assertIsNone(result) + self.assertIsNone( + result, + "The _read_data method did not return None for an invalid SELFIES string.", + ) # Verify that the error count was incremented - self.assertEqual(self.reader.error_count, 1) + self.assertEqual( + self.reader.error_count, + 1, + "The error count was not incremented for an invalid SELFIES string.", + ) if __name__ == "__main__": From c0aaeeaef84efa06b0a68879ddf3e0874c749138 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Sep 2024 17:34:03 +0200 Subject: [PATCH 14/46] test data class for chebi ontology --- tests/unit/mock_data/ontology_mock_data.py | 146 +++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 tests/unit/mock_data/ontology_mock_data.py diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py new file mode 100644 index 00000000..27fd511f --- /dev/null +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -0,0 +1,146 @@ +class ChebiMockOntology: + """ + Nodes: + - CHEBI:12345 (Compound A) + - CHEBI:54321 (Compound B) + - CHEBI:67890 (Compound C) + - CHEBI:11111 (Compound D) + - CHEBI:22222 (Compound E) + - CHEBI:99999 (Compound F) + - CHEBI:77533 (Compound F, Obsolete node) + - CHEBI:77564 (Compound H, Obsolete node) + - CHEBI:88888 (Compound I) + + Valid Edges: + - CHEBI:54321 -> CHEBI:12345 + - CHEBI:67890 -> CHEBI:12345 + - CHEBI:67890 -> CHEBI:88888 + - CHEBI:11111 -> CHEBI:54321 + - CHEBI:77564 -> CHEBI:54321 (Ignored due to obsolete status) + - CHEBI:22222 -> CHEBI:67890 + - CHEBI:12345 -> CHEBI:99999 + - CHEBI:77533 -> CHEBI:99999 (Ignored due to obsolete status) + """ + + @staticmethod + def get_nodes(): + return {12345, 54321, 67890, 11111, 22222, 99999, 88888} + + @staticmethod + def get_number_of_nodes(): + return len(ChebiMockOntology.get_nodes()) + + @staticmethod + def get_edges_of_transitive_closure_graph(): + return { + (54321, 12345), + (54321, 99999), + (67890, 12345), + (67890, 99999), + (67890, 88888), + (11111, 54321), + (11111, 12345), + (11111, 99999), + (22222, 67890), + (22222, 12345), + (22222, 99999), + (22222, 88888), + (12345, 99999), + } + + @staticmethod + def get_number_of_transitive_edges(): + return len(ChebiMockOntology.get_edges_of_transitive_closure_graph()) + + @staticmethod + def get_edges(): + return { + (54321, 12345), + (67890, 12345), + (67890, 88888), + (11111, 54321), + (22222, 67890), + (12345, 99999), + } + + @staticmethod + def get_number_of_edges(): + return len(ChebiMockOntology.get_edges()) + + @staticmethod + def get_obsolete_nodes_ids(): + return {77533, 77564} + + @staticmethod + def get_raw_data(): + # Create mock terms with a complex hierarchy, names, and SMILES strings + return """ + [Term] + id: CHEBI:12345 + name: Compound A + subset: 2_STAR + property_value: http://purl.obolibrary.org/obo/chebi/formula "C26H35ClN4O6S" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/charge "0" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/monoisotopicmass "566.19658" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/mass "567.099" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/inchikey "ROXPMFGZZQEKHB-IUKKYPGJSA-N" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/inchi "InChI=1S/C26H35ClN4O6S/c1-16(2)28-26(34)30(5)14-23-17(3)13-31(18(4)15-32)25(33)21-7-6-8-22(24(21)37-23)29-38(35,36)20-11-9-19(27)10-12-20/h6-12,16-18,23,29,32H,13-15H2,1-5H3,(H,28,34)/t17-,18-,23+/m0/s1" xsd:string + xref: LINCS:LSM-20139 + is_a: CHEBI:54321 + is_a: CHEBI:67890 + + [Term] + id: CHEBI:54321 + name: Compound B + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1O" xsd:string + is_a: CHEBI:11111 + is_a: CHEBI:77564 + + [Term] + id: CHEBI:67890 + name: Compound C + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1N" xsd:string + is_a: CHEBI:22222 + + [Term] + id: CHEBI:11111 + name: Compound D + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1F" xsd:string + + [Term] + id: CHEBI:22222 + name: Compound E + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1Cl" xsd:string + + [Term] + id: CHEBI:99999 + name: Compound F + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1Br" xsd:string + is_a: CHEBI:12345 + + [Term] + id: CHEBI:77533 + name: Compound G + is_a: CHEBI:99999 + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=C1Br" xsd:string + is_obsolete: true + + [Term] + id: CHEBI:77564 + name: Compound H + property_value: http://purl.obolibrary.org/obo/chebi/smiles "CC=C1Br" xsd:string + is_obsolete: true + + [Typedef] + id: has_major_microspecies_at_pH_7_3 + name: has major microspecies at pH 7.3 + is_cyclic: true + is_transitive: false + + [Term] + id: CHEBI:88888 + name: Compound I + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1[Mg+]" xsd:string + is_a: CHEBI:67890 + """ From 764216e91e032693b90b9044eccc2fb411fcfad5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Sep 2024 17:38:13 +0200 Subject: [PATCH 15/46] test for term callback + mock data changes --- .../dataset_classes/testChebiTermCallback.py | 67 +++++++++++++ tests/unit/mock_data/__init__.py | 0 tests/unit/mock_data/ontology_mock_data.py | 98 +++++++++++++++---- 3 files changed, 144 insertions(+), 21 deletions(-) create mode 100644 tests/unit/dataset_classes/testChebiTermCallback.py create mode 100644 tests/unit/mock_data/__init__.py diff --git a/tests/unit/dataset_classes/testChebiTermCallback.py b/tests/unit/dataset_classes/testChebiTermCallback.py new file mode 100644 index 00000000..7b22d1a2 --- /dev/null +++ b/tests/unit/dataset_classes/testChebiTermCallback.py @@ -0,0 +1,67 @@ +import unittest +from typing import Any, Dict + +import fastobo +from fastobo.term import TermFrame + +from chebai.preprocessing.datasets.chebi import term_callback +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChebiTermCallback(unittest.TestCase): + """ + Unit tests for the `term_callback` function used in processing ChEBI ontology terms. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test class by loading ChEBI term data and storing it in a dictionary + where keys are the term IDs and values are TermFrame instances. + """ + cls.callback_input_data: Dict[int, TermFrame] = { + int(term_doc.id.local): term_doc + for term_doc in fastobo.loads(ChebiMockOntology.get_raw_data()) + if term_doc and ":" in str(term_doc.id) + } + + def test_process_valid_terms(self) -> None: + """ + Test that `term_callback` correctly processes valid ChEBI terms. + """ + + expected_result: Dict[str, Any] = { + "id": 12345, + "parents": [54321, 67890], + "has_part": set(), + "name": "Compound A", + "smiles": "C1=CC=CC=C1", + } + + actual_dict: Dict[str, Any] = term_callback( + self.callback_input_data.get(expected_result["id"]) + ) + self.assertEqual( + expected_result, + actual_dict, + msg="term_callback should correctly extract information from valid ChEBI terms.", + ) + + def test_skip_obsolete_terms(self) -> None: + """ + Test that `term_callback` correctly skips obsolete ChEBI terms. + """ + + term_callback_output = [ + term_callback(self.callback_input_data.get(ident)) + for ident in ChebiMockOntology.get_obsolete_nodes_ids() + ] + self.assertEqual( + term_callback_output, + [], + msg="The term_callback function should skip obsolete terms and return an empty list.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/mock_data/__init__.py b/tests/unit/mock_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 27fd511f..11d5c9ce 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -1,5 +1,12 @@ +from typing import Set, Tuple + + class ChebiMockOntology: """ + A mock ontology representing a simplified ChEBI (Chemical Entities of Biological Interest) structure. + This class is used for testing purposes and includes nodes and edges representing chemical compounds + and their relationships in a graph structure. + Nodes: - CHEBI:12345 (Compound A) - CHEBI:54321 (Compound B) @@ -7,7 +14,7 @@ class ChebiMockOntology: - CHEBI:11111 (Compound D) - CHEBI:22222 (Compound E) - CHEBI:99999 (Compound F) - - CHEBI:77533 (Compound F, Obsolete node) + - CHEBI:77533 (Compound G, Obsolete node) - CHEBI:77564 (Compound H, Obsolete node) - CHEBI:88888 (Compound I) @@ -16,64 +23,113 @@ class ChebiMockOntology: - CHEBI:67890 -> CHEBI:12345 - CHEBI:67890 -> CHEBI:88888 - CHEBI:11111 -> CHEBI:54321 - - CHEBI:77564 -> CHEBI:54321 (Ignored due to obsolete status) - CHEBI:22222 -> CHEBI:67890 - CHEBI:12345 -> CHEBI:99999 - - CHEBI:77533 -> CHEBI:99999 (Ignored due to obsolete status) + + The class also includes methods to retrieve nodes, edges, and transitive closure of the graph. """ @staticmethod - def get_nodes(): + def get_nodes() -> Set[int]: + """ + Get the set of valid node IDs in the mock ontology. + + Returns: + - Set[int]: A set of integers representing the valid ChEBI node IDs. + """ return {12345, 54321, 67890, 11111, 22222, 99999, 88888} @staticmethod - def get_number_of_nodes(): + def get_number_of_nodes() -> int: + """ + Get the number of valid nodes in the mock ontology. + + Returns: + - int: The number of valid nodes. + """ return len(ChebiMockOntology.get_nodes()) @staticmethod - def get_edges_of_transitive_closure_graph(): + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of valid edges in the mock ontology. + + Returns: + - Set[Tuple[int, int]]: A set of tuples representing the directed edges + between ChEBI nodes. + """ return { (54321, 12345), - (54321, 99999), (67890, 12345), - (67890, 99999), (67890, 88888), (11111, 54321), - (11111, 12345), - (11111, 99999), (22222, 67890), - (22222, 12345), - (22222, 99999), - (22222, 88888), (12345, 99999), } @staticmethod - def get_number_of_transitive_edges(): - return len(ChebiMockOntology.get_edges_of_transitive_closure_graph()) + def get_number_of_edges() -> int: + """ + Get the number of valid edges in the mock ontology. + + Returns: + - int: The number of valid edges. + """ + return len(ChebiMockOntology.get_edges()) @staticmethod - def get_edges(): + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges derived from the transitive closure of the mock ontology graph. + + Returns: + - Set[Tuple[int, int]]: A set of tuples representing the directed edges + in the transitive closure of the ChEBI graph. + """ return { (54321, 12345), + (54321, 99999), (67890, 12345), + (67890, 99999), (67890, 88888), (11111, 54321), + (11111, 12345), + (11111, 99999), (22222, 67890), + (22222, 12345), + (22222, 99999), + (22222, 88888), (12345, 99999), } @staticmethod - def get_number_of_edges(): - return len(ChebiMockOntology.get_edges()) + def get_number_of_transitive_edges() -> int: + """ + Get the number of edges in the transitive closure of the mock ontology graph. + + Returns: + - int: The number of edges in the transitive closure graph. + """ + return len(ChebiMockOntology.get_edges_of_transitive_closure_graph()) @staticmethod - def get_obsolete_nodes_ids(): + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the mock ontology. + + Returns: + - Set[int]: A set of integers representing the obsolete ChEBI node IDs. + """ return {77533, 77564} @staticmethod - def get_raw_data(): - # Create mock terms with a complex hierarchy, names, and SMILES strings + def get_raw_data() -> str: + """ + Get the raw data representing the mock ontology in OBO format. + + Returns: + - str: A string containing the raw OBO data for the mock ChEBI terms. + """ return """ [Term] id: CHEBI:12345 From 1dd8428bbfc46ebf5aa445cc542851cfd8df4f5a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 5 Sep 2024 20:10:35 +0200 Subject: [PATCH 16/46] test for chebidataextractor + changes in mock data --- .../dataset_classes/testChebiDataExtractor.py | 214 ++++++++++++++++++ tests/unit/mock_data/ontology_mock_data.py | 80 ++++++- 2 files changed, 291 insertions(+), 3 deletions(-) create mode 100644 tests/unit/dataset_classes/testChebiDataExtractor.py diff --git a/tests/unit/dataset_classes/testChebiDataExtractor.py b/tests/unit/dataset_classes/testChebiDataExtractor.py new file mode 100644 index 00000000..cb52e68f --- /dev/null +++ b/tests/unit/dataset_classes/testChebiDataExtractor.py @@ -0,0 +1,214 @@ +import unittest +from unittest.mock import PropertyMock, mock_open, patch + +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor +from chebai.preprocessing.reader import ChemDataReader +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChEBIDataExtractor(unittest.TestCase): + + @classmethod + @patch.multiple(_ChEBIDataExtractor, __abstractmethods__=frozenset()) + @patch.object(_ChEBIDataExtractor, "base_dir", new_callable=PropertyMock) + @patch.object(_ChEBIDataExtractor, "_name", new_callable=PropertyMock) + def setUpClass( + cls, mock_base_dir_property: PropertyMock, mock_name_property: PropertyMock + ) -> None: + """ + Set up a base instance of _DynamicDataset for testing with mocked properties. + """ + + # Mocking properties + mock_base_dir_property.return_value = "MockedBaseDirPropertyChebiDataExtractor" + mock_name_property.return_value = "MockedNamePropertyChebiDataExtractor" + + # Assigning a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) + _ChEBIDataExtractor.READER = ChemDataReader + + # Creating an instance of the dataset + cls.extractor: _ChEBIDataExtractor = _ChEBIDataExtractor( + chebi_version=231, chebi_version_train=200 + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_extract_class_hierarchy(self, mock_open): + # Mock the output of fastobo.loads + graph = self.extractor._extract_class_hierarchy("fake_path") + + # Validate the graph structure + self.assertIsInstance( + graph, nx.DiGraph, "The result should be a directed graph." + ) + + # Check nodes + actual_nodes = set(graph.nodes) + self.assertEqual( + set(ChebiMockOntology.get_nodes()), + actual_nodes, + "The graph nodes do not match the expected nodes.", + ) + + # Check edges + actual_edges = set(graph.edges) + self.assertEqual( + ChebiMockOntology.get_edges_of_transitive_closure_graph(), + actual_edges, + "The graph edges do not match the expected edges.", + ) + + # Check number of nodes and edges + self.assertEqual( + ChebiMockOntology.get_number_of_nodes(), + len(actual_nodes), + "The number of nodes should match the actual number of nodes in the graph.", + ) + + self.assertEqual( + ChebiMockOntology.get_number_of_transitive_edges(), + len(actual_edges), + "The number of transitive edges should match the actual number of transitive edges in the graph.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + @patch.object( + _ChEBIDataExtractor, + "select_classes", + return_value=ChebiMockOntology.get_nodes(), + ) + def test_graph_to_raw_dataset(self, mock_open, mock_select_classes): + graph = self.extractor._extract_class_hierarchy("fake_path") + data_df = self.extractor._graph_to_raw_dataset(graph) + + pd.testing.assert_frame_equal( + data_df, + ChebiMockOntology.get_data_in_dataframe(), + obj="The DataFrame should match the expected structure", + ) + + @patch( + "builtins.open", new_callable=mock_open, read_data=b"Mocktestdata" + ) # Mocking open as a binary file + @patch("pandas.read_pickle") + def test_load_dict(self, mock_open, mock_read_pickle): + + # Mock the DataFrame returned by read_pickle + mock_df = pd.DataFrame( + { + "id": [12345, 67890, 11111, 54321], # Corrected ID + "name": ["A", "B", "C", "D"], + "SMILES": ["C1CCCCC1", "O=C=O", "C1CC=CC1", "C[Mg+]"], + 12345: [True, False, False, True], + 67890: [False, True, True, False], + 11111: [True, False, True, False], + } + ) + mock_read_pickle.return_value = mock_df # Mock the return value of read_pickle + + # Call the actual function (with open correctly mocked) + generator = self.extractor._load_dict("data/tests") + result = list(generator) # Collect all output from the generator + + # Expected output for comparison + expected_result = [ + {"features": "C1CCCCC1", "labels": [True, False, True], "ident": 12345}, + {"features": "O=C=O", "labels": [False, True, False], "ident": 67890}, + {"features": "C1CC=CC1", "labels": [False, True, True], "ident": 11111}, + { + "features": "C[Mg+]", + "labels": [True, False, False], + "ident": 54321, + }, # Corrected ID + ] + + # Assert if the result matches the expected output + self.assertEqual( + result, + expected_result, + "The loaded dictionary should match the expected structure.", + ) + + @patch("builtins.open", new_callable=mock_open) + @patch.object(_ChEBIDataExtractor, "_name", new_callable=PropertyMock) + @patch.object(_ChEBIDataExtractor, "processed_dir_main", new_callable=PropertyMock) + @patch.object( + _ChEBIDataExtractor, "_chebi_version_train_obj", new_callable=PropertyMock + ) + def test_setup_pruned_test_set( + self, + mock_chebi_version_train_obj, + mock_processed_dir_main, + mock_name_property, + mock_open_file, + ): + # Mock the content for the two open calls (original classes and new classes) + mock_orig_classes = "12345\n67890\n88888\n54321\n77777\n" + mock_new_classes = "12345\n67890\n99999\n77777\n" + + # Use side_effect to simulate the two different file reads + mock_open_file.side_effect = [ + mock_open( + read_data=mock_orig_classes + ).return_value, # First open() for orig_classes + mock_open( + read_data=mock_new_classes + ).return_value, # Second open() for new_classes + ] + + # Mock the attributes used in the method + mock_processed_dir_main.return_value = "/mock/path/to/current" + mock_chebi_version_train_obj.return_value.processed_dir_main = ( + "/mock/path/to/train" + ) + + # Mock DataFrame to simulate the test dataset + mock_df = pd.DataFrame( + { + "labels": [ + [ + True, + False, + True, + False, + True, + ], # First test instance labels (match orig_classes) + [False, True, False, True, False], + ] # Second test instance labels + } + ) + + # Call the method under test + pruned_df = self.extractor._setup_pruned_test_set(mock_df) + + # Expected DataFrame labels after pruning (only "12345", "67890", "77777",and "99999" remain) + expected_labels = [[True, False, False, True], [False, True, False, False]] + + # Check if the pruned DataFrame still has the same number of rows + self.assertEqual( + len(pruned_df), + len(mock_df), + "The pruned DataFrame should have the same number of rows.", + ) + + # Check that the labels are correctly pruned + for i in range(len(pruned_df)): + self.assertEqual( + pruned_df.iloc[i]["labels"], + expected_labels[i], + f"Row {i}'s labels should be pruned correctly.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 11d5c9ce..61b4462a 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -1,4 +1,7 @@ -from typing import Set, Tuple +from collections import OrderedDict +from typing import List, Set, Tuple + +import pandas as pd class ChebiMockOntology: @@ -30,14 +33,14 @@ class ChebiMockOntology: """ @staticmethod - def get_nodes() -> Set[int]: + def get_nodes() -> List[int]: """ Get the set of valid node IDs in the mock ontology. Returns: - Set[int]: A set of integers representing the valid ChEBI node IDs. """ - return {12345, 54321, 67890, 11111, 22222, 99999, 88888} + return [11111, 12345, 22222, 54321, 67890, 88888, 99999] @staticmethod def get_number_of_nodes() -> int: @@ -200,3 +203,74 @@ def get_raw_data() -> str: property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1[Mg+]" xsd:string is_a: CHEBI:67890 """ + + @staticmethod + def get_data_in_dataframe(): + data = OrderedDict( + id=[ + 12345, + 54321, + 67890, + 11111, + 22222, + 99999, + 88888, + ], + name=[ + "Compound A", + "Compound B", + "Compound C", + "Compound D", + "Compound E", + "Compound F", + "Compound I", + ], + SMILES=[ + "C1=CC=CC=C1", + "C1=CC=CC=C1O", + "C1=CC=CC=C1N", + "C1=CC=CC=C1F", + "C1=CC=CC=C1Cl", + "C1=CC=CC=C1Br", + "C1=CC=CC=C1[Mg+]", + ], + # Relationships { + # 12345: [11111, 54321, 22222, 67890], + # 67890: [22222], + # 99999: [67890, 11111, 54321, 22222, 12345], + # 54321: [11111], + # 88888: [22222, 67890] + # 11111: [] + # 22222: [] + # } + **{ + # -row- [11111, 12345, 22222, 54321, 67890, 88888, 99999] + 11111: [False, False, False, False, False, False, False], + 12345: [True, True, True, True, True, False, False], + 22222: [False, False, False, False, False, False, False], + 54321: [True, False, False, True, False, False, False], + 67890: [False, False, True, False, True, False, False], + 88888: [False, False, True, False, True, True, False], + 99999: [True, True, True, True, True, False, True], + } + ) + + data_df = pd.DataFrame(data) + + # ------------- Code Approach ------- + # ancestors_of_nodes = {} + # for parent, child in ChebiMockOntology.get_edges_of_transitive_closure_graph(): + # if child not in ancestors_of_nodes: + # ancestors_of_nodes[child] = set() + # if parent not in ancestors_of_nodes: + # ancestors_of_nodes[parent] = set() + # ancestors_of_nodes[child].add(parent) + # ancestors_of_nodes[child].add(child) + # + # # For each node in the ontology, create a column to check if it's an ancestor of any other node or itself + # for node in ChebiMockOntology.get_nodes(): + # data_df[node] = data_df['id'].apply( + # lambda x: (x == node) or (node in ancestors_of_nodes[x]) + # ) + + return data_df From f3519b566410ef1d20f9020258bceabe57199f74 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 5 Sep 2024 21:13:59 +0200 Subject: [PATCH 17/46] mock reader for all + test_setup_pruned_test_set changes --- .../dataset_classes/testChebiDataExtractor.py | 78 +++++++++++-------- .../dataset_classes/testDynamicDataset.py | 25 +++--- .../dataset_classes/testXYBaseDataModule.py | 8 +- 3 files changed, 62 insertions(+), 49 deletions(-) diff --git a/tests/unit/dataset_classes/testChebiDataExtractor.py b/tests/unit/dataset_classes/testChebiDataExtractor.py index cb52e68f..0559e090 100644 --- a/tests/unit/dataset_classes/testChebiDataExtractor.py +++ b/tests/unit/dataset_classes/testChebiDataExtractor.py @@ -1,11 +1,10 @@ import unittest -from unittest.mock import PropertyMock, mock_open, patch +from unittest.mock import MagicMock, PropertyMock, mock_open, patch import networkx as nx import pandas as pd from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor -from chebai.preprocessing.reader import ChemDataReader from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology @@ -16,30 +15,39 @@ class TestChEBIDataExtractor(unittest.TestCase): @patch.object(_ChEBIDataExtractor, "base_dir", new_callable=PropertyMock) @patch.object(_ChEBIDataExtractor, "_name", new_callable=PropertyMock) def setUpClass( - cls, mock_base_dir_property: PropertyMock, mock_name_property: PropertyMock + cls, mock_name_property: PropertyMock, mock_base_dir_property: PropertyMock ) -> None: """ - Set up a base instance of _DynamicDataset for testing with mocked properties. + Set up a base instance of _ChEBIDataExtractor for testing with mocked properties. """ - # Mocking properties mock_base_dir_property.return_value = "MockedBaseDirPropertyChebiDataExtractor" mock_name_property.return_value = "MockedNamePropertyChebiDataExtractor" - # Assigning a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) - _ChEBIDataExtractor.READER = ChemDataReader + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + _ChEBIDataExtractor.READER = ReaderMock - # Creating an instance of the dataset + # Create an instance of the dataset cls.extractor: _ChEBIDataExtractor = _ChEBIDataExtractor( chebi_version=231, chebi_version_train=200 ) + # Mock instance for _chebi_version_train_obj + mock_train_obj = MagicMock() + mock_train_obj.processed_dir_main = "/mock/path/to/train" + cls.extractor._chebi_version_train_obj = mock_train_obj + @patch( "builtins.open", new_callable=mock_open, read_data=ChebiMockOntology.get_raw_data(), ) - def test_extract_class_hierarchy(self, mock_open): + def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: + """ + Test the extraction of class hierarchy and validate the structure of the resulting graph. + """ # Mock the output of fastobo.loads graph = self.extractor._extract_class_hierarchy("fake_path") @@ -87,22 +95,31 @@ def test_extract_class_hierarchy(self, mock_open): "select_classes", return_value=ChebiMockOntology.get_nodes(), ) - def test_graph_to_raw_dataset(self, mock_open, mock_select_classes): + def test_graph_to_raw_dataset( + self, mock_select_classes: PropertyMock, mock_open: mock_open + ) -> None: + """ + Test conversion of a graph to a raw dataset and compare it with the expected DataFrame. + """ graph = self.extractor._extract_class_hierarchy("fake_path") data_df = self.extractor._graph_to_raw_dataset(graph) pd.testing.assert_frame_equal( data_df, ChebiMockOntology.get_data_in_dataframe(), - obj="The DataFrame should match the expected structure", + obj="The DataFrame should match the expected structure.", ) @patch( "builtins.open", new_callable=mock_open, read_data=b"Mocktestdata" ) # Mocking open as a binary file @patch("pandas.read_pickle") - def test_load_dict(self, mock_open, mock_read_pickle): - + def test_load_dict( + self, mock_read_pickle: PropertyMock, mock_open: mock_open + ) -> None: + """ + Test loading data from a pickled file and verify the generator output. + """ # Mock the DataFrame returned by read_pickle mock_df = pd.DataFrame( { @@ -114,22 +131,21 @@ def test_load_dict(self, mock_open, mock_read_pickle): 11111: [True, False, True, False], } ) - mock_read_pickle.return_value = mock_df # Mock the return value of read_pickle + mock_read_pickle.return_value = mock_df - # Call the actual function (with open correctly mocked) generator = self.extractor._load_dict("data/tests") - result = list(generator) # Collect all output from the generator + result = list(generator) + + # Convert NumPy arrays to lists for comparison + for item in result: + item["labels"] = list(item["labels"]) # Expected output for comparison expected_result = [ {"features": "C1CCCCC1", "labels": [True, False, True], "ident": 12345}, {"features": "O=C=O", "labels": [False, True, False], "ident": 67890}, {"features": "C1CC=CC1", "labels": [False, True, True], "ident": 11111}, - { - "features": "C[Mg+]", - "labels": [True, False, False], - "ident": 54321, - }, # Corrected ID + {"features": "C[Mg+]", "labels": [True, False, False], "ident": 54321}, ] # Assert if the result matches the expected output @@ -140,18 +156,15 @@ def test_load_dict(self, mock_open, mock_read_pickle): ) @patch("builtins.open", new_callable=mock_open) - @patch.object(_ChEBIDataExtractor, "_name", new_callable=PropertyMock) @patch.object(_ChEBIDataExtractor, "processed_dir_main", new_callable=PropertyMock) - @patch.object( - _ChEBIDataExtractor, "_chebi_version_train_obj", new_callable=PropertyMock - ) def test_setup_pruned_test_set( self, - mock_chebi_version_train_obj, - mock_processed_dir_main, - mock_name_property, - mock_open_file, - ): + mock_processed_dir_main: PropertyMock, + mock_open_file: mock_open, + ) -> None: + """ + Test the pruning of the test set to match classes in the training set. + """ # Mock the content for the two open calls (original classes and new classes) mock_orig_classes = "12345\n67890\n88888\n54321\n77777\n" mock_new_classes = "12345\n67890\n99999\n77777\n" @@ -168,9 +181,6 @@ def test_setup_pruned_test_set( # Mock the attributes used in the method mock_processed_dir_main.return_value = "/mock/path/to/current" - mock_chebi_version_train_obj.return_value.processed_dir_main = ( - "/mock/path/to/train" - ) # Mock DataFrame to simulate the test dataset mock_df = pd.DataFrame( @@ -191,7 +201,7 @@ def test_setup_pruned_test_set( # Call the method under test pruned_df = self.extractor._setup_pruned_test_set(mock_df) - # Expected DataFrame labels after pruning (only "12345", "67890", "77777",and "99999" remain) + # Expected DataFrame labels after pruning (only "12345", "67890", "77777", and "99999" remain) expected_labels = [[True, False, False, True], [False, True, False, False]] # Check if the pruned DataFrame still has the same number of rows diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py index 50b9287a..1ff6c26d 100644 --- a/tests/unit/dataset_classes/testDynamicDataset.py +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -1,11 +1,10 @@ import unittest from typing import Tuple -from unittest.mock import PropertyMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pandas as pd from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.reader import ProteinDataReader class TestDynamicDataset(unittest.TestCase): @@ -29,8 +28,10 @@ def setUpClass( mock_base_dir_property.return_value = "MockedBaseDirPropertyDynamicDataset" mock_name_property.return_value = "MockedNamePropertyDynamicDataset" - # Assigning a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) - _DynamicDataset.READER = ProteinDataReader + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + _DynamicDataset.READER = ReaderMock # Creating an instance of the dataset cls.dataset: _DynamicDataset = _DynamicDataset() @@ -72,7 +73,7 @@ def setUpClass( [True, False], [True, True], ] - cls.df = pd.DataFrame( + cls.data_df = pd.DataFrame( {"ident": [f"id{i + 1}" for i in range(len(X))], "features": X, "labels": y} ) @@ -82,7 +83,7 @@ def test_get_test_split_valid(self) -> None: """ self.dataset.train_split = 0.5 # Test size will be 0.25 * 16 = 4 - train_df, test_df = self.dataset.get_test_split(self.df, seed=0) + train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) # Assert the correct number of rows in train and test sets self.assertEqual(len(train_df), 12, "Train set should contain 12 samples.") @@ -127,8 +128,8 @@ def test_get_test_split_seed_consistency(self) -> None: """ Test that splitting the dataset with the same seed produces consistent results. """ - train_df1, test_df1 = self.dataset.get_test_split(self.df, seed=42) - train_df2, test_df2 = self.dataset.get_test_split(self.df, seed=42) + train_df1, test_df1 = self.dataset.get_test_split(self.data_df, seed=42) + train_df2, test_df2 = self.dataset.get_test_split(self.data_df, seed=42) pd.testing.assert_frame_equal( train_df1, @@ -145,7 +146,7 @@ def test_get_train_val_splits_given_test(self) -> None: """ self.dataset.use_inner_cross_validation = False self.dataset.train_split = 0.5 - df_train_main, test_df = self.dataset.get_test_split(self.df, seed=0) + df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) train_df, val_df = self.dataset.get_train_val_splits_given_test( df_train_main, test_df, seed=42 ) @@ -192,12 +193,12 @@ def test_get_train_val_splits_given_test_consistency(self) -> None: """ Test that splitting the dataset into train and validation sets with the same seed produces consistent results. """ - test_df = self.df.iloc[12:] # Assume rows 12 onward are for testing + test_df = self.data_df.iloc[12:] # Assume rows 12 onward are for testing train_df1, val_df1 = self.dataset.get_train_val_splits_given_test( - self.df, test_df, seed=42 + self.data_df, test_df, seed=42 ) train_df2, val_df2 = self.dataset.get_train_val_splits_given_test( - self.df, test_df, seed=42 + self.data_df, test_df, seed=42 ) pd.testing.assert_frame_equal( diff --git a/tests/unit/dataset_classes/testXYBaseDataModule.py b/tests/unit/dataset_classes/testXYBaseDataModule.py index 4c2d21dc..8e3575ab 100644 --- a/tests/unit/dataset_classes/testXYBaseDataModule.py +++ b/tests/unit/dataset_classes/testXYBaseDataModule.py @@ -1,8 +1,7 @@ import unittest -from unittest.mock import PropertyMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.reader import ProteinDataReader class TestXYBaseDataModule(unittest.TestCase): @@ -21,7 +20,10 @@ def setUpClass(cls, mock_name_property: PropertyMock) -> None: mock_name_property.return_value = "MockedNamePropXYBaseDataModule" # Assign a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) - XYBaseDataModule.READER = ProteinDataReader + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + XYBaseDataModule.READER = ReaderMock # Initialize the module with a label_filter cls.module = XYBaseDataModule( From fc0fd47389ea60a7573b4de7645c1a133816245d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 5 Sep 2024 21:43:10 +0200 Subject: [PATCH 18/46] fix for misalignment between x an y in RaggedCollator - https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 --- tests/unit/collators/testRaggedCollator.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py index d31776a6..d9ab2b1d 100644 --- a/tests/unit/collators/testRaggedCollator.py +++ b/tests/unit/collators/testRaggedCollator.py @@ -78,12 +78,15 @@ def test_call_with_missing_entire_labels(self) -> None: result: XYData = self.collator(data) - expected_x = torch.tensor([[1, 2], [6, 0]]) + # https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) expected_y = torch.tensor( [[True, False], [True, False]] ) # True -> 1, False -> 0 - expected_mask_for_x = torch.tensor([[True, True], [True, False]]) - expected_lens_for_x = torch.tensor([2, 1]) + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) self.assertTrue( torch.equal(result.x, expected_x), @@ -110,6 +113,11 @@ def test_call_with_missing_entire_labels(self) -> None: [0, 2], "The non-null labels list does not match the expected output.", ) + self.assertEqual( + len(result.additional_fields["loss_kwargs"]["non_null_labels"]), + result.y.shape[1], + "The length of non null labels list must match with target label variable size", + ) self.assertEqual( result.additional_fields["idents"], ("sample1", "sample2", "sample3"), From f7f163142c86480c08d31d9b686baba2eabcc81a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 6 Sep 2024 12:24:58 +0200 Subject: [PATCH 19/46] test for ChebiOverX --- tests/unit/dataset_classes/testChEBIOverX.py | 123 +++++++++++++++++++ tests/unit/mock_data/ontology_mock_data.py | 34 ++++- 2 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 tests/unit/dataset_classes/testChEBIOverX.py diff --git a/tests/unit/dataset_classes/testChEBIOverX.py b/tests/unit/dataset_classes/testChEBIOverX.py new file mode 100644 index 00000000..78d85dd4 --- /dev/null +++ b/tests/unit/dataset_classes/testChEBIOverX.py @@ -0,0 +1,123 @@ +import unittest +from unittest.mock import PropertyMock, mock_open, patch + +from chebai.preprocessing.datasets.chebi import ChEBIOverX +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChEBIOverX(unittest.TestCase): + @classmethod + @patch.multiple(ChEBIOverX, __abstractmethods__=frozenset()) + @patch.object(ChEBIOverX, "processed_dir_main", new_callable=PropertyMock) + def setUpClass(cls, mock_processed_dir_main: PropertyMock) -> None: + """ + Set up the ChEBIOverX instance with a mock processed directory path and a test graph. + + Args: + mock_processed_dir_main (PropertyMock): Mocked property for the processed directory path. + """ + mock_processed_dir_main.return_value = "/mock/processed_dir" + cls.chebi_extractor = ChEBIOverX(chebi_version=231) + cls.test_graph = ChebiMockOntology.get_transitively_closed_graph() + + @patch("builtins.open", new_callable=mock_open) + def test_select_classes(self, mock_open_file: mock_open) -> None: + """ + Test the select_classes method to ensure it correctly selects nodes based on the threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.chebi_extractor.THRESHOLD = 3 + selected_classes = self.chebi_extractor.select_classes(self.test_graph) + + # Check if the returned selected classes match the expected list + expected_classes = sorted([11111, 22222, 67890]) + self.assertListEqual( + selected_classes, + expected_classes, + "The selected classes do not match the expected output for the given threshold of 3.", + ) + + # Expected data as string + expected_lines = "\n".join(map(str, expected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + "The written lines do not match the expected lines for the given threshold of 3.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_no_classes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the select_classes method when no nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.chebi_extractor.THRESHOLD = 5 + selected_classes = self.chebi_extractor.select_classes(self.test_graph) + + # Expected empty result + self.assertEqual( + selected_classes, + [], + "The selected classes list should be empty when no nodes meet the threshold of 5.", + ) + + # Expected data as string + expected_lines = "" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + "The written lines do not match the expected lines when no nodes meet the threshold of 5.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_all_nodes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the select_classes method when all nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.chebi_extractor.THRESHOLD = 0 + selected_classes = self.chebi_extractor.select_classes(self.test_graph) + + expected_classes = sorted(ChebiMockOntology.get_nodes()) + # Check if the returned selected classes match the expected list + self.assertListEqual( + selected_classes, + expected_classes, + "The selected classes do not match the expected output when all nodes meet the threshold of 0.", + ) + + # Expected data as string + expected_lines = "\n".join(map(str, expected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + "The written lines do not match the expected lines when all nodes meet the threshold of 0.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 61b4462a..e6c14a93 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -1,6 +1,7 @@ from collections import OrderedDict -from typing import List, Set, Tuple +from typing import Dict, List, Set, Tuple +import networkx as nx import pandas as pd @@ -30,6 +31,18 @@ class ChebiMockOntology: - CHEBI:12345 -> CHEBI:99999 The class also includes methods to retrieve nodes, edges, and transitive closure of the graph. + + Visual Representation Graph with Valid Nodes and Edges: + + 22222 + / + 11111 67890 + \\ / \ + 54321 / 88888 + \\ / + 12345 + \ + 99999 """ @staticmethod @@ -205,7 +218,7 @@ def get_raw_data() -> str: """ @staticmethod - def get_data_in_dataframe(): + def get_data_in_dataframe() -> pd.DataFrame: data = OrderedDict( id=[ 12345, @@ -274,3 +287,20 @@ def get_data_in_dataframe(): # ) return data_df + + @staticmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Create a directed graph, compute its transitive closure, and return it. + + Returns: + g (nx.DiGraph): A transitively closed directed graph. + """ + g = nx.DiGraph() + + for node in ChebiMockOntology.get_nodes(): + g.add_node(node, **{"smiles": "test_smiles_placeholder"}) + + g.add_edges_from(ChebiMockOntology.get_edges_of_transitive_closure_graph()) + + return g From bf45bb5360eceadf7f8fb7c651a42d8208de20ec Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 6 Sep 2024 13:52:12 +0200 Subject: [PATCH 20/46] test for ChebiXOverPartial --- .../dataset_classes/testChebiOverXPartial.py | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/unit/dataset_classes/testChebiOverXPartial.py diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py new file mode 100644 index 00000000..c2515d75 --- /dev/null +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -0,0 +1,108 @@ +import unittest +from unittest.mock import mock_open, patch + +import networkx as nx + +from chebai.preprocessing.datasets.chebi import ChEBIOverXPartial +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChEBIOverX(unittest.TestCase): + + @classmethod + @patch.multiple(ChEBIOverXPartial, __abstractmethods__=frozenset()) + def setUpClass(cls) -> None: + """ + Set up the ChEBIOverXPartial instance with a mock processed directory path and a test graph. + """ + cls.chebi_extractor = ChEBIOverXPartial(top_class_id=11111, chebi_version=231) + cls.test_graph = ChebiMockOntology.get_transitively_closed_graph() + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: + """ + Test the extraction of class hierarchy and validate the structure of the resulting graph. + """ + # Mock the output of fastobo.loads + self.chebi_extractor.top_class_id = 11111 + graph: nx.DiGraph = self.chebi_extractor.extract_class_hierarchy("fake_path") + + # Validate the graph structure + self.assertIsInstance( + graph, nx.DiGraph, "The result should be a directed graph." + ) + + # Check nodes + expected_nodes = {11111, 54321, 12345, 99999} + expected_edges = { + (54321, 12345), + (54321, 99999), + (11111, 54321), + (11111, 12345), + (11111, 99999), + (12345, 99999), + } + self.assertEqual( + set(graph.nodes), + expected_nodes, + f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + ) + + # Check edges + self.assertEqual( + expected_edges, + set(graph.edges), + "The graph edges do not match the expected edges.", + ) + + # Check number of nodes and edges + self.assertEqual( + len(graph.nodes), + len(expected_nodes), + "The number of nodes should match the actual number of nodes in the graph.", + ) + + self.assertEqual( + len(expected_edges), + len(graph.edges), + "The number of transitive edges should match the actual number of transitive edges in the graph.", + ) + + self.chebi_extractor.top_class_id = 22222 + graph = self.chebi_extractor.extract_class_hierarchy("fake_path") + + # Check nodes with top class as 22222 + self.assertEqual( + set(graph.nodes), + {67890, 88888, 12345, 99999, 22222}, + f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_extract_class_hierarchy_with_bottom_cls( + self, mock_open: mock_open + ) -> None: + """ + Test the extraction of class hierarchy and validate the structure of the resulting graph. + """ + self.chebi_extractor.top_class_id = 88888 + graph: nx.DiGraph = self.chebi_extractor.extract_class_hierarchy("fake_path") + + # Check nodes with top class as 88888 + self.assertEqual( + set(graph.nodes), + {self.chebi_extractor.top_class_id}, + f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + ) + + +if __name__ == "__main__": + unittest.main() From 17bf5843df4ade5dde7264ee926cb7123cb97289 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Sep 2024 11:26:58 +0200 Subject: [PATCH 21/46] Mock data for GOUniProt --- tests/unit/mock_data/ontology_mock_data.py | 459 ++++++++++++++++++++- 1 file changed, 457 insertions(+), 2 deletions(-) diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index e6c14a93..dbce56d2 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from collections import OrderedDict from typing import Dict, List, Set, Tuple @@ -5,7 +6,115 @@ import pandas as pd -class ChebiMockOntology: +class MockOntologyGraphData(ABC): + """ + Abstract base class for mocking ontology graph data. + + This class provides a set of static methods that must be implemented by subclasses + to return various elements of an ontology graph such as nodes, edges, and dataframes. + """ + + @staticmethod + @abstractmethod + def get_nodes() -> List[int]: + """ + Get a list of node IDs in the ontology graph. + + Returns: + List[int]: A list of node IDs. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_nodes() -> int: + """ + Get the number of nodes in the ontology graph. + + Returns: + int: The total number of nodes. + """ + pass + + @staticmethod + @abstractmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples where each tuple represents an edge between two nodes. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_edges() -> int: + """ + Get the number of edges in the ontology graph. + + Returns: + int: The total number of edges. + """ + pass + + @staticmethod + @abstractmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the transitive closure of the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples representing the transitive closure edges. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_transitive_edges() -> int: + """ + Get the number of edges in the transitive closure of the ontology graph. + + Returns: + int: The total number of transitive edges. + """ + pass + + @staticmethod + @abstractmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the ontology graph. + + Returns: + Set[int]: A set of obsolete node IDs. + """ + pass + + @staticmethod + @abstractmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Get the transitive closure of the ontology graph. + + Returns: + nx.DiGraph: A directed graph representing the transitive closure of the ontology graph. + """ + pass + + @staticmethod + @abstractmethod + def get_data_in_dataframe() -> pd.DataFrame: + """ + Get the ontology data as a Pandas DataFrame. + + Returns: + pd.DataFrame: A DataFrame containing ontology data. + """ + pass + + +class ChebiMockOntology(MockOntologyGraphData): """ A mock ontology representing a simplified ChEBI (Chemical Entities of Biological Interest) structure. This class is used for testing purposes and includes nodes and edges representing chemical compounds @@ -265,7 +374,7 @@ def get_data_in_dataframe() -> pd.DataFrame: 67890: [False, False, True, False, True, False, False], 88888: [False, False, True, False, True, True, False], 99999: [True, True, True, True, True, False, True], - } + }, ) data_df = pd.DataFrame(data) @@ -304,3 +413,349 @@ def get_transitively_closed_graph() -> nx.DiGraph: g.add_edges_from(ChebiMockOntology.get_edges_of_transitive_closure_graph()) return g + + +class GOUniProtMockData(MockOntologyGraphData): + """ + A mock ontology representing a simplified version of the Gene Ontology (GO) structure with nodes and edges + representing GO terms and their relationships in a directed acyclic graph (DAG). + + Nodes: + - GO_1 + - GO_2 + - GO_3 + - GO_4 + - GO_5 + - GO_6 + + Edges (Parent-Child Relationships): + - GO_1 -> GO_2 + - GO_1 -> GO_3 + - GO_2 -> GO_4 + - GO_2 -> GO_5 + - GO_3 -> GO_4 + - GO_4 -> GO_6 + + This mock ontology structure is useful for testing methods related to GO hierarchy, graph extraction, and transitive + closure operations. + + The class also includes methods to retrieve nodes, edges, and transitive closure of the graph. + + Visual Representation Graph with Valid Nodes and Edges: + + GO_1 + / \ + GO_2 GO_3 + / \ / + GO_5 GO_4 + \ + GO_6 + + Valid Swiss Proteins with mapping to valid GO ids + Swiss_Prot_1 -> GO_2, GO_3, GO_5 + Swiss_Prot_2 -> GO_2, GO_5 + """ + + @staticmethod + def get_nodes() -> List[int]: + """ + Get a sorted list of node IDs. + + Returns: + List[int]: A sorted list of node IDs in the ontology graph. + """ + return sorted([1, 2, 3, 4, 5, 6]) + + @staticmethod + def get_number_of_nodes() -> int: + """ + Get the total number of nodes in the ontology graph. + + Returns: + int: The number of nodes. + """ + return len(GOUniProtMockData.get_nodes()) + + @staticmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples where each tuple represents an edge between two nodes. + """ + return {(1, 2), (1, 3), (2, 4), (2, 5), (3, 4), (4, 6)} + + @staticmethod + def get_number_of_edges() -> int: + """ + Get the total number of edges in the ontology graph. + + Returns: + int: The number of edges. + """ + return len(GOUniProtMockData.get_edges()) + + @staticmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the transitive closure of the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples representing edges in the transitive closure graph. + """ + return { + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (1, 6), + (2, 4), + (2, 5), + (2, 6), + (3, 4), + (3, 6), + (4, 6), + } + + @staticmethod + def get_number_of_transitive_edges() -> int: + """ + Get the total number of edges in the transitive closure graph. + + Returns: + int: The number of transitive edges. + """ + return len(GOUniProtMockData.get_edges_of_transitive_closure_graph()) + + @staticmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the ontology graph. + + Returns: + Set[int]: A set of node IDs representing obsolete nodes. + """ + return {7, 8} + + @staticmethod + def get_GO_raw_data() -> str: + """ + Get raw data in string format for GO ontology. + + This data simulates a basic GO ontology in a format typically used for testing. + + Returns: + str: The raw GO data in string format. + """ + return """ + [Term] + id: GO:0000001 + name: GO_1 + namespace: molecular_function + def: "OBSOLETE. Assists in the correct assembly of ribosomes or ribosomal subunits in vivo, but is not a component of the assembled ribosome when performing its normal biological function." [GOC:jl, PMID:12150913] + comment: This term was made obsolete because it refers to a class of gene products and a biological process rather than a molecular function. + synonym: "ribosomal chaperone activity" EXACT [] + xref: MetaCyc:BETAGALACTOSID-RXN + xref: Reactome:R-HSA-189062 "lactose + H2O => D-glucose + D-galactose" + xref: Reactome:R-HSA-5658001 "Defective LCT does not hydrolyze Lac" + xref: RHEA:10076 + + [Term] + id: GO:0000002 + name: GO_2 + namespace: biological_process + is_a: GO:0000001 ! hydrolase activity, hydrolyzing O-glycosyl compounds + + [Term] + id: GO:0000003 + name: GO_3 + namespace: cellular_component + is_a: GO:0000001 ! regulation of DNA recombination + + [Term] + id: GO:0000004 + name: GO_4 + namespace: biological_process + is_a: GO:0000003 ! regulation of DNA recombination + is_a: GO:0000002 ! hydrolase activity, hydrolyzing O-glycosyl compounds + + [Term] + id: GO:0000005 + name: GO_5 + namespace: molecular_function + is_a: GO:0000002 ! regulation of DNA recombination + + [Term] + id: GO:0000006 + name: GO_6 + namespace: cellular_component + is_a: GO:0000004 ! glucoside transport + + [Term] + id: GO:0000007 + name: GO_7 + namespace: biological_process + is_a: GO:0000003 ! glucoside transport + is_obsolete: true + + [Term] + id: GO:0000008 + name: GO_8 + namespace: molecular_function + is_a: GO:0000001 ! glucoside transport + is_obsolete: true + + [Typedef] + id: term_tracker_item + name: term tracker item + namespace: external + xref: IAO:0000233 + is_metadata_tag: true + is_class_level: true + """ + + @staticmethod + def protein_sequences() -> Dict[str, str]: + """ + Get the protein sequences for Swiss-Prot proteins. + + Returns: + Dict[str, str]: A dictionary where keys are Swiss-Prot IDs and values are their respective sequences. + """ + return { + "Swiss_Prot_1": "MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK".replace( + " ", "" + ), + "Swiss_Prot_2": "EKGLIVGHFS GIKYKGEKAQ ASEVDVNKMC CWVSKFKDAM RRYQGIQTCK".replace( + " ", "" + ), + } + + @staticmethod + def get_UniProt_raw_data() -> str: + """ + Get raw data in string format for UniProt proteins. + + This mock data contains six Swiss-Prot proteins with different properties: + - Swiss_Prot_1 and Swiss_Prot_2 are valid proteins. + - Swiss_Prot_3 has a sequence length greater than 1002. + - Swiss_Prot_4 contains "X", a non-valid amino acid in its sequence. + - Swiss_Prot_5 has no GO IDs mapped to it. + - Swiss_Prot_6 has GO IDs mapped, but no evidence codes. + + Returns: + str: The raw UniProt data in string format. + """ + protein_sq_1 = GOUniProtMockData.protein_sequences()["Swiss_Prot_1"] + protein_sq_2 = GOUniProtMockData.protein_sequences()["Swiss_Prot_2"] + raw_str = ( + f"ID Swiss_Prot_1 Reviewed; {len(protein_sq_1)} AA. \n" + + "AC Q6GZX4;\n" + + "DR GO; GO:0000002; C:membrane; EXP:UniProtKB-KW.\n" + + "DR GO; GO:0000003; C:membrane; IDA:UniProtKB-KW.\n" + + "DR GO; GO:0000005; P:regulation of viral transcription; IPI:InterPro.\n" + + "DR GO; GO:0000004; P:regulation of viral transcription; IEA:SGD.\n" + + f"SQ SEQUENCE {len(protein_sq_1)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + + f" {protein_sq_1}\n" + + "//\n" + + f"ID Swiss_Prot_2 Reviewed; {len(protein_sq_2)} AA.\n" + + "AC DCGZX4;\n" + + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + + "DR GO; GO:0000002; P:regulation of viral transcription; IMP:InterPro.\n" + + "DR GO; GO:0000005; P:regulation of viral transcription; IGI:InterPro.\n" + + "DR GO; GO:0000006; P:regulation of viral transcription; IEA:PomBase.\n" + + f"SQ SEQUENCE {len(protein_sq_2)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + + f" {protein_sq_2}\n" + + "//\n" + + "ID Swiss_Prot_3 Reviewed; 1165 AA.\n" + + "AC Q6GZX4;\n" + + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + + "DR GO; GO:0000002; P:regulation of viral transcription; IEP:InterPro.\n" + + "DR GO; GO:0000005; P:regulation of viral transcription; TAS:InterPro.\n" + + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + + "SQ SEQUENCE 1165 AA; 129118 MW; FE2984658CED53A8 CRC64;\n" + + " MRVVVNAKAL EVPVGMSFTE WTRTLSPGSS PRFLAWNPVR PRTFKDVTDP FWNGKVFDLL\n" + + " GVVNGKDDLL FPASEIQEWL EYAPNVDLAE LERIFVATHR HRGMMGFAAA VQDSLVHVDP\n" + + " DSVDVTRVKD GLHKELDEHA SKAAATDVRL KRLRSVKPVD GFSDPVLIRT VFSVTVPEFG\n" + + " DRTAYEIVDS AVPTGSCPYI SAGPFVKTIP GFKPAPEWPA QTAHAEGAVF FKADAEFPDT\n" + + " KPLKDMYRKY SGAAVVPGDV TYPAVITFDV PQGSRHVPPE DFAARVAESL SLDLRGRPLV\n" + + " EMGRVVSVRL DGMRFRPYVL TDLLVSDPDA SHVMQTDELN RAHKIKGTVY AQVCGTGQTV\n" + + " SFQEKTDEDS GEAYISLRVR ARDRKGVEEL MEAAGRVMAI YSRRESEIVS FYALYDKTVA\n" + + " KEAAPPRPPR KSKAPEPTGD KADRKLLRTL APDIFLPTYS RKCLHMPVIL RGAELEDARK\n" + + " KGLNLMDFPL FGESERLTYA CKHPQHPYPG LRANLLPNKA KYPFVPCCYS KDQAVRPNSK\n" + + " WTAYTTGNAE ARRQGRIREG VMQAEPLPEG ALIFLRRVLG QETGSKFFAL RTTGVPETPV\n" + + " NAVHVAVFQR SLTAEEQAEE RAAMALDPSA MGACAQELYV EPDVDWDRWR REMGDPNVPF\n" + + " NLLKYFRALE TRYDCDIYIM DNKGIIHTKA VRGRLRYRSR RPTVILHLRE ESCVPVMTPP\n" + + " SDWTRGPVRN GILTFSPIDP ITVKLHDLYQ DSRPVYVDGV RVPPLRSDWL PCSGQVVDRA\n" + + " GKARVFVVTP TGKMSRGSFT LVTWPMPPLA APILRTDTGF PRGRSDSPLS FLGSRFVPSG\n" + + " YRRSVETGAI REITGILDGA CEACLLTHDP VLVPDPSWSD GGPPVYEDPV PSRALEGFTG\n" + + " AEKKARMLVE YAKKAISIRE GSCTQESVRS FAANGGFVVS PGALDGMKVF NPRFEAPGPF\n" + + " AEADWAVKVP DVKTARRLVY ALRVASVNGT CPVQEYASAS LVPNFYKTST DFVQSPAYTI\n" + + " NVWRNDLDQS AVKKTRRAVV DWERGLAVPW PLPETELGFS YSLRFAGISR TFMAMNHPTW\n" + + " ESAAFAALTW AKSGYCPGVT SNQIPEGEKV PTYACVKGMK PAKVLESGDG TLKLDKSSYG\n" + + " DVRVSGVMIY RASEGKPMQY VSLLM\n" + + "//\n" + + "ID Swiss_Prot_4 Reviewed; 60 AA.\n" + + "AC Q6GZX4;\n" + + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + + "DR GO; GO:0000002; P:regulation of viral transcription; EXP:InterPro.\n" + + "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" + + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + + " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + + "//\n" + + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" + + "AC Q6GZX4;\n" + + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + + "//\n" + + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" + + "AC Q6GZX4;\n" + + "DR GO; GO:0000005; P:regulation of viral transcription;\n" + + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + + "//" + ) + + return raw_str + + @staticmethod + def get_data_in_dataframe() -> pd.DataFrame: + """ + Get a mock DataFrame representing UniProt data. + + The DataFrame contains Swiss-Prot protein data, including identifiers, accessions, GO terms, sequences, + and binary label columns representing whether each protein is associated with certain GO classes. + + Returns: + pd.DataFrame: A DataFrame containing mock UniProt data with columns for 'swiss_id', 'accession', 'go_ids', 'sequence', + and binary labels for GO classes. + """ + expected_data = OrderedDict( + swiss_id=["Swiss_Prot_1", "Swiss_Prot_2"], + accession=["Q6GZX4", "DCGZX4"], + go_ids=[[2, 3, 5], [2, 5]], + sequence=list(GOUniProtMockData.protein_sequences().values()), + **{ + # SP_1, SP_2 + 1: [False, False], + 2: [True, True], + 3: [True, False], + 4: [False, False], + 5: [True, True], + 6: [False, False], + }, + ) + return pd.DataFrame(expected_data) + + @staticmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Get the transitive closure of the ontology graph. + + Returns: + nx.DiGraph: A directed graph representing the transitive closure of the ontology graph. + """ + pass From c6c5a59990b6933d785898d6001595a94a5396be Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Sep 2024 11:27:26 +0200 Subject: [PATCH 22/46] test for GOUniProtDataExtractor --- .../testGOUniProDataExtractor.py | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 tests/unit/dataset_classes/testGOUniProDataExtractor.py diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py new file mode 100644 index 00000000..7394405d --- /dev/null +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -0,0 +1,217 @@ +import unittest +from unittest.mock import MagicMock, PropertyMock, mock_open, patch + +import fastobo +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.go_uniprot import _GOUniProtDataExtractor +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestGOUniProtDataExtractor(unittest.TestCase): + """ + Unit tests for the _GOUniProtDataExtractor class. + """ + + @classmethod + @patch.multiple(_GOUniProtDataExtractor, __abstractmethods__=frozenset()) + @patch.object(_GOUniProtDataExtractor, "base_dir", new_callable=PropertyMock) + @patch.object(_GOUniProtDataExtractor, "_name", new_callable=PropertyMock) + def setUpClass( + cls, mock_name_property: PropertyMock, mock_base_dir_property: PropertyMock + ) -> None: + """ + Class setup for mocking abstract properties of _GOUniProtDataExtractor. + """ + mock_base_dir_property.return_value = "MockedBaseDirPropGOUniProtDataExtractor" + mock_name_property.return_value = "MockedNamePropGOUniProtDataExtractor" + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + _GOUniProtDataExtractor.READER = ReaderMock + + cls.extractor = _GOUniProtDataExtractor() + + def test_term_callback(self) -> None: + """ + Test the term_callback method for correct parsing and filtering of GO terms. + """ + self.extractor.go_branch = "all" + term_mapping = {} + for term in fastobo.loads(GOUniProtMockData.get_GO_raw_data()): + if isinstance(term, fastobo.typedef.TypedefFrame): + continue + term_mapping[self.extractor._parse_go_id(term.id)] = term + + # Test individual term callback + term_dict = self.extractor.term_callback(term_mapping[4]) + expected_dict = {"go_id": 4, "parents": [3, 2], "name": "GO_4"} + self.assertEqual( + term_dict, + expected_dict, + "The term_callback did not return the expected dictionary.", + ) + + # Test filtering valid terms + valid_terms_docs = set() + for term_id, term_doc in term_mapping.items(): + if self.extractor.term_callback(term_doc): + valid_terms_docs.add(term_id) + + self.assertEqual( + valid_terms_docs, + set(GOUniProtMockData.get_nodes()), + "The valid terms do not match expected nodes.", + ) + + # Test that obsolete terms are filtered out + self.assertFalse( + any( + self.extractor.term_callback(term_mapping[obs_id]) + for obs_id in GOUniProtMockData.get_obsolete_nodes_ids() + ), + "Obsolete terms should not be present.", + ) + + # Test filtering by GO branch (e.g., BP) + self.extractor.go_branch = "BP" + BP_terms = { + term_id + for term_id, term in term_mapping.items() + if self.extractor.term_callback(term) + } + self.assertEqual( + BP_terms, {2, 4}, "The BP terms do not match the expected set." + ) + + @patch( + "fastobo.load", return_value=fastobo.loads(GOUniProtMockData.get_GO_raw_data()) + ) + def test_extract_class_hierarchy(self, mock_load) -> None: + """ + Test the extraction of the class hierarchy from the ontology. + """ + graph = self.extractor._extract_class_hierarchy("fake_path") + + # Validate the graph structure + self.assertIsInstance( + graph, nx.DiGraph, "The result should be a directed graph." + ) + + # Check nodes + actual_nodes = set(graph.nodes) + self.assertEqual( + set(GOUniProtMockData.get_nodes()), + actual_nodes, + "The graph nodes do not match the expected nodes.", + ) + + # Check edges + actual_edges = set(graph.edges) + self.assertEqual( + GOUniProtMockData.get_edges_of_transitive_closure_graph(), + actual_edges, + "The graph edges do not match the expected edges.", + ) + + # Check number of nodes and edges + self.assertEqual( + GOUniProtMockData.get_number_of_nodes(), + len(actual_nodes), + "The number of nodes should match the actual number of nodes in the graph.", + ) + + self.assertEqual( + GOUniProtMockData.get_number_of_transitive_edges(), + len(actual_edges), + "The number of transitive edges should match the actual number of transitive edges in the graph.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + def test_get_swiss_to_go_mapping(self, mock_open) -> None: + """ + Test the extraction of SwissProt to GO term mapping. + """ + mapping_df = self.extractor._get_swiss_to_go_mapping() + expected_df = GOUniProtMockData.get_data_in_dataframe().iloc[:, :4] + + pd.testing.assert_frame_equal( + mapping_df, + expected_df, + obj="The SwissProt to GO mapping DataFrame does not match the expected DataFrame.", + ) + + @patch( + "fastobo.load", return_value=fastobo.loads(GOUniProtMockData.get_GO_raw_data()) + ) + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + @patch.object( + _GOUniProtDataExtractor, + "select_classes", + return_value=GOUniProtMockData.get_nodes(), + ) + def test_graph_to_raw_dataset( + self, mock_select_classes, mock_open, mock_load + ) -> None: + """ + Test the conversion of the class hierarchy graph to a raw dataset. + """ + graph = self.extractor._extract_class_hierarchy("fake_path") + actual_df = self.extractor._graph_to_raw_dataset(graph) + expected_df = GOUniProtMockData.get_data_in_dataframe() + + pd.testing.assert_frame_equal( + actual_df, + expected_df, + obj="The raw dataset DataFrame does not match the expected DataFrame.", + ) + + @patch("builtins.open", new_callable=mock_open, read_data=b"Mocktestdata") + @patch("pandas.read_pickle") + def test_load_dict( + self, mock_read_pickle: PropertyMock, mock_open: mock_open + ) -> None: + """ + Test the loading of the dictionary from a DataFrame. + """ + mock_df = GOUniProtMockData.get_data_in_dataframe() + mock_read_pickle.return_value = mock_df + + generator = self.extractor._load_dict("data/tests") + result = list(generator) + + # Convert NumPy arrays to lists for comparison + for item in result: + item["labels"] = list(item["labels"]) + + # Expected output for comparison + expected_result = [ + { + "features": mock_df["sequence"][0], + "labels": mock_df.iloc[0, 4:].to_list(), + "ident": mock_df["swiss_id"][0], + }, + { + "features": mock_df["sequence"][1], + "labels": mock_df.iloc[1, 4:].to_list(), + "ident": mock_df["swiss_id"][1], + }, + ] + + self.assertEqual( + result, + expected_result, + "The loaded dictionary does not match the expected structure.", + ) + + +if __name__ == "__main__": + unittest.main() From 427bc60a1e6d6d33a7fbfd7a7707224f3922a894 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Sep 2024 12:29:32 +0200 Subject: [PATCH 23/46] update test to new method name _extract_class_hierarchy --- tests/unit/dataset_classes/testChebiOverXPartial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py index c2515d75..a8c53408 100644 --- a/tests/unit/dataset_classes/testChebiOverXPartial.py +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -29,7 +29,7 @@ def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: """ # Mock the output of fastobo.loads self.chebi_extractor.top_class_id = 11111 - graph: nx.DiGraph = self.chebi_extractor.extract_class_hierarchy("fake_path") + graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") # Validate the graph structure self.assertIsInstance( @@ -73,7 +73,7 @@ def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: ) self.chebi_extractor.top_class_id = 22222 - graph = self.chebi_extractor.extract_class_hierarchy("fake_path") + graph = self.chebi_extractor._extract_class_hierarchy("fake_path") # Check nodes with top class as 22222 self.assertEqual( @@ -94,7 +94,7 @@ def test_extract_class_hierarchy_with_bottom_cls( Test the extraction of class hierarchy and validate the structure of the resulting graph. """ self.chebi_extractor.top_class_id = 88888 - graph: nx.DiGraph = self.chebi_extractor.extract_class_hierarchy("fake_path") + graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") # Check nodes with top class as 88888 self.assertEqual( From c01ecde837227eb4c4e99afb95063aa58d7cb9cb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Sep 2024 13:12:22 +0200 Subject: [PATCH 24/46] test for GOUniProtOverX --- .../dataset_classes/testGoUniProtOverX.py | 139 ++++++++++++++++++ tests/unit/mock_data/ontology_mock_data.py | 5 +- 2 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 tests/unit/dataset_classes/testGoUniProtOverX.py diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py new file mode 100644 index 00000000..282091b5 --- /dev/null +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -0,0 +1,139 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.go_uniprot import _GOUniProtOverX +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestGOUniProtOverX(unittest.TestCase): + @classmethod + @patch.multiple(_GOUniProtOverX, __abstractmethods__=frozenset()) + def setUpClass(cls) -> None: + """ + Set up the class for tests by initializing the extractor, graph, and input DataFrame. + """ + cls.extractor = _GOUniProtOverX() + cls.test_graph: nx.DiGraph = GOUniProtMockData.get_transitively_closed_graph() + cls.input_df: pd.DataFrame = GOUniProtMockData.get_data_in_dataframe().iloc[ + :, :4 + ] + + @patch("builtins.open", new_callable=mock_open) + def test_select_classes(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method to ensure it selects classes based on the threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + # Set threshold for testing + self.extractor.THRESHOLD = 2 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: GO terms 1, 2, and 5 should be selected based on the threshold + expected_selected_classes: List[int] = sorted([1, 2, 5]) + + # Check if the selected classes are as expected + self.assertEqual( + selected_classes, + expected_selected_classes, + msg="The selected classes do not match the expected output for threshold 2.", + ) + + # Expected data as string + expected_lines: str = "\n".join(map(str, expected_selected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines for the given threshold of 2.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_no_classes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method when no nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.extractor.THRESHOLD = 5 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: No classes should meet the threshold of 5 + expected_selected_classes: List[int] = [] + + # Check if the selected classes are as expected + self.assertEqual( + selected_classes, + expected_selected_classes, + msg="The selected classes list should be empty when no nodes meet the threshold of 5.", + ) + + # Expected data as string + expected_lines: str = "" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines when no nodes meet the threshold of 5.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_all_nodes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method when all nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.extractor.THRESHOLD = 0 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: All nodes except those not referenced by any protein (4 and 6) should be selected + expected_classes: List[int] = sorted([1, 2, 3, 5]) + + # Check if the returned selected classes match the expected list + self.assertListEqual( + selected_classes, + expected_classes, + msg="The selected classes do not match the expected output when all nodes meet the threshold of 0.", + ) + + # Expected data as string + expected_lines: str = "\n".join(map(str, expected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines when all nodes meet the threshold of 0.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index dbce56d2..d516a7a0 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -758,4 +758,7 @@ def get_transitively_closed_graph() -> nx.DiGraph: Returns: nx.DiGraph: A directed graph representing the transitive closure of the ontology graph. """ - pass + g = nx.DiGraph() + g.add_nodes_from(node for node in ChebiMockOntology.get_nodes()) + g.add_edges_from(GOUniProtMockData.get_edges_of_transitive_closure_graph()) + return g From dfd084e6c49ef10d1f4c22388fe2c01217c8cde6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 10 Sep 2024 15:21:24 +0200 Subject: [PATCH 25/46] test for _load_data_from_file for Tox21MolNet --- .../testGOUniProDataExtractor.py | 2 +- tests/unit/dataset_classes/testTox21MolNet.py | 115 ++++++++++ tests/unit/mock_data/tox_mock_data.py | 201 ++++++++++++++++++ 3 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 tests/unit/dataset_classes/testTox21MolNet.py create mode 100644 tests/unit/mock_data/tox_mock_data.py diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 7394405d..1b60aa97 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -27,7 +27,7 @@ def setUpClass( mock_base_dir_property.return_value = "MockedBaseDirPropGOUniProtDataExtractor" mock_name_property.return_value = "MockedNamePropGOUniProtDataExtractor" ReaderMock = MagicMock() - ReaderMock.name.return_value = "MockedReader" + ReaderMock.name.return_value = "MockedReaderGOUniProtDataExtractor" _GOUniProtDataExtractor.READER = ReaderMock cls.extractor = _GOUniProtDataExtractor() diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py new file mode 100644 index 00000000..3639f5d1 --- /dev/null +++ b/tests/unit/dataset_classes/testTox21MolNet.py @@ -0,0 +1,115 @@ +import os +import unittest +from typing import Dict, List +from unittest.mock import MagicMock, mock_open, patch + +import torch +from sklearn.model_selection import GroupShuffleSplit + +from chebai.preprocessing.datasets.tox21 import Tox21MolNet +from tests.unit.mock_data.tox_mock_data import Tox21MockData + + +class TestTox21MolNet(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + """Initialize a Tox21MolNet instance for testing.""" + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReaderTox21MolNet" + Tox21MolNet.READER = ReaderMock + cls.data_module = Tox21MolNet() + # cls.data_module.raw_dir = "/mock/raw_dir" + # cls.data_module.processed_dir = "/mock/processed_dir" + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=Tox21MockData.get_raw_data(), + ) + def test_load_data_from_file(self, mock_open_file: mock_open) -> None: + """ + Test the `_load_data_from_file` method for correct CSV parsing. + + Args: + mock_open_file (mock_open): Mocked open function to simulate file reading. + """ + expected_data = Tox21MockData.get_processed_data() + actual_data = self.data_module._load_data_from_file("fake/file/path.csv") + + self.assertEqual( + list(actual_data), + expected_data, + "The loaded data does not match the expected output.", + ) + + @patch.object( + Tox21MolNet, + "_load_data_from_file", + return_value=Tox21MockData.get_processed_data(), + ) + @patch("torch.save") + def test_setup_processed_simple_split( + self, mock_load_data: MagicMock, mock_torch_save: MagicMock + ) -> None: + """ + Test the `setup_processed` method for basic data splitting and saving. + + Args: + mock_load_data (MagicMock): Mocked `_load_data_from_file` method to provide controlled data. + mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. + """ + self.data_module.setup_processed() + + # # Check that torch.save was called for train, test, and validation splits + # self.assertEqual( + # mock_torch_save.call_count, + # 3, + # "torch.save should have been called exactly three times for train, test, and validation splits." + # ) + + # @patch("os.path.isfile", return_value=False) + # @patch.object(Tox21MolNet, + # "_load_data_from_file", + # return_value= Tox21MockData.get_processed_grouped_data()) + # @patch("torch.save") + # @patch("torch.load") + # @patch("chebai.preprocessing.datasets.tox21.GroupShuffleSplit") + # def test_setup_processed_group_split( + # self, + # mock_group_split: MagicMock, + # mock_torch_load: MagicMock, + # mock_save: MagicMock, + # mock_load_data: MagicMock, + # mock_isfile: MagicMock + # ) -> None: + # """ + # Test the `setup_processed` method for group-based data splitting and saving. + # + # Args: + # mock_save (MagicMock): Mocked `torch.save` function to avoid file writes. + # mock_load_data (MagicMock): Mocked `_load_data_from_file` method to provide controlled data. + # mock_isfile (MagicMock): Mocked `os.path.isfile` function to simulate file presence. + # mock_group_split (MagicMock): Mocked `GroupShuffleSplit` to control data splitting behavior. + # """ + # mock_group_split.return_value = GroupShuffleSplit(n_splits=1, train_size=0.7) + # self.data_module.setup_processed() + # + # # Load the test split + # test_split_path = os.path.join(self.data_module.processed_dir, "test.pt") + # test_split = torch.load(test_split_path) + # + # # Check if torch.save was called with correct arguments + # mock_save.assert_any_call([mock_data[1]], "/mock/processed_dir/test.pt") + # mock_save.assert_any_call([mock_data[0]], "/mock/processed_dir/train.pt") + # mock_save.assert_any_call([mock_data[1]], "/mock/processed_dir/validation.pt") + # # Check that torch.save was called for train, test, and validation splits + # self.assertEqual( + # mock_torch_save.call_count, + # 3, + # "torch.save should have been called exactly three times for train, test, and validation splits." + # ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/mock_data/tox_mock_data.py b/tests/unit/mock_data/tox_mock_data.py new file mode 100644 index 00000000..912d172c --- /dev/null +++ b/tests/unit/mock_data/tox_mock_data.py @@ -0,0 +1,201 @@ +class Tox21MockData: + """ + A utility class providing mock data for testing the Tox21MolNet dataset. + + This class includes static methods that return mock data in various formats, simulating + the raw and processed data of the Tox21MolNet dataset. The mock data is used for unit tests + to verify the functionality of methods within the Tox21MolNet class without relying on actual + data files. + """ + + @staticmethod + def get_raw_data() -> str: + """ + Returns a raw CSV string that simulates the raw data of the Tox21MolNet dataset. + """ + return ( + "NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53," + + "mol_id,smiles\n" + + "0,0,1,0,1,1,0,1,0,,1,0,TOX958,Nc1ccc([N+](=O)[O-])cc1N\n" + + ",,,,,,,,,1,,,TOX31681,Nc1cc(C(F)(F)F)ccc1S\n" + + "0,0,0,0,0,0,0,,0,0,0,0,TOX5110,CC(C)(C)OOC(C)(C)CCC(C)(C)OOC(C)(C)C\n" + + "0,0,0,0,0,0,0,0,0,0,0,0,TOX6619,O=S(=O)(Cl)c1ccccc1\n" + + "0,0,0,,0,0,,,0,,1,,TOX27679,CCCCCc1ccco1\n" + + "0,,1,,,,0,,1,1,1,1,TOX2801,Oc1c(Cl)cc(Cl)c2cccnc12\n" + + "0,0,0,0,,0,,,0,0,,1,TOX2808,CN(C)CCCN1c2ccccc2Sc2ccc(Cl)cc21\n" + + "0,,0,1,,,,1,0,,1,,TOX29085,CCCCCCCCCCCCCCn1cc[n+](C)c1\n" + ) + + @staticmethod + def get_processed_data() -> list: + """ + Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. + Each dictionary contains 'ident', 'features', and 'labels'. + """ + return [ + { + "ident": "TOX958", + "features": "Nc1ccc([N+](=O)[O-])cc1N", + "labels": [ + False, + False, + True, + False, + True, + True, + False, + True, + False, + None, + True, + False, + ], + }, + { + "ident": "TOX31681", + "features": "Nc1cc(C(F)(F)F)ccc1S", + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + True, + None, + None, + ], + }, + { + "ident": "TOX5110", + "features": "CC(C)(C)OOC(C)(C)CCC(C)(C)OOC(C)(C)C", + "labels": [ + False, + False, + False, + False, + False, + False, + False, + None, + False, + False, + False, + False, + ], + }, + { + "ident": "TOX6619", + "features": "O=S(=O)(Cl)c1ccccc1", + "labels": [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + }, + { + "ident": "TOX27679", + "features": "CCCCCc1ccco1", + "labels": [ + False, + False, + False, + None, + False, + False, + None, + None, + False, + None, + True, + None, + ], + }, + { + "ident": "TOX2801", + "features": "Oc1c(Cl)cc(Cl)c2cccnc12", + "labels": [ + False, + None, + True, + None, + None, + None, + False, + None, + True, + True, + True, + True, + ], + }, + { + "ident": "TOX2808", + "features": "CN(C)CCCN1c2ccccc2Sc2ccc(Cl)cc21", + "labels": [ + False, + False, + False, + False, + None, + False, + None, + None, + False, + False, + None, + True, + ], + }, + { + "ident": "TOX29085", + "features": "CCCCCCCCCCCCCCn1cc[n+](C)c1", + "labels": [ + False, + None, + False, + True, + None, + None, + None, + True, + False, + None, + True, + None, + ], + }, + ] + + @staticmethod + def get_processed_grouped_data(): + """ + Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. + Each dictionary contains 'ident', 'features', and 'labels'. + """ + processed_data = Tox21MockData.get_processed_data() + groups = ["A", "A", "B", "B", "C", "C", "C", "C"] + + assert len(processed_data) == len( + groups + ), "The number of processed data entries does not match the number of groups." + + # Combine processed data with their corresponding groups + grouped_data = [ + {**data, "group": group, "original": True} + for data, group in zip(processed_data, groups) + ] + + return grouped_data From 77956d473b88f71cc0fa7b262da9b595849fa92e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 16 Sep 2024 13:06:47 +0200 Subject: [PATCH 26/46] _load_data_from_file test case Tox21Challenge --- .../dataset_classes/testTox21Challenge.py | 43 ++++ tests/unit/dataset_classes/testTox21MolNet.py | 10 +- tests/unit/mock_data/ontology_mock_data.py | 132 +++++------ tests/unit/mock_data/tox_mock_data.py | 214 +++++++++++++++++- 4 files changed, 317 insertions(+), 82 deletions(-) create mode 100644 tests/unit/dataset_classes/testTox21Challenge.py diff --git a/tests/unit/dataset_classes/testTox21Challenge.py b/tests/unit/dataset_classes/testTox21Challenge.py new file mode 100644 index 00000000..4b23c487 --- /dev/null +++ b/tests/unit/dataset_classes/testTox21Challenge.py @@ -0,0 +1,43 @@ +import os +import unittest +from unittest.mock import MagicMock, mock_open, patch + +from rdkit import Chem + +from chebai.preprocessing.datasets.tox21 import Tox21Challenge +from chebai.preprocessing.reader import ChemDataReader +from tests.unit.mock_data.tox_mock_data import Tox21ChallengeMockData + + +class TestTox21Challenge(unittest.TestCase): + + @classmethod + def setUpClass(cls): + """ + Set up the Tox21Challenge instance and mock data for testing. + """ + Tox21Challenge.READER = ChemDataReader + cls.tox21 = Tox21Challenge() + + @patch("rdkit.Chem.SDMolSupplier") + def test_load_data_from_file(self, mock_sdmol_supplier) -> None: + """ + Test the _load_data_from_file method to ensure it correctly loads data from an SDF file. + """ + # Use ForwardSDMolSupplier to read the mock data from the binary string + mock_file = mock_open(read_data=Tox21ChallengeMockData.get_raw_train_data()) + with patch("builtins.open", mock_file): + with open( + r"G:\github-aditya0by0\chebai_data\tox21_challenge\tox21_10k_data_all.sdf\tox21_10k_data_all.sdf", + "rb", + ) as f: + suppl = Chem.ForwardSDMolSupplier(f) + + mock_sdmol_supplier.return_value = suppl + + actual_data = self.tox21._load_data_from_file("fake/path") + self.assertEqual(Tox21ChallengeMockData.data_in_dict_format(), actual_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py index 3639f5d1..0a2d67b1 100644 --- a/tests/unit/dataset_classes/testTox21MolNet.py +++ b/tests/unit/dataset_classes/testTox21MolNet.py @@ -7,7 +7,7 @@ from sklearn.model_selection import GroupShuffleSplit from chebai.preprocessing.datasets.tox21 import Tox21MolNet -from tests.unit.mock_data.tox_mock_data import Tox21MockData +from tests.unit.mock_data.tox_mock_data import Tox21MolNetMockData class TestTox21MolNet(unittest.TestCase): @@ -25,7 +25,7 @@ def setUpClass(cls) -> None: @patch( "builtins.open", new_callable=mock_open, - read_data=Tox21MockData.get_raw_data(), + read_data=Tox21MolNetMockData.get_raw_data(), ) def test_load_data_from_file(self, mock_open_file: mock_open) -> None: """ @@ -34,7 +34,7 @@ def test_load_data_from_file(self, mock_open_file: mock_open) -> None: Args: mock_open_file (mock_open): Mocked open function to simulate file reading. """ - expected_data = Tox21MockData.get_processed_data() + expected_data = Tox21MolNetMockData.get_processed_data() actual_data = self.data_module._load_data_from_file("fake/file/path.csv") self.assertEqual( @@ -46,7 +46,7 @@ def test_load_data_from_file(self, mock_open_file: mock_open) -> None: @patch.object( Tox21MolNet, "_load_data_from_file", - return_value=Tox21MockData.get_processed_data(), + return_value=Tox21MolNetMockData.get_processed_data(), ) @patch("torch.save") def test_setup_processed_simple_split( @@ -71,7 +71,7 @@ def test_setup_processed_simple_split( # @patch("os.path.isfile", return_value=False) # @patch.object(Tox21MolNet, # "_load_data_from_file", - # return_value= Tox21MockData.get_processed_grouped_data()) + # return_value= Tox21MolNetMockData.get_processed_grouped_data()) # @patch("torch.save") # @patch("torch.load") # @patch("chebai.preprocessing.datasets.tox21.GroupShuffleSplit") diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index d516a7a0..478a2bbb 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -651,72 +651,72 @@ def get_UniProt_raw_data() -> str: protein_sq_2 = GOUniProtMockData.protein_sequences()["Swiss_Prot_2"] raw_str = ( f"ID Swiss_Prot_1 Reviewed; {len(protein_sq_1)} AA. \n" - + "AC Q6GZX4;\n" - + "DR GO; GO:0000002; C:membrane; EXP:UniProtKB-KW.\n" - + "DR GO; GO:0000003; C:membrane; IDA:UniProtKB-KW.\n" - + "DR GO; GO:0000005; P:regulation of viral transcription; IPI:InterPro.\n" - + "DR GO; GO:0000004; P:regulation of viral transcription; IEA:SGD.\n" - + f"SQ SEQUENCE {len(protein_sq_1)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - + f" {protein_sq_1}\n" - + "//\n" - + f"ID Swiss_Prot_2 Reviewed; {len(protein_sq_2)} AA.\n" - + "AC DCGZX4;\n" - + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" - + "DR GO; GO:0000002; P:regulation of viral transcription; IMP:InterPro.\n" - + "DR GO; GO:0000005; P:regulation of viral transcription; IGI:InterPro.\n" - + "DR GO; GO:0000006; P:regulation of viral transcription; IEA:PomBase.\n" - + f"SQ SEQUENCE {len(protein_sq_2)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - + f" {protein_sq_2}\n" - + "//\n" - + "ID Swiss_Prot_3 Reviewed; 1165 AA.\n" - + "AC Q6GZX4;\n" - + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" - + "DR GO; GO:0000002; P:regulation of viral transcription; IEP:InterPro.\n" - + "DR GO; GO:0000005; P:regulation of viral transcription; TAS:InterPro.\n" - + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" - + "SQ SEQUENCE 1165 AA; 129118 MW; FE2984658CED53A8 CRC64;\n" - + " MRVVVNAKAL EVPVGMSFTE WTRTLSPGSS PRFLAWNPVR PRTFKDVTDP FWNGKVFDLL\n" - + " GVVNGKDDLL FPASEIQEWL EYAPNVDLAE LERIFVATHR HRGMMGFAAA VQDSLVHVDP\n" - + " DSVDVTRVKD GLHKELDEHA SKAAATDVRL KRLRSVKPVD GFSDPVLIRT VFSVTVPEFG\n" - + " DRTAYEIVDS AVPTGSCPYI SAGPFVKTIP GFKPAPEWPA QTAHAEGAVF FKADAEFPDT\n" - + " KPLKDMYRKY SGAAVVPGDV TYPAVITFDV PQGSRHVPPE DFAARVAESL SLDLRGRPLV\n" - + " EMGRVVSVRL DGMRFRPYVL TDLLVSDPDA SHVMQTDELN RAHKIKGTVY AQVCGTGQTV\n" - + " SFQEKTDEDS GEAYISLRVR ARDRKGVEEL MEAAGRVMAI YSRRESEIVS FYALYDKTVA\n" - + " KEAAPPRPPR KSKAPEPTGD KADRKLLRTL APDIFLPTYS RKCLHMPVIL RGAELEDARK\n" - + " KGLNLMDFPL FGESERLTYA CKHPQHPYPG LRANLLPNKA KYPFVPCCYS KDQAVRPNSK\n" - + " WTAYTTGNAE ARRQGRIREG VMQAEPLPEG ALIFLRRVLG QETGSKFFAL RTTGVPETPV\n" - + " NAVHVAVFQR SLTAEEQAEE RAAMALDPSA MGACAQELYV EPDVDWDRWR REMGDPNVPF\n" - + " NLLKYFRALE TRYDCDIYIM DNKGIIHTKA VRGRLRYRSR RPTVILHLRE ESCVPVMTPP\n" - + " SDWTRGPVRN GILTFSPIDP ITVKLHDLYQ DSRPVYVDGV RVPPLRSDWL PCSGQVVDRA\n" - + " GKARVFVVTP TGKMSRGSFT LVTWPMPPLA APILRTDTGF PRGRSDSPLS FLGSRFVPSG\n" - + " YRRSVETGAI REITGILDGA CEACLLTHDP VLVPDPSWSD GGPPVYEDPV PSRALEGFTG\n" - + " AEKKARMLVE YAKKAISIRE GSCTQESVRS FAANGGFVVS PGALDGMKVF NPRFEAPGPF\n" - + " AEADWAVKVP DVKTARRLVY ALRVASVNGT CPVQEYASAS LVPNFYKTST DFVQSPAYTI\n" - + " NVWRNDLDQS AVKKTRRAVV DWERGLAVPW PLPETELGFS YSLRFAGISR TFMAMNHPTW\n" - + " ESAAFAALTW AKSGYCPGVT SNQIPEGEKV PTYACVKGMK PAKVLESGDG TLKLDKSSYG\n" - + " DVRVSGVMIY RASEGKPMQY VSLLM\n" - + "//\n" - + "ID Swiss_Prot_4 Reviewed; 60 AA.\n" - + "AC Q6GZX4;\n" - + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" - + "DR GO; GO:0000002; P:regulation of viral transcription; EXP:InterPro.\n" - + "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" - + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" - + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - + " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" - + "//\n" - + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" - + "AC Q6GZX4;\n" - + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" - + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" - + "//\n" - + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" - + "AC Q6GZX4;\n" - + "DR GO; GO:0000005; P:regulation of viral transcription;\n" - + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" - + "//" + "AC Q6GZX4;\n" + "DR GO; GO:0000002; C:membrane; EXP:UniProtKB-KW.\n" + "DR GO; GO:0000003; C:membrane; IDA:UniProtKB-KW.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IPI:InterPro.\n" + "DR GO; GO:0000004; P:regulation of viral transcription; IEA:SGD.\n" + f"SQ SEQUENCE {len(protein_sq_1)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_1}\n" + "//\n" + f"ID Swiss_Prot_2 Reviewed; {len(protein_sq_2)} AA.\n" + "AC DCGZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; IMP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IGI:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; IEA:PomBase.\n" + f"SQ SEQUENCE {len(protein_sq_2)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_2}\n" + "//\n" + "ID Swiss_Prot_3 Reviewed; 1165 AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; IEP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; TAS:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + "SQ SEQUENCE 1165 AA; 129118 MW; FE2984658CED53A8 CRC64;\n" + " MRVVVNAKAL EVPVGMSFTE WTRTLSPGSS PRFLAWNPVR PRTFKDVTDP FWNGKVFDLL\n" + " GVVNGKDDLL FPASEIQEWL EYAPNVDLAE LERIFVATHR HRGMMGFAAA VQDSLVHVDP\n" + " DSVDVTRVKD GLHKELDEHA SKAAATDVRL KRLRSVKPVD GFSDPVLIRT VFSVTVPEFG\n" + " DRTAYEIVDS AVPTGSCPYI SAGPFVKTIP GFKPAPEWPA QTAHAEGAVF FKADAEFPDT\n" + " KPLKDMYRKY SGAAVVPGDV TYPAVITFDV PQGSRHVPPE DFAARVAESL SLDLRGRPLV\n" + " EMGRVVSVRL DGMRFRPYVL TDLLVSDPDA SHVMQTDELN RAHKIKGTVY AQVCGTGQTV\n" + " SFQEKTDEDS GEAYISLRVR ARDRKGVEEL MEAAGRVMAI YSRRESEIVS FYALYDKTVA\n" + " KEAAPPRPPR KSKAPEPTGD KADRKLLRTL APDIFLPTYS RKCLHMPVIL RGAELEDARK\n" + " KGLNLMDFPL FGESERLTYA CKHPQHPYPG LRANLLPNKA KYPFVPCCYS KDQAVRPNSK\n" + " WTAYTTGNAE ARRQGRIREG VMQAEPLPEG ALIFLRRVLG QETGSKFFAL RTTGVPETPV\n" + " NAVHVAVFQR SLTAEEQAEE RAAMALDPSA MGACAQELYV EPDVDWDRWR REMGDPNVPF\n" + " NLLKYFRALE TRYDCDIYIM DNKGIIHTKA VRGRLRYRSR RPTVILHLRE ESCVPVMTPP\n" + " SDWTRGPVRN GILTFSPIDP ITVKLHDLYQ DSRPVYVDGV RVPPLRSDWL PCSGQVVDRA\n" + " GKARVFVVTP TGKMSRGSFT LVTWPMPPLA APILRTDTGF PRGRSDSPLS FLGSRFVPSG\n" + " YRRSVETGAI REITGILDGA CEACLLTHDP VLVPDPSWSD GGPPVYEDPV PSRALEGFTG\n" + " AEKKARMLVE YAKKAISIRE GSCTQESVRS FAANGGFVVS PGALDGMKVF NPRFEAPGPF\n" + " AEADWAVKVP DVKTARRLVY ALRVASVNGT CPVQEYASAS LVPNFYKTST DFVQSPAYTI\n" + " NVWRNDLDQS AVKKTRRAVV DWERGLAVPW PLPETELGFS YSLRFAGISR TFMAMNHPTW\n" + " ESAAFAALTW AKSGYCPGVT SNQIPEGEKV PTYACVKGMK PAKVLESGDG TLKLDKSSYG\n" + " DVRVSGVMIY RASEGKPMQY VSLLM\n" + "//\n" + "ID Swiss_Prot_4 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; EXP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000005; P:regulation of viral transcription;\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//" ) return raw_str diff --git a/tests/unit/mock_data/tox_mock_data.py b/tests/unit/mock_data/tox_mock_data.py index 912d172c..96b31a91 100644 --- a/tests/unit/mock_data/tox_mock_data.py +++ b/tests/unit/mock_data/tox_mock_data.py @@ -1,4 +1,4 @@ -class Tox21MockData: +class Tox21MolNetMockData: """ A utility class providing mock data for testing the Tox21MolNet dataset. @@ -15,15 +15,15 @@ def get_raw_data() -> str: """ return ( "NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53," - + "mol_id,smiles\n" - + "0,0,1,0,1,1,0,1,0,,1,0,TOX958,Nc1ccc([N+](=O)[O-])cc1N\n" - + ",,,,,,,,,1,,,TOX31681,Nc1cc(C(F)(F)F)ccc1S\n" - + "0,0,0,0,0,0,0,,0,0,0,0,TOX5110,CC(C)(C)OOC(C)(C)CCC(C)(C)OOC(C)(C)C\n" - + "0,0,0,0,0,0,0,0,0,0,0,0,TOX6619,O=S(=O)(Cl)c1ccccc1\n" - + "0,0,0,,0,0,,,0,,1,,TOX27679,CCCCCc1ccco1\n" - + "0,,1,,,,0,,1,1,1,1,TOX2801,Oc1c(Cl)cc(Cl)c2cccnc12\n" - + "0,0,0,0,,0,,,0,0,,1,TOX2808,CN(C)CCCN1c2ccccc2Sc2ccc(Cl)cc21\n" - + "0,,0,1,,,,1,0,,1,,TOX29085,CCCCCCCCCCCCCCn1cc[n+](C)c1\n" + "mol_id,smiles\n" + "0,0,1,0,1,1,0,1,0,,1,0,TOX958,Nc1ccc([N+](=O)[O-])cc1N\n" + ",,,,,,,,,1,,,TOX31681,Nc1cc(C(F)(F)F)ccc1S\n" + "0,0,0,0,0,0,0,,0,0,0,0,TOX5110,CC(C)(C)OOC(C)(C)CCC(C)(C)OOC(C)(C)C\n" + "0,0,0,0,0,0,0,0,0,0,0,0,TOX6619,O=S(=O)(Cl)c1ccccc1\n" + "0,0,0,,0,0,,,0,,1,,TOX27679,CCCCCc1ccco1\n" + "0,,1,,,,0,,1,1,1,1,TOX2801,Oc1c(Cl)cc(Cl)c2cccnc12\n" + "0,0,0,0,,0,,,0,0,,1,TOX2808,CN(C)CCCN1c2ccccc2Sc2ccc(Cl)cc21\n" + "0,,0,1,,,,1,0,,1,,TOX29085,CCCCCCCCCCCCCCn1cc[n+](C)c1\n" ) @staticmethod @@ -185,7 +185,7 @@ def get_processed_grouped_data(): Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. Each dictionary contains 'ident', 'features', and 'labels'. """ - processed_data = Tox21MockData.get_processed_data() + processed_data = Tox21MolNetMockData.get_processed_data() groups = ["A", "A", "B", "B", "C", "C", "C", "C"] assert len(processed_data) == len( @@ -199,3 +199,195 @@ def get_processed_grouped_data(): ] return grouped_data + + +class Tox21ChallengeMockData: + + MOL_BINARY_STR = ( + b"cyclobutane\n" + b" RDKit 2D\n\n" + b" 4 4 0 0 0 0 0 0 0 0999 V2000\n" + b" 1.0607 -0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" -0.0000 -1.0607 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" -1.0607 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" 0.0000 1.0607 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" 1 2 1 0\n" + b" 2 3 1 0\n" + b" 3 4 1 0\n" + b" 4 1 1 0\n" + b"M END\n\n" + ) + + SMILES_OF_MOL = "C1CCC1" + # Feature encoding of SMILES as per chebai/preprocessing/bin/smiles_token/tokens.txt + FEATURE_OF_SMILES = [19, 42, 19, 19, 19, 42] + + @staticmethod + def get_raw_train_data(): + raw_str = ( + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"25848\n\n" + b"> \n" + b"0\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"2384\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"0\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"27102\n\n" + b"> \n" + b"0\n\n" + b"> \n" + b"0\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"26792\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"26401\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"25973\n\n" + b"$$$$\n" + ) + return raw_str + + @staticmethod + def data_in_dict_format(): + data_list = [ + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + 0, + None, + None, + ], + "ident": "25848", + }, + { + "labels": [ + 0, + None, + None, + 1, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "2384", + }, + { + "labels": [ + 0, + None, + 0, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "27102", + }, + { + "labels": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + "ident": "26792", + }, + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + 1, + None, + 1, + None, + None, + ], + "ident": "26401", + }, + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "25973", + }, + ] + + for dict_ in data_list: + dict_["features"] = Tox21ChallengeMockData.FEATURE_OF_SMILES + dict_["group"] = None + + return data_list From a3670b0ca2a73ebb417bb4d45dea8e87d61937ac Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 17 Sep 2024 12:23:24 +0200 Subject: [PATCH 27/46] test for Tox21Chal --- .../dataset_classes/testTox21Challenge.py | 95 +++++++++++++- tests/unit/mock_data/tox_mock_data.py | 122 +++++++++++++++++- 2 files changed, 206 insertions(+), 11 deletions(-) diff --git a/tests/unit/dataset_classes/testTox21Challenge.py b/tests/unit/dataset_classes/testTox21Challenge.py index 4b23c487..9986c82f 100644 --- a/tests/unit/dataset_classes/testTox21Challenge.py +++ b/tests/unit/dataset_classes/testTox21Challenge.py @@ -1,28 +1,37 @@ -import os import unittest -from unittest.mock import MagicMock, mock_open, patch +from unittest.mock import mock_open, patch from rdkit import Chem from chebai.preprocessing.datasets.tox21 import Tox21Challenge from chebai.preprocessing.reader import ChemDataReader -from tests.unit.mock_data.tox_mock_data import Tox21ChallengeMockData +from tests.unit.mock_data.tox_mock_data import ( + Tox21ChallengeMockData, + Tox21MolNetMockData, +) class TestTox21Challenge(unittest.TestCase): + """ + Unit tests for the Tox21Challenge class. + """ @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: """ Set up the Tox21Challenge instance and mock data for testing. + This is run once for the test class. """ Tox21Challenge.READER = ChemDataReader cls.tox21 = Tox21Challenge() @patch("rdkit.Chem.SDMolSupplier") - def test_load_data_from_file(self, mock_sdmol_supplier) -> None: + def test_load_data_from_file(self, mock_sdmol_supplier: patch) -> None: """ - Test the _load_data_from_file method to ensure it correctly loads data from an SDF file. + Test the `_load_data_from_file` method to ensure it correctly loads data from an SDF file. + + Args: + mock_sdmol_supplier (patch): A mock of the RDKit SDMolSupplier. """ # Use ForwardSDMolSupplier to read the mock data from the binary string mock_file = mock_open(read_data=Tox21ChallengeMockData.get_raw_train_data()) @@ -36,7 +45,79 @@ def test_load_data_from_file(self, mock_sdmol_supplier) -> None: mock_sdmol_supplier.return_value = suppl actual_data = self.tox21._load_data_from_file("fake/path") - self.assertEqual(Tox21ChallengeMockData.data_in_dict_format(), actual_data) + expected_data = Tox21ChallengeMockData.data_in_dict_format() + + self.assertEqual( + actual_data, + expected_data, + "The loaded data from file does not match the expected data.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=Tox21MolNetMockData.get_raw_data(), + ) + def test_load_dict(self, mock_open_file: mock_open) -> None: + """ + Test the `_load_dict` method to ensure correct CSV parsing. + + Args: + mock_open_file (mock_open): Mocked open function to simulate file reading. + """ + expected_data = Tox21MolNetMockData.get_processed_data() + actual_data = self.tox21._load_dict("fake/file/path.csv") + + self.assertEqual( + list(actual_data), + expected_data, + "The loaded data from CSV does not match the expected processed data.", + ) + + @patch.object(Tox21Challenge, "_load_data_from_file", return_value="test") + @patch("builtins.open", new_callable=mock_open) + @patch("torch.save") + @patch("os.path.join") + def test_setup_processed( + self, + mock_join: patch, + mock_torch_save: patch, + mock_open_file: mock_open, + mock_load_file: patch, + ) -> None: + """ + Test the `setup_processed` method to ensure it processes and saves data correctly. + + Args: + mock_join (patch): Mock of os.path.join to simulate file path joining. + mock_torch_save (patch): Mock of torch.save to simulate saving processed data. + mock_open_file (mock_open): Mocked open function to simulate file reading. + mock_load_file (patch): Mocked data loading method. + """ + # Simulated raw and processed directories + path_str = "fake/test/path" + mock_join.return_value = path_str + + # Mock the file content for test.smiles and score.txt + mock_open_file.side_effect = [ + mock_open( + read_data=Tox21ChallengeMockData.get_raw_smiles_data() + ).return_value, + mock_open( + read_data=Tox21ChallengeMockData.get_raw_score_txt_data() + ).return_value, + ] + + # Call setup_processed to simulate the data processing workflow + self.tox21.setup_processed() + + # Assert that torch.save was called with the correct processed data + expected_test_data = Tox21ChallengeMockData.get_setup_processed_output_data() + mock_torch_save.assert_called_with(expected_test_data, path_str) + + self.assertTrue( + mock_torch_save.called, "The processed data was not saved as expected." + ) if __name__ == "__main__": diff --git a/tests/unit/mock_data/tox_mock_data.py b/tests/unit/mock_data/tox_mock_data.py index 96b31a91..32745c38 100644 --- a/tests/unit/mock_data/tox_mock_data.py +++ b/tests/unit/mock_data/tox_mock_data.py @@ -1,3 +1,6 @@ +from typing import Dict, List + + class Tox21MolNetMockData: """ A utility class providing mock data for testing the Tox21MolNet dataset. @@ -27,7 +30,7 @@ def get_raw_data() -> str: ) @staticmethod - def get_processed_data() -> list: + def get_processed_data() -> List[Dict]: """ Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. Each dictionary contains 'ident', 'features', and 'labels'. @@ -180,7 +183,7 @@ def get_processed_data() -> list: ] @staticmethod - def get_processed_grouped_data(): + def get_processed_grouped_data() -> List[Dict]: """ Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. Each dictionary contains 'ident', 'features', and 'labels'. @@ -223,7 +226,7 @@ class Tox21ChallengeMockData: FEATURE_OF_SMILES = [19, 42, 19, 19, 19, 42] @staticmethod - def get_raw_train_data(): + def get_raw_train_data() -> bytes: raw_str = ( Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" b"25848\n\n" @@ -280,7 +283,7 @@ def get_raw_train_data(): return raw_str @staticmethod - def data_in_dict_format(): + def data_in_dict_format() -> List[Dict]: data_list = [ { "labels": [ @@ -391,3 +394,114 @@ def data_in_dict_format(): dict_["group"] = None return data_list + + @staticmethod + def get_raw_smiles_data() -> str: + """ + Returns mock SMILES data in a tab-delimited format (mocks test.smiles file). + + The data represents molecules and their associated sample IDs. + + Returns: + str: A string containing SMILES representations and corresponding sample IDs. + """ + return ( + "#SMILES\tSample ID\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00260869-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261776-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261380-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261842-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261662-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261190-01\n" + ) + + @staticmethod + def get_raw_score_txt_data() -> str: + """ + Returns mock score data in a tab-delimited format (mocks test_results.txt file). + + The data represents toxicity test results for different molecular samples, including several toxicity endpoints. + + Returns: + str: A string containing toxicity scores for each molecular sample and corresponding toxicity endpoints. + """ + return ( + "Sample ID\tNR-AhR\tNR-AR\tNR-AR-LBD\tNR-Aromatase\tNR-ER\tNR-ER-LBD\tNR-PPAR-gamma\t" + "SR-ARE\tSR-ATAD5\tSR-HSE\tSR-MMP\tSR-p53\n" + "NCGC00260869-01\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n" + "NCGC00261776-01\t1\t1\t1\t1\t1\t1\t1\t1\t1\t1\t1\t1\n" + "NCGC00261380-01\tx\tx\tx\tx\tx\tx\tx\tx\tx\tx\tx\tx\n" + "NCGC00261842-01\t0\t0\t0\tx\t0\t0\t0\t0\t0\t0\tx\t1\n" + "NCGC00261662-01\t1\t0\t0\tx\t1\t1\t1\tx\t1\t1\tx\t1\n" + "NCGC00261190-01\tx\t0\t0\tx\t1\t0\t0\t1\t0\t0\t1\t1\n" + ) + + @staticmethod + def get_setup_processed_output_data() -> List[Dict]: + """ + Returns mock processed data used for testing the `setup_processed` method. + + The data contains molecule identifiers and their corresponding toxicity labels for multiple endpoints. + Each dictionary in the list represents a molecule with its associated labels, features, and group information. + + Returns: + List[Dict]: A list of dictionaries where each dictionary contains: + - "features": The SMILES features of the molecule. + - "labels": A list of toxicity endpoint labels (0, 1, or None). + - "ident": The sample identifier. + - "group": None (default value for the group key). + """ + + # "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", + # "SR-HSE", "SR-MMP", "SR-p53", + data_list = [ + { + "labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "ident": "NCGC00260869-01", + }, + { + "labels": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "ident": "NCGC00261776-01", + }, + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "NCGC00261380-01", + }, + { + "labels": [0, 0, 0, None, 0, 0, 0, 0, 0, 0, None, 1], + "ident": "NCGC00261842-01", + }, + { + "labels": [0, 0, 1, None, 1, 1, 1, None, 1, 1, None, 1], + "ident": "NCGC00261662-01", + }, + { + "labels": [0, 0, None, None, 1, 0, 0, 1, 0, 0, 1, 1], + "ident": "NCGC00261190-01", + }, + ] + + complete_list = [] + for dict_ in data_list: + complete_list.append( + { + "features": Tox21ChallengeMockData.FEATURE_OF_SMILES, + **dict_, + "group": None, + } + ) + + return complete_list From ac3ac19deed760fb422a60f8f8b2e84bc45540cb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 17 Sep 2024 13:12:35 +0200 Subject: [PATCH 28/46] patch `os.makedirs` in tests to avoid creating directories --- tests/unit/dataset_classes/testChEBIOverX.py | 4 +- .../dataset_classes/testChebiDataExtractor.py | 6 +- .../dataset_classes/testChebiOverXPartial.py | 3 +- .../dataset_classes/testDynamicDataset.py | 6 +- .../testGOUniProDataExtractor.py | 6 +- .../dataset_classes/testGoUniProtOverX.py | 3 +- .../dataset_classes/testTox21Challenge.py | 3 +- tests/unit/dataset_classes/testTox21MolNet.py | 55 +------------------ .../dataset_classes/testXYBaseDataModule.py | 3 +- 9 files changed, 29 insertions(+), 60 deletions(-) diff --git a/tests/unit/dataset_classes/testChEBIOverX.py b/tests/unit/dataset_classes/testChEBIOverX.py index 78d85dd4..270b868c 100644 --- a/tests/unit/dataset_classes/testChEBIOverX.py +++ b/tests/unit/dataset_classes/testChEBIOverX.py @@ -9,11 +9,13 @@ class TestChEBIOverX(unittest.TestCase): @classmethod @patch.multiple(ChEBIOverX, __abstractmethods__=frozenset()) @patch.object(ChEBIOverX, "processed_dir_main", new_callable=PropertyMock) - def setUpClass(cls, mock_processed_dir_main: PropertyMock) -> None: + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs, mock_processed_dir_main: PropertyMock) -> None: """ Set up the ChEBIOverX instance with a mock processed directory path and a test graph. Args: + mock_makedirs: This patches os.makedirs to do nothing mock_processed_dir_main (PropertyMock): Mocked property for the processed directory path. """ mock_processed_dir_main.return_value = "/mock/processed_dir" diff --git a/tests/unit/dataset_classes/testChebiDataExtractor.py b/tests/unit/dataset_classes/testChebiDataExtractor.py index 0559e090..8da900da 100644 --- a/tests/unit/dataset_classes/testChebiDataExtractor.py +++ b/tests/unit/dataset_classes/testChebiDataExtractor.py @@ -14,8 +14,12 @@ class TestChEBIDataExtractor(unittest.TestCase): @patch.multiple(_ChEBIDataExtractor, __abstractmethods__=frozenset()) @patch.object(_ChEBIDataExtractor, "base_dir", new_callable=PropertyMock) @patch.object(_ChEBIDataExtractor, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) def setUpClass( - cls, mock_name_property: PropertyMock, mock_base_dir_property: PropertyMock + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, ) -> None: """ Set up a base instance of _ChEBIDataExtractor for testing with mocked properties. diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py index a8c53408..7720d301 100644 --- a/tests/unit/dataset_classes/testChebiOverXPartial.py +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -11,7 +11,8 @@ class TestChEBIOverX(unittest.TestCase): @classmethod @patch.multiple(ChEBIOverXPartial, __abstractmethods__=frozenset()) - def setUpClass(cls) -> None: + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: """ Set up the ChEBIOverXPartial instance with a mock processed directory path and a test graph. """ diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py index 1ff6c26d..e42c3e7e 100644 --- a/tests/unit/dataset_classes/testDynamicDataset.py +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -17,8 +17,12 @@ class TestDynamicDataset(unittest.TestCase): @patch.multiple(_DynamicDataset, __abstractmethods__=frozenset()) @patch.object(_DynamicDataset, "base_dir", new_callable=PropertyMock) @patch.object(_DynamicDataset, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) def setUpClass( - cls, mock_base_dir_property: PropertyMock, mock_name_property: PropertyMock + cls, + mock_makedirs, + mock_base_dir_property: PropertyMock, + mock_name_property: PropertyMock, ) -> None: """ Set up a base instance of _DynamicDataset for testing with mocked properties. diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 1b60aa97..976334f0 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -18,8 +18,12 @@ class TestGOUniProtDataExtractor(unittest.TestCase): @patch.multiple(_GOUniProtDataExtractor, __abstractmethods__=frozenset()) @patch.object(_GOUniProtDataExtractor, "base_dir", new_callable=PropertyMock) @patch.object(_GOUniProtDataExtractor, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) def setUpClass( - cls, mock_name_property: PropertyMock, mock_base_dir_property: PropertyMock + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, ) -> None: """ Class setup for mocking abstract properties of _GOUniProtDataExtractor. diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py index 282091b5..d4157770 100644 --- a/tests/unit/dataset_classes/testGoUniProtOverX.py +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -12,7 +12,8 @@ class TestGOUniProtOverX(unittest.TestCase): @classmethod @patch.multiple(_GOUniProtOverX, __abstractmethods__=frozenset()) - def setUpClass(cls) -> None: + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: """ Set up the class for tests by initializing the extractor, graph, and input DataFrame. """ diff --git a/tests/unit/dataset_classes/testTox21Challenge.py b/tests/unit/dataset_classes/testTox21Challenge.py index 9986c82f..b94c8ca4 100644 --- a/tests/unit/dataset_classes/testTox21Challenge.py +++ b/tests/unit/dataset_classes/testTox21Challenge.py @@ -17,7 +17,8 @@ class TestTox21Challenge(unittest.TestCase): """ @classmethod - def setUpClass(cls) -> None: + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: """ Set up the Tox21Challenge instance and mock data for testing. This is run once for the test class. diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py index 0a2d67b1..c995e701 100644 --- a/tests/unit/dataset_classes/testTox21MolNet.py +++ b/tests/unit/dataset_classes/testTox21MolNet.py @@ -13,14 +13,13 @@ class TestTox21MolNet(unittest.TestCase): @classmethod - def setUpClass(cls) -> None: + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: """Initialize a Tox21MolNet instance for testing.""" ReaderMock = MagicMock() ReaderMock.name.return_value = "MockedReaderTox21MolNet" Tox21MolNet.READER = ReaderMock cls.data_module = Tox21MolNet() - # cls.data_module.raw_dir = "/mock/raw_dir" - # cls.data_module.processed_dir = "/mock/processed_dir" @patch( "builtins.open", @@ -59,57 +58,9 @@ def test_setup_processed_simple_split( mock_load_data (MagicMock): Mocked `_load_data_from_file` method to provide controlled data. mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. """ + # Facing technical error here self.data_module.setup_processed() - # # Check that torch.save was called for train, test, and validation splits - # self.assertEqual( - # mock_torch_save.call_count, - # 3, - # "torch.save should have been called exactly three times for train, test, and validation splits." - # ) - - # @patch("os.path.isfile", return_value=False) - # @patch.object(Tox21MolNet, - # "_load_data_from_file", - # return_value= Tox21MolNetMockData.get_processed_grouped_data()) - # @patch("torch.save") - # @patch("torch.load") - # @patch("chebai.preprocessing.datasets.tox21.GroupShuffleSplit") - # def test_setup_processed_group_split( - # self, - # mock_group_split: MagicMock, - # mock_torch_load: MagicMock, - # mock_save: MagicMock, - # mock_load_data: MagicMock, - # mock_isfile: MagicMock - # ) -> None: - # """ - # Test the `setup_processed` method for group-based data splitting and saving. - # - # Args: - # mock_save (MagicMock): Mocked `torch.save` function to avoid file writes. - # mock_load_data (MagicMock): Mocked `_load_data_from_file` method to provide controlled data. - # mock_isfile (MagicMock): Mocked `os.path.isfile` function to simulate file presence. - # mock_group_split (MagicMock): Mocked `GroupShuffleSplit` to control data splitting behavior. - # """ - # mock_group_split.return_value = GroupShuffleSplit(n_splits=1, train_size=0.7) - # self.data_module.setup_processed() - # - # # Load the test split - # test_split_path = os.path.join(self.data_module.processed_dir, "test.pt") - # test_split = torch.load(test_split_path) - # - # # Check if torch.save was called with correct arguments - # mock_save.assert_any_call([mock_data[1]], "/mock/processed_dir/test.pt") - # mock_save.assert_any_call([mock_data[0]], "/mock/processed_dir/train.pt") - # mock_save.assert_any_call([mock_data[1]], "/mock/processed_dir/validation.pt") - # # Check that torch.save was called for train, test, and validation splits - # self.assertEqual( - # mock_torch_save.call_count, - # 3, - # "torch.save should have been called exactly three times for train, test, and validation splits." - # ) - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/dataset_classes/testXYBaseDataModule.py b/tests/unit/dataset_classes/testXYBaseDataModule.py index 8e3575ab..64dfbe40 100644 --- a/tests/unit/dataset_classes/testXYBaseDataModule.py +++ b/tests/unit/dataset_classes/testXYBaseDataModule.py @@ -11,7 +11,8 @@ class TestXYBaseDataModule(unittest.TestCase): @classmethod @patch.object(XYBaseDataModule, "_name", new_callable=PropertyMock) - def setUpClass(cls, mock_name_property: PropertyMock) -> None: + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs, mock_name_property: PropertyMock) -> None: """ Set up a base instance of XYBaseDataModule for testing. """ From 44a1dfda8f92627b3bab97f62ab9101452a2754e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 22 Sep 2024 12:39:42 +0200 Subject: [PATCH 29/46] add test case for invalid token/input to read_data --- tests/unit/readers/testChemDataReader.py | 10 ++++++++++ tests/unit/readers/testDeepChemDataReader.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index fde8604f..0c1c4d6f 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -92,6 +92,16 @@ def test_read_data_with_new_token(self) -> None: "The new token '[H-]' was not added at the correct index in the cache.", ) + def test_read_data_with_invalid_input(self) -> None: + """ + Test the _read_data method with an invalid input. + The invalid token should raise an error or be handled appropriately. + """ + raw_data = "%INVALID%" + + with self.assertRaises(ValueError): + self.reader._read_data(raw_data) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/readers/testDeepChemDataReader.py b/tests/unit/readers/testDeepChemDataReader.py index 31a63dd1..dc29c9a6 100644 --- a/tests/unit/readers/testDeepChemDataReader.py +++ b/tests/unit/readers/testDeepChemDataReader.py @@ -100,6 +100,16 @@ def test_read_data_with_new_token(self) -> None: "The new token '[H-]' was not added to the correct index in the cache.", ) + def test_read_data_with_invalid_input(self) -> None: + """ + Test the _read_data method with an invalid input string. + The invalid token should raise an error or be handled appropriately. + """ + raw_data = "CBr))(OCI" + + with self.assertRaises(Exception): + self.reader._read_data(raw_data) + if __name__ == "__main__": unittest.main() From aab0fea1df5801b047e0f1ba9e3d2bce9f928f91 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 25 Sep 2024 13:57:54 +0200 Subject: [PATCH 30/46] test case for `Tox21MolNet.setup_processed` simple split --- tests/unit/dataset_classes/testTox21MolNet.py | 43 +++++++++++++++---- tests/unit/mock_data/tox_mock_data.py | 5 ++- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py index c995e701..042a6ae4 100644 --- a/tests/unit/dataset_classes/testTox21MolNet.py +++ b/tests/unit/dataset_classes/testTox21MolNet.py @@ -42,25 +42,52 @@ def test_load_data_from_file(self, mock_open_file: mock_open) -> None: "The loaded data does not match the expected output.", ) - @patch.object( - Tox21MolNet, - "_load_data_from_file", - return_value=Tox21MolNetMockData.get_processed_data(), + @patch( + "builtins.open", + new_callable=mock_open, + read_data=Tox21MolNetMockData.get_raw_data(), ) @patch("torch.save") def test_setup_processed_simple_split( - self, mock_load_data: MagicMock, mock_torch_save: MagicMock + self, + mock_torch_save, + mock_open_file: mock_open, ) -> None: """ Test the `setup_processed` method for basic data splitting and saving. Args: - mock_load_data (MagicMock): Mocked `_load_data_from_file` method to provide controlled data. - mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. + mock_torch_save : Mocked `torch.save` function to avoid actual file writes. + mock_open_file (mock_open): Mocked `open` builtin-method to provide custom data. """ - # Facing technical error here self.data_module.setup_processed() + # Verify if torch.save was called for each split + self.assertEqual(mock_torch_save.call_count, 3) + call_args_list = mock_torch_save.call_args_list + self.assertIn("test", call_args_list[0][0][1]) + self.assertIn("train", call_args_list[1][0][1]) + self.assertIn("validation", call_args_list[2][0][1]) + + # Check for non-overlap between train, test, and validation + test_split = [d["ident"] for d in call_args_list[0][0][0]] + train_split = [d["ident"] for d in call_args_list[1][0][0]] + validation_split = [d["ident"] for d in call_args_list[2][0][0]] + + # Assert no overlap between splits + self.assertTrue( + set(train_split).isdisjoint(test_split), + "There is an overlap between the train and test splits.", + ) + self.assertTrue( + set(train_split).isdisjoint(validation_split), + "There is an overlap between the train and validation splits.", + ) + self.assertTrue( + set(test_split).isdisjoint(validation_split), + "There is an overlap between the test and validation splits.", + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/mock_data/tox_mock_data.py b/tests/unit/mock_data/tox_mock_data.py index 32745c38..b5f85bda 100644 --- a/tests/unit/mock_data/tox_mock_data.py +++ b/tests/unit/mock_data/tox_mock_data.py @@ -35,7 +35,7 @@ def get_processed_data() -> List[Dict]: Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. Each dictionary contains 'ident', 'features', and 'labels'. """ - return [ + data_list = [ { "ident": "TOX958", "features": "Nc1ccc([N+](=O)[O-])cc1N", @@ -182,6 +182,9 @@ def get_processed_data() -> List[Dict]: }, ] + data_with_group = [{**data, "group": None} for data in data_list] + return data_with_group + @staticmethod def get_processed_grouped_data() -> List[Dict]: """ From fc8182e0cc80187fcdf6ce8d9b0e783030378c5e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 25 Sep 2024 19:11:35 +0200 Subject: [PATCH 31/46] test case for `Tox21MolNet.setup_processed` group split --- tests/unit/dataset_classes/testTox21MolNet.py | 117 ++++++++++++++---- 1 file changed, 93 insertions(+), 24 deletions(-) diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py index 042a6ae4..5d5f3497 100644 --- a/tests/unit/dataset_classes/testTox21MolNet.py +++ b/tests/unit/dataset_classes/testTox21MolNet.py @@ -1,21 +1,21 @@ -import os import unittest -from typing import Dict, List +from typing import List from unittest.mock import MagicMock, mock_open, patch -import torch -from sklearn.model_selection import GroupShuffleSplit - from chebai.preprocessing.datasets.tox21 import Tox21MolNet from tests.unit.mock_data.tox_mock_data import Tox21MolNetMockData class TestTox21MolNet(unittest.TestCase): - @classmethod @patch("os.makedirs", return_value=None) - def setUpClass(cls, mock_makedirs) -> None: - """Initialize a Tox21MolNet instance for testing.""" + def setUpClass(cls, mock_makedirs: MagicMock) -> None: + """ + Initialize a Tox21MolNet instance for testing. + + Args: + mock_makedirs (MagicMock): Mocked `os.makedirs` function. + """ ReaderMock = MagicMock() ReaderMock.name.return_value = "MockedReaderTox21MolNet" Tox21MolNet.READER = ReaderMock @@ -39,7 +39,7 @@ def test_load_data_from_file(self, mock_open_file: mock_open) -> None: self.assertEqual( list(actual_data), expected_data, - "The loaded data does not match the expected output.", + "The loaded data does not match the expected output from the file.", ) @patch( @@ -50,42 +50,111 @@ def test_load_data_from_file(self, mock_open_file: mock_open) -> None: @patch("torch.save") def test_setup_processed_simple_split( self, - mock_torch_save, + mock_torch_save: MagicMock, mock_open_file: mock_open, ) -> None: """ Test the `setup_processed` method for basic data splitting and saving. Args: - mock_torch_save : Mocked `torch.save` function to avoid actual file writes. - mock_open_file (mock_open): Mocked `open` builtin-method to provide custom data. + mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. + mock_open_file (mock_open): Mocked `open` function to simulate file reading. + """ + self.data_module.setup_processed() + + # Verify if torch.save was called for each split (train, test, validation) + self.assertEqual( + mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times." + ) + call_args_list = mock_torch_save.call_args_list + self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.") + self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.") + self.assertIn( + "validation", call_args_list[2][0][1], "Missing 'validation' split." + ) + + # Check for non-overlap between train, test, and validation splits + test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]] + train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]] + validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]] + + self.assertTrue( + set(train_split).isdisjoint(test_split), + "Overlap detected between the train and test splits.", + ) + self.assertTrue( + set(train_split).isdisjoint(validation_split), + "Overlap detected between the train and validation splits.", + ) + self.assertTrue( + set(test_split).isdisjoint(validation_split), + "Overlap detected between the test and validation splits.", + ) + + @patch.object( + Tox21MolNet, + "_load_data_from_file", + return_value=Tox21MolNetMockData.get_processed_grouped_data(), + ) + @patch("torch.save") + def test_setup_processed_with_group_split( + self, mock_torch_save: MagicMock, mock_load_file: MagicMock + ) -> None: + """ + Test the `setup_processed` method for group-based splitting and saving. + + Args: + mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. + mock_load_file (MagicMock): Mocked `_load_data_from_file` to provide custom data. """ + self.data_module.train_split = 0.5 self.data_module.setup_processed() # Verify if torch.save was called for each split - self.assertEqual(mock_torch_save.call_count, 3) + self.assertEqual( + mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times." + ) call_args_list = mock_torch_save.call_args_list - self.assertIn("test", call_args_list[0][0][1]) - self.assertIn("train", call_args_list[1][0][1]) - self.assertIn("validation", call_args_list[2][0][1]) + self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.") + self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.") + self.assertIn( + "validation", call_args_list[2][0][1], "Missing 'validation' split." + ) - # Check for non-overlap between train, test, and validation - test_split = [d["ident"] for d in call_args_list[0][0][0]] - train_split = [d["ident"] for d in call_args_list[1][0][0]] - validation_split = [d["ident"] for d in call_args_list[2][0][0]] + # Check for non-overlap between train, test, and validation splits (based on 'ident') + test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]] + train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]] + validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]] - # Assert no overlap between splits self.assertTrue( set(train_split).isdisjoint(test_split), - "There is an overlap between the train and test splits.", + "Overlap detected between the train and test splits (based on 'ident').", ) self.assertTrue( set(train_split).isdisjoint(validation_split), - "There is an overlap between the train and validation splits.", + "Overlap detected between the train and validation splits (based on 'ident').", ) self.assertTrue( set(test_split).isdisjoint(validation_split), - "There is an overlap between the test and validation splits.", + "Overlap detected between the test and validation splits (based on 'ident').", + ) + + # Check for non-overlap between train, test, and validation splits (based on 'group') + test_split_grp: List[str] = [d["group"] for d in call_args_list[0][0][0]] + train_split_grp: List[str] = [d["group"] for d in call_args_list[1][0][0]] + validation_split_grp: List[str] = [d["group"] for d in call_args_list[2][0][0]] + + self.assertTrue( + set(train_split_grp).isdisjoint(test_split_grp), + "Overlap detected between the train and test splits (based on 'group').", + ) + self.assertTrue( + set(train_split_grp).isdisjoint(validation_split_grp), + "Overlap detected between the train and validation splits (based on 'group').", + ) + self.assertTrue( + set(test_split_grp).isdisjoint(validation_split_grp), + "Overlap detected between the test and validation splits (based on 'group').", ) From e4caae8c68368bffb9b018d35b1298f3887a5500 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 25 Sep 2024 19:14:06 +0200 Subject: [PATCH 32/46] add group key + convert generator to list --- chebai/preprocessing/datasets/tox21.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/tox21.py b/chebai/preprocessing/datasets/tox21.py index 4bdfbdee..98d78009 100644 --- a/chebai/preprocessing/datasets/tox21.py +++ b/chebai/preprocessing/datasets/tox21.py @@ -68,7 +68,7 @@ def download(self) -> None: def setup_processed(self) -> None: """Processes and splits the dataset.""" print("Create splits") - data = self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv")) + data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv"))) groups = np.array([d["group"] for d in data]) if not all(g is None for g in groups): split_size = int(len(set(groups)) * self.train_split) @@ -145,7 +145,10 @@ def _load_data_from_file(self, input_file_path: str) -> List[Dict]: labels = [ bool(int(l)) if l else None for l in (row[k] for k in self.HEADERS) ] - yield dict(features=smiles, labels=labels, ident=row["mol_id"]) + group = row.get("group", None) + yield dict( + features=smiles, labels=labels, ident=row["mol_id"], group=group + ) class Tox21Challenge(XYBaseDataModule): From 1d3ecbe327b63324c52347ccc806a25c51471d40 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 26 Sep 2024 00:17:14 +0200 Subject: [PATCH 33/46] update chebi test as per modified term_callback --- .../dataset_classes/testChebiTermCallback.py | 10 +++++--- tests/unit/mock_data/ontology_mock_data.py | 25 ++++++------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/tests/unit/dataset_classes/testChebiTermCallback.py b/tests/unit/dataset_classes/testChebiTermCallback.py index 7b22d1a2..8680760e 100644 --- a/tests/unit/dataset_classes/testChebiTermCallback.py +++ b/tests/unit/dataset_classes/testChebiTermCallback.py @@ -51,11 +51,13 @@ def test_skip_obsolete_terms(self) -> None: """ Test that `term_callback` correctly skips obsolete ChEBI terms. """ + term_callback_output = [] + for ident in ChebiMockOntology.get_obsolete_nodes_ids(): + raw_term = self.callback_input_data.get(ident) + term_dict = term_callback(raw_term) + if term_dict: + term_callback_output.append(term_dict) - term_callback_output = [ - term_callback(self.callback_input_data.get(ident)) - for ident in ChebiMockOntology.get_obsolete_nodes_ids() - ] self.assertEqual( term_callback_output, [], diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 478a2bbb..40d9674e 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -356,24 +356,15 @@ def get_data_in_dataframe() -> pd.DataFrame: "C1=CC=CC=C1Br", "C1=CC=CC=C1[Mg+]", ], - # Relationships { - # 12345: [11111, 54321, 22222, 67890], - # 67890: [22222], - # 99999: [67890, 11111, 54321, 22222, 12345], - # 54321: [11111], - # 88888: [22222, 67890] - # 11111: [] - # 22222: [] - # } **{ - # -row- [11111, 12345, 22222, 54321, 67890, 88888, 99999] - 11111: [False, False, False, False, False, False, False], - 12345: [True, True, True, True, True, False, False], - 22222: [False, False, False, False, False, False, False], - 54321: [True, False, False, True, False, False, False], - 67890: [False, False, True, False, True, False, False], - 88888: [False, False, True, False, True, True, False], - 99999: [True, True, True, True, True, False, True], + # -row- [12345, 54321, 67890, 11111, 22222, 99999, 88888] + 11111: [True, True, False, True, False, True, False], + 12345: [True, False, False, False, False, True, False], + 22222: [True, False, True, False, True, True, True], + 54321: [True, True, False, False, False, True, False], + 67890: [True, False, True, False, False, True, True], + 88888: [False, False, False, False, False, False, True], + 99999: [False, False, False, False, False, True, False], }, ) From 35a621cee6cfd3c732d6e851ba2bc320defa760d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 26 Sep 2024 00:30:49 +0200 Subject: [PATCH 34/46] group key not needed for Tox21Chal._load_dict - group key needed in Tox21MolNet but not needed for Tox21Chal._load_dict --- tests/unit/dataset_classes/testTox21Challenge.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/dataset_classes/testTox21Challenge.py b/tests/unit/dataset_classes/testTox21Challenge.py index b94c8ca4..fedde8e5 100644 --- a/tests/unit/dataset_classes/testTox21Challenge.py +++ b/tests/unit/dataset_classes/testTox21Challenge.py @@ -67,6 +67,9 @@ def test_load_dict(self, mock_open_file: mock_open) -> None: mock_open_file (mock_open): Mocked open function to simulate file reading. """ expected_data = Tox21MolNetMockData.get_processed_data() + for item in expected_data: + item.pop("group", None) + actual_data = self.tox21._load_dict("fake/file/path.csv") self.assertEqual( From 016134f815c810f989566b94759514588cd09e02 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Oct 2024 20:33:02 +0200 Subject: [PATCH 35/46] Obsolete terms being the parent of valid terms --- tests/unit/mock_data/ontology_mock_data.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 40d9674e..0c713334 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -532,12 +532,21 @@ def get_obsolete_nodes_ids() -> Set[int]: @staticmethod def get_GO_raw_data() -> str: """ - Get raw data in string format for GO ontology. + Get raw data in string format for a basic Gene Ontology (GO) structure. - This data simulates a basic GO ontology in a format typically used for testing. + This data simulates a basic GO ontology format typically used for testing purposes. + The data will include valid and obsolete GO terms with various relationships between them. + + Scenarios covered: + - Obsolete terms being the parent of valid terms. + - Valid terms being the parent of obsolete terms. + - Both direct and indirect hierarchical relationships between terms. + + The data is designed to help test the proper handling of obsolete and valid GO terms, + ensuring that the ontology parser can correctly manage both cases. Returns: - str: The raw GO data in string format. + str: The raw GO data in string format, structured as test input. """ return """ [Term] @@ -557,6 +566,7 @@ def get_GO_raw_data() -> str: name: GO_2 namespace: biological_process is_a: GO:0000001 ! hydrolase activity, hydrolyzing O-glycosyl compounds + is_a: GO:0000008 ! hydrolase activity, hydrolyzing O-glycosyl compounds [Term] id: GO:0000003 @@ -594,7 +604,6 @@ def get_GO_raw_data() -> str: id: GO:0000008 name: GO_8 namespace: molecular_function - is_a: GO:0000001 ! glucoside transport is_obsolete: true [Typedef] From b479d5aff52adb580346ea70f3736e4ef876ac1a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Oct 2024 12:54:10 +0200 Subject: [PATCH 36/46] remove absolete path for mocked open func --- tests/unit/dataset_classes/testTox21Challenge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/dataset_classes/testTox21Challenge.py b/tests/unit/dataset_classes/testTox21Challenge.py index fedde8e5..9ad2af21 100644 --- a/tests/unit/dataset_classes/testTox21Challenge.py +++ b/tests/unit/dataset_classes/testTox21Challenge.py @@ -38,7 +38,7 @@ def test_load_data_from_file(self, mock_sdmol_supplier: patch) -> None: mock_file = mock_open(read_data=Tox21ChallengeMockData.get_raw_train_data()) with patch("builtins.open", mock_file): with open( - r"G:\github-aditya0by0\chebai_data\tox21_challenge\tox21_10k_data_all.sdf\tox21_10k_data_all.sdf", + r"fake/path", "rb", ) as f: suppl = Chem.ForwardSDMolSupplier(f) From adedc093435a8fb53d5bdb8c0210f65204c4d45d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Oct 2024 16:17:58 +0200 Subject: [PATCH 37/46] test single label split scenario implemented in #54 --- .../dataset_classes/testChebiOverXPartial.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py index 7720d301..76584ebf 100644 --- a/tests/unit/dataset_classes/testChebiOverXPartial.py +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -104,6 +104,72 @@ def test_extract_class_hierarchy_with_bottom_cls( f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", ) + @patch("pandas.DataFrame.to_csv") + @patch("pandas.read_pickle") + @patch.object(ChEBIOverXPartial, "_get_data_size", return_value=4.0) + @patch("torch.load") + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_single_label_data_split( + self, mock_open, mock_load, mock_get_data_size, mock_read_pickle, mock_to_csv + ) -> None: + """ + Test the single-label data splitting functionality of the ChebiExtractor class. + + This test mocks several key methods (file operations, torch loading, and pandas functions) + to ensure that the class hierarchy is properly extracted, data is processed into a raw dataset, + and the data splitting logic works as intended without actual file I/O. + + It also verifies that there is no overlap between training, validation, and test sets. + """ + self.chebi_extractor.top_class_id = 11111 + self.chebi_extractor.THRESHOLD = 3 + self.chebi_extractor.chebi_version_train = None + + graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") + data_df = self.chebi_extractor._graph_to_raw_dataset(graph) + + mock_read_pickle.return_value = data_df + data_pt = self.chebi_extractor._load_data_from_file("fake/path") + + # Verify that the data contains only 1 label + self.assertEqual(len(data_pt[0]["labels"]), 1) + + mock_load.return_value = data_pt + + # Retrieve the data splits (train, validation, and test) + train_split = self.chebi_extractor.dynamic_split_dfs["train"] + validation_split = self.chebi_extractor.dynamic_split_dfs["validation"] + test_split = self.chebi_extractor.dynamic_split_dfs["test"] + + train_idents = set(train_split["ident"]) + val_idents = set(validation_split["ident"]) + test_idents = set(test_split["ident"]) + + # Ensure there is no overlap between train and test sets + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + # Ensure there is no overlap between validation and test sets + self.assertEqual( + len(val_idents.intersection(test_idents)), + 0, + "Validation and test sets should not overlap.", + ) + + # Ensure there is no overlap between train and validation sets + self.assertEqual( + len(train_idents.intersection(val_idents)), + 0, + "Train and validation sets should not overlap.", + ) + if __name__ == "__main__": unittest.main() From 65c2d9bd6cdd1241b2b9d2cb0c69bc892f760274 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Oct 2024 17:04:10 +0200 Subject: [PATCH 38/46] test output format for Tox21MolNet._load_data_from_file --- tests/unit/dataset_classes/testTox21MolNet.py | 37 ++++++++++++++----- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py index 5d5f3497..86cbb752 100644 --- a/tests/unit/dataset_classes/testTox21MolNet.py +++ b/tests/unit/dataset_classes/testTox21MolNet.py @@ -2,7 +2,10 @@ from typing import List from unittest.mock import MagicMock, mock_open, patch +import torch + from chebai.preprocessing.datasets.tox21 import Tox21MolNet +from chebai.preprocessing.reader import ChemDataReader from tests.unit.mock_data.tox_mock_data import Tox21MolNetMockData @@ -16,9 +19,7 @@ def setUpClass(cls, mock_makedirs: MagicMock) -> None: Args: mock_makedirs (MagicMock): Mocked `os.makedirs` function. """ - ReaderMock = MagicMock() - ReaderMock.name.return_value = "MockedReaderTox21MolNet" - Tox21MolNet.READER = ReaderMock + Tox21MolNet.READER = ChemDataReader cls.data_module = Tox21MolNet() @patch( @@ -28,20 +29,38 @@ def setUpClass(cls, mock_makedirs: MagicMock) -> None: ) def test_load_data_from_file(self, mock_open_file: mock_open) -> None: """ - Test the `_load_data_from_file` method for correct CSV parsing. + Test the `_load_data_from_file` method for correct output. Args: mock_open_file (mock_open): Mocked open function to simulate file reading. """ - expected_data = Tox21MolNetMockData.get_processed_data() actual_data = self.data_module._load_data_from_file("fake/file/path.csv") - self.assertEqual( - list(actual_data), - expected_data, - "The loaded data does not match the expected output from the file.", + first_instance = next(actual_data) + + # Check for required keys + required_keys = ["features", "labels", "ident"] + for key in required_keys: + self.assertIn( + key, first_instance, f"'{key}' key is missing in the output data." + ) + + self.assertTrue( + all(isinstance(feature, int) for feature in first_instance["features"]), + "Not all elements in 'features' are integers.", ) + # Check that 'features' can be converted to a tensor + features = first_instance["features"] + try: + tensor_features = torch.tensor(features) + self.assertTrue( + tensor_features.ndim > 0, + "'features' should be convertible to a non-empty tensor.", + ) + except Exception as e: + self.fail(f"'features' cannot be converted to a tensor: {str(e)}") + @patch( "builtins.open", new_callable=mock_open, From a63c010f46cce5780d4f4068a01268ecec292e64 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Oct 2024 17:40:10 +0200 Subject: [PATCH 39/46] DynamicDataset: check split stratification --- .../dataset_classes/testDynamicDataset.py | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py index e42c3e7e..c8846273 100644 --- a/tests/unit/dataset_classes/testDynamicDataset.py +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -216,6 +216,142 @@ def test_get_train_val_splits_given_test_consistency(self) -> None: obj="Validation sets should be identical for the same seed.", ) + def test_get_test_split_stratification(self) -> None: + """ + Test that the split into train and test sets maintains the stratification of labels. + """ + self.dataset.train_split = 0.5 + train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) + + number_of_labels = len(self.data_df["labels"][0]) + + # Check the label distribution in the original dataset + original_pos_count, original_neg_count = ( + self.get_positive_negative_labels_counts(self.data_df) + ) + total_count = len(self.data_df) * number_of_labels + + # Calculate the expected proportions + original_pos_proportion = original_pos_count / total_count + original_neg_proportion = original_neg_count / total_count + + # Check the label distribution in the train set + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + train_total_count = len(train_df) * number_of_labels + + # Calculate the train set proportions + train_pos_proportion = train_pos_count / train_total_count + train_neg_proportion = train_neg_count / train_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + train_pos_proportion, + original_pos_proportion, + places=1, + msg="Train set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + train_neg_proportion, + original_neg_proportion, + places=1, + msg="Train set labels should maintain original negative label proportion.", + ) + + # Check the label distribution in the test set + test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( + test_df + ) + test_total_count = len(test_df) * number_of_labels + + # Calculate the test set proportions + test_pos_proportion = test_pos_count / test_total_count + test_neg_proportion = test_neg_count / test_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + test_pos_proportion, + original_pos_proportion, + places=1, + msg="Test set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + test_neg_proportion, + original_neg_proportion, + places=1, + msg="Test set labels should maintain original negative label proportion.", + ) + + def test_get_train_val_splits_given_test_stratification(self) -> None: + """ + Test that the split into train and validation sets maintains the stratification of labels. + """ + self.dataset.use_inner_cross_validation = False + self.dataset.train_split = 0.5 + df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) + train_df, val_df = self.dataset.get_train_val_splits_given_test( + df_train_main, test_df, seed=42 + ) + + number_of_labels = len(self.data_df["labels"][0]) + + # Check the label distribution in the original dataset + original_pos_count, original_neg_count = ( + self.get_positive_negative_labels_counts(self.data_df) + ) + total_count = len(self.data_df) * number_of_labels + + # Calculate the expected proportions + original_pos_proportion = original_pos_count / total_count + original_neg_proportion = original_neg_count / total_count + + # Check the label distribution in the train set + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + train_total_count = len(train_df) * number_of_labels + + # Calculate the train set proportions + train_pos_proportion = train_pos_count / train_total_count + train_neg_proportion = train_neg_count / train_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + train_pos_proportion, + original_pos_proportion, + places=1, + msg="Train set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + train_neg_proportion, + original_neg_proportion, + places=1, + msg="Train set labels should maintain original negative label proportion.", + ) + + # Check the label distribution in the validation set + val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) + val_total_count = len(val_df) * number_of_labels + + # Calculate the validation set proportions + val_pos_proportion = val_pos_count / val_total_count + val_neg_proportion = val_neg_count / val_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + val_pos_proportion, + original_pos_proportion, + places=1, + msg="Validation set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + val_neg_proportion, + original_neg_proportion, + places=1, + msg="Validation set labels should maintain original negative label proportion.", + ) + @staticmethod def get_positive_negative_labels_counts(df: pd.DataFrame) -> Tuple[int, int]: """ From e3c4b6e2c2ec1b5f30a883b0d53be71532d8adf7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 12 Oct 2024 15:50:56 +0200 Subject: [PATCH 40/46] fix testcase for GO --- tests/unit/dataset_classes/testGOUniProDataExtractor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 976334f0..dcde90bc 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -1,11 +1,12 @@ import unittest -from unittest.mock import MagicMock, PropertyMock, mock_open, patch +from unittest.mock import PropertyMock, mock_open, patch import fastobo import networkx as nx import pandas as pd from chebai.preprocessing.datasets.go_uniprot import _GOUniProtDataExtractor +from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData @@ -30,9 +31,8 @@ def setUpClass( """ mock_base_dir_property.return_value = "MockedBaseDirPropGOUniProtDataExtractor" mock_name_property.return_value = "MockedNamePropGOUniProtDataExtractor" - ReaderMock = MagicMock() - ReaderMock.name.return_value = "MockedReaderGOUniProtDataExtractor" - _GOUniProtDataExtractor.READER = ReaderMock + + _GOUniProtDataExtractor.READER = ProteinDataReader cls.extractor = _GOUniProtDataExtractor() From c1ddd17667c3532b0ca80b1196b2e6c0bb855f7f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 20 Oct 2024 11:44:56 +0200 Subject: [PATCH 41/46] update testcase as per transitive go ids --- .../unit/dataset_classes/testGOUniProDataExtractor.py | 10 +++++++++- tests/unit/mock_data/ontology_mock_data.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index dcde90bc..9da48bee 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -1,4 +1,5 @@ import unittest +from collections import OrderedDict from unittest.mock import PropertyMock, mock_open, patch import fastobo @@ -141,7 +142,14 @@ def test_get_swiss_to_go_mapping(self, mock_open) -> None: Test the extraction of SwissProt to GO term mapping. """ mapping_df = self.extractor._get_swiss_to_go_mapping() - expected_df = GOUniProtMockData.get_data_in_dataframe().iloc[:, :4] + expected_df = pd.DataFrame( + OrderedDict( + swiss_id=["Swiss_Prot_1", "Swiss_Prot_2"], + accession=["Q6GZX4", "DCGZX4"], + go_ids=[[2, 3, 5], [2, 5]], + sequence=list(GOUniProtMockData.protein_sequences().values()), + ) + ) pd.testing.assert_frame_equal( mapping_df, diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 0c713334..d6feb33d 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -736,11 +736,11 @@ def get_data_in_dataframe() -> pd.DataFrame: expected_data = OrderedDict( swiss_id=["Swiss_Prot_1", "Swiss_Prot_2"], accession=["Q6GZX4", "DCGZX4"], - go_ids=[[2, 3, 5], [2, 5]], + go_ids=[[1, 2, 3, 5], [1, 2, 5]], sequence=list(GOUniProtMockData.protein_sequences().values()), **{ # SP_1, SP_2 - 1: [False, False], + 1: [True, True], 2: [True, True], 3: [True, False], 4: [False, False], From bf6bc4aa14999afdfa3d4ebec017791dc6edad09 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 20 Oct 2024 11:51:14 +0200 Subject: [PATCH 42/46] remove test for tox21mol net - this test will be added in another branch later once #53 is completed --- tests/unit/dataset_classes/testTox21MolNet.py | 181 ------------------ 1 file changed, 181 deletions(-) delete mode 100644 tests/unit/dataset_classes/testTox21MolNet.py diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py deleted file mode 100644 index 86cbb752..00000000 --- a/tests/unit/dataset_classes/testTox21MolNet.py +++ /dev/null @@ -1,181 +0,0 @@ -import unittest -from typing import List -from unittest.mock import MagicMock, mock_open, patch - -import torch - -from chebai.preprocessing.datasets.tox21 import Tox21MolNet -from chebai.preprocessing.reader import ChemDataReader -from tests.unit.mock_data.tox_mock_data import Tox21MolNetMockData - - -class TestTox21MolNet(unittest.TestCase): - @classmethod - @patch("os.makedirs", return_value=None) - def setUpClass(cls, mock_makedirs: MagicMock) -> None: - """ - Initialize a Tox21MolNet instance for testing. - - Args: - mock_makedirs (MagicMock): Mocked `os.makedirs` function. - """ - Tox21MolNet.READER = ChemDataReader - cls.data_module = Tox21MolNet() - - @patch( - "builtins.open", - new_callable=mock_open, - read_data=Tox21MolNetMockData.get_raw_data(), - ) - def test_load_data_from_file(self, mock_open_file: mock_open) -> None: - """ - Test the `_load_data_from_file` method for correct output. - - Args: - mock_open_file (mock_open): Mocked open function to simulate file reading. - """ - actual_data = self.data_module._load_data_from_file("fake/file/path.csv") - - first_instance = next(actual_data) - - # Check for required keys - required_keys = ["features", "labels", "ident"] - for key in required_keys: - self.assertIn( - key, first_instance, f"'{key}' key is missing in the output data." - ) - - self.assertTrue( - all(isinstance(feature, int) for feature in first_instance["features"]), - "Not all elements in 'features' are integers.", - ) - - # Check that 'features' can be converted to a tensor - features = first_instance["features"] - try: - tensor_features = torch.tensor(features) - self.assertTrue( - tensor_features.ndim > 0, - "'features' should be convertible to a non-empty tensor.", - ) - except Exception as e: - self.fail(f"'features' cannot be converted to a tensor: {str(e)}") - - @patch( - "builtins.open", - new_callable=mock_open, - read_data=Tox21MolNetMockData.get_raw_data(), - ) - @patch("torch.save") - def test_setup_processed_simple_split( - self, - mock_torch_save: MagicMock, - mock_open_file: mock_open, - ) -> None: - """ - Test the `setup_processed` method for basic data splitting and saving. - - Args: - mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. - mock_open_file (mock_open): Mocked `open` function to simulate file reading. - """ - self.data_module.setup_processed() - - # Verify if torch.save was called for each split (train, test, validation) - self.assertEqual( - mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times." - ) - call_args_list = mock_torch_save.call_args_list - self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.") - self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.") - self.assertIn( - "validation", call_args_list[2][0][1], "Missing 'validation' split." - ) - - # Check for non-overlap between train, test, and validation splits - test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]] - train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]] - validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]] - - self.assertTrue( - set(train_split).isdisjoint(test_split), - "Overlap detected between the train and test splits.", - ) - self.assertTrue( - set(train_split).isdisjoint(validation_split), - "Overlap detected between the train and validation splits.", - ) - self.assertTrue( - set(test_split).isdisjoint(validation_split), - "Overlap detected between the test and validation splits.", - ) - - @patch.object( - Tox21MolNet, - "_load_data_from_file", - return_value=Tox21MolNetMockData.get_processed_grouped_data(), - ) - @patch("torch.save") - def test_setup_processed_with_group_split( - self, mock_torch_save: MagicMock, mock_load_file: MagicMock - ) -> None: - """ - Test the `setup_processed` method for group-based splitting and saving. - - Args: - mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. - mock_load_file (MagicMock): Mocked `_load_data_from_file` to provide custom data. - """ - self.data_module.train_split = 0.5 - self.data_module.setup_processed() - - # Verify if torch.save was called for each split - self.assertEqual( - mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times." - ) - call_args_list = mock_torch_save.call_args_list - self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.") - self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.") - self.assertIn( - "validation", call_args_list[2][0][1], "Missing 'validation' split." - ) - - # Check for non-overlap between train, test, and validation splits (based on 'ident') - test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]] - train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]] - validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]] - - self.assertTrue( - set(train_split).isdisjoint(test_split), - "Overlap detected between the train and test splits (based on 'ident').", - ) - self.assertTrue( - set(train_split).isdisjoint(validation_split), - "Overlap detected between the train and validation splits (based on 'ident').", - ) - self.assertTrue( - set(test_split).isdisjoint(validation_split), - "Overlap detected between the test and validation splits (based on 'ident').", - ) - - # Check for non-overlap between train, test, and validation splits (based on 'group') - test_split_grp: List[str] = [d["group"] for d in call_args_list[0][0][0]] - train_split_grp: List[str] = [d["group"] for d in call_args_list[1][0][0]] - validation_split_grp: List[str] = [d["group"] for d in call_args_list[2][0][0]] - - self.assertTrue( - set(train_split_grp).isdisjoint(test_split_grp), - "Overlap detected between the train and test splits (based on 'group').", - ) - self.assertTrue( - set(train_split_grp).isdisjoint(validation_split_grp), - "Overlap detected between the train and validation splits (based on 'group').", - ) - self.assertTrue( - set(test_split_grp).isdisjoint(validation_split_grp), - "Overlap detected between the test and validation splits (based on 'group').", - ) - - -if __name__ == "__main__": - unittest.main() From b915b0db7f11710e8f5eabb81c070995cab13844 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 22 Oct 2024 12:53:35 +0200 Subject: [PATCH 43/46] Revert "add group key + convert generator to list" This reverts commit e4caae8c68368bffb9b018d35b1298f3887a5500. --- chebai/preprocessing/datasets/tox21.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/datasets/tox21.py b/chebai/preprocessing/datasets/tox21.py index 98d78009..4bdfbdee 100644 --- a/chebai/preprocessing/datasets/tox21.py +++ b/chebai/preprocessing/datasets/tox21.py @@ -68,7 +68,7 @@ def download(self) -> None: def setup_processed(self) -> None: """Processes and splits the dataset.""" print("Create splits") - data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv"))) + data = self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv")) groups = np.array([d["group"] for d in data]) if not all(g is None for g in groups): split_size = int(len(set(groups)) * self.train_split) @@ -145,10 +145,7 @@ def _load_data_from_file(self, input_file_path: str) -> List[Dict]: labels = [ bool(int(l)) if l else None for l in (row[k] for k in self.HEADERS) ] - group = row.get("group", None) - yield dict( - features=smiles, labels=labels, ident=row["mol_id"], group=group - ) + yield dict(features=smiles, labels=labels, ident=row["mol_id"]) class Tox21Challenge(XYBaseDataModule): From a71b199f3b569ec0d13e5c96e43fdd036308060b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 2 Nov 2024 16:40:21 +0100 Subject: [PATCH 44/46] update swiss data for pretraining test --- tests/unit/mock_data/ontology_mock_data.py | 110 +++++++++++++++------ 1 file changed, 79 insertions(+), 31 deletions(-) diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index d6feb33d..92a070cb 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -632,17 +632,44 @@ def protein_sequences() -> Dict[str, str]: ), } + @staticmethod + def proteins_for_pretraining() -> List[str]: + """ + Returns a list of protein IDs which will be used for pretraining based on mock UniProt data. + + Proteins include those with: + - No GO classes or invalid GO classes (missing required evidence codes). + + Returns: + List[str]: A list of protein IDs that do not meet validation criteria. + """ + return [ + "Swiss_Prot_5", # No GO classes associated + "Swiss_Prot_6", # GO class with no evidence code + "Swiss_Prot_7", # GO class with invalid evidence code + ] + @staticmethod def get_UniProt_raw_data() -> str: """ Get raw data in string format for UniProt proteins. - This mock data contains six Swiss-Prot proteins with different properties: - - Swiss_Prot_1 and Swiss_Prot_2 are valid proteins. - - Swiss_Prot_3 has a sequence length greater than 1002. - - Swiss_Prot_4 contains "X", a non-valid amino acid in its sequence. - - Swiss_Prot_5 has no GO IDs mapped to it. - - Swiss_Prot_6 has GO IDs mapped, but no evidence codes. + This mock data contains eleven Swiss-Prot proteins with different properties: + - **Swiss_Prot_1**: A valid protein with three valid GO classes and one invalid GO class. + - **Swiss_Prot_2**: Another valid protein with two valid GO classes and one invalid. + - **Swiss_Prot_3**: Contains valid GO classes but has a sequence length > 1002. + - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'X'. + - **Swiss_Prot_5**: Has a sequence but no GO classes associated. + - **Swiss_Prot_6**: Has GO classes without any associated evidence codes. + - **Swiss_Prot_7**: Has a GO class with an invalid evidence code. + - **Swiss_Prot_8**: Has a sequence length > 1002 and has only invalid GO class. + - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'X', in its sequence. + - **Swiss_Prot_10**: Has a valid GO class but lacks a sequence. + - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. + + Note: + A valid GO label is the one which has one of the following evidence code + (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). Returns: str: The raw UniProt data in string format. @@ -650,6 +677,7 @@ def get_UniProt_raw_data() -> str: protein_sq_1 = GOUniProtMockData.protein_sequences()["Swiss_Prot_1"] protein_sq_2 = GOUniProtMockData.protein_sequences()["Swiss_Prot_2"] raw_str = ( + # Below protein with 3 valid associated GO class and one invalid GO class f"ID Swiss_Prot_1 Reviewed; {len(protein_sq_1)} AA. \n" "AC Q6GZX4;\n" "DR GO; GO:0000002; C:membrane; EXP:UniProtKB-KW.\n" @@ -659,6 +687,7 @@ def get_UniProt_raw_data() -> str: f"SQ SEQUENCE {len(protein_sq_1)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" f" {protein_sq_1}\n" "//\n" + # Below protein with 2 valid associated GO class and one invalid GO class f"ID Swiss_Prot_2 Reviewed; {len(protein_sq_2)} AA.\n" "AC DCGZX4;\n" "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" @@ -668,34 +697,17 @@ def get_UniProt_raw_data() -> str: f"SQ SEQUENCE {len(protein_sq_2)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" f" {protein_sq_2}\n" "//\n" - "ID Swiss_Prot_3 Reviewed; 1165 AA.\n" + # Below protein with all valid associated GO class but sequence length greater than 1002 + f"ID Swiss_Prot_3 Reviewed; {len(protein_sq_1 * 25)} AA.\n" "AC Q6GZX4;\n" "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" "DR GO; GO:0000002; P:regulation of viral transcription; IEP:InterPro.\n" "DR GO; GO:0000005; P:regulation of viral transcription; TAS:InterPro.\n" "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" - "SQ SEQUENCE 1165 AA; 129118 MW; FE2984658CED53A8 CRC64;\n" - " MRVVVNAKAL EVPVGMSFTE WTRTLSPGSS PRFLAWNPVR PRTFKDVTDP FWNGKVFDLL\n" - " GVVNGKDDLL FPASEIQEWL EYAPNVDLAE LERIFVATHR HRGMMGFAAA VQDSLVHVDP\n" - " DSVDVTRVKD GLHKELDEHA SKAAATDVRL KRLRSVKPVD GFSDPVLIRT VFSVTVPEFG\n" - " DRTAYEIVDS AVPTGSCPYI SAGPFVKTIP GFKPAPEWPA QTAHAEGAVF FKADAEFPDT\n" - " KPLKDMYRKY SGAAVVPGDV TYPAVITFDV PQGSRHVPPE DFAARVAESL SLDLRGRPLV\n" - " EMGRVVSVRL DGMRFRPYVL TDLLVSDPDA SHVMQTDELN RAHKIKGTVY AQVCGTGQTV\n" - " SFQEKTDEDS GEAYISLRVR ARDRKGVEEL MEAAGRVMAI YSRRESEIVS FYALYDKTVA\n" - " KEAAPPRPPR KSKAPEPTGD KADRKLLRTL APDIFLPTYS RKCLHMPVIL RGAELEDARK\n" - " KGLNLMDFPL FGESERLTYA CKHPQHPYPG LRANLLPNKA KYPFVPCCYS KDQAVRPNSK\n" - " WTAYTTGNAE ARRQGRIREG VMQAEPLPEG ALIFLRRVLG QETGSKFFAL RTTGVPETPV\n" - " NAVHVAVFQR SLTAEEQAEE RAAMALDPSA MGACAQELYV EPDVDWDRWR REMGDPNVPF\n" - " NLLKYFRALE TRYDCDIYIM DNKGIIHTKA VRGRLRYRSR RPTVILHLRE ESCVPVMTPP\n" - " SDWTRGPVRN GILTFSPIDP ITVKLHDLYQ DSRPVYVDGV RVPPLRSDWL PCSGQVVDRA\n" - " GKARVFVVTP TGKMSRGSFT LVTWPMPPLA APILRTDTGF PRGRSDSPLS FLGSRFVPSG\n" - " YRRSVETGAI REITGILDGA CEACLLTHDP VLVPDPSWSD GGPPVYEDPV PSRALEGFTG\n" - " AEKKARMLVE YAKKAISIRE GSCTQESVRS FAANGGFVVS PGALDGMKVF NPRFEAPGPF\n" - " AEADWAVKVP DVKTARRLVY ALRVASVNGT CPVQEYASAS LVPNFYKTST DFVQSPAYTI\n" - " NVWRNDLDQS AVKKTRRAVV DWERGLAVPW PLPETELGFS YSLRFAGISR TFMAMNHPTW\n" - " ESAAFAALTW AKSGYCPGVT SNQIPEGEKV PTYACVKGMK PAKVLESGDG TLKLDKSSYG\n" - " DVRVSGVMIY RASEGKPMQY VSLLM\n" + f"SQ SEQUENCE {len(protein_sq_1 * 25)} AA; 129118 MW; FE2984658CED53A8 CRC64;\n" + f" {protein_sq_1 * 25}\n" "//\n" + # Below protein has valid go class association but invalid amino acid `X` in its sequence "ID Swiss_Prot_4 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" @@ -705,18 +717,54 @@ def get_UniProt_raw_data() -> str: "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" + # Below protein with sequence string but has no GO class "ID Swiss_Prot_5 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" - "ID Swiss_Prot_5 Reviewed; 60 AA.\n" + # Below protein with sequence string and with NO `valid` associated GO class (no evidence code) + "ID Swiss_Prot_6 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000023; P:regulation of viral transcription;\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence string and with NO `valid` associated GO class (invalid evidence code) + "ID Swiss_Prot_7 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" - "DR GO; GO:0000005; P:regulation of viral transcription;\n" + "DR GO; GO:0000024; P:regulation of viral transcription; IEA:SGD.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" - "//" + "//\n" + # Below protein with sequence length greater than 1002 but with `Invalid` associated GO class + f"ID Swiss_Prot_8 Reviewed; {len(protein_sq_2 * 25)} AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000025; P:regulation of viral transcription; IC:Inferred.\n" + f"SQ SEQUENCE {len(protein_sq_2 * 25)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_2 * 25}\n" + "//\n" + # Below protein with sequence string but invalid amino acid `X` in its sequence + "ID Swiss_Prot_9 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with a `valid` associated GO class but without sequence string + "ID Swiss_Prot_10 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000027; P:regulation of viral transcription; EXP:InterPro.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " \n" + "//\n" + # Below protein with a `Invalid` associated GO class but without sequence string + "ID Swiss_Prot_11 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000028; P:regulation of viral transcription; ND:NoData.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " \n" + "//\n" ) return raw_str From 8abd14d7e72ec62efd5aba801c8fe547d04a12ea Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 2 Nov 2024 16:41:20 +0100 Subject: [PATCH 45/46] add test for protein pretraining class --- .../testProteinPretrainingData.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/unit/dataset_classes/testProteinPretrainingData.py diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py new file mode 100644 index 00000000..d3046fdf --- /dev/null +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -0,0 +1,71 @@ +import unittest +from unittest.mock import PropertyMock, mock_open, patch +from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData +from chebai.preprocessing.reader import ProteinDataReader +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestProteinPretrainingData(unittest.TestCase): + """ + Unit tests for the _ProteinPretrainingData class. + Tests focus on data parsing and validation checks for protein pretraining. + """ + + @classmethod + @patch.multiple(_ProteinPretrainingData, __abstractmethods__=frozenset()) + @patch.object(_ProteinPretrainingData, "base_dir", new_callable=PropertyMock) + @patch.object(_ProteinPretrainingData, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, + ) -> None: + """ + Class setup for mocking abstract properties of _ProteinPretrainingData. + + Mocks the required abstract properties and sets up the data extractor. + """ + mock_base_dir_property.return_value = "MockedBaseDirPropProteinPretrainingData" + mock_name_property.return_value = "MockedNameProp_ProteinPretrainingData" + + # Set the READER class for the pretraining data + _ProteinPretrainingData.READER = ProteinDataReader + + # Initialize the extractor instance + cls.extractor = _ProteinPretrainingData() + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + def test_parse_protein_data_for_pretraining(self, mock_open_file: mock_open) -> None: + """ + Tests the _parse_protein_data_for_pretraining method. + + Verifies that: + - The parsed DataFrame contains the expected protein IDs. + - The protein sequences are not empty. + """ + # Parse the pretraining data + pretrain_df = self.extractor._parse_protein_data_for_pretraining() + list_of_pretrain_swiss_ids = GOUniProtMockData.proteins_for_pretraining() + + # Assert that all expected Swiss-Prot IDs are present in the DataFrame + self.assertEqual( + set(pretrain_df['swiss_id']), + set(list_of_pretrain_swiss_ids), + msg="The parsed DataFrame does not contain the expected Swiss-Prot IDs for pretraining." + ) + + # Assert that all sequences are not empty + self.assertTrue( + pretrain_df['sequence'].str.len().gt(0).all(), + msg="Some protein sequences in the pretraining DataFrame are empty." + ) + + +if __name__ == "__main__": + unittest.main() From aae57d355b03b96600ded3e41c8e88361855d624 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 2 Nov 2024 16:44:12 +0100 Subject: [PATCH 46/46] test : reformat with precommit --- .../dataset_classes/testProteinPretrainingData.py | 13 ++++++++----- tests/unit/mock_data/ontology_mock_data.py | 8 ++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py index d3046fdf..cb6b0688 100644 --- a/tests/unit/dataset_classes/testProteinPretrainingData.py +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -1,5 +1,6 @@ import unittest from unittest.mock import PropertyMock, mock_open, patch + from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData @@ -41,7 +42,9 @@ def setUpClass( new_callable=mock_open, read_data=GOUniProtMockData.get_UniProt_raw_data(), ) - def test_parse_protein_data_for_pretraining(self, mock_open_file: mock_open) -> None: + def test_parse_protein_data_for_pretraining( + self, mock_open_file: mock_open + ) -> None: """ Tests the _parse_protein_data_for_pretraining method. @@ -55,15 +58,15 @@ def test_parse_protein_data_for_pretraining(self, mock_open_file: mock_open) -> # Assert that all expected Swiss-Prot IDs are present in the DataFrame self.assertEqual( - set(pretrain_df['swiss_id']), + set(pretrain_df["swiss_id"]), set(list_of_pretrain_swiss_ids), - msg="The parsed DataFrame does not contain the expected Swiss-Prot IDs for pretraining." + msg="The parsed DataFrame does not contain the expected Swiss-Prot IDs for pretraining.", ) # Assert that all sequences are not empty self.assertTrue( - pretrain_df['sequence'].str.len().gt(0).all(), - msg="Some protein sequences in the pretraining DataFrame are empty." + pretrain_df["sequence"].str.len().gt(0).all(), + msg="Some protein sequences in the pretraining DataFrame are empty.", ) diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 92a070cb..a05b89f1 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -751,19 +751,19 @@ def get_UniProt_raw_data() -> str: "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" - # Below protein with a `valid` associated GO class but without sequence string + # Below protein with a `valid` associated GO class but without sequence string "ID Swiss_Prot_10 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "DR GO; GO:0000027; P:regulation of viral transcription; EXP:InterPro.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " \n" + " \n" "//\n" - # Below protein with a `Invalid` associated GO class but without sequence string + # Below protein with a `Invalid` associated GO class but without sequence string "ID Swiss_Prot_11 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "DR GO; GO:0000028; P:regulation of viral transcription; ND:NoData.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " \n" + " \n" "//\n" )