Skip to content

Commit ad1ab72

Browse files
committed
Fix to_hf and to_ks for GDF SCF methods (fix pyscf#2788)
1 parent befb2fe commit ad1ab72

File tree

3 files changed

+98
-10
lines changed

3 files changed

+98
-10
lines changed

pyscf/dft/test/test_h2o.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pyscf import gto
1919
from pyscf import lib
2020
from pyscf import dft
21+
from pyscf import scf
2122
try:
2223
from pyscf.dispersion import dftd3, dftd4
2324
except ImportError:
@@ -639,6 +640,88 @@ def test_init(self):
639640
self.assertTrue(isinstance(mol_u.DKS(), dft.dks.UDKS))
640641
#TODO: self.assertTrue(isinstance(dft.X2C(mol_r), x2c.dft.UKS))
641642

643+
def test_to_hf(self):
644+
self.assertEqual(dft.RKS(h2o).to_rhf().__class__, scf.rhf.RHF)
645+
self.assertEqual(dft.RKS(h2o).to_uhf().__class__, scf.uhf.UHF)
646+
self.assertEqual(dft.RKS(h2o).to_ghf().__class__, scf.ghf.GHF)
647+
self.assertEqual(dft.RKS(h2o).to_hf() .__class__, scf.rhf.RHF)
648+
self.assertEqual(dft.RKS(h2o).to_rks().__class__, dft.rks.RKS)
649+
self.assertEqual(dft.RKS(h2o).to_uks().__class__, dft.uks.UKS)
650+
self.assertEqual(dft.RKS(h2o).to_gks().__class__, dft.gks.GKS)
651+
652+
self.assertEqual(dft.UKS(h2o).to_rhf().__class__, scf.rhf.RHF)
653+
self.assertEqual(dft.UKS(h2o).to_uhf().__class__, scf.uhf.UHF)
654+
self.assertEqual(dft.UKS(h2o).to_ghf().__class__, scf.ghf.GHF)
655+
self.assertEqual(dft.UKS(h2o).to_hf() .__class__, scf.uhf.UHF)
656+
self.assertEqual(dft.UKS(h2o).to_rks().__class__, dft.rks.RKS)
657+
self.assertEqual(dft.UKS(h2o).to_uks().__class__, dft.uks.UKS)
658+
self.assertEqual(dft.UKS(h2o).to_gks().__class__, dft.gks.GKS)
659+
660+
self.assertEqual(dft.GKS(h2o).to_ghf().__class__, scf.ghf.GHF)
661+
self.assertEqual(dft.GKS(h2o).to_hf() .__class__, scf.ghf.GHF)
662+
self.assertEqual(dft.GKS(h2o).to_gks().__class__, dft.gks.GKS)
663+
664+
self.assertEqual(dft.RKS(h2o).density_fit().to_rhf().__class__, scf.rhf.RHF(h2o).density_fit().__class__)
665+
self.assertEqual(dft.RKS(h2o).density_fit().to_uhf().__class__, scf.uhf.UHF(h2o).density_fit().__class__)
666+
self.assertEqual(dft.RKS(h2o).density_fit().to_ghf().__class__, scf.ghf.GHF(h2o).density_fit().__class__)
667+
self.assertEqual(dft.RKS(h2o).density_fit().to_hf() .__class__, scf.rhf.RHF(h2o).density_fit().__class__)
668+
self.assertEqual(dft.RKS(h2o).density_fit().to_rks().__class__, dft.rks.RKS(h2o).density_fit().__class__)
669+
self.assertEqual(dft.RKS(h2o).density_fit().to_uks().__class__, dft.uks.UKS(h2o).density_fit().__class__)
670+
self.assertEqual(dft.RKS(h2o).density_fit().to_gks().__class__, dft.gks.GKS(h2o).density_fit().__class__)
671+
672+
self.assertEqual(dft.UKS(h2o).density_fit().to_rhf().__class__, scf.rhf.RHF(h2o).density_fit().__class__)
673+
self.assertEqual(dft.UKS(h2o).density_fit().to_uhf().__class__, scf.uhf.UHF(h2o).density_fit().__class__)
674+
self.assertEqual(dft.UKS(h2o).density_fit().to_ghf().__class__, scf.ghf.GHF(h2o).density_fit().__class__)
675+
self.assertEqual(dft.UKS(h2o).density_fit().to_hf() .__class__, scf.uhf.UHF(h2o).density_fit().__class__)
676+
self.assertEqual(dft.UKS(h2o).density_fit().to_rks().__class__, dft.rks.RKS(h2o).density_fit().__class__)
677+
self.assertEqual(dft.UKS(h2o).density_fit().to_uks().__class__, dft.uks.UKS(h2o).density_fit().__class__)
678+
self.assertEqual(dft.UKS(h2o).density_fit().to_gks().__class__, dft.gks.GKS(h2o).density_fit().__class__)
679+
680+
self.assertEqual(dft.GKS(h2o).density_fit().to_ghf().__class__, scf.ghf.GHF(h2o).density_fit().__class__)
681+
self.assertEqual(dft.GKS(h2o).density_fit().to_hf() .__class__, scf.ghf.GHF(h2o).density_fit().__class__)
682+
self.assertEqual(dft.GKS(h2o).density_fit().to_gks().__class__, dft.gks.GKS(h2o).density_fit().__class__)
683+
684+
def test_to_ks(self):
685+
self.assertEqual(scf.RHF(h2o).to_rhf().__class__, scf.rhf.RHF)
686+
self.assertEqual(scf.RHF(h2o).to_uhf().__class__, scf.uhf.UHF)
687+
self.assertEqual(scf.RHF(h2o).to_ghf().__class__, scf.ghf.GHF)
688+
self.assertEqual(scf.RHF(h2o).to_ks() .__class__, dft.rks.RKS)
689+
self.assertEqual(scf.RHF(h2o).to_rks().__class__, dft.rks.RKS)
690+
self.assertEqual(scf.RHF(h2o).to_uks().__class__, dft.uks.UKS)
691+
self.assertEqual(scf.RHF(h2o).to_gks().__class__, dft.gks.GKS)
692+
693+
self.assertEqual(scf.UHF(h2o).to_rhf().__class__, scf.rhf.RHF)
694+
self.assertEqual(scf.UHF(h2o).to_uhf().__class__, scf.uhf.UHF)
695+
self.assertEqual(scf.UHF(h2o).to_ghf().__class__, scf.ghf.GHF)
696+
self.assertEqual(scf.UHF(h2o).to_ks() .__class__, dft.uks.UKS)
697+
self.assertEqual(scf.UHF(h2o).to_rks().__class__, dft.rks.RKS)
698+
self.assertEqual(scf.UHF(h2o).to_uks().__class__, dft.uks.UKS)
699+
self.assertEqual(scf.UHF(h2o).to_gks().__class__, dft.gks.GKS)
700+
701+
self.assertEqual(scf.GHF(h2o).to_ghf().__class__, scf.ghf.GHF)
702+
self.assertEqual(scf.GHF(h2o).to_ks() .__class__, dft.gks.GKS)
703+
self.assertEqual(scf.GHF(h2o).to_gks().__class__, dft.gks.GKS)
704+
705+
self.assertEqual(scf.RHF(h2o).density_fit().to_rhf().__class__, scf.rhf.RHF(h2o).density_fit().__class__)
706+
self.assertEqual(scf.RHF(h2o).density_fit().to_uhf().__class__, scf.uhf.UHF(h2o).density_fit().__class__)
707+
self.assertEqual(scf.RHF(h2o).density_fit().to_ghf().__class__, scf.ghf.GHF(h2o).density_fit().__class__)
708+
self.assertEqual(scf.RHF(h2o).density_fit().to_ks() .__class__, dft.rks.RKS(h2o).density_fit().__class__)
709+
self.assertEqual(scf.RHF(h2o).density_fit().to_rks().__class__, dft.rks.RKS(h2o).density_fit().__class__)
710+
self.assertEqual(scf.RHF(h2o).density_fit().to_uks().__class__, dft.uks.UKS(h2o).density_fit().__class__)
711+
self.assertEqual(scf.RHF(h2o).density_fit().to_gks().__class__, dft.gks.GKS(h2o).density_fit().__class__)
712+
713+
self.assertEqual(scf.UHF(h2o).density_fit().to_rhf().__class__, scf.rhf.RHF(h2o).density_fit().__class__)
714+
self.assertEqual(scf.UHF(h2o).density_fit().to_uhf().__class__, scf.uhf.UHF(h2o).density_fit().__class__)
715+
self.assertEqual(scf.UHF(h2o).density_fit().to_ghf().__class__, scf.ghf.GHF(h2o).density_fit().__class__)
716+
self.assertEqual(scf.UHF(h2o).density_fit().to_ks() .__class__, dft.uks.UKS(h2o).density_fit().__class__)
717+
self.assertEqual(scf.UHF(h2o).density_fit().to_rks().__class__, dft.rks.RKS(h2o).density_fit().__class__)
718+
self.assertEqual(scf.UHF(h2o).density_fit().to_uks().__class__, dft.uks.UKS(h2o).density_fit().__class__)
719+
self.assertEqual(scf.UHF(h2o).density_fit().to_gks().__class__, dft.gks.GKS(h2o).density_fit().__class__)
720+
721+
self.assertEqual(scf.GHF(h2o).density_fit().to_ghf().__class__, scf.ghf.GHF(h2o).density_fit().__class__)
722+
self.assertEqual(scf.GHF(h2o).density_fit().to_ks() .__class__, dft.gks.GKS(h2o).density_fit().__class__)
723+
self.assertEqual(scf.GHF(h2o).density_fit().to_gks().__class__, dft.gks.GKS(h2o).density_fit().__class__)
724+
642725
if __name__ == "__main__":
643726
print("Full Tests for H2O")
644727
unittest.main()

pyscf/lib/misc.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -870,9 +870,7 @@ def fn(obj, *args, **kwargs):
870870
fn.__name__ = name
871871
return fn
872872

873-
@functools.lru_cache(None)
874-
def _define_class(name, bases):
875-
return type(name, bases, {})
873+
_registered_classes = {}
876874

877875
def make_class(bases, name=None, attrs=None):
878876
'''
@@ -883,14 +881,15 @@ def make_class(bases, name=None, attrs=None):
883881
class {name}(*bases):
884882
__dict__ = attrs
885883
'''
884+
global _registered_classes
886885
if name is None:
887886
name = ''.join(getattr(x, '__name_mixin__', x.__name__) for x in bases)
888887

889-
cls = _define_class(name, bases)
890-
cls.__name_mixin__ = name
891-
if attrs is not None:
892-
for key, val in attrs.items():
893-
setattr(cls, key, val)
888+
cls = _registered_classes.get((name, bases))
889+
if cls is None:
890+
if attrs is None:
891+
attrs = {}
892+
_registered_classes[name, bases] = cls = type(name, bases, attrs)
894893
return cls
895894

896895
def set_class(obj, bases, name=None, attrs=None):
@@ -929,7 +928,8 @@ def drop_class(cls, base_cls, name_mixin=None):
929928

930929
# rebuild the dynamic_mixin class
931930
attrs = {**cls.__dict__, '__name_mixin__': cls_name}
932-
cls_undressed = type(cls_name, tuple(filter_bases), attrs)
931+
cls_undressed = make_class(tuple(filter_bases), cls_name, attrs)
932+
cls_undressed.__module__ = cls.__module__
933933
return cls_undressed
934934

935935
def replace_class(cls, old_cls, new_cls):
@@ -951,7 +951,9 @@ def replace_class(cls, old_cls, new_cls):
951951

952952
name = cls.__name__.replace(old_cls.__name__, new_cls.__name__)
953953
attrs = {**cls.__dict__, '__name_mixin__': name}
954-
return type(name, tuple(bases), attrs)
954+
_cls = make_class(tuple(bases), name, attrs)
955+
_cls.__module__ = cls.__module__
956+
return _cls
955957

956958
def overwrite_mro(obj, mro):
957959
'''A hacky function to overwrite the __mro__ attribute'''

pyscf/scf/hf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,6 +2303,9 @@ def _transfer_attrs_(self, dst):
23032303
'''This helper function transfers attributes from one SCF object to
23042304
another SCF object. It is invoked by to_ks and to_hf methods.
23052305
'''
2306+
if hasattr(self, 'with_df') and not hasattr(dst, 'with_df'):
2307+
# Handle DF_SCF instances for to_xxx methods
2308+
dst = dst.density_fit(auxbasis=self.with_df.auxbasis)
23062309
# Search for all tracked attributes, including those in base classes
23072310
cls_keys = [getattr(cls, '_keys', ()) for cls in dst.__class__.__mro__[:-1]]
23082311
dst_keys = set(dst.__dict__).union(*cls_keys)

0 commit comments

Comments
 (0)