Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds a missing split functionality to the synthetic_data module that was accidentally omitted during a refactoring of the pivae synthetic data code, which was causing failures in task.py.
- Adds a
splitmethod to handle train/valid/all data splits - Implements 80/20 train/validation split logic
- Provides compatibility with existing pivae task code
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| train_idx = np.arange(tot_len)[:int(tot_len*0.8)] | ||
| valid_idx = np.arange(tot_len)[int(tot_len*0.8):] |
There was a problem hiding this comment.
[nitpick] The hardcoded 0.8 split ratio should be extracted as a configurable parameter or class constant to improve maintainability and allow for different split ratios.
| if split == 'train': | ||
| self.neural = self.neural[train_idx] | ||
| self.index = self.index[train_idx] | ||
| self.idx = train_idx |
There was a problem hiding this comment.
The method modifies the instance state by setting self.idx, but this attribute is not defined in init. This could lead to inconsistent state if split() is called multiple times or if other code expects self.idx to always exist.
| elif split == 'valid': | ||
| self.neural = self.neural[valid_idx] | ||
| self.index = self.index[valid_idx] | ||
| self.idx = valid_idx |
There was a problem hiding this comment.
The method modifies the instance state by setting self.idx, but this attribute is not defined in init. This could lead to inconsistent state if split() is called multiple times or if other code expects self.idx to always exist.
During refactoring of the pivae synthetic data we did not pull over the
splitfunction, which causes task.py to fail due to these lines:CEBRA/third_party/pivae/task.py
Lines 220 to 221 in e982248
This pulls over the split function to the main package.
TODO: check if the pivae runs now