Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LPSE Pump Equation #28

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion adept/lpse2d/core/integrator.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import equinox as eqx
import diffrax

from adept.lpse2d.core import epw
from adept.lpse2d.core import epw, pump


class Stepper(diffrax.Euler):
@@ -27,12 +27,14 @@ class VectorField(eqx.Module):

cfg: Dict
epw: eqx.Module
pump: eqx.Module
complex_state_vars: List

def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.epw = epw.EPW2D(cfg)
self.pump = pump.Pump2D(cfg)
self.complex_state_vars = ["e0", "phi"]

def unpack_y(self, y):
@@ -46,6 +48,7 @@ def unpack_y(self, y):

def __call__(self, t, y, args):
new_y = self.epw(t, self.unpack_y(y), args)
new_y = self.pump(t, self.unpack_y(y), args)

for k in y.keys():
y[k] = y[k].view(jnp.float64)
97 changes: 97 additions & 0 deletions adept/lpse2d/core/pump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Dict, Tuple

import diffrax
import jax
from jax import numpy as jnp
import equinox as eqx
import numpy as np
from theory import electrostatic
from adept.lpse2d.core.driver import Driver


class Pump2D(eqx.Module):
cfg: Dict
dt: float
wp0: float
w0: float
n0: float
nuei: float
kax_sq: jax.Array
kx: jax.Array
ky: jax.Array
transverse_mask: jax.Array

def __init__(self, cfg):
self.cfg = cfg
self.dt = cfg["grid"]["dt"]
self.wp0 = cfg["plasma"]["wp0"]
self.w0 = cfg["drivers"]["E0"]["w0"]
self.n0 = np.sqrt(self.wp0)
self.nuei = -cfg["units"]["derived"]["nuei_norm"]
self.kax_sq = cfg["grid"]["kx"][:, None] ** 2.0 + cfg["grid"]["ky"][None, :] ** 2.0
self.kx = cfg["grid"]["kx"]
self.ky = cfg["grid"]["ky"]
self.transverse_mask = None

def _calc_div_(self, arr: jax.Array) -> jax.Array:
arrk = jnp.fft.fft2(arr)
divk = self.kx[:, None] * arrk[..., 0] + self.ky[None, :] * arrk[..., 1]
return jnp.fft.ifft2(divk)

def calc_damping(self, nb) -> Tuple[jax.Array, jax.Array]:
return nb / self.w0**2.0 * self.nuei / 2.0

def calc_oscillation(self, nb: jax.Array) -> jax.Array:
"""

calculates 1j / (2 w0) * w0^2 - nb - wp0^2 * nb / n0

"""
coeff = self.w0**2.0 - nb - self.wp0**2 * nb / self.n0
return 1j / 2 / self.w0 * coeff

def calc_nabla(self, e0: jax.Array) -> jax.Array:
"""
Calculates the spatial advection term

"""
nabla2 = jnp.fft.fft2(e0, axis=(0, 1)) * (-self.kax_sq)[:, :, None] # (ikx^2 + iky^2) * E0(kx, ky)
div = self._calc_div_(e0) # div(E0)
grad_x, grad_y = self.kx[:, None] * div, self.ky[None, :] * div # kx * div(E0), ky * div(E0)
term = nabla2 - 1j * jnp.concatenate(
[grad_x[..., None], grad_y[..., None]], axis=-1
) # (ikx^2 + iky^2) * E0(kx, ky) - i * (kx * div(E0), ky * div(E0))
term *= 1j * self.c_light**2.0 / 2.0 / self.w0 # * i * c^2 / 2 / w0
return jnp.fft.ifft2(term)

def calc_epw_term(self, t: float, eh: jax.Array, nb: jax.Array) -> jax.Array:
"""
Calculates the pump depletion term

"""
coeff = 1j / 2.0 / self.w0 * jnp.exp(1j * (self.w0 - 2 * self.wp0) * t)

div_eh = self._calc_div_(eh)
term = nb / self.n0 * (eh * div_eh) * self.transverse_mask

return coeff * term

def get_eh_x(self, phi: jax.Array) -> jax.Array:
ehx = -jnp.fft.ifft2(1j * self.kx[:, None] * phi)
ehy = -jnp.fft.ifft2(1j * self.ky[None, :] * phi)

return jnp.concatenate([ehx[..., None], ehy[..., None]], axis=-1) * self.kx.size * self.ky.size / 4

def __call__(self, t, y, args):
e0 = y["e0"]
phi = y["phi"]
nb = y["nb"]

