|
| 1 | +""" |
| 2 | +Apply SCPH force constant corrections to ALAMODE HDF5 files. |
| 3 | +
|
| 4 | +This script reads force constants from an HDF5 file, applies temperature- |
| 5 | +dependent corrections from a dfc2 file, and saves the updated force constants |
| 6 | +to a new HDF5 file. |
| 7 | +""" |
| 8 | + |
| 9 | +import argparse |
| 10 | +import h5py |
| 11 | +import numpy as np |
| 12 | + |
| 13 | + |
| 14 | +class FC2Data: |
| 15 | + """Container for force constant data from HDF5 file.""" |
| 16 | + |
| 17 | + def __init__(self, fname_h5): |
| 18 | + """Load force constant data from HDF5 file.""" |
| 19 | + with h5py.File(fname_h5, 'r') as h5file: |
| 20 | + self.values = h5file['ForceConstants/Order2/force_constant_values'][:] |
| 21 | + self.atom_indices = h5file['ForceConstants/Order2/atom_indices'][:] |
| 22 | + self.coord_indices = h5file['ForceConstants/Order2/coord_indices'][:] |
| 23 | + self.shift_vectors = h5file['ForceConstants/Order2/shift_vectors'][:] |
| 24 | + self.lattice_vectors = h5file['PrimitiveCell/lattice_vector'][:] |
| 25 | + self.fractional_coords = h5file['PrimitiveCell/fractional_coordinate'][:] |
| 26 | + |
| 27 | + def calculate_fractional_shifts(self): |
| 28 | + """Calculate fractional shift vectors for each force constant entry.""" |
| 29 | + inv_lattice = np.linalg.inv(self.lattice_vectors.T) |
| 30 | + shifts_frac = [] |
| 31 | + |
| 32 | + for shift, atoms in zip(self.shift_vectors, self.atom_indices): |
| 33 | + # Convert Cartesian shift to fractional coordinates |
| 34 | + shift_cart_to_frac = inv_lattice @ shift |
| 35 | + # Calculate relative position in fractional coordinates |
| 36 | + shift_frac = np.round( |
| 37 | + shift_cart_to_frac |
| 38 | + - self.fractional_coords[atoms[1]] |
| 39 | + + self.fractional_coords[atoms[0]] |
| 40 | + ) |
| 41 | + shifts_frac.append(shift_frac) |
| 42 | + |
| 43 | + return np.array(shifts_frac) |
| 44 | + |
| 45 | + |
| 46 | +class DFC2Correction: |
| 47 | + """Container for force constant corrections from dfc2 file.""" |
| 48 | + |
| 49 | + def __init__(self, fname_dfc2, temperature): |
| 50 | + """ |
| 51 | + Parse dfc2 correction file for a specific temperature. |
| 52 | + |
| 53 | + Args: |
| 54 | + fname_dfc2: Path to dfc2 correction file |
| 55 | + temperature: Temperature in Kelvin |
| 56 | + """ |
| 57 | + self.temperature = temperature |
| 58 | + self._parse_file(fname_dfc2) |
| 59 | + |
| 60 | + def _parse_file(self, fname_dfc2): |
| 61 | + """Parse the dfc2 file format.""" |
| 62 | + with open(fname_dfc2, 'r') as f: |
| 63 | + # Read primitive cell lattice vectors |
| 64 | + lattice = [] |
| 65 | + for _ in range(3): |
| 66 | + lattice.append([float(x) for x in f.readline().split()]) |
| 67 | + self.lattice = np.array(lattice) |
| 68 | + |
| 69 | + # Read number of atoms and elements |
| 70 | + natoms, _ = [int(x) for x in f.readline().split()] |
| 71 | + _ = f.readline() # element names |
| 72 | + |
| 73 | + # Read atomic positions and indices |
| 74 | + positions = [] |
| 75 | + indices = [] |
| 76 | + for _ in range(natoms): |
| 77 | + line = f.readline().split() |
| 78 | + positions.append([float(x) for x in line[0:3]]) |
| 79 | + indices.append(int(line[3]) - 1) |
| 80 | + |
| 81 | + self.positions = np.array(positions) |
| 82 | + self.indices = np.array(indices) |
| 83 | + |
| 84 | + # Read corrections for the specified temperature |
| 85 | + self._read_corrections(f) |
| 86 | + |
| 87 | + def _read_corrections(self, f): |
| 88 | + """Read correction data for the target temperature.""" |
| 89 | + shifts = [] |
| 90 | + atoms = [] |
| 91 | + coords = [] |
| 92 | + values = [] |
| 93 | + |
| 94 | + current_temp_match = False |
| 95 | + |
| 96 | + for line in f: |
| 97 | + if line.startswith('#'): |
| 98 | + if 'Temp' in line: |
| 99 | + temp_in_file = float(line.split('=')[1].strip()) |
| 100 | + current_temp_match = abs(temp_in_file - self.temperature) < 0.01 |
| 101 | + elif current_temp_match: |
| 102 | + parts = line.split() |
| 103 | + if len(parts) == 8: |
| 104 | + shifts.append([int(x) for x in parts[0:3]]) |
| 105 | + atoms.append([int(parts[3]), int(parts[5])]) |
| 106 | + coords.append([int(parts[4]), int(parts[6])]) |
| 107 | + values.append(float(parts[7])) |
| 108 | + |
| 109 | + if len(values) == 0: |
| 110 | + raise ValueError( |
| 111 | + f"No corrections found at T={self.temperature} K" |
| 112 | + ) |
| 113 | + |
| 114 | + self.shifts = np.array(shifts) |
| 115 | + self.atoms = np.array(atoms) |
| 116 | + self.coords = np.array(coords) |
| 117 | + self.values = np.array(values) |
| 118 | + |
| 119 | + print(f"Loaded {len(values)} corrections at T={self.temperature} K") |
| 120 | + |
| 121 | + |
| 122 | +class FC2Updater: |
| 123 | + """Update force constants with corrections using efficient lookup.""" |
| 124 | + |
| 125 | + @staticmethod |
| 126 | + def create_composite_key(atoms, coords, shifts): |
| 127 | + """ |
| 128 | + Create a structured array for efficient sorting and searching. |
| 129 | + |
| 130 | + Combines atom indices, coordinate indices, and shifts into a single |
| 131 | + sortable key similar to std::tuple in C++. |
| 132 | + """ |
| 133 | + n = len(atoms) |
| 134 | + shifts_rounded = np.round(shifts).astype(np.int32) |
| 135 | + |
| 136 | + dtype = [ |
| 137 | + ('atom0', np.int32), ('atom1', np.int32), |
| 138 | + ('coord0', np.int32), ('coord1', np.int32), |
| 139 | + ('shift0', np.int32), ('shift1', np.int32), ('shift2', np.int32) |
| 140 | + ] |
| 141 | + |
| 142 | + keys = np.zeros(n, dtype=dtype) |
| 143 | + keys['atom0'] = atoms[:, 0] |
| 144 | + keys['atom1'] = atoms[:, 1] |
| 145 | + keys['coord0'] = coords[:, 0] |
| 146 | + keys['coord1'] = coords[:, 1] |
| 147 | + keys['shift0'] = shifts_rounded[:, 0] |
| 148 | + keys['shift1'] = shifts_rounded[:, 1] |
| 149 | + keys['shift2'] = shifts_rounded[:, 2] |
| 150 | + |
| 151 | + return keys |
| 152 | + |
| 153 | + @staticmethod |
| 154 | + def update(fc2_data, shifts_frac, dfc2_correction): |
| 155 | + """ |
| 156 | + Apply corrections to force constants using binary search. |
| 157 | + |
| 158 | + Time complexity: O(M log M + N log M) where M is the number of |
| 159 | + force constants and N is the number of corrections. |
| 160 | + """ |
| 161 | + fc2_updated = fc2_data.values.copy() |
| 162 | + |
| 163 | + # Create sorted lookup table for force constants |
| 164 | + fc2_keys = FC2Updater.create_composite_key( |
| 165 | + fc2_data.atom_indices, |
| 166 | + fc2_data.coord_indices, |
| 167 | + shifts_frac |
| 168 | + ) |
| 169 | + sort_idx = np.argsort(fc2_keys) |
| 170 | + fc2_keys_sorted = fc2_keys[sort_idx] |
| 171 | + |
| 172 | + # Apply each correction |
| 173 | + n_matched = 0 |
| 174 | + n_total = len(dfc2_correction.values) |
| 175 | + |
| 176 | + for i in range(n_total): |
| 177 | + # Skip negligible corrections |
| 178 | + if abs(dfc2_correction.values[i]) < 1e-10: |
| 179 | + continue |
| 180 | + |
| 181 | + # Create query key |
| 182 | + query_key = FC2Updater.create_composite_key( |
| 183 | + dfc2_correction.atoms[i:i+1], |
| 184 | + dfc2_correction.coords[i:i+1], |
| 185 | + dfc2_correction.shifts[i:i+1] |
| 186 | + )[0] |
| 187 | + |
| 188 | + # Binary search (equivalent to C++ std::lower_bound) |
| 189 | + idx = np.searchsorted(fc2_keys_sorted, query_key) |
| 190 | + |
| 191 | + # Check if match found |
| 192 | + if idx < len(fc2_keys_sorted) and fc2_keys_sorted[idx] == query_key: |
| 193 | + original_idx = sort_idx[idx] |
| 194 | + fc2_updated[original_idx] += dfc2_correction.values[i] |
| 195 | + n_matched += 1 |
| 196 | + else: |
| 197 | + print( |
| 198 | + f"Warning: No match for correction {i}: " |
| 199 | + f"atoms={dfc2_correction.atoms[i]}, " |
| 200 | + f"coords={dfc2_correction.coords[i]}, " |
| 201 | + f"shift={dfc2_correction.shifts[i]}" |
| 202 | + ) |
| 203 | + |
| 204 | + print(f"Applied {n_matched}/{n_total} nonzero corrections") |
| 205 | + return fc2_updated |
| 206 | + |
| 207 | + |
| 208 | +class HDF5Writer: |
| 209 | + """Write updated force constants to HDF5 file.""" |
| 210 | + |
| 211 | + @staticmethod |
| 212 | + def copy_with_updated_fc2(fname_in, fname_out, fc2_updated): |
| 213 | + """ |
| 214 | + Copy HDF5 file with updated force constant values. |
| 215 | + |
| 216 | + All data and attributes are preserved except the force constant |
| 217 | + values, which are replaced with the updated values. |
| 218 | + """ |
| 219 | + with h5py.File(fname_in, 'r') as f_in, \ |
| 220 | + h5py.File(fname_out, 'w') as f_out: |
| 221 | + |
| 222 | + def copy_item(name, obj): |
| 223 | + """Recursively copy datasets and groups.""" |
| 224 | + # Replace force constants with updated values |
| 225 | + if name == 'ForceConstants/Order2/force_constant_values': |
| 226 | + f_out.create_dataset(name, data=fc2_updated) |
| 227 | + return |
| 228 | + |
| 229 | + if isinstance(obj, h5py.Dataset): |
| 230 | + # Handle scalar vs array datasets |
| 231 | + data = obj[()] if obj.shape == () else obj[:] |
| 232 | + |
| 233 | + f_out.create_dataset( |
| 234 | + name, |
| 235 | + data=data, |
| 236 | + dtype=obj.dtype, |
| 237 | + compression=obj.compression |
| 238 | + ) |
| 239 | + |
| 240 | + # Copy attributes |
| 241 | + for attr_name, attr_value in obj.attrs.items(): |
| 242 | + f_out[name].attrs[attr_name] = attr_value |
| 243 | + |
| 244 | + elif isinstance(obj, h5py.Group): |
| 245 | + if name not in f_out: |
| 246 | + f_out.create_group(name) |
| 247 | + |
| 248 | + # Copy attributes |
| 249 | + for attr_name, attr_value in obj.attrs.items(): |
| 250 | + f_out[name].attrs[attr_name] = attr_value |
| 251 | + |
| 252 | + # Copy all items |
| 253 | + f_in.visititems(copy_item) |
| 254 | + |
| 255 | + # Copy root attributes |
| 256 | + for attr_name, attr_value in f_in.attrs.items(): |
| 257 | + f_out.attrs[attr_name] = attr_value |
| 258 | + |
| 259 | + print(f"Saved updated force constants to: {fname_out}") |
| 260 | + |
| 261 | + |
| 262 | +def main(): |
| 263 | + """Main entry point for the script.""" |
| 264 | + parser = argparse.ArgumentParser( |
| 265 | + description="Apply SCPH corrections to force constants in HDF5 file" |
| 266 | + ) |
| 267 | + parser.add_argument( |
| 268 | + '--input', '-i', |
| 269 | + required=True, |
| 270 | + help="Input HDF5 file with force constants" |
| 271 | + ) |
| 272 | + parser.add_argument( |
| 273 | + '--output', '-o', |
| 274 | + required=True, |
| 275 | + help="Output HDF5 file for updated force constants" |
| 276 | + ) |
| 277 | + parser.add_argument( |
| 278 | + '--dfc2', |
| 279 | + required=True, |
| 280 | + help="File containing force constant corrections" |
| 281 | + ) |
| 282 | + parser.add_argument( |
| 283 | + '--temp', |
| 284 | + type=float, |
| 285 | + required=True, |
| 286 | + help="Temperature (K) for corrections" |
| 287 | + ) |
| 288 | + args = parser.parse_args() |
| 289 | + |
| 290 | + print(f"Loading force constants from: {args.input}") |
| 291 | + fc2_data = FC2Data(args.input) |
| 292 | + |
| 293 | + print(f"Calculating fractional shifts...") |
| 294 | + shifts_frac = fc2_data.calculate_fractional_shifts() |
| 295 | + |
| 296 | + print(f"Loading corrections from: {args.dfc2}") |
| 297 | + dfc2_correction = DFC2Correction(args.dfc2, args.temp) |
| 298 | + |
| 299 | + print(f"Applying corrections...") |
| 300 | + fc2_updated = FC2Updater.update(fc2_data, shifts_frac, dfc2_correction) |
| 301 | + |
| 302 | + print(f"Writing results...") |
| 303 | + HDF5Writer.copy_with_updated_fc2(args.input, args.output, fc2_updated) |
| 304 | + |
| 305 | + print("Done!") |
| 306 | + |
| 307 | + |
| 308 | +if __name__ == "__main__": |
| 309 | + main() |
0 commit comments