Skip to content

Commit

Permalink
add typehint
Browse files Browse the repository at this point in the history
Signed-off-by: ashors1 <[email protected]>
  • Loading branch information
ashors1 committed Jan 17, 2025
1 parent 8c81e07 commit 48b0ed9
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions nemo/utils/flops_formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FLOPSConfig:
query_groups: int


def gpt3(mdl):
def gpt3(config: FLOPSConfig):
"""Model FLOPs for GPT3 family"""

vocab_size = LLM_VOCAB_SIZE_MAP["gpt3"]
Expand All @@ -40,7 +40,7 @@ def gpt3(mdl):
) * (3 * mdl.layers) + (6 * mdl.gbs * mdl.enc_seq_len * mdl.hs * vocab_size)


def llama2(mdl):
def llama2(config: FLOPSConfig):
"""Model FLOPs for llama2 family"""
vocab_size = LLM_VOCAB_SIZE_MAP["llama2"]

Expand All @@ -60,7 +60,7 @@ def llama2(mdl):
)


def llama3(mdl):
def llama3(config: FLOPSConfig):
"""Model FLOPs for llama3 family"""
vocab_size = LLM_VOCAB_SIZE_MAP["llama3"]

Expand All @@ -80,7 +80,7 @@ def llama3(mdl):
)


def nemotron(mdl):
def nemotron(config: FLOPSConfig):
"""Model FLOPs for nemotron family"""
vocab_size = LLM_VOCAB_SIZE_MAP["nemotron"]

Expand All @@ -95,32 +95,32 @@ def nemotron(mdl):
+ (12 * mdl.query_groups / mdl.attention_heads)
+ (12 * mdl.ffn_hs / mdl.hs)
+ (12 * mdl.enc_seq_len / mdl.hs)
+ (6 * vocab_size / (mdl.layers * mdl.hs))
+ (6 * vocab_size / (mdl.layers * config.hs))
)
)


def mixtral(mdl):
def mixtral(config: FLOPSConfig):
"""Model FLOPs for mixtral family"""
vocab_size = LLM_VOCAB_SIZE_MAP["mixtral"]

return (
mdl.gbs
* mdl.enc_seq_len
* mdl.layers
* mdl.hs
* mdl.hs
config.gbs
* config.enc_seq_len
* config.layers
* config.hs
* config.hs
* (
12
+ (12 * mdl.query_groups / mdl.attention_heads)
+ (18 * mdl.moe_router_topk * mdl.ffn_hs / mdl.hs)
+ (12 * mdl.enc_seq_len / mdl.hs)
+ (6 * vocab_size / (mdl.layers * mdl.hs))
+ (12 * config.query_groups / config.attention_heads)
+ (18 * config.moe_router_topk * config.ffn_hs / config.hs)
+ (12 * config.enc_seq_len / config.hs)
+ (6 * vocab_size / (config.layers * config.hs))
)
)


def bert(mdl):
def bert(config: FLOPSConfig):
"""Model FLOPs for BERT family"""
vocab_size = LLM_VOCAB_SIZE_MAP["bert"]

Expand Down

0 comments on commit 48b0ed9

Please sign in to comment.