-
Notifications
You must be signed in to change notification settings - Fork 5
deepseek r1 running #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
deepseek r1 running #102
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple comments.
# self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) | ||
# self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have changed how we do caching as seen in the bellow lines. Do we have a reason to keep these comments?
# self.k_cache[:bsz, start_pos:end_pos] = k | ||
# self.v_cache[:bsz, start_pos:end_pos] = v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment bellow
# assert args.n_routed_experts % world_size == 0 | ||
# self.n_routed_experts = args.n_routed_experts | ||
# self.n_local_experts = args.n_routed_experts // world_size | ||
# self.n_activated_experts = args.n_activated_experts | ||
# self.experts_start_idx = rank * self.n_local_experts | ||
# self.experts_end_idx = self.experts_start_idx + self.n_local_experts | ||
self.gate = Gate(model_args) | ||
# self.experts = nn.ModuleList( | ||
# [ | ||
# Expert(args.dim, args.moe_inter_dim) | ||
# if self.experts_start_idx <= i < self.experts_end_idx | ||
# else None | ||
# for i in range(self.n_routed_experts) | ||
# ] | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have reason to leave these as comments rather than removing them?
# if seqlen > 1: | ||
# mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have reason to leave these as comments rather than removing them?
@@ -544,6 +543,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |||
else: | |||
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) | |||
indices = group_scores.topk(self.topk_groups, dim=-1)[1] | |||
print('i am here') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test print. Please remove. If appropriate, we could use some logging at info level.
I see there are a couple other prints on model functions. I would consider applying the same criteria to those.
@@ -26,6 +26,7 @@ WORKDIR /workspaces | |||
# Install torchax | |||
RUN git clone https://github.com/pytorch/xla.git | |||
WORKDIR /workspaces/xla/torchax | |||
RUN git checkout hanq_torchax1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a temporary fix while pytorch/xla@master...hanq_torchax1 is not merged?
|
||
tokens = name.split(".") | ||
for i, t in enumerate(tokens): | ||
if is_integer(t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using something like type(t) is int
?
name0 = "tp0" | ||
# name1 = "tp1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add the definition of "name0"?
@pytest.mark.deepseek | ||
def test_single_device_compile(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test no longer useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file contains many helpers for sharding. Should we create a file that contains the sharding tooling written here?
In that case, adding unit tests for these sharding methods could be helpful
No description provided.