diff --git a/src/skillmodels/qr.py b/src/skillmodels/qr.py index c690eac7..965b9cc3 100644 --- a/src/skillmodels/qr.py +++ b/src/skillmodels/qr.py @@ -19,16 +19,13 @@ def _householder(r: jax.Array, tau: jax.Array): """ m = r.shape[0] n = tau.shape[0] + r = jnp.tril(jnp.fill_diagonal(r, 1, inplace=False)) # Calculate Householder Vector which is saved in the lower triangle of R v1 = jnp.expand_dims(r[:, 0], 1) - v1 = v1.at[0:0].set(0) - v1 = v1.at[0].set(1) h = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) # Multiply all Householder Vectors Q = H(1)*H(2)...*H(n) for i in range(1, n): vi = jnp.expand_dims(r[:, i], 1) - vi = vi.at[0:i].set(0) - vi = vi.at[i].set(1) h = h - tau[i] * (h @ vi) @ jnp.transpose(vi) return h[:, :n]