From 5a6f3f0d6b132c7b640938b6b799377239098197 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 4 Mar 2026 14:16:45 +0100 Subject: [PATCH 1/5] allow input_dim and out_dim of 0 --- chebai/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index df060e9a..b22f28f8 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -48,8 +48,8 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion - assert out_dim is not None and out_dim > 0, "out_dim must be specified" - assert input_dim is not None and input_dim > 0, "input_dim must be specified" + assert out_dim is not None, "out_dim must be specified" + assert input_dim is not None, "input_dim must be specified" self.out_dim = out_dim self.input_dim = input_dim print( From bd7f1f4cd2aca77388ccc7a5cf044d3565f8f8bf Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 4 Mar 2026 14:17:15 +0100 Subject: [PATCH 2/5] dont pass config to base model --- chebai/models/electra.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 36430773..20678b84 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -39,8 +39,8 @@ class ElectraPre(ChebaiBaseNet): replace_p (float): Probability of replacing tokens during training. """ - def __init__(self, config: Dict[str, Any] = None, **kwargs: Any): - super().__init__(config=config, **kwargs) + def __init__(self, config: Dict[str, Any], **kwargs: Any): + super().__init__(**kwargs) self.generator_config = ElectraConfig(**config["generator"]) self.generator = ElectraForMaskedLM(self.generator_config) From 94aab596e9d58a1eed9f05f61c765aa4622ad41b Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 4 Mar 2026 14:18:24 +0100 Subject: [PATCH 3/5] catch missing labels if no labels exist --- chebai/preprocessing/collate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py index 308ed6c7..06b4ff34 100644 --- a/chebai/preprocessing/collate.py +++ b/chebai/preprocessing/collate.py @@ -87,7 +87,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: *((d["features"], d["labels"], d.get("ident")) for d in data) ) missing_labels = [ - d.get("missing_labels", [False for _ in y[0]]) for d in data + d.get( + "missing_labels", + [False for _ in y[0]] if y[0] is not None else [False], + ) + for d in data ] if any(x is not None for x in y): From 9618c2f4648006339b9c5ade6ea35acfb73d8284 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 4 Mar 2026 14:19:00 +0100 Subject: [PATCH 4/5] only create classes.txt path for dynamic datasets --- chebai/preprocessing/datasets/base.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index e295a3ed..0e164508 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -514,11 +514,13 @@ def setup(self, *args, **kwargs) -> None: rank_zero_info(f"Check for processed data in {self.processed_dir}") rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}") - rank_zero_info(f"Looking for files: {self.processed_file_names}") if any( not os.path.isfile(os.path.join(self.processed_dir, f)) for f in self.processed_file_names ): + rank_zero_info( + f"Did not find one of: {', '.join(self.processed_file_names)} in {self.processed_dir}" + ) self.setup_processed() self._after_setup(**kwargs) @@ -627,17 +629,17 @@ def raw_file_names_dict(self) -> dict: raise NotImplementedError @property - def classes_txt_file_path(self) -> str: + def classes_txt_file_path(self) -> Optional[str]: """ - Returns the filename for the classes text file. + Returns the filename for the classes text file (for labeled datasets that produce a list of labels). Returns: - str: The filename for the classes text file. + Optional[str]: The filename for the classes text file. """ # This property also used in following places: # - chebai/result/prediction.py: to load class names for csv columns names # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` - return os.path.join(self.processed_dir_main, "classes.txt") + return None class MergedDataset(XYBaseDataModule): @@ -1406,3 +1408,16 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} + + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in following places: + # - chebai/result/prediction.py: to load class names for csv columns names + # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` + return os.path.join(self.processed_dir_main, "classes.txt") From 8f6d5e03e04fee26acc1eb23cbefa32e573e3f06 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 4 Mar 2026 14:19:36 +0100 Subject: [PATCH 5/5] change message for loading properties --- chebai/preprocessing/datasets/pubchem.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 8cc208b9..2f169b0e 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -195,8 +195,9 @@ def _set_processed_data_props(self): self._num_of_labels = 0 self._feature_vector_size = 0 - print(f"Number of labels for loaded data: {self._num_of_labels}") - print(f"Feature vector size: {self._feature_vector_size}") + print( + f"Number of labels and feature vector size set to: {self._num_of_labels} / {self._feature_vector_size} (default values, not used for self-supervised learning)" + ) def _perform_data_preparation(self, *args, **kwargs): """