-
Notifications
You must be signed in to change notification settings - Fork 19
add Trinity model and debug the lem #256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Kindly suggest the flag 'init2b' change to 'only2b'. |
dptb/nnops/loss.py
Outdated
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY] | ||
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY] | ||
elif batch.max() >= 1: | ||
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))] | ||
slices = [0] + slices | ||
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)]) | ||
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要改为torch.tensor(i.shape[0]) 不然会报错。因为torch.stack要对tensor操作不能是int 之前i.sum() 没问题。改成i.shape[0]就会报错了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
…lements Initialize full_mask_to_diag tensor to track diagonal orbital pairs in the reduced matrix. This helps in identifying diagonal elements during further processing.
…alculation The shift_mu function was extracted to avoid code duplication across multiple loss classes. The onsite shift calculation was simplified by using a more accurate formula that accounts for both node and edge features.
for orbs, islice in self.orbpair_maps.items(): | ||
fio, fjo = orbs.split('-') | ||
if fio == fjo: | ||
self.full_mask_to_diag[islice] = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块我新加了一个属性。获取在full basis 的feature 中轨道相同的那些指标。其实就是feature to block 里需要*0.5的那部分!
norm_ss_e = (ref_data[AtomicDataDict.EDGE_OVERLAP_KEY] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]).sum(dim=-1) | ||
norm_ss_e_diag = (ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][:,idp.full_mask_to_diag] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][:,idp.full_mask_to_diag]).sum(dim=-1) | ||
|
||
return mu_n, mu_e, mu_e_diag, norm_ss_n, norm_ss_e, norm_ss_e_diag |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一个函数我新增了。计算各个部分。
|
||
mu = mu_n + 2 * mu_e - mu_e_diag | ||
ss = norm_ss_n + 2 * norm_ss_e - norm_ss_e_diag | ||
mu = mu / ss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
优化了这部分的逻辑!
This pull request introduces a new
ScalarMLPFunction
module for modular MLP functionality, integrates theTrinity
embedding method, and refactors existing code to enhance modularity and maintainability. Key changes include the addition of new functionality for two-body predictions in thee3tb
method, improved handling of onsite shifts in loss computation, and cleanup of duplicate code.New Features and Modules:
Added
ScalarMLPFunction
todptb/nn/base.py
, which implements a modular MLP with configurable dimensions, nonlinearity, initialization, and optional dropout and batch normalization. This replaces duplicate implementations in other files. [1] [2] [3]Introduced the
Trinity
embedding method indptb/nn/embedding/__init__.py
and integrated it into thee3tb
prediction workflow indptb/nn/deeptb.py
. This enables additional embedding options for specific use cases. [1] [2] [3]Enhancements to
e3tb
Predictions:e3tb
by introducingedge_prediction_h2
andh2miltonian
modules indptb/nn/deeptb.py
. These handle additional edge attributes and Hamiltonian calculations. [1] [2] [3]Refactoring and Code Cleanup:
Removed duplicate
ScalarMLPFunction
implementations fromdptb/nn/embedding/lem.py
anddptb/nn/embedding/slem.py
, replacing them with the centralized implementation indptb/nn/base.py
. This reduces redundancy and simplifies maintenance. [1] [2]Updated
lem.py
to use the newScalarMLPFunction
and fixed dimension mismatches in latent and output scalars. [1] [2]Loss Computation Improvements:
dptb/nnops/loss.py
to account for overlap weights and avoid overflow issues. This ensures more accurate loss adjustment for batch data.