|
8 | 8 |
|
9 | 9 | # tpl imports |
10 | 10 | import torch |
11 | | -from transformers import pipeline |
| 11 | +from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer |
12 | 12 |
|
13 | 13 | # local imports |
14 | 14 | from utils import BalancedBracketsCriteria, PromptDataset, clean_output, get_inference_config |
|
33 | 33 | parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)') |
34 | 34 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size for generation (default: 8)') |
35 | 35 | parser.add_argument('--prompted', action='store_true', help='Use prompted generation. See StarCoder paper (default: False)') |
| 36 | +device_group = parser.add_mutually_exclusive_group() |
| 37 | +device_group.add_argument('--device_map', help='Path to the device map JSON file or the string "auto"') |
| 38 | +device_group.add_argument('--device', type=int, help='Device to use for inference') |
| 39 | +device_group.add_argument('--axonn', action='store_true', help='Use AxoNN for inference') |
36 | 40 | args = parser.parse_args() |
37 | 41 |
|
38 | 42 | """ Load prompts """ |
|
96 | 100 | # and repeat them for however many samples we want to generate per prompt |
97 | 101 | prompts_repeated = [p for p in prompts for _ in range(args.num_samples_per_prompt)] |
98 | 102 |
|
| 103 | +""" Set device kwarg for inference """ |
| 104 | +device_kwarg = {} |
| 105 | +USE_AXONN = False |
| 106 | +if args.device_map: |
| 107 | + if args.device_map == "auto": |
| 108 | + device_kwarg["device_map"] = "auto" |
| 109 | + else: |
| 110 | + with open(args.device_map, 'r') as json_file: |
| 111 | + device_map = json.load(json_file) |
| 112 | + device_kwarg["device_map"] = device_map |
| 113 | +elif args.device: |
| 114 | + device_kwarg["device"] = args.device |
| 115 | +elif args.axonn: |
| 116 | + from mpi4py import MPI |
| 117 | + from axonn import axonn as ax |
| 118 | + from modify_llama import monkey_patch_llama_with_axonn |
| 119 | + world_size = MPI.COMM_WORLD.Get_size() |
| 120 | + rank = MPI.COMM_WORLD.Get_rank() |
| 121 | + if rank == 0: |
| 122 | + print(f"Using AxoNN with {world_size} GPUs.") |
| 123 | + ax.init(G_data=1, G_inter=1, G_intra_r=1, G_intra_c=1, G_intra_d=world_size) |
| 124 | + if "llama" in args.model: |
| 125 | + monkey_patch_llama_with_axonn() |
| 126 | + USE_AXONN = True |
| 127 | + device_kwarg["device"] = "cuda" |
| 128 | +else: |
| 129 | + device_kwarg["device"] = 0 |
| 130 | + |
| 131 | +""" Load model and tokenizer """ |
| 132 | +model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=inference_config.get_dtype()) |
| 133 | +if USE_AXONN: |
| 134 | + model = model.to("cuda") |
| 135 | +tokenizer = AutoTokenizer.from_pretrained(args.model) |
| 136 | + |
99 | 137 | """ Initialize HuggingFace pipeline for generation """ |
100 | | -generator = pipeline(model=args.model, torch_dtype=inference_config.get_dtype(), device=0) |
| 138 | +generator = pipeline(task='text-generation', model=model, tokenizer=tokenizer, **device_kwarg) |
101 | 139 | inference_config.init_padding(generator.tokenizer) |
102 | 140 |
|
103 | 141 | """ Create a prompt data set to pass to generate method """ |
104 | 142 | prompt_dataset = PromptDataset([inference_config.format_prompt(p["prompt"]) for p in prompts_repeated]) |
| 143 | +if USE_AXONN: |
| 144 | + prompt_dataset = prompt_dataset#.to("cuda") |
105 | 145 | generated_outputs = generator( |
106 | 146 | prompt_dataset, |
107 | 147 | max_new_tokens=args.max_new_tokens, |
|
114 | 154 | ) |
115 | 155 |
|
116 | 156 | """ Iterate over prompts and generate code """ |
117 | | -if not args.restart and args.cache is not None: |
| 157 | +if not args.restart and args.cache is not None and os.path.exists(args.cache): |
118 | 158 | with open(args.cache, 'r') as jsonl_file: |
119 | 159 | responses = [json.loads(line) for line in jsonl_file] |
120 | 160 | responses = [r for r in responses if r["temperature"] == args.temperature and r["prompted"] == args.prompted |
|
140 | 180 | responses.append(cur_prompt) |
141 | 181 |
|
142 | 182 | if not args.restart and args.cache is not None: |
143 | | - with open(args.cache, 'a') as jsonl_file: |
144 | | - jsonl_file.write(json.dumps(cur_prompt) + "\n") |
| 183 | + if not USE_AXONN or rank == 0: |
| 184 | + with open(args.cache, 'a') as jsonl_file: |
| 185 | + jsonl_file.write(json.dumps(cur_prompt) + "\n") |
145 | 186 |
|
146 | 187 | if idx != 0 and idx % args.num_samples_per_prompt == 0: |
147 | 188 | print(f"Tokens per second: {total_tokens / (time.time() - start_time):.2f}") |
|
151 | 192 | print(f"Generated {len(responses)} code samples in {end_time - start_time:.2f} seconds ({tokens_per_second:.2f} tokens per second)") |
152 | 193 |
|
153 | 194 | """ Save responses to JSON file """ |
154 | | -with open(args.output, 'w') as output_file: |
155 | | - json.dump(responses, output_file, indent=4) |
| 195 | +if not USE_AXONN or rank == 0: |
| 196 | + with open(args.output, 'w') as output_file: |
| 197 | + json.dump(responses, output_file, indent=4) |
0 commit comments