Skip to content

deepseek r1 running #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

deepseek r1 running #102

wants to merge 6 commits into from

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Feb 10, 2025

No description provided.

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple comments.

Comment on lines +430 to +431
# self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
# self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have changed how we do caching as seen in the bellow lines. Do we have a reason to keep these comments?

Comment on lines +419 to +420
# self.k_cache[:bsz, start_pos:end_pos] = k
# self.v_cache[:bsz, start_pos:end_pos] = v
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment bellow

Comment on lines +633 to +647
# assert args.n_routed_experts % world_size == 0
# self.n_routed_experts = args.n_routed_experts
# self.n_local_experts = args.n_routed_experts // world_size
# self.n_activated_experts = args.n_activated_experts
# self.experts_start_idx = rank * self.n_local_experts
# self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(model_args)
# self.experts = nn.ModuleList(
# [
# Expert(args.dim, args.moe_inter_dim)
# if self.experts_start_idx <= i < self.experts_end_idx
# else None
# for i in range(self.n_routed_experts)
# ]
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have reason to leave these as comments rather than removing them?

Comment on lines +768 to +769
# if seqlen > 1:
# mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have reason to leave these as comments rather than removing them?

@@ -544,6 +543,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
print('i am here')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test print. Please remove. If appropriate, we could use some logging at info level.

I see there are a couple other prints on model functions. I would consider applying the same criteria to those.

@@ -26,6 +26,7 @@ WORKDIR /workspaces
# Install torchax
RUN git clone https://github.com/pytorch/xla.git
WORKDIR /workspaces/xla/torchax
RUN git checkout hanq_torchax1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a temporary fix while pytorch/xla@master...hanq_torchax1 is not merged?


tokens = name.split(".")
for i, t in enumerate(tokens):
if is_integer(t):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using something like type(t) is int?

Comment on lines +94 to +95
name0 = "tp0"
# name1 = "tp1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add the definition of "name0"?

@pytest.mark.deepseek
def test_single_device_compile():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test no longer useful?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file contains many helpers for sharding. Should we create a file that contains the sharding tooling written here?

In that case, adding unit tests for these sharding methods could be helpful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants