Skip to content

add Trinity model and debug the lem #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

floatingCatty
Copy link
Member

@floatingCatty floatingCatty commented May 29, 2025

This pull request introduces a new ScalarMLPFunction module for modular MLP functionality, integrates the Trinity embedding method, and refactors existing code to enhance modularity and maintainability. Key changes include the addition of new functionality for two-body predictions in the e3tb method, improved handling of onsite shifts in loss computation, and cleanup of duplicate code.

New Features and Modules:

  • Added ScalarMLPFunction to dptb/nn/base.py, which implements a modular MLP with configurable dimensions, nonlinearity, initialization, and optional dropout and batch normalization. This replaces duplicate implementations in other files. [1] [2] [3]

  • Introduced the Trinity embedding method in dptb/nn/embedding/__init__.py and integrated it into the e3tb prediction workflow in dptb/nn/deeptb.py. This enables additional embedding options for specific use cases. [1] [2] [3]

Enhancements to e3tb Predictions:

  • Added functionality for two-body predictions in e3tb by introducing edge_prediction_h2 and h2miltonian modules in dptb/nn/deeptb.py. These handle additional edge attributes and Hamiltonian calculations. [1] [2] [3]

Refactoring and Code Cleanup:

  • Removed duplicate ScalarMLPFunction implementations from dptb/nn/embedding/lem.py and dptb/nn/embedding/slem.py, replacing them with the centralized implementation in dptb/nn/base.py. This reduces redundancy and simplifies maintenance. [1] [2]

  • Updated lem.py to use the new ScalarMLPFunction and fixed dimension mismatches in latent and output scalars. [1] [2]

Loss Computation Improvements:

  • Enhanced the onsite shift calculation in dptb/nnops/loss.py to account for overlap weights and avoid overflow issues. This ensures more accurate loss adjustment for batch data.

@Franklalalala
Copy link
Collaborator

Kindly suggest the flag 'init2b' change to 'only2b'.
'only2b'=true, train 2b only and skip the heavy message passing
'only2b'=false, freeze 2b and train the message passing

@QG-phy
Copy link
Collaborator

QG-phy commented Jun 5, 2025

image
如果是用了先用双中心,再加E3还需要对数据统计分析吗?

@floatingCatty
Copy link
Member Author

image
如果是用了先用双中心,再加E3还需要对数据统计分析吗?

不需要了,我comment下

ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
elif batch.max() >= 1:
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
slices = [0] + slices
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要改为torch.tensor(i.shape[0]) 不然会报错。因为torch.stack要对tensor操作不能是int 之前i.sum() 没问题。改成i.shape[0]就会报错了。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

floatingCatty and others added 4 commits June 6, 2025 08:30
…lements

Initialize full_mask_to_diag tensor to track diagonal orbital pairs in the reduced matrix. This helps in identifying diagonal elements during further processing.
…alculation

The shift_mu function was extracted to avoid code duplication across multiple loss classes. The onsite shift calculation was simplified by using a more accurate formula that accounts for both node and edge features.
for orbs, islice in self.orbpair_maps.items():
fio, fjo = orbs.split('-')
if fio == fjo:
self.full_mask_to_diag[islice] = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块我新加了一个属性。获取在full basis 的feature 中轨道相同的那些指标。其实就是feature to block 里需要*0.5的那部分!

norm_ss_e = (ref_data[AtomicDataDict.EDGE_OVERLAP_KEY] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]).sum(dim=-1)
norm_ss_e_diag = (ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][:,idp.full_mask_to_diag] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][:,idp.full_mask_to_diag]).sum(dim=-1)

return mu_n, mu_e, mu_e_diag, norm_ss_n, norm_ss_e, norm_ss_e_diag
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一个函数我新增了。计算各个部分。


mu = mu_n + 2 * mu_e - mu_e_diag
ss = norm_ss_n + 2 * norm_ss_e - norm_ss_e_diag
mu = mu / ss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

优化了这部分的逻辑!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants