Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optional Activation node to NodeUnit #22888

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

centwang
Copy link
Contributor

No description provided.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines 111 to 113
const Node& target_node_;
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present.
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const Node& target_node_;
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present.
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
const Node& target_node_;
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present.
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs

@@ -87,6 +89,7 @@ class NodeUnit {
ProviderType GetExecutionProviderType() const noexcept;

const Node& GetNode() const noexcept { return target_node_; }
const Node* GetActivationNode() const noexcept { return p_activation_node_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a comment explaining what the 'activation' node is and how it is expected to be used.

IIUC

  • if you're using the QDQ node unit for the quantized version of the target node the activation node can be ignored as it's made redundant by the values of the Q node
    • so it's not really a 'fusion' per se as we're not combining the values of the Clip/Relu with the Q, we're ignoring it
  • if you are falling back to higher precision and dropping the DQ/Q nodes, you need to keep both the target node and activation node if present

If that's correct I'd almost be inclined to call it something like redundant_clip_node (given Relu is a form of Clip).

Also as the OpenVINO EP (IIRC) is doing the fallback to higher precision does it need an update to be aware of the activation node in the NodeUnit?

}
return true;
}
bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp,
bool GetQScalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp,

Would be good to refactor some of the utils here as there seems to be a fair bit of duplication.

e.g. maybe a general purpose helper that reads the scale and zp values (scalar or otherwise), and has a bool to indicate if they're scalar. that helper could be used by many of the utils here.

int32_t& data_type) {
assert(q_node.OpType() == QOpName);
const auto& q_input_defs = q_node.InputDefs();
if (q_input_defs.size() != 3 || !q_input_defs[2]->Exists()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zp is optional and defaults to zero, so do we need to require 3 inputs here?

@@ -49,7 +49,7 @@ std::vector<const Node*> FindQDQNodes(const GraphViewer& graph_viewer, const Nod
}
} // namespace

bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't use hungarian notation anywhere else, so could we use activation_node instead of p_activation_node?

Comment on lines +227 to +229
if (p_activation_node) {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the activation node is made redundant by the Q, what's the reason we can't create a QDQ node unit for this sort of operator?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants