Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug about PromptEHR #19

Open
xansar opened this issue Apr 17, 2024 · 0 comments
Open

Bug about PromptEHR #19

xansar opened this issue Apr 17, 2024 · 0 comments

Comments

@xansar
Copy link

xansar commented Apr 17, 2024

When I tried to load PromptEHR from pretrained, a bug occurred:

AttributeError                            Traceback (most recent call last)
Input In [9], in <cell line: 5>()
      3 vocs = data['voc']
      4 model = PromptEHR()
----> 5 model.from_pretrained()

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/pytrial/tasks/trial_simulation/sequence/promptehr.py:222, in PromptEHR.from_pretrained(self, input_dir)
    211 def from_pretrained(self, input_dir='./simulation/pretrained_promptEHR'):
    212     '''
    213     Load pretrained PromptEHR model and make patient EHRs generation.
    214     Pretrained model was learned from MIMIC-III patient sequence data.
   (...)
    220         to this folder.
    221     '''
--> 222     self.model.from_pretrained(input_dir=input_dir)
    223     self.config.update(self.model.config)

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:359, in PromptEHR.from_pretrained(self, input_dir)
    356     print(f'Download pretrained PromptEHR model, save to {input_dir}.')
    358 print('Load pretrained PromptEHR model from', input_dir)
--> 359 self.load_model(input_dir)

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:298, in PromptEHR.load_model(self, checkpoint)
    295 self._load_tokenizer(data_tokenizer_file, model_tokenizer_file)
    297 # load configuration
--> 298 self.configuration = EHRBartConfig(self.data_tokenizer, self.model_tokenizer, n_num_feature=self.config['n_num_feature'], cat_cardinalities=self.config['cat_cardinalities'])
    299 self.configuration.from_pretrained(checkpoint)
    301 # build model

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/modeling_config.py:24, in EHRBartConfig(data_tokenizer, model_tokenizer, **kwargs)
     22 bart_config = BartConfig.from_pretrained('facebook/bart-base')
     23 kwargs.update(model_tokenizer.get_num_tokens)
---> 24 kwargs['data_tokenizer_num_vocab'] = len(data_tokenizer)
     25 if 'd_prompt_hidden' not in kwargs:
     26     kwargs['d_prompt_hidden'] = 128

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:431, in PreTrainedTokenizer.__len__(self)
    426 def __len__(self):
    427     """
    428     Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if
    429     there is a hole in the vocab, we will add tokenizers at a wrong index.
    430     """
--> 431     return len(set(self.get_vocab().keys()))

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/models/bart/tokenization_bart.py:243, in BartTokenizer.get_vocab(self)
    242 def get_vocab(self):
--> 243     return dict(self.encoder, **self.added_tokens_encoder)

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:391, in PreTrainedTokenizer.added_tokens_encoder(self)
    385 @property
    386 def added_tokens_encoder(self) -> Dict[str, int]:
    387     """
    388     Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
    389     optimisation in `self._added_tokens_encoder` for the slow tokenizers.
    390     """
--> 391     return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}

AttributeError: 'DataTokenizer' object has no attribute '_added_tokens_decoder'

——————————————————————————
My codes are:

from pytrial.tasks.trial_simulation.data import SequencePatient
from pytrial.data.demo_data import load_synthetic_ehr_sequence
data = load_synthetic_ehr_sequence()

train_data = SequencePatient(
    data={
        'v': data['visit'],
        'y': data['y'],
        'x': data['feature'],
        },
    metadata={
        'visit': {'mode': 'dense'},
        'label': {'mode': 'tensor'},
        'voc': data['voc'],
        'max_visit': 20,
        'n_num_feature': data['n_num_feature'],
        'cat_cardinalities': data['cat_cardinalities'],
    }
)

from pytrial.tasks.trial_simulation.sequence import PromptEHR

vocs = data['voc']
model = PromptEHR()
model.from_pretrained()

I can directly load BartTokenizer successfully:

from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
tokenizer

BartTokenizer(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True),
}

tokenizer.added_tokens_decoder

{0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True)}

Could you please help me to fix this bug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant