@@ -672,7 +672,7 @@ def get_slices3(atom_symbols,edge_indices,to_jimages):
672
672
if strategy == 3 :
673
673
return get_slices3 (atom_symbols ,edge_indices ,to_jimages )
674
674
675
- def structure2SLICESAug (self ,structure ,strategy = 3 ,num = 40 ):
675
+ def structure2SLICESAug (self ,structure ,strategy = 3 ,num = 200 ):
676
676
"""
677
677
(1) extract edge_indices, to_jimages and atom_types from a pymatgen structure object
678
678
(2) encoding edge_indices, to_jimages and atom_types into multiple equalivent SLICES strings
@@ -743,7 +743,7 @@ def get_slices3(atom_symbols,edge_indices,to_jimages):
743
743
if strategy == 3 :
744
744
SLICES_list .append (get_slices3 (atom_symbols ,edge_indices ,to_jimages ))
745
745
#calcualte how many element and edge permuatations needed. round((n/6)**(1/2))
746
- num_permutation = round (( num / 6 )** (1 / 3 ))
746
+ num_permutation = int ( math . ceil (( num / 6 )** (1 / 3 ) ))
747
747
# shuffle to get permu
748
748
permu = []
749
749
for i in range (num ):
@@ -817,7 +817,8 @@ def remove_duplicate_arrays(arrays):
817
817
SLICES_list .append (get_slices2 (atom_symbols_new ,edge_indices_new_final_flip ,to_jimages_shu_trans_per_trans_final_flip ))
818
818
if strategy == 3 :
819
819
SLICES_list .append (get_slices3 (atom_symbols_new ,edge_indices_new_final_flip ,to_jimages_shu_trans_per_trans_final_flip ))
820
- return SLICES_list
820
+ random .shuffle (SLICES_list )
821
+ return SLICES_list [:num ]
821
822
822
823
def get_dim (self ,structure ):
823
824
"""
0 commit comments