|
8 | 8 |
|
9 | 9 | # tpl imports
|
10 | 10 | import torch
|
11 |
| -from transformers import pipeline |
| 11 | +from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer |
12 | 12 |
|
13 | 13 | # local imports
|
14 | 14 | from utils import BalancedBracketsCriteria, PromptDataset, clean_output, get_inference_config
|
|
33 | 33 | parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)')
|
34 | 34 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size for generation (default: 8)')
|
35 | 35 | parser.add_argument('--prompted', action='store_true', help='Use prompted generation. See StarCoder paper (default: False)')
|
| 36 | +device_group = parser.add_mutually_exclusive_group() |
| 37 | +device_group.add_argument('--device_map', help='Path to the device map JSON file or the string "auto"') |
| 38 | +device_group.add_argument('--device', type=int, help='Device to use for inference') |
| 39 | +device_group.add_argument('--axonn', action='store_true', help='Use AxoNN for inference') |
36 | 40 | args = parser.parse_args()
|
37 | 41 |
|
38 | 42 | """ Load prompts """
|
|
96 | 100 | # and repeat them for however many samples we want to generate per prompt
|
97 | 101 | prompts_repeated = [p for p in prompts for _ in range(args.num_samples_per_prompt)]
|
98 | 102 |
|
| 103 | +""" Set device kwarg for inference """ |
| 104 | +device_kwarg = {} |
| 105 | +USE_AXONN = False |
| 106 | +if args.device_map: |
| 107 | + if args.device_map == "auto": |
| 108 | + device_kwarg["device_map"] = "auto" |
| 109 | + else: |
| 110 | + with open(args.device_map, 'r') as json_file: |
| 111 | + device_map = json.load(json_file) |
| 112 | + device_kwarg["device_map"] = device_map |
| 113 | +elif args.device: |
| 114 | + device_kwarg["device"] = args.device |
| 115 | +elif args.axonn: |
| 116 | + from mpi4py import MPI |
| 117 | + from axonn import axonn as ax |
| 118 | + from modify_llama import monkey_patch_llama_with_axonn |
| 119 | + world_size = MPI.COMM_WORLD.Get_size() |
| 120 | + rank = MPI.COMM_WORLD.Get_rank() |
| 121 | + if rank == 0: |
| 122 | + print(f"Using AxoNN with {world_size} GPUs.") |
| 123 | + ax.init(G_data=1, G_inter=1, G_intra_r=1, G_intra_c=1, G_intra_d=world_size) |
| 124 | + if "llama" in args.model: |
| 125 | + monkey_patch_llama_with_axonn() |
| 126 | + USE_AXONN = True |
| 127 | + device_kwarg["device"] = "cuda" |
| 128 | +else: |
| 129 | + device_kwarg["device"] = 0 |
| 130 | + |
| 131 | +""" Load model and tokenizer """ |
| 132 | +model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=inference_config.get_dtype()) |
| 133 | +if USE_AXONN: |
| 134 | + model = model.to("cuda") |
| 135 | +tokenizer = AutoTokenizer.from_pretrained(args.model) |
| 136 | + |
99 | 137 | """ Initialize HuggingFace pipeline for generation """
|
100 |
| -generator = pipeline(model=args.model, torch_dtype=inference_config.get_dtype(), device=0) |
| 138 | +generator = pipeline(task='text-generation', model=model, tokenizer=tokenizer, **device_kwarg) |
101 | 139 | inference_config.init_padding(generator.tokenizer)
|
102 | 140 |
|
103 | 141 | """ Create a prompt data set to pass to generate method """
|
104 | 142 | prompt_dataset = PromptDataset([inference_config.format_prompt(p["prompt"]) for p in prompts_repeated])
|
| 143 | +if USE_AXONN: |
| 144 | + prompt_dataset = prompt_dataset#.to("cuda") |
105 | 145 | generated_outputs = generator(
|
106 | 146 | prompt_dataset,
|
107 | 147 | max_new_tokens=args.max_new_tokens,
|
|
114 | 154 | )
|
115 | 155 |
|
116 | 156 | """ Iterate over prompts and generate code """
|
117 |
| -if not args.restart and args.cache is not None: |
| 157 | +if not args.restart and args.cache is not None and os.path.exists(args.cache): |
118 | 158 | with open(args.cache, 'r') as jsonl_file:
|
119 | 159 | responses = [json.loads(line) for line in jsonl_file]
|
120 | 160 | responses = [r for r in responses if r["temperature"] == args.temperature and r["prompted"] == args.prompted
|
|
140 | 180 | responses.append(cur_prompt)
|
141 | 181 |
|
142 | 182 | if not args.restart and args.cache is not None:
|
143 |
| - with open(args.cache, 'a') as jsonl_file: |
144 |
| - jsonl_file.write(json.dumps(cur_prompt) + "\n") |
| 183 | + if not USE_AXONN or rank == 0: |
| 184 | + with open(args.cache, 'a') as jsonl_file: |
| 185 | + jsonl_file.write(json.dumps(cur_prompt) + "\n") |
145 | 186 |
|
146 | 187 | if idx != 0 and idx % args.num_samples_per_prompt == 0:
|
147 | 188 | print(f"Tokens per second: {total_tokens / (time.time() - start_time):.2f}")
|
|
151 | 192 | print(f"Generated {len(responses)} code samples in {end_time - start_time:.2f} seconds ({tokens_per_second:.2f} tokens per second)")
|
152 | 193 |
|
153 | 194 | """ Save responses to JSON file """
|
154 |
| -with open(args.output, 'w') as output_file: |
155 |
| - json.dump(responses, output_file, indent=4) |
| 195 | +if not USE_AXONN or rank == 0: |
| 196 | + with open(args.output, 'w') as output_file: |
| 197 | + json.dump(responses, output_file, indent=4) |
0 commit comments