Skip to content

Fix JVPs of power, divmod, and slice_update with partly traced inputs#3636

Open
qflen wants to merge 1 commit into
ml-explore:mainfrom
qflen:fix/jvp-power-divmod-slice-update
Open

Fix JVPs of power, divmod, and slice_update with partly traced inputs#3636
qflen wants to merge 1 commit into
ml-explore:mainfrom
qflen:fix/jvp-power-divmod-slice-update

Conversation

@qflen

@qflen qflen commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

Fixes #3634

Same class as #3627/#3629 (fixed by #3633):

Primitive::jvp implementations that break the packed-tangents convention (tangents[i] is the tangent of input argnums[i]; a jvp returns one tangent per output).

jvp bug fix
Power delegated to vjp, which scales every partial by cotangents[0], so with both inputs traced both partials used the first input's tangent and the second tangent was ignored one vjp call per traced input with its own tangent as the cotangent, summed
DivMod returned one tangent for a two-output primitive, so the second output never entered the tangent map and its consumer was invoked with empty argnums/tangents — segfault return one zero tangent per output (divmod stays non-differentiable, matching its vjp)
DynamicSliceUpdate ignored argnums and read tangents[0]/tangents[1] unconditionally — OOB segfault with one traced input, mispaired tangents for argnums = {0, 2} untraced inputs get zero tangents; traced start indices throw, matching the vjp

Before

_, (j,) = mx.jvp(lambda a, b: a**b, [mx.array(2.0), mx.array(3.0)],
                 [mx.array(0.0), mx.array(1.0)])
# j = 0.0, silently wrong (vjp and finite differences give ln(2)·8 ≈ 5.5452)

The divmod and slice_update repros in #3634 segfaulted.

After

j = 5.5452, jvp matches vjp and finite differences, and all repros pass.

Added a regression test covering tangent pairing for power, per-output tangents for divmod (including a consumer of the second output), and dynamic slice_update with each subset of traced inputs. Each case fails on unfixed main (wrong value for power, segfault for the other two).

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] JVPs of power, divmod, and slice_update mishandle partly traced inputs

2 participants