From 77a13f3f494659d8b6b86b488ea2ad28fd57bc6d Mon Sep 17 00:00:00 2001 From: Sanchit Baweja <99529694+sanchit45@users.noreply.github.com> Date: Fri, 16 Jun 2023 01:07:54 +0530 Subject: [PATCH] Created using Colaboratory --- Inference.ipynb | 1550 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1550 insertions(+) create mode 100644 Inference.ipynb diff --git a/Inference.ipynb b/Inference.ipynb new file mode 100644 index 0000000..fe389b7 --- /dev/null +++ b/Inference.ipynb @@ -0,0 +1,1550 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "mount_file_id": "1OcWUdMLPn4DUqxnHXzX-i9SNjEG6V0ed", + "authorship_tag": "ABX9TyOssri3CxykKoQxqNpBaRp/", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d2oXw1Rcf3Iz", + "outputId": "680e32ad-b163-4174-ca83-56bbf6e94a00" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[nltk_data] Downloading package stopwords to /root/nltk_data...\n", + "[nltk_data] Package stopwords is already up-to-date!\n", + "[nltk_data] Downloading package wordnet to /root/nltk_data...\n", + "[nltk_data] Package wordnet is already up-to-date!\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "device(type='cpu')" + ] + }, + "metadata": {}, + "execution_count": 3 + } + ], + "source": [ + "import nltk\n", + "nltk.download('stopwords')\n", + "nltk.download('wordnet')\n", + "import random\n", + "import re\n", + "import csv\n", + "from collections import Counter\n", + "from functools import partial\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from IPython.core.display import display, HTML\n", + "from sklearn.feature_extraction.text import TfidfVectorizer # TF-IDF\n", + "from sklearn.metrics import classification_report\n", + "from tqdm import tqdm, tqdm_notebook\n", + "\n", + "# PyTorch modules\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch import optim\n", + "from torch.optim.lr_scheduler import CosineAnnealingLR\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torch.utils.data.dataset import random_split\n", + "# nltk text processors\n", + "import nltk\n", + "from nltk.corpus import stopwords\n", + "from nltk.tokenize import wordpunct_tokenize\n", + "from nltk.stem import WordNetLemmatizer\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "plt.style.use('ggplot')\n", + "tqdm.pandas()\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "device" + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "\n", + "class FeedfowardTextClassifier(nn.Module):\n", + " def __init__(self, device, vocab_size, hidden1, hidden2, num_labels, batch_size):\n", + " super(FeedfowardTextClassifier, self).__init__()\n", + " self.device = device\n", + " self.batch_size = batch_size\n", + " self.fc1 = nn.Linear(vocab_size, hidden1)\n", + " self.fc2 = nn.Linear(hidden1, hidden2)\n", + " self.fc3 = nn.Linear(hidden2, num_labels)\n", + "\n", + " def forward(self, x):\n", + " batch_size = len(x)\n", + " if batch_size != self.batch_size:\n", + " self.batch_size = batch_size\n", + " x = torch.FloatTensor(x)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.relu(self.fc2(x))\n", + " return torch.sigmoid(self.fc3(x))\n", + "\n" + ], + "metadata": { + "id": "KwuXIy3J79tF" + }, + "execution_count": 79, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model=torch.load(\"/content/drive/MyDrive/bow.pth\")\n", + "\n", + "token_idx_mapping=model[\"token2idx\"]\n", + "\n", + "bow_model = FeedfowardTextClassifier(\n", + " vocab_size=len(token_idx_mapping),\n", + " hidden1=100,\n", + " hidden2=50,\n", + " num_labels=2,\n", + " device=device,\n", + " batch_size=528,\n", + ")\n", + "\n", + "\n", + "bow_model.load_state_dict(model[\"state_dict\"])\n", + "token_idx_mapping=model[\"token2idx\"]\n", + "idx_token_mapping=model[\"index2token\"]\n", + "print(bow_model)\n", + "print(model[\"state_dict\"])\n", + "\n", + "\n" + ], + "metadata": { + "id": "RZCeCptaf_vF", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5de58c37-2183-4c50-fd43-188710100790" + }, + "execution_count": 80, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "FeedfowardTextClassifier(\n", + " (fc1): Linear(in_features=1001, out_features=100, bias=True)\n", + " (fc2): Linear(in_features=100, out_features=50, bias=True)\n", + " (fc3): Linear(in_features=50, out_features=2, bias=True)\n", + ")\n", + "OrderedDict([('fc1.weight', tensor([[ 0.0641, 0.0777, 0.0736, ..., -0.0120, 0.0778, 0.0525],\n", + " [ 0.0108, 0.0332, 0.0205, ..., 0.0196, 0.0811, 0.0330],\n", + " [ 0.0385, 0.0363, 0.0224, ..., -0.0252, 0.0674, 0.0468],\n", + " ...,\n", + " [-0.0169, 0.0522, -0.0166, ..., 0.0073, -0.1093, -0.0380],\n", + " [ 0.0412, 0.0564, 0.0108, ..., 0.0010, 0.0973, 0.0652],\n", + " [ 0.0522, 0.0629, 0.0614, ..., 0.0130, 0.0807, 0.0153]])), ('fc1.bias', tensor([ 0.1516, -0.0235, -0.0126, 0.1621, 0.1890, -0.0100, 0.0177, -0.0038,\n", + " 0.1844, 0.0191, 0.1438, -0.0268, 0.1655, 0.1424, -0.0088, 0.1266,\n", + " -0.0149, 0.1494, -0.0268, -0.0415, 0.1577, 0.1955, -0.0312, 0.1441,\n", + " 0.0134, 0.1635, -0.0129, -0.0006, -0.0392, 0.1374, 0.0815, -0.0284,\n", + " -0.0372, 0.1256, -0.0046, 0.1263, -0.0021, -0.0031, 0.1666, 0.1318,\n", + " 0.1748, -0.0235, -0.0113, 0.1577, 0.1764, -0.0418, 0.1985, 0.0011,\n", + " 0.1286, 0.0003, 0.0092, -0.0190, 0.1538, -0.0131, -0.0257, 0.0412,\n", + " 0.1660, 0.1364, -0.0123, -0.0176, 0.0010, 0.0683, -0.0101, 0.1356,\n", + " 0.1348, -0.0073, -0.0297, 0.0048, 0.1192, -0.0377, -0.0225, 0.0698,\n", + " 0.0054, 0.0387, -0.0028, -0.0203, 0.1546, -0.0049, -0.0368, 0.0029,\n", + " 0.0535, -0.0246, 0.0227, -0.0214, 0.1341, 0.0152, -0.0214, 0.1828,\n", + " 0.0167, 0.0261, 0.0019, 0.0129, -0.0393, -0.0222, 0.1220, -0.0156,\n", + " -0.0046, 0.1622, -0.0135, 0.1019])), ('fc2.weight', tensor([[-0.0009, 0.0691, -0.0061, ..., -0.0125, -0.0497, 0.0171],\n", + " [ 0.0574, 0.0357, 0.1850, ..., -0.1763, 0.0296, 0.0415],\n", + " [ 0.0888, -0.1652, -0.1152, ..., 0.0689, -0.1314, 0.0317],\n", + " ...,\n", + " [ 0.0522, 0.0959, 0.1603, ..., -0.1666, 0.0266, 0.0918],\n", + " [ 0.1451, 0.0095, 0.1040, ..., -0.0018, 0.0703, 0.1030],\n", + " [ 0.0909, -0.1567, -0.1130, ..., 0.2051, -0.0170, 0.0962]])), ('fc2.bias', tensor([-0.0952, 0.0583, 0.0702, 0.0970, -0.0771, -0.0140, 0.0481, -0.0786,\n", + " 0.0871, -0.0383, -0.0989, 0.1007, 0.1372, 0.0631, -0.0539, 0.0111,\n", + " 0.0458, -0.0723, 0.0471, 0.0109, 0.0978, 0.0913, 0.1427, 0.0962,\n", + " 0.0551, 0.0504, -0.0860, 0.0875, 0.1323, 0.0986, -0.0560, -0.0186,\n", + " 0.0868, -0.0094, -0.0374, 0.0252, 0.0578, 0.0504, -0.1001, 0.0689,\n", + " 0.0503, -0.0639, 0.0086, 0.0872, 0.0899, 0.0606, 0.1024, -0.0206,\n", + " -0.0636, 0.0927])), ('fc3.weight', tensor([[ 0.1036, -0.1349, 0.1535, 0.2534, 0.1018, -0.2792, -0.1729, -0.2809,\n", + " -0.0762, -0.1080, -0.0304, 0.2792, 0.3222, -0.1937, 0.0067, -0.2276,\n", + " -0.2002, 0.1160, -0.1804, -0.1340, 0.3016, -0.0546, 0.1546, 0.2769,\n", + " 0.2378, -0.2121, 0.0523, -0.1221, 0.0939, 0.2682, -0.2137, -0.0166,\n", + " -0.2132, -0.1239, 0.0444, 0.0005, -0.0818, -0.0844, 0.0393, -0.0764,\n", + " -0.1519, -0.2586, -0.2240, -0.0493, -0.0410, -0.0172, 0.2522, -0.2151,\n", + " -0.0805, 0.2371],\n", + " [-0.0469, 0.1692, -0.3003, -0.1693, 0.0580, 0.0204, 0.0977, 0.0448,\n", + " 0.1486, 0.1134, 0.0028, -0.2857, -0.1327, 0.1828, -0.0789, 0.2077,\n", + " 0.1039, 0.0295, 0.1280, -0.1332, -0.2520, 0.1778, -0.2842, -0.2033,\n", + " -0.2871, 0.2008, -0.1146, 0.2140, -0.3136, -0.1811, 0.1098, -0.0269,\n", + " 0.0226, 0.0795, 0.0456, 0.2498, 0.1454, 0.1424, -0.1027, 0.1556,\n", + " 0.0394, 0.0533, 0.2207, 0.2238, 0.1800, 0.2094, -0.1223, 0.0894,\n", + " 0.1945, -0.1664]])), ('fc3.bias', tensor([ 0.1394, -0.1215]))])\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "tfidf_model = FeedfowardTextClassifier(\n", + " vocab_size=len(token_idx_mapping),\n", + " hidden1=100,\n", + " hidden2=50,\n", + " num_labels=2,\n", + " device=device,\n", + " batch_size=528,\n", + ")\n", + "tfidf_model" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KONGwXnt-qV5", + "outputId": "73e3f007-7596-4703-85fa-4589709282a2" + }, + "execution_count": 81, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "FeedfowardTextClassifier(\n", + " (fc1): Linear(in_features=1001, out_features=100, bias=True)\n", + " (fc2): Linear(in_features=100, out_features=50, bias=True)\n", + " (fc3): Linear(in_features=50, out_features=2, bias=True)\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 81 + } + ] + }, + { + "cell_type": "code", + "source": [ + "tfidf=torch.load(\"/content/drive/MyDrive/tfidf.pth\")\n", + "tfidf_model.load_state_dict(tfidf[\"state_dict\"])\n", + "print(tfidf_model)" + ], + "metadata": { + "id": "X6qhHj6ggNFH", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "00c52409-4a3e-4615-a506-6938cc68d907" + }, + "execution_count": 82, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "FeedfowardTextClassifier(\n", + " (fc1): Linear(in_features=1001, out_features=100, bias=True)\n", + " (fc2): Linear(in_features=100, out_features=50, bias=True)\n", + " (fc3): Linear(in_features=50, out_features=2, bias=True)\n", + ")\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "def build_vocab(corpus):\n", + " vocab = {}\n", + " for token in corpus:\n", + " if token not in vocab.keys():\n", + " vocab[token] = len(vocab)\n", + " return vocab\n", + "\n", + "def build_index2token(vocab):\n", + " index2token = {}\n", + " for token in vocab.keys():\n", + " index2token[vocab[token]] = token\n", + " return index2token\n", + "\n", + "def tokenize(text, stop_words, lemmatizer):\n", + " text = re.sub(r'[^\\w\\s]', '', text) # remove special characters\n", + " text = text.lower() # lowercase\n", + " tokens = wordpunct_tokenize(text) # tokenize\n", + " tokens = [lemmatizer.lemmatize(token) for token in tokens] # noun lemmatizer\n", + " tokens = [lemmatizer.lemmatize(token, \"v\") for token in tokens] # verb lemmatizer\n", + " tokens = [token for token in tokens if token not in stop_words] # remove stopwords\n", + " return tokens\n", + "\n", + "def build_bow_vector(sequence, idx2token):\n", + " vector = [0] * len(idx2token)\n", + " for token_idx in sequence:\n", + " if token_idx not in idx2token:\n", + " raise ValueError('Wrong sequence index found!')\n", + " else:\n", + " vector[token_idx] += 1\n", + " return vector\n", + "\n", + "def preprocess(data,token_idx_mapping,idx_token_mapping,feature):\n", + "\n", + " stop_words = set(stopwords.words('english'))\n", + " lemmatizer = WordNetLemmatizer()\n", + " tokens=tokenize(data, stop_words, lemmatizer)\n", + " vocab=build_vocab(tokens)\n", + " index2token=build_index2token(vocab)\n", + " doc=list(vocab.keys())\n", + " sequence=[token_idx_mapping[token] for token in doc if token in token_idx_mapping]\n", + " idx2token=idx_token_mapping\n", + " bow_vector=build_bow_vector(sequence, idx2token)\n", + " \"\"\"vectorizer = TfidfVectorizer(\n", + " analyzer='word',\n", + " tokenizer=lambda doc: doc,\n", + " preprocessor=lambda doc: doc,\n", + " token_pattern=None,\n", + " )\n", + " tfidf_vectors = vectorizer.fit_transform(tokens).toarray()\n", + " tfidf_vector = [vector.tolist() for vector in tfidf_vectors]\"\"\"\n", + " vectorizer = TfidfVectorizer()\n", + "\n", + " # Fit and transform the input text to obtain TF-IDF vectors\n", + " tfidf_vectors = vectorizer.fit_transform(tokens).toarray()\n", + "\n", + "\n", + " # Convert the TF-IDF vectors to a list of lists\n", + " tfidf_vector = tfidf_vectors.tolist()\n", + " if feature ==\"bow\":\n", + " return bow_vector\n", + " elif feature==\"tfidf\":\n", + " #return tfidf_vector\n", + " print(tfidf_vector)" + ], + "metadata": { + "id": "yLdfzBhqg33w" + }, + "execution_count": 83, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "data=\"Quora is a place to gain and share knowledge. It's a platform to ask questions and connect with people who contribute unique insights and quality answers \"" + ], + "metadata": { + "id": "bXs4jeEX-ABZ" + }, + "execution_count": 84, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "preprocess(data,token_idx_mapping,idx_token_mapping,feature=\"bow\")" + ], + "metadata": { + "id": "Pfpty-TyEOoz", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e22b0755-e7fc-496d-aea7-8ad7126b720b" + }, + "execution_count": 85, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " ...]" + ] + }, + "metadata": {}, + "execution_count": 85 + } + ] + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "r3Y5ZZtBCt18" + } + }, + { + "cell_type": "markdown", + "source": [ + "#use BOw model(just below this text cell) ,few input issues in tfidf" + ], + "metadata": { + "id": "dAXON2IvCuEK" + } + }, + { + "cell_type": "code", + "source": [ + "bow_model.eval()\n", + "bow_model.to(device)\n", + "y_pred = []\n", + "\n", + "\n", + "with torch.no_grad():\n", + "\n", + " inputs = preprocess(data,token_idx_mapping,idx_token_mapping,feature=\"bow\")\n", + " probs = bow_model([inputs])\n", + "\n", + "\n", + " probs = probs.detach().cpu().numpy()\n", + " predictions = np.argmax(probs, axis=1)\n", + " y_pred.extend(predictions)\n", + "\n", + "\n", + "print(y_pred)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "esm5puNChJRd", + "outputId": "00a9e473-4c48-4a4c-8fbd-d5f0605bb544" + }, + "execution_count": 86, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[0]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# working on this one\n", + "tfidf_model.eval()\n", + "tfidf_model.to(device)\n", + "y_pred = []\n", + "\n", + "\n", + "with torch.no_grad():\n", + "\n", + " inputs = preprocess(data,token_idx_mapping,idx_token_mapping,feature=\"tfidf\")\n", + " probs = tfidf_model([inputs])\n", + "\n", + "\n", + " probs = probs.detach().cpu().numpy()\n", + " predictions = np.argmax(probs, axis=1)\n", + " y_pred.extend(predictions)\n", + "\n", + "\n", + "print(y_pred)" + ], + "metadata": { + "id": "byLj-3hbiO-j", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "ebf2a6a8-5298-4d0a-d3d9-081f74b2da3f" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " (0, 16)\t0.39028945195301484\n", + " (0, 8)\t0.455453973865905\n", + " (0, 15)\t0.657873467373268\n", + " (0, 11)\t0.455453973865905\n", + " (1, 7)\t0.20259317165969928\n", + " (1, 14)\t0.40518634331939857\n", + " (1, 17)\t0.20259317165969928\n", + " (1, 4)\t0.20259317165969928\n", + " (1, 0)\t0.40518634331939857\n", + " (1, 12)\t0.40518634331939857\n", + " (1, 5)\t0.17104307638950533\n", + " (1, 9)\t0.17104307638950533\n", + " (1, 13)\t0.20259317165969928\n", + " (1, 1)\t0.34208615277901067\n", + " (1, 16)\t0.1465709212673861\n", + " (1, 8)\t0.34208615277901067\n", + " (1, 11)\t0.17104307638950533\n", + " (2, 7)\t0.21211528370190072\n", + " (2, 14)\t0.42423056740380144\n", + " (2, 17)\t0.21211528370190072\n", + " (2, 4)\t0.21211528370190072\n", + " (2, 0)\t0.42423056740380144\n", + " (2, 12)\t0.42423056740380144\n", + " (2, 5)\t0.17908229767263645\n", + " (2, 9)\t0.17908229767263645\n", + " (2, 13)\t0.21211528370190072\n", + " (2, 1)\t0.17908229767263645\n", + " (2, 16)\t0.15345992311775974\n", + " (2, 8)\t0.3581645953452729\n", + " (2, 11)\t0.17908229767263645\n", + " (3, 6)\t0.657873467373268\n", + " (3, 5)\t0.455453973865905\n", + " (3, 1)\t0.455453973865905\n", + " (3, 16)\t0.39028945195301484\n", + " (4, 10)\t0.6817217130641468\n", + " (4, 2)\t0.5590215568768956\n", + " (4, 9)\t0.4719644106114538\n", + " (5, 3)\t0.7732623667832087\n", + " (5, 2)\t0.6340862024337309\n" + ] + }, + { + "output_type": "error", + "ename": "TypeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpreprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtoken_idx_mapping\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0midx_token_mapping\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfeature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"tfidf\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprobs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtfidf_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: must be real number, not NoneType" + ] + } + ] + } + ] +} \ No newline at end of file