Speed up find_pairs by using a Numba-optimised kd-tree for searching M to build S#104
Speed up find_pairs by using a Numba-optimised kd-tree for searching M to build S#104robinmessage wants to merge 3 commits intomainfrom
Conversation
mdales
left a comment
There was a problem hiding this comment.
LGTM overall at a first pass.
methods/matching/find_pairs.py
Outdated
| k_set = pd.read_parquet(k_parquet_filename) | ||
| k_subset = k_set.sample( | ||
| frac=0.1, | ||
| frac=1, |
There was a problem hiding this comment.
Isn't this just the same as k_subset = k_set?
There was a problem hiding this comment.
Yes; I didn't clean this up yet as I wasn't sure if we definitely wanted to change to 100% of K instead of 10%.
methods/matching/find_pairs.py
Outdated
|
|
||
| # Find categories in K | ||
| hard_match_category_columns = [k[hard_match_columns].to_numpy() for _, k in k_set.iterrows()] | ||
| hard_match_categories = {k.tobytes(): k for k in hard_match_category_columns} |
There was a problem hiding this comment.
I think this needs some comment here about what's happening - I had to work through this with real data to figure out the trick that's going on here to get unique columns. Given you don't use the keys ever again, I'd rather you called values here, rather than in make_s_set_mask, as again that'd make it a bit more obvious you're using this to find unique sets of columns. (assuming I understand what's happening here).
There was a problem hiding this comment.
Fair point, I'll tidy this
methods/matching/find_pairs.py
Outdated
| return s_include, k_miss | ||
|
|
||
| @jit(nopython=True, fastmath=True, error_model="numpy") | ||
| def make_s_set_mask_numba( |
There was a problem hiding this comment.
Very happy to delete this version.
methods/matching/find_pairs.py
Outdated
| k_subset_dist_hard = np.ascontiguousarray(k_subset[hard_match_columns].to_numpy()).astype(np.int32) | ||
|
|
||
| # Methodology 6.5.5: S should be 10 times the size of K, in order to achieve this for every | ||
| # pixel in the subsample (which is 10% the size of K) we select 100 pixels. |
| if value >= low[d]: | ||
| queue.append(self.lefts[pos]) | ||
| return count | ||
| def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator): |
There was a problem hiding this comment.
I have to confess, due to lack of comments, I only skim reviewed this to try and work out what count was achieving, and then gave up. Which is fine at the prototype stage, but before we merge this some comments to API/algorithm would be useful, as this is quite nuanced I think.
There was a problem hiding this comment.
I've added some docstrings and comments, hopefully that covers what is needed but please. do come back to me on anything else.
|
Thanks for reviewing this @mdales, and you're right it could do with some more comments in the gnarly bits and generally clearing up a bit. I'll do that as soon as I can and bounce it back to you (probably after Easter unfortunately). |
|
@mdales I think I've fixed the stuff you've reviewed and improved the comments on the other parts. |
mdales
left a comment
There was a problem hiding this comment.
LGTM, just a couple of things it'd be nice to tidy up.
| random_state=rng | ||
| ).reset_index() | ||
| # TODO: This assumes the methodolgy is being updated to 100% of K | ||
| k_subset = k_set |
There was a problem hiding this comment.
Can we just collapse this change throughout, and when this is merged we bump versions of both the code and the methodology.
There was a problem hiding this comment.
Will do when I merge
| rand = rand_state[0] + rand_state[3] | ||
| t = rand_state[1] << 17 | ||
| rand_state[2] ^= rand_state[0] | ||
| rand_state[3] ^= rand_state[1] | ||
| rand_state[1] ^= rand_state[2] | ||
| rand_state[0] ^= rand_state[3] | ||
| rand_state[2] ^= t | ||
| rand_state[3] = (rand_state[3] >> 45) | (rand_state[3] << 19) |
There was a problem hiding this comment.
Sad that we have to do this, but I see it's because of performance reasons. Can we at least pull out this code so that we're not baking into the algorithm the encryption method? The methodology does not require this particular algorithm, just we've chosen to use it for performance reasons.
There was a problem hiding this comment.
I'm not sure I quite understand: just pull it out into a function (and hope Numba inlines it), with a comment saying it is a random number but no specific algorithm is needed, and this was just chosen for speed? Or something else? (And we can only do this I suspect if Numba does inline it correctly, and we'll be passing the state around so I'm not sure it'll be particularly clearer)
This has none of the memory sharing optimisations to make this easier to run in parallel (or possible for larger projects). It also has none of the optimisations we talked about to split M into 100 or so subsets and pick from the randomly.
However, it does seem to pick reasonable pairs and have better SMD number than the current version.
It's also worth nothing this code moves to using 100% of K instead of 10% - is that still what we want? It was part of the original motivation for this change.
I'm happy to talk it through with anyone or do any further testing you want to suggest to make sure it is robust before I add the memory sharing optimisation (which requires a fair bit of restructuring to thread everything through but shouldn't change the output).