-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
problem on prediction stage #113
Comments
hey @lordfiftysix Could you point me to some code where I can reproduce the error? I am assuming you have train the |
yes |
Then if you could please point me to some code? Otherwise, maybe I can try later with some dataset I might have and "report back" the results here :) |
I suppose I am fine with the second option |
here is some fully functioning code from sklearn.model_selection import train_test_split
from pytorch_widedeep import Trainer
from pytorch_widedeep.datasets import load_womens_ecommerce
from pytorch_widedeep.models import BasicRNN, WideDeep
from pytorch_widedeep.preprocessing import TextPreprocessor
if __name__ == "__main__":
df = load_womens_ecommerce(as_frame=True)
# to be safe, but one can me more gentle here
df = df.dropna().reset_index(drop=True)
# just aesthetics
df.columns = [c.lower().replace(" ", "_") for c in df.columns]
# the reviews are a bit imbalanced, so we turned the problem into a binary
# classification
df["target"] = (df.rating >= 4).astype("int")
text_col = "review_text"
target = "target"
# train/test split
train, test = train_test_split(df, test_size=0.2, stratify=df.target)
# processing
text_processor = TextPreprocessor(text_col=text_col)
X_train = text_processor.fit_transform(train)
X_test = text_processor.transform(test)
# model definition. The model component needs to be wrap up with the
# WideDeep class
basic_rnn = BasicRNN(
vocab_size=len(text_processor.vocab.itos),
embed_dim=100,
hidden_dim=64,
n_layers=3,
bidirectional=True,
rnn_dropout=0.5,
padding_idx=1,
head_hidden_dims=[100, 50],
)
model = WideDeep(deeptext=basic_rnn, pred_dim=1)
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_text=X_train,
target=train[target].values,
n_epochs=1,
batch_size=256,
val_split=0.2,
)
# predict
preds = trainer.predict(X_text=X_test) |
It did not work. I am trying to do multi-output regression. Here is some more of my code.
And I am still getting the same error on the prediction stage |
To do multi-output regression or multi-label classification we would need to modify the code. In fact I don't know what the Anyway, if you can point me towards a notebook/colab with some small dataset or mock data would save me a lot of time. Otherwise I will try to mock some data myself and dig into this later |
Hey I wonder if you were ever able to dig into this problem. I can confirm that i have 6 columns and a few thousand rows as my output so RMSE probably wont work. That being said I am struggling to do multi-output regression on these 6 target columns given a single input text column. |
Hey, sorry @lordfiftysix I am buried at work these days, sorry for the late reply. No I did not have the time sorry 🙁. maybe you could consider this as 6 independent problems? and then combine the losses? Alternatively, maybe you could code a custom loss yourself? Although this might not be straightforward. See if I get a sec towards the end of the week. Alternatively I will see if @5uperpalo can look into it @5uperpalo let's have a chat see if we can code a custom loss that takes multiple inputs and produces a single output |
I am attempting to run the following code:
where x_test consists of a single column with text descriptions
and I am getting the following output
predict: 75%|███████▌ | 3/4 [00:00<00:00, 12.85it/s]
/usr/local/lib/python3.7/dist-packages/pytorch_widedeep/training/trainer.py in
--> 581 return np.vstack(preds_l).squeeze(1)
ValueError: cannot select an axis to squeeze out which has size not equal to one
I am wondering how to go about resolving this. I have already tried expanding the dims of the x_test and tried resizing it but I am still getting the same issue
The text was updated successfully, but these errors were encountered: