solve_triangular
gradients are wrong for certain options
#1230
Labels
solve_triangular
gradients are wrong for certain options
#1230
Description
Right now we assume the inputs to
solve_triangular
are always consistent with the settings passed to the function. For example, if one setslower
, we assume that the incoming matrix really is lower triangular. If you setunit_diag
, we assume there are ones on the diag, and so on.This isn't actually how the function works. If you set
unit_diag
for instead, the main diagonal is simply ignored in computation; the function doesn't care whether it literally is unit diagonal or not. This is useful when working with LU factors for example. Instead of storing L (lower with unit diagonal) and U (upper) separately, it is common to storeLU = tril(L, k=-1) + triu(U)
. Thensolve_triangular(LU, b, lower=True, unit_diagonal=True)
is exactly equivalent to doingsolve(L, b)
-- no need rebuild the individual matrices.At the moment the analytical gradients don't take this into account, so setting
unit_diag=True
on a non-unit-diagonal matrix will have non-zero diagonal sensitivites, which is wrong.Gradients are also incorrect when
trans != "N"
, because we have no special logic inSolveBase.L_op
to handle the transpose argument (see #1229)The text was updated successfully, but these errors were encountered: