Hi,
Historically, scipy assumed that the default floating-point type is a double-precision float64. Now that we’re moving towards the array API world, we face torch and jax which default to float32, unless explicitly asked.
Thus, there’s a question of how we test scipy on these new frameworks. Note that here I am only talking about the scipy test suite.
End users are free to configure their scipy-using workloads as they see fit.
At the moment, scipy tests are inconsistent: we configure JAX to use default to float64 but torch keeps defaulting to float32.
This leads to having to write tests along the lines of
xp_assert_close(..., atol=1e-7 if is_torch(xp) else 1e-15)
The ability to test things in the f32 world is apparently important to some of us, and there is a body of tests which were converted to use this kind of pattern.
The use of float64 for jax is needed AFAICS due to jax’s specifics. Not using fp64 for torch is, IMO, an oversight.
There’s a discussion in a recent PR: ENH: signal.windows: add array API support (take 2) by ev-br · Pull Request #21783 · scipy/scipy · GitHub
Proposal:
- configure the scipy test suite for torch to default to float64
- cook up a backend-agnostic way of querying the default fp dtype and use it in tests. Then, tests would use
atol=1e-7 if default_fp_type(xp) == xp.float32 else xp.float64
instead of if is_torch(xp)
. (this is using a made-up default_fp_type
function, and of course we can spell it in a different way)
- if desired, add a separate CI run with fp32 as the default.
In a recent community meeting, Tyler mentioned that there are even emerging frameworks that do not have float64 at all. If we want to test scipy with those, we could add a separate CI run for those, and explicitly skip tests of those things which hard require float64. That would be a separate effort, IMO.
Thoughts?
Evgeni