From 77a13f3f494659d8b6b86b488ea2ad28fd57bc6d Mon Sep 17 00:00:00 2001
From: Sanchit Baweja <99529694+sanchit45@users.noreply.github.com>
Date: Fri, 16 Jun 2023 01:07:54 +0530
Subject: [PATCH] Created using Colaboratory
---
Inference.ipynb | 1550 +++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 1550 insertions(+)
create mode 100644 Inference.ipynb
diff --git a/Inference.ipynb b/Inference.ipynb
new file mode 100644
index 0000000..fe389b7
--- /dev/null
+++ b/Inference.ipynb
@@ -0,0 +1,1550 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "mount_file_id": "1OcWUdMLPn4DUqxnHXzX-i9SNjEG6V0ed",
+ "authorship_tag": "ABX9TyOssri3CxykKoQxqNpBaRp/",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "d2oXw1Rcf3Iz",
+ "outputId": "680e32ad-b163-4174-ca83-56bbf6e94a00"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
+ "[nltk_data] Package stopwords is already up-to-date!\n",
+ "[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
+ "[nltk_data] Package wordnet is already up-to-date!\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "device(type='cpu')"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 3
+ }
+ ],
+ "source": [
+ "import nltk\n",
+ "nltk.download('stopwords')\n",
+ "nltk.download('wordnet')\n",
+ "import random\n",
+ "import re\n",
+ "import csv\n",
+ "from collections import Counter\n",
+ "from functools import partial\n",
+ "from pathlib import Path\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from IPython.core.display import display, HTML\n",
+ "from sklearn.feature_extraction.text import TfidfVectorizer # TF-IDF\n",
+ "from sklearn.metrics import classification_report\n",
+ "from tqdm import tqdm, tqdm_notebook\n",
+ "\n",
+ "# PyTorch modules\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "from torch import optim\n",
+ "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from torch.utils.data.dataset import random_split\n",
+ "# nltk text processors\n",
+ "import nltk\n",
+ "from nltk.corpus import stopwords\n",
+ "from nltk.tokenize import wordpunct_tokenize\n",
+ "from nltk.stem import WordNetLemmatizer\n",
+ "\n",
+ "%matplotlib inline\n",
+ "%config InlineBackend.figure_formats = ['svg']\n",
+ "plt.style.use('ggplot')\n",
+ "tqdm.pandas()\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "\n",
+ "class FeedfowardTextClassifier(nn.Module):\n",
+ " def __init__(self, device, vocab_size, hidden1, hidden2, num_labels, batch_size):\n",
+ " super(FeedfowardTextClassifier, self).__init__()\n",
+ " self.device = device\n",
+ " self.batch_size = batch_size\n",
+ " self.fc1 = nn.Linear(vocab_size, hidden1)\n",
+ " self.fc2 = nn.Linear(hidden1, hidden2)\n",
+ " self.fc3 = nn.Linear(hidden2, num_labels)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " batch_size = len(x)\n",
+ " if batch_size != self.batch_size:\n",
+ " self.batch_size = batch_size\n",
+ " x = torch.FloatTensor(x)\n",
+ " x = F.relu(self.fc1(x))\n",
+ " x = F.relu(self.fc2(x))\n",
+ " return torch.sigmoid(self.fc3(x))\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "KwuXIy3J79tF"
+ },
+ "execution_count": 79,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model=torch.load(\"/content/drive/MyDrive/bow.pth\")\n",
+ "\n",
+ "token_idx_mapping=model[\"token2idx\"]\n",
+ "\n",
+ "bow_model = FeedfowardTextClassifier(\n",
+ " vocab_size=len(token_idx_mapping),\n",
+ " hidden1=100,\n",
+ " hidden2=50,\n",
+ " num_labels=2,\n",
+ " device=device,\n",
+ " batch_size=528,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "bow_model.load_state_dict(model[\"state_dict\"])\n",
+ "token_idx_mapping=model[\"token2idx\"]\n",
+ "idx_token_mapping=model[\"index2token\"]\n",
+ "print(bow_model)\n",
+ "print(model[\"state_dict\"])\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "RZCeCptaf_vF",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "5de58c37-2183-4c50-fd43-188710100790"
+ },
+ "execution_count": 80,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "FeedfowardTextClassifier(\n",
+ " (fc1): Linear(in_features=1001, out_features=100, bias=True)\n",
+ " (fc2): Linear(in_features=100, out_features=50, bias=True)\n",
+ " (fc3): Linear(in_features=50, out_features=2, bias=True)\n",
+ ")\n",
+ "OrderedDict([('fc1.weight', tensor([[ 0.0641, 0.0777, 0.0736, ..., -0.0120, 0.0778, 0.0525],\n",
+ " [ 0.0108, 0.0332, 0.0205, ..., 0.0196, 0.0811, 0.0330],\n",
+ " [ 0.0385, 0.0363, 0.0224, ..., -0.0252, 0.0674, 0.0468],\n",
+ " ...,\n",
+ " [-0.0169, 0.0522, -0.0166, ..., 0.0073, -0.1093, -0.0380],\n",
+ " [ 0.0412, 0.0564, 0.0108, ..., 0.0010, 0.0973, 0.0652],\n",
+ " [ 0.0522, 0.0629, 0.0614, ..., 0.0130, 0.0807, 0.0153]])), ('fc1.bias', tensor([ 0.1516, -0.0235, -0.0126, 0.1621, 0.1890, -0.0100, 0.0177, -0.0038,\n",
+ " 0.1844, 0.0191, 0.1438, -0.0268, 0.1655, 0.1424, -0.0088, 0.1266,\n",
+ " -0.0149, 0.1494, -0.0268, -0.0415, 0.1577, 0.1955, -0.0312, 0.1441,\n",
+ " 0.0134, 0.1635, -0.0129, -0.0006, -0.0392, 0.1374, 0.0815, -0.0284,\n",
+ " -0.0372, 0.1256, -0.0046, 0.1263, -0.0021, -0.0031, 0.1666, 0.1318,\n",
+ " 0.1748, -0.0235, -0.0113, 0.1577, 0.1764, -0.0418, 0.1985, 0.0011,\n",
+ " 0.1286, 0.0003, 0.0092, -0.0190, 0.1538, -0.0131, -0.0257, 0.0412,\n",
+ " 0.1660, 0.1364, -0.0123, -0.0176, 0.0010, 0.0683, -0.0101, 0.1356,\n",
+ " 0.1348, -0.0073, -0.0297, 0.0048, 0.1192, -0.0377, -0.0225, 0.0698,\n",
+ " 0.0054, 0.0387, -0.0028, -0.0203, 0.1546, -0.0049, -0.0368, 0.0029,\n",
+ " 0.0535, -0.0246, 0.0227, -0.0214, 0.1341, 0.0152, -0.0214, 0.1828,\n",
+ " 0.0167, 0.0261, 0.0019, 0.0129, -0.0393, -0.0222, 0.1220, -0.0156,\n",
+ " -0.0046, 0.1622, -0.0135, 0.1019])), ('fc2.weight', tensor([[-0.0009, 0.0691, -0.0061, ..., -0.0125, -0.0497, 0.0171],\n",
+ " [ 0.0574, 0.0357, 0.1850, ..., -0.1763, 0.0296, 0.0415],\n",
+ " [ 0.0888, -0.1652, -0.1152, ..., 0.0689, -0.1314, 0.0317],\n",
+ " ...,\n",
+ " [ 0.0522, 0.0959, 0.1603, ..., -0.1666, 0.0266, 0.0918],\n",
+ " [ 0.1451, 0.0095, 0.1040, ..., -0.0018, 0.0703, 0.1030],\n",
+ " [ 0.0909, -0.1567, -0.1130, ..., 0.2051, -0.0170, 0.0962]])), ('fc2.bias', tensor([-0.0952, 0.0583, 0.0702, 0.0970, -0.0771, -0.0140, 0.0481, -0.0786,\n",
+ " 0.0871, -0.0383, -0.0989, 0.1007, 0.1372, 0.0631, -0.0539, 0.0111,\n",
+ " 0.0458, -0.0723, 0.0471, 0.0109, 0.0978, 0.0913, 0.1427, 0.0962,\n",
+ " 0.0551, 0.0504, -0.0860, 0.0875, 0.1323, 0.0986, -0.0560, -0.0186,\n",
+ " 0.0868, -0.0094, -0.0374, 0.0252, 0.0578, 0.0504, -0.1001, 0.0689,\n",
+ " 0.0503, -0.0639, 0.0086, 0.0872, 0.0899, 0.0606, 0.1024, -0.0206,\n",
+ " -0.0636, 0.0927])), ('fc3.weight', tensor([[ 0.1036, -0.1349, 0.1535, 0.2534, 0.1018, -0.2792, -0.1729, -0.2809,\n",
+ " -0.0762, -0.1080, -0.0304, 0.2792, 0.3222, -0.1937, 0.0067, -0.2276,\n",
+ " -0.2002, 0.1160, -0.1804, -0.1340, 0.3016, -0.0546, 0.1546, 0.2769,\n",
+ " 0.2378, -0.2121, 0.0523, -0.1221, 0.0939, 0.2682, -0.2137, -0.0166,\n",
+ " -0.2132, -0.1239, 0.0444, 0.0005, -0.0818, -0.0844, 0.0393, -0.0764,\n",
+ " -0.1519, -0.2586, -0.2240, -0.0493, -0.0410, -0.0172, 0.2522, -0.2151,\n",
+ " -0.0805, 0.2371],\n",
+ " [-0.0469, 0.1692, -0.3003, -0.1693, 0.0580, 0.0204, 0.0977, 0.0448,\n",
+ " 0.1486, 0.1134, 0.0028, -0.2857, -0.1327, 0.1828, -0.0789, 0.2077,\n",
+ " 0.1039, 0.0295, 0.1280, -0.1332, -0.2520, 0.1778, -0.2842, -0.2033,\n",
+ " -0.2871, 0.2008, -0.1146, 0.2140, -0.3136, -0.1811, 0.1098, -0.0269,\n",
+ " 0.0226, 0.0795, 0.0456, 0.2498, 0.1454, 0.1424, -0.1027, 0.1556,\n",
+ " 0.0394, 0.0533, 0.2207, 0.2238, 0.1800, 0.2094, -0.1223, 0.0894,\n",
+ " 0.1945, -0.1664]])), ('fc3.bias', tensor([ 0.1394, -0.1215]))])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "tfidf_model = FeedfowardTextClassifier(\n",
+ " vocab_size=len(token_idx_mapping),\n",
+ " hidden1=100,\n",
+ " hidden2=50,\n",
+ " num_labels=2,\n",
+ " device=device,\n",
+ " batch_size=528,\n",
+ ")\n",
+ "tfidf_model"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "KONGwXnt-qV5",
+ "outputId": "73e3f007-7596-4703-85fa-4589709282a2"
+ },
+ "execution_count": 81,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "FeedfowardTextClassifier(\n",
+ " (fc1): Linear(in_features=1001, out_features=100, bias=True)\n",
+ " (fc2): Linear(in_features=100, out_features=50, bias=True)\n",
+ " (fc3): Linear(in_features=50, out_features=2, bias=True)\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 81
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "tfidf=torch.load(\"/content/drive/MyDrive/tfidf.pth\")\n",
+ "tfidf_model.load_state_dict(tfidf[\"state_dict\"])\n",
+ "print(tfidf_model)"
+ ],
+ "metadata": {
+ "id": "X6qhHj6ggNFH",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "00c52409-4a3e-4615-a506-6938cc68d907"
+ },
+ "execution_count": 82,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "FeedfowardTextClassifier(\n",
+ " (fc1): Linear(in_features=1001, out_features=100, bias=True)\n",
+ " (fc2): Linear(in_features=100, out_features=50, bias=True)\n",
+ " (fc3): Linear(in_features=50, out_features=2, bias=True)\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def build_vocab(corpus):\n",
+ " vocab = {}\n",
+ " for token in corpus:\n",
+ " if token not in vocab.keys():\n",
+ " vocab[token] = len(vocab)\n",
+ " return vocab\n",
+ "\n",
+ "def build_index2token(vocab):\n",
+ " index2token = {}\n",
+ " for token in vocab.keys():\n",
+ " index2token[vocab[token]] = token\n",
+ " return index2token\n",
+ "\n",
+ "def tokenize(text, stop_words, lemmatizer):\n",
+ " text = re.sub(r'[^\\w\\s]', '', text) # remove special characters\n",
+ " text = text.lower() # lowercase\n",
+ " tokens = wordpunct_tokenize(text) # tokenize\n",
+ " tokens = [lemmatizer.lemmatize(token) for token in tokens] # noun lemmatizer\n",
+ " tokens = [lemmatizer.lemmatize(token, \"v\") for token in tokens] # verb lemmatizer\n",
+ " tokens = [token for token in tokens if token not in stop_words] # remove stopwords\n",
+ " return tokens\n",
+ "\n",
+ "def build_bow_vector(sequence, idx2token):\n",
+ " vector = [0] * len(idx2token)\n",
+ " for token_idx in sequence:\n",
+ " if token_idx not in idx2token:\n",
+ " raise ValueError('Wrong sequence index found!')\n",
+ " else:\n",
+ " vector[token_idx] += 1\n",
+ " return vector\n",
+ "\n",
+ "def preprocess(data,token_idx_mapping,idx_token_mapping,feature):\n",
+ "\n",
+ " stop_words = set(stopwords.words('english'))\n",
+ " lemmatizer = WordNetLemmatizer()\n",
+ " tokens=tokenize(data, stop_words, lemmatizer)\n",
+ " vocab=build_vocab(tokens)\n",
+ " index2token=build_index2token(vocab)\n",
+ " doc=list(vocab.keys())\n",
+ " sequence=[token_idx_mapping[token] for token in doc if token in token_idx_mapping]\n",
+ " idx2token=idx_token_mapping\n",
+ " bow_vector=build_bow_vector(sequence, idx2token)\n",
+ " \"\"\"vectorizer = TfidfVectorizer(\n",
+ " analyzer='word',\n",
+ " tokenizer=lambda doc: doc,\n",
+ " preprocessor=lambda doc: doc,\n",
+ " token_pattern=None,\n",
+ " )\n",
+ " tfidf_vectors = vectorizer.fit_transform(tokens).toarray()\n",
+ " tfidf_vector = [vector.tolist() for vector in tfidf_vectors]\"\"\"\n",
+ " vectorizer = TfidfVectorizer()\n",
+ "\n",
+ " # Fit and transform the input text to obtain TF-IDF vectors\n",
+ " tfidf_vectors = vectorizer.fit_transform(tokens).toarray()\n",
+ "\n",
+ "\n",
+ " # Convert the TF-IDF vectors to a list of lists\n",
+ " tfidf_vector = tfidf_vectors.tolist()\n",
+ " if feature ==\"bow\":\n",
+ " return bow_vector\n",
+ " elif feature==\"tfidf\":\n",
+ " #return tfidf_vector\n",
+ " print(tfidf_vector)"
+ ],
+ "metadata": {
+ "id": "yLdfzBhqg33w"
+ },
+ "execution_count": 83,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "data=\"Quora is a place to gain and share knowledge. It's a platform to ask questions and connect with people who contribute unique insights and quality answers \""
+ ],
+ "metadata": {
+ "id": "bXs4jeEX-ABZ"
+ },
+ "execution_count": 84,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "preprocess(data,token_idx_mapping,idx_token_mapping,feature=\"bow\")"
+ ],
+ "metadata": {
+ "id": "Pfpty-TyEOoz",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "e22b0755-e7fc-496d-aea7-8ad7126b720b"
+ },
+ "execution_count": 85,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 1,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 1,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 1,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 1,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 1,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 1,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " 0,\n",
+ " ...]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 85
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [],
+ "metadata": {
+ "id": "r3Y5ZZtBCt18"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#use BOw model(just below this text cell) ,few input issues in tfidf"
+ ],
+ "metadata": {
+ "id": "dAXON2IvCuEK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "bow_model.eval()\n",
+ "bow_model.to(device)\n",
+ "y_pred = []\n",
+ "\n",
+ "\n",
+ "with torch.no_grad():\n",
+ "\n",
+ " inputs = preprocess(data,token_idx_mapping,idx_token_mapping,feature=\"bow\")\n",
+ " probs = bow_model([inputs])\n",
+ "\n",
+ "\n",
+ " probs = probs.detach().cpu().numpy()\n",
+ " predictions = np.argmax(probs, axis=1)\n",
+ " y_pred.extend(predictions)\n",
+ "\n",
+ "\n",
+ "print(y_pred)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "esm5puNChJRd",
+ "outputId": "00a9e473-4c48-4a4c-8fbd-d5f0605bb544"
+ },
+ "execution_count": 86,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "[0]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# working on this one\n",
+ "tfidf_model.eval()\n",
+ "tfidf_model.to(device)\n",
+ "y_pred = []\n",
+ "\n",
+ "\n",
+ "with torch.no_grad():\n",
+ "\n",
+ " inputs = preprocess(data,token_idx_mapping,idx_token_mapping,feature=\"tfidf\")\n",
+ " probs = tfidf_model([inputs])\n",
+ "\n",
+ "\n",
+ " probs = probs.detach().cpu().numpy()\n",
+ " predictions = np.argmax(probs, axis=1)\n",
+ " y_pred.extend(predictions)\n",
+ "\n",
+ "\n",
+ "print(y_pred)"
+ ],
+ "metadata": {
+ "id": "byLj-3hbiO-j",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "outputId": "ebf2a6a8-5298-4d0a-d3d9-081f74b2da3f"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " (0, 16)\t0.39028945195301484\n",
+ " (0, 8)\t0.455453973865905\n",
+ " (0, 15)\t0.657873467373268\n",
+ " (0, 11)\t0.455453973865905\n",
+ " (1, 7)\t0.20259317165969928\n",
+ " (1, 14)\t0.40518634331939857\n",
+ " (1, 17)\t0.20259317165969928\n",
+ " (1, 4)\t0.20259317165969928\n",
+ " (1, 0)\t0.40518634331939857\n",
+ " (1, 12)\t0.40518634331939857\n",
+ " (1, 5)\t0.17104307638950533\n",
+ " (1, 9)\t0.17104307638950533\n",
+ " (1, 13)\t0.20259317165969928\n",
+ " (1, 1)\t0.34208615277901067\n",
+ " (1, 16)\t0.1465709212673861\n",
+ " (1, 8)\t0.34208615277901067\n",
+ " (1, 11)\t0.17104307638950533\n",
+ " (2, 7)\t0.21211528370190072\n",
+ " (2, 14)\t0.42423056740380144\n",
+ " (2, 17)\t0.21211528370190072\n",
+ " (2, 4)\t0.21211528370190072\n",
+ " (2, 0)\t0.42423056740380144\n",
+ " (2, 12)\t0.42423056740380144\n",
+ " (2, 5)\t0.17908229767263645\n",
+ " (2, 9)\t0.17908229767263645\n",
+ " (2, 13)\t0.21211528370190072\n",
+ " (2, 1)\t0.17908229767263645\n",
+ " (2, 16)\t0.15345992311775974\n",
+ " (2, 8)\t0.3581645953452729\n",
+ " (2, 11)\t0.17908229767263645\n",
+ " (3, 6)\t0.657873467373268\n",
+ " (3, 5)\t0.455453973865905\n",
+ " (3, 1)\t0.455453973865905\n",
+ " (3, 16)\t0.39028945195301484\n",
+ " (4, 10)\t0.6817217130641468\n",
+ " (4, 2)\t0.5590215568768956\n",
+ " (4, 9)\t0.4719644106114538\n",
+ " (5, 3)\t0.7732623667832087\n",
+ " (5, 2)\t0.6340862024337309\n"
+ ]
+ },
+ {
+ "output_type": "error",
+ "ename": "TypeError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpreprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtoken_idx_mapping\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0midx_token_mapping\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfeature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"tfidf\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprobs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtfidf_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mTypeError\u001b[0m: must be real number, not NoneType"
+ ]
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file
|