Broadcasting behavior within scipy.optimize.elementwise

I recently discovered the elementwise optimizers, and they seem to be ideally suited to my use cases. However, I’m running into an issue with the broadcasting of *args, and would appreciate if anyone was aware of a workaround. Or perhaps this could become an enhancement request.

Essentially I’m trying to find the roots x of a function f(x, y) where x is an array of shape (N, M, 1) (last dimension currently retained for broadcasting purposes) and y is an array of shape (N, M, L) which is held constant. Essentially it’s an (N, M) array of scalar root solves which have an extra argument which is a vector of shape (L,), and ideally I would like to use elementwise rather than iterate over the first two dimensions.

In my initial testing, it appears that x is broadcast up to shape (N, M, L), and then y is passed in elementwise rather than retaining the (L,)-sized dimension. Is there a way to preserve this dimension for an elementwise find_root? If not, would this be worth an enhancement request?

I suppose I could split y into L arrays of shape (N, M, 1) and stack them inside the residual function, but that doesn’t seem like it would be very efficient.

I went ahead and tried this and it worked. This is the basic outline of what I did:

x = np.ones((N, M))
y = np.ones((N, M, L))

def residual(x, *y_split):
    y = np.stack(y_split, axis=-1)
    return f(x, y)

y_split = [y_elem[..., 0] for y_elem in np.split(y, y.shape[-1], axis=-1)]
result = scipy.optimize.elementwise.find_root(residual, bracket, args=(*y_split,))

I’m still open to suggestions and improvements!

When it rains, it pours, apparently! Please join the conversation in gh-24869. I just helped another user with a similar problem. We discussed a PR (already open, gh-24657) and documentation improvements that will make this sort of thing easier.

Your solution is interesting! Another approach is to not pass y in as an argument and instead use arguments i = np.arange(N)[:, np.newaxis] and j = np.arange(M) as args=(i, j), and use these to index into your array y. I’ll post an example solution in the issue.

I knew I should have done some more searching… I’ll take a look at those, thank you!

And the indexing approach is also an interesting workaround. I try not to rely on that sort of scoping trick (accessing y inside the function even though it’s only defined outside its scope), but it would admittedly work pretty nicely in my use case.

EDIT: Ah, but of course I forgot about binding it in the function definition, e.g.

def residual(x, i, j, y=y): ...

So that’s not even an issue!

1 Like