|
| 1 | +import torch |
| 2 | +import deepspeed |
| 3 | +import megatron |
| 4 | +from megatron import get_args |
| 5 | +from megatron import mpu |
| 6 | +from megatron.checkpointing import load_checkpoint |
| 7 | +from megatron.initialize import initialize_megatron |
| 8 | +from megatron.model import GPTModel |
| 9 | +from megatron.training import get_model |
| 10 | +from megatron.text_generation_utils import generate_samples_eval |
| 11 | + |
| 12 | + |
| 13 | +def model_provider(pre_process=True, post_process=True): |
| 14 | + model = GPTModel( |
| 15 | + num_tokentypes=0, |
| 16 | + parallel_output=False, |
| 17 | + pre_process=pre_process, |
| 18 | + post_process=post_process, |
| 19 | + return_moe_loss=False, |
| 20 | + ) |
| 21 | + return model |
| 22 | + |
| 23 | + |
| 24 | +def add_text_generate_args(parser): |
| 25 | + """Text generation arguments.""" |
| 26 | + group = parser.add_argument_group(title="text generation") |
| 27 | + |
| 28 | + group.add_argument( |
| 29 | + "--temperature", type=float, default=1.0, help="Sampling temperature." |
| 30 | + ) |
| 31 | + group.add_argument( |
| 32 | + "--greedy", action="store_true", default=False, help="Use greedy sampling." |
| 33 | + ) |
| 34 | + group.add_argument("--top_p", type=float, default=0.0, help="Top p sampling.") |
| 35 | + group.add_argument("--top_k", type=int, default=0, help="Top k sampling.") |
| 36 | + group.add_argument( |
| 37 | + "--out-seq-length", |
| 38 | + type=int, |
| 39 | + default=1024, |
| 40 | + help="Size of the output generated text.", |
| 41 | + ) |
| 42 | + group.add_argument( |
| 43 | + "--sample-input-file", |
| 44 | + type=str, |
| 45 | + default=None, |
| 46 | + help="Get input from file instead of interactive mode, " |
| 47 | + "each line is an input.", |
| 48 | + ) |
| 49 | + group.add_argument( |
| 50 | + "--sample-output-file", |
| 51 | + type=str, |
| 52 | + default=None, |
| 53 | + help="Output file got from --sample-input-file", |
| 54 | + ) |
| 55 | + group.add_argument( |
| 56 | + "--num-samples", |
| 57 | + type=int, |
| 58 | + default=0, |
| 59 | + help="Number of samples to generate unconditionally, " |
| 60 | + "defaults to 0 and interactive conditional sampling", |
| 61 | + ) |
| 62 | + group.add_argument( |
| 63 | + "--genfile", type=str, help="Output file when generating unconditionally" |
| 64 | + ) |
| 65 | + group.add_argument( |
| 66 | + "--recompute", |
| 67 | + action="store_true", |
| 68 | + help="During generation recompute all attention " |
| 69 | + "instead of using previously computed keys/values.", |
| 70 | + ) |
| 71 | + group.add_argument( |
| 72 | + "--context-tokens", type=str, default="DeepSpeed is the greatest" |
| 73 | + ) |
| 74 | + group.add_argument("--max-tokens", type=int, default=50) |
| 75 | + |
| 76 | + return parser |
| 77 | + |
| 78 | + |
| 79 | +if __name__ == "__main__": |
| 80 | + # initialize megatron |
| 81 | + initialize_megatron( |
| 82 | + extra_args_provider=add_text_generate_args, |
| 83 | + args_defaults={ |
| 84 | + "tokenizer_type": "GPT2BPETokenizer", |
| 85 | + "no_load_rng": True, |
| 86 | + "no_load_optim": True, |
| 87 | + }, |
| 88 | + ) |
| 89 | + args = get_args() |
| 90 | + |
| 91 | + # setup model |
| 92 | + model = get_model(model_provider) |
| 93 | + _ = load_checkpoint(model, None, None) |
| 94 | + model = model[0] |
| 95 | + if args.ds_inference: |
| 96 | + engine = deepspeed.init_inference( |
| 97 | + model=model, |
| 98 | + mp_size=args.tensor_model_parallel_size, |
| 99 | + tensor_parallel={"mpu": mpu}, |
| 100 | + dtype=torch.half, |
| 101 | + replace_with_kernel_inject=True, |
| 102 | + moe_experts=args.num_experts, |
| 103 | + moe_type=args.mlp_type, |
| 104 | + ) |
| 105 | + model = engine.module |
| 106 | + |
| 107 | + # generate output |
| 108 | + generate_samples_eval( |
| 109 | + model, args.context_tokens, 1, 0 |
| 110 | + ) # Just so we don't get log output from DeepSpeed (this should be removed once we improve logging in DeepSpeed) |
| 111 | + print("===START OUTPUT===") |
| 112 | + print(generate_samples_eval(model, args.context_tokens, args.max_tokens, 0)) |
| 113 | + print("===END OUTPUT===") |
0 commit comments