-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Check for zero-param operators in LiftTransformParams #16595
[Transform] Check for zero-param operators in LiftTransformParams #16595
Conversation
This PR depends on changes made in #16594, and is marked as a draft until it lands. |
|
Good question, and this case should be handled, by allowing zero-param operators to potentially appear in both functions. (See this unit test for how this looks in practice.) While the case isn't ever explicitly handled, it instead results from the overall lifting.
Since the zero-param operators don't depend on runtime parameters, they appear in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize it's a draft but I had a look anyway. Since the code changes also included those from #16594, it was a little difficult to see what had changed. I didn't see anything to take issue with, though it was a little less than obvious to see how the delta from #16594 accomplished the stated purpose, but I think I follow it.
@@ -169,13 +169,11 @@ class CallNode : public ExprNode { | |||
|
|||
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { | |||
// skip sinfo_args check for primitive ops. | |||
equal->MarkGraphNode(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MarkGraphNode
function should be called for IR nodes that have reference equality, such as variables. Using it with the relax::CallNode
means that results from StructuralEqual
will be dependent on how a relax::Call
node was constructed, even if they have identical contents.
Relevant to this PR, if R.zeros([16], "int32")
appears in both main
and transform_params
, the expected output generated by the TVMScript parser would have two different relax::Call
objects, while the output of LiftTransformParams
would use the same relax::Call
object. Because the relax::Call
objects were being checked for analogous reference equality, this would cause StructuralEqual
to erroneously report the test as failing.
I've added a unit test to specifically exercise this behavior, rather than implicitly relying on the tests for LiftTransformParams
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
auto upstream_vars = FreeVars(bound_value); | ||
bool depends_on_compile_time_param = std::any_of( | ||
upstream_vars.begin(), upstream_vars.end(), | ||
[&](const Var& var) -> bool { return info_.requires_compile_time_param.count(var); }); | ||
if (depends_on_compile_time_param) { | ||
info_.requires_compile_time_param.insert(binding->var); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the logic that is meant to handle zero-param operators? I had a really hard time figuring out what the change was from #16594 and this was the only major area that came up in the diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section along with this line, yes. This section collects information to determine if this parameter depends, directly or indirectly, on one of the model weights. The linked line uses the collected information to determine if the transform_params
should output a variable.
I've updated the comment here to indicate why the additional information is being collected.
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress; | ||
explicit SuppressCompileTime( | ||
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress) | ||
: to_suppress(to_suppress) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question separate from the review: Are we going with the convention of having member var names end in an underscore? I'm not a partisan on that one, but we should be consistent. @tqchen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TVM convention, as I understand it, is to have private members have a trailing underscore, while public members do not. Often if I'm making quick subclasses of ExprMutator
, I'll have a struct with public fields rather than a class with private fields. Since the class definition already occurs within a function scope, the visibility of the entire class is already restricted, and additional visibility restrictions of private members isn't necessary.
That said, I've updated it to be a class with a private to_suppress_
. I figure that if a deviation from convention is big enough to raise a question, it's big enough to avoid the deviation altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just curious, I don't think it's worthy of holding up a review or anything like that. Having a different rule for public/private members is reasonable.
Apologies there. Once #16594 lands, the "Files Changed" tab in this PR should update to only show the changes unique to this PR. In the meantime, this PR branch has its changes in a separate commit (link), where they can be viewed separately from the #16594 changes. |
Thank you for the changes. The new unit test and the new comment are both helpful. |
Prior to this commit, `LiftTransformParams` would extract out all variable binding that have no runtime dependencies. As a result, expressions such as `R.zeros([16], "int32")` would be extracted out into the parameter transformation, even though they do not depend on any parameters. This commit updates `LiftTransformParams` to only output variables that depend on at least one compile-time parameter. The unit test for this functionality also found that `relax::Call` was erroneously calling `MarkGraphNode` in `SEqualReduce` and `SHashReduce`. This should only be called for nodes that have have reference equality, such as `relax::Var`, and not for composite objects. This caused erroneous failures in the unit test when two instances of `R.zeros([16], "int32")` were being compared by reference equality in `StructuralEqual`. These extra calls to `MarkGraphNode` have been removed.
8d69261
to
2390033
Compare
With #16594 landed, I've rebased this PR on top of it, and it is now ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for responding to previous feedback.
…ache#16595) Prior to this commit, `LiftTransformParams` would extract out all variable binding that have no runtime dependencies. As a result, expressions such as `R.zeros([16], "int32")` would be extracted out into the parameter transformation, even though they do not depend on any parameters. This commit updates `LiftTransformParams` to only output variables that depend on at least one compile-time parameter. The unit test for this functionality also found that `relax::Call` was erroneously calling `MarkGraphNode` in `SEqualReduce` and `SHashReduce`. This should only be called for nodes that have have reference equality, such as `relax::Var`, and not for composite objects. This caused erroneous failures in the unit test when two instances of `R.zeros([16], "int32")` were being compared by reference equality in `StructuralEqual`. These extra calls to `MarkGraphNode` have been removed.
Prior to this commit,
LiftTransformParams
would extract out all variable binding that have no runtime dependencies. As a result, expressions such asR.zeros([16], "int32")
would be extracted out into the parameter transformation, even though they do not depend on any parameters.This commit updates
LiftTransformParams
to only output variables that depend on at least one compile-time parameter.The unit test for this functionality also found that
relax::Call
was erroneously callingMarkGraphNode
inSEqualReduce
andSHashReduce
. This should only be called for nodes that have have reference equality, such asrelax::Var
, and not for composite objects. This caused erroneous failures in the unit test when two instances ofR.zeros([16], "int32")
were being compared by reference equality inStructuralEqual
. These extra calls toMarkGraphNode
have been removed.