Skip to content

Commit

Permalink
Merge pull request #182 from dihm/ACStark_fixes
Browse files Browse the repository at this point in the history
AC stark fixes and extensions
  • Loading branch information
nikolasibalic authored Oct 4, 2024
2 parents ece00f9 + e4e8fb6 commit 2e00f5b
Show file tree
Hide file tree
Showing 3 changed files with 597 additions and 122 deletions.
202 changes: 147 additions & 55 deletions arc/calculations_atom_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -3343,7 +3343,7 @@ def _onePhotonCoupling(self, ns, ls, js, mjs, nt, lt, jt, mjt, q, s=0.5):
if (ns == nt) and (ls == lt) and (js == jt) and (mjs == mjt):
return False
# transitions that change l by 1
elif (abs(ls - lt) == 1) and (mjs - mjt == q):
elif (abs(ls - lt) == 1) and (mjt - mjs == q):
if ls - lt == js - jt:
return True
elif (js == jt) and ((js == ls + s) or (jt == lt + s)):
Expand Down Expand Up @@ -3380,11 +3380,11 @@ def _twoPhotonCoupling(self, ns, ls, js, mjs, nt, lt, jt, mjt, q, s=0.5):
elif (
(abs(ls - lt) == 2)
and (ls - lt == js - jt)
and ((mjs - mjt) / 2 == q)
and ((mjt - mjs) / 2 == q)
):
return True
# transitions that don't change l
elif ((ls - lt) == 0) and (js == jt) and ((mjs - mjt) / 2 == q):
elif ((ls - lt) == 0) and (js == jt) and ((mjt - mjs) / 2 == q):
return True
else:
return False
Expand All @@ -3395,12 +3395,13 @@ def defineBasis(
l,
j,
mj,
q,
nMin,
nMax,
maxL,
q=0,
nMin=None,
nMax=None,
maxL=None,
Bz=0,
edN=0,
basisStates=None,
progressOutput=False,
debugOutput=False,
s=0.5,
Expand All @@ -3419,14 +3420,18 @@ def defineBasis(
l (int): angular orbital momentum of the state
j (flaot): total angular momentum of the state
mj (float): projection of total angular momentum of the state
q (int): polarization of coupling field is spherical basis.
q (int, optional): polarization of coupling field is spherical basis.
Must be -1, 0, or 1: corresponding to sigma-, pi, sigma+
nMin (int): *minimal* principal quantum number of the states to
Default is 0.
nMin (int, optional): *minimal* principal quantum number of the states to
be included in the basis for calculation
nMax (int): *maximal* principal quantum number of the states to
If not provided, `basisStates` must be provided.
nMax (int, optional): *maximal* principal quantum number of the states to
be included in the basis for calculation
maxL (int): *maximal* value of orbital angular momentum for the
If not provided, `basisStates` must be provided.
maxL (int, optional): *maximal* value of orbital angular momentum for the
states to be included in the basis for calculation
If not provided, `basisStates` must be provided.
Bz (float, optional): magnetic field directed along z-axis in
units of Tesla. Calculation will be correct only for weak
magnetic fields, where paramagnetic term is much stronger
Expand All @@ -3437,6 +3442,9 @@ def defineBasis(
only include single-photon dipole-allowed transitions.
Setting to 2 means include up to 2 photon transitions.
Higher numbers not supported.
basisStates (list of states, optional): Manually specify the basis.
Defaults to None, in which case it creates the basis as normal.
If specified, `nMin`, `nMax`, `maxL`, and `edN` are ignored.
progressOutput (:obj:`bool`, optional): if True prints the
progress of calculation; Set to false by default.
debugOutput (:obj:`bool`, optional): if True prints additional
Expand All @@ -3454,21 +3462,47 @@ def defineBasis(
self.l = l
self.j = j
self.mj = mj
self.q = q
if edN in [0, 1, 2]:
self.edN = edN
else:
raise ValueError("EN must be 0, 1, or 2")
self.nMin = nMin
self.nMax = nMax
self.maxL = maxL
self.Bz = Bz
self.s = s
# save calculation details END

self._findBasisStates(progressOutput, debugOutput)
# basis definition
if basisStates is not None:
self._defineBasisStates(basisStates)
elif nMin is not None and nMax is not None and maxL is not None:
# options to control automatically finding basis states
self.q = q
if edN in [0, 1, 2]:
self.edN = edN
else:
raise ValueError("edN must be 0, 1, or 2")
self.nMin = nMin
self.nMax = nMax
self.maxL = maxL
self._findBasisStates(progressOutput, debugOutput)
else:
raise ValueError(
"Input arguments are not complete. "
+ "Either specify nMin, nMax, maxL or basisStates"
)
# generate the hamiltonian
self._buildHamiltonian(progressOutput, debugOutput)

def _defineBasisStates(self, states_list):
"""
Use the user-provided list of states to define the basis
Args:
states_list (list): List of state quantum numbers in the form
`[n, l, j, mj]`.
"""
self.basisStates = states_list
# find index of target state
t_state = [self.n, self.l, self.j, self.mj]
try:
self.indexOfCoupledState = states_list.index(t_state)
except ValueError:
raise ValueError(f"Target state {t_state} not in states list")
self.targetState = states_list[self.indexOfCoupledState]

def _findBasisStates(self, progressOutput=False, debugOutput=False):
"""
Creates the list of basis states we want to include.
Expand Down Expand Up @@ -3501,18 +3535,17 @@ def _findBasisStates(self, progressOutput=False, debugOutput=False):
for tn in range(nMin, nMax):
for tl in range(min(maxL + 1, tn)):
for tj in np.linspace(tl - s, tl + s, round(2 * s + 1)):
# skip test state if unphysical
if abs(mj + q) - 0.1 > tj:
continue
# ensure we add the target state
if (n == tn) and (l == tl) and (j == tj):
states.append([tn, tl, tj, mj])
indexOfCoupledState = index
# adding all manifold states
elif (
(edN == 0)
and (abs(mj) + q - 0.1 <= tj)
and (
tn >= self.atom.groundStateN
or [tn, tl, tj] in self.atom.extraLevels
)
elif (edN == 0) and (
tn >= self.atom.groundStateN
or [tn, tl, tj] in self.atom.extraLevels
):
states.append([tn, tl, tj, mj + q])
index += 1
Expand Down Expand Up @@ -3743,12 +3776,22 @@ def __init__(self, atom):
"""
self.transProbs = []
"""
Probability to transition from the target state to another state in the basis.
Long-time averaged probability to transition from the target state to another state in the basis.
Calculated using Eq. 19 of Shirley, Physical Review (1965).
"""
self.targetShifts = []
"""
This is the shift of the target state relative to the zero field energy for an applied
field of :obj:`eField` and :obj:`freq`. Given in units of Hz.
.. note::
Accurate calculation of the target shifts relies on
the energy ordering of the eigenstates to remain the same.
Situations with large shifts can result in re-ordering which makes
this calculation (and the associated :obj:`transProbs`) invalid.
In such a case, it is best to perform a more informed analysis on the
base eigenvalues and eigenenergies.
"""

def defineShirleyHamiltonian(self, fn, debugOutput=False):
Expand Down Expand Up @@ -3843,24 +3886,30 @@ def diagonalise(
"""

# get basic info about solve structure from class
dim0 = len(self.basisStates)
dim0 = len(self.basisStates) # atomic basis size
dim1 = 2 * self.fn + 1 # floquet basis size
targetEnergy = self.targetEnergy

# index of first basis state in k=0 block diagonal
refInd = self.fn * dim0
# index of target state in basis
tarInd = self.indexOfCoupledState + refInd

# ensure inputs are numpy arrays, if scalars, 0d-arrays
self.eFields = np.array(eFields, ndmin=1)
self.freqs = np.array(freqs, ndmin=1)

# pre-allocation of results array
eig = np.zeros(
(*self.eFields.shape, *self.freqs.shape, dim0 * (2 * self.fn + 1)),
(*self.eFields.shape, *self.freqs.shape, dim0 * dim1),
dtype=np.double,
)
eigVec = np.zeros(
(
*self.eFields.shape,
*self.freqs.shape,
dim0 * (2 * self.fn + 1),
dim0 * (2 * self.fn + 1),
dim0 * dim1,
dim0 * dim1,
),
dtype=np.complex128,
)
Expand Down Expand Up @@ -3901,29 +3950,14 @@ def diagonalise(
eigVec[it.multi_index] = egvector

# get transition probabilities from target state to other basis states
# index of first basis state in k=0 block diagonal
refInd = self.fn * dim0
# index of target state in basis
tarInd = self.indexOfCoupledState + refInd
transProbs[it.multi_index] = np.array(
[
np.sum(
[
np.abs(
np.conj(egvector[refInd + k * dim0 + i])
* egvector[tarInd]
)
** 2
for k in range(-self.fn, self.fn + 1, 1)
]
)
for i in range(0, dim0, 1)
]
# note: conj not necessary since all eigenvalues are real
transProbs[it.multi_index] = (
(np.abs(egvector * egvector[tarInd].conj()) ** 2)
.reshape((dim1, dim0, dim1 * dim0))
.sum(axis=(0, -1))
)
# get the target shift by finding the max overlap with the target state
evInd = np.argmax(
np.abs(egvector[tarInd].conj() * egvector[tarInd]) ** 2
)
evInd = np.argmax(np.abs(egvector[tarInd]) ** 2)
if np.count_nonzero(ev == ev[evInd]) > 1:
warnings.warn(
"Multiple states have same overlap with target. Only saving first one."
Expand All @@ -3946,6 +3980,64 @@ def diagonalise(
self.transProbs = transProbs.squeeze()
self.targetShifts = targetShifts.squeeze()

def calcTransitionProbability(self, tevals):
"""
Calculates the time-dependent transition probability.
Note that the calculation assumes all population is in
the target state at time :math:`t=0` and
that this probability is averaged over initial phases
of the driving field, therefore is primarily valid in the CW
driving case.
This function implements Eq. 18 of [1]_ and is primarily
for calculation the Rabi flops in population under high driving fields.
Args:
tevals (float or numpy.ndarray): Times to evaluate the transition probability.
In units of seconds. Input is coerced to a numpy array.
Returns:
numpy.ndarray: Transition probability to go from the target state
to any state in the basis, as a function of t.
Output tensor shape is of the form
`(tevals, efields, freqs, floquet_basis)`,
where `efields`, `freqs`, and `floquet_basis` sizes are taken
from the result of the :meth:`diagonalise` function.
References:
.. [1] J. H. Shirley, Physical Review **138**, B979 (1965)
https://link.aps.org/doi/10.1103/PhysRev.138.B979
"""

t_eval = np.array(tevals, ndmin=1)
dim0 = len(self.basisStates)
dim1 = 2 * self.fn + 1
tarInd = self.indexOfCoupledState + (self.fn * dim0)

ut = np.exp(
-1.0j * 2 * np.pi * np.einsum("i,...j->i...j", t_eval, self.eigs)
)
# reshape separates atomic and floquet basis expansions
# final result sums along floquet expansion only
Pab = (
(
np.abs(
np.einsum(
"...km,l...m,...m->l...k",
self.eigVectors,
ut,
self.eigVectors[..., tarInd, :].conj(),
)
)
** 2
)
.reshape(ut.shape[:-1] + (dim1, dim0))
.sum(axis=-2)
)

return Pab.squeeze() # remove 0-d time, if applicable


class RWAStarkShift(StarkBasisGenerator):
"""
Expand Down
Loading

0 comments on commit 2e00f5b

Please sign in to comment.