eh = self.get_eh_x(phi)

e0 = e0 * jnp.exp(self.calc_damping(nb=nb)[:, :, None])
e0 = e0 + self.dt * self.calc_nabla(e0)
e0 = e0 + self.dt * e0 * self.calc_oscillation(nb)[:, :, None]
y["e0"] = e0 + self.dt * self.calc_epw_term(t, eh, nb)

return y
37 changes: 13 additions & 24 deletions configs/envelope-2d/tpd.yaml
Original file line number Diff line number Diff line change
@@ -21,27 +21,13 @@ drivers:
y_w: 60000000.
y_r: 20.
k0: 1.0
a0: 0.0
intensity: 4.0e14
E2:
w0: 0.03375
t_c: 230.
t_w: 400.
t_r: 5.
x_c: 1400.
x_w: 600.
x_r: 20.
y_c: 0.
y_w: 2000000.
y_r: 5.
k0: 0.15
a0: 0.0
intensity: 4.0e14
a0: doesnt_matter
intensity: 1.0e5 W/cm^2

save:
t:
tmin: 0.0
tmax: 10000.0
tmax: 20000.0
nt: 32

plasma:
@@ -58,6 +44,9 @@ units:
normalizing density: 1.5e21/cc
Z: 10
Zp: 10
gas fill: N
ionization state: 6
electron temperature: 4000eV

grid:
xmin: 000.0
@@ -67,12 +56,12 @@ grid:
ymax: 1000.0
ny: 512
tmin: 0.
tmax: 10000.0
tmax: 20000.0
dt: 2.0

mlflow:
experiment: lpse2d-tpd
run: noise-test-no-density-gradient
run: tpd-test-pump-evolution

# models:
# file: None #/Users/archis/Dev/code/ergodic/laplax/weights.eqx
@@ -86,15 +75,15 @@ mlflow:

terms:
epw:
linear: True
density_gradient: True
kinetic real part: False
linear: true
density_gradient: true
kinetic real part: false
boundary:
x: absorbing
y: periodic
trapping:
active: False
active: false
kld: 0.28
nuee: 0.0000001
source:
tpd: False
tpd: true
3 changes: 2 additions & 1 deletion nersc-cpu.sh
Original file line number Diff line number Diff line change
@@ -8,5 +8,6 @@
export BASE_TEMPDIR="$PSCRATCH/tmp/"
export MLFLOW_TRACKING_URI="$PSCRATCH/mlflow"

source /global/u2/a/archis/adept/venv/bin/activate
module load conda
mamba activate adept
cd /global/u2/a/archis/adept/
3 changes: 2 additions & 1 deletion nersc-gpu.sh
Original file line number Diff line number Diff line change
@@ -10,9 +10,10 @@
export SLURM_CPU_BIND="cores"
export BASE_TEMPDIR="$PSCRATCH/tmp/"
export MLFLOW_TRACKING_URI="$PSCRATCH/mlflow"
export MLFLOW_EXPORT="True"

# copy job stuff over
module load python
module load conda
module load cudnn/8.9.3_cuda12.lua
module load cudatoolkit/12.0.lua

12 changes: 6 additions & 6 deletions queue_adept.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
BASE_TEMPDIR = None


def _queue_run_(machine, run_id):
def _queue_run_(machine, run_id, mode, run_name):
if "cpu" in machine:
base_job_file = os.environ["CPU_BASE_JOB_FILE"]
elif "gpu" in machine:
@@ -23,11 +23,11 @@ def _queue_run_(machine, run_id):
with open(base_job_file, "r") as fh:
base_job = fh.read()

with open(os.path.join(os.getcwd(), "new_job.sh"), "w") as job_file:
with open(os.path.join(os.getcwd(), f"queue-{mode}-{run_name}.sh"), "w") as job_file:
job_file.write(base_job)
job_file.writelines(f"srun python run.py --mode remote --run_id {run_id}")

os.system(f"sbatch new_job.sh")
job_file.writelines(f"\nsrun python run.py --run_id {run_id}")
os.system(f"sbatch queue-{mode}-{run_name}.sh")
time.sleep(0.1)
os.system("sqs")

@@ -55,4 +55,4 @@ def load_and_make_folders(cfg_path):
args = parser.parse_args()

cfg, run_id = load_and_make_folders(args.cfg)
_queue_run_(cfg["machine"]["calculator"], run_id)
_queue_run_(cfg["machine"], run_id, cfg["mode"], cfg["mlflow"]["run"])