Differentiable generic functions and type constraints #5749
-
Hello, I have just started using slang for its autodiff capability so this might be trivial or a non-issue. This is with slang 2024.11 and Vulkan, but I don't see any related change since 2024.11. I am confused with the behavior of differentiable generic functions. In my project, I've defined a bunch of generic math utilities functions such as:
Then, I wanted to write a differentiable function that called
However, I found that the derivative
And then My question is: Is this the intended behavior? It seems to me that when marking generic functions that accept both differentiable and non-differentiable types as differentiable, one should expect one of the following two reasonable outcomes:
However, currently neither of the above happens and we simply get a silent failure even with differentiable types (float). This seems to be confusing (and took me a long while to figure out...) Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Thank you for opening a discussion on this! I understand that having gradients disappear silently can be very annoying to track down. We do have a gradient data-flow analysis pass to warn about these cases (i.e. requires the 'no_diff' modifier when calling a non-differentiable function with differentiable inputs). It looks like this case doesn't get detected because the function is marked [Differentiable] This is a slightly tricky situation. You are correct in that its confusing that
is allowed but silently returns a 0 since One fix is to warn users if the function has no side-effects & no differentiable inputs/outputs. However, more complicated cases with a mix of differentiable and non-differentiable types will still pose a problem
Call A more solid fix would be to extend the |
Beta Was this translation helpful? Give feedback.
Thank you for opening a discussion on this! I understand that having gradients disappear silently can be very annoying to track down. We do have a gradient data-flow analysis pass to warn about these cases (i.e. requires the 'no_diff' modifier when calling a non-differentiable function with differentiable inputs). It looks like this case doesn't get detected because the function is marked [Differentiable]
This is a slightly tricky situation. You are correct in that its confusing that
is allowed but silently returns a 0 since
__BuiltinArithmeticType
is non-differentiable.One fix is to warn users if the…