-
Notifications
You must be signed in to change notification settings - Fork 398
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #381 from pyerie/main
Create anti-malware, wikipedia bot, WikiLLM and RAG boilerplate
- Loading branch information
Showing
5 changed files
with
1,018 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from langchain_community.llms import Ollama | ||
from langchain_community.document_loaders import PyPDFLoader | ||
from langchain_community.embeddings import OllamaEmbeddings | ||
from langchain_community.vectorstores import FAISS | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.chains import create_retrieval_chain | ||
|
||
def create_RAG_model(input_file, llm): | ||
# Create the LLM (Large Language Model) | ||
llm = Ollama(model="dolphin-phi") | ||
# Define model used to embed the info | ||
embeddings = OllamaEmbeddings(model="nomic-embed-text") | ||
# Load the PDF | ||
loader = PyPDFLoader(input_file) | ||
doc = loader.load() | ||
# Split the text and embed it into the vector DB | ||
text_splitter = RecursiveCharacterTextSplitter() | ||
split = text_splitter.split_documents(doc) | ||
vector_store = FAISS.from_documents(split, embeddings) | ||
|
||
|
||
# Prompt generation: Giving the LLM character and purpose | ||
prompt = ChatPromptTemplate.from_template( | ||
""" | ||
Answer the following questions only based on the given context | ||
<context> | ||
{context} | ||
</context> | ||
Question: {input} | ||
""" | ||
) | ||
# Linking the LLM, vector DB and the prompt | ||
docs_chain = create_stuff_documents_chain(llm, prompt) | ||
retriever = vector_store.as_retriever() | ||
retrieval_chain = create_retrieval_chain(retriever, docs_chain) | ||
return retrieval_chain | ||
|
||
# Using the retrieval chain | ||
# Example: | ||
|
||
''' | ||
chain = create_RAG_model("your_file_here.pdf", "mistral") | ||
output = chain.invoke({"input":"What is the purpose of RAG?"}) | ||
print(output["answer"]) | ||
''' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from langchain_community.llms import Ollama | ||
from langchain_community.document_loaders import WebBaseLoader | ||
from langchain_community.embeddings import OllamaEmbeddings | ||
from langchain_community.vectorstores import FAISS | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.chains import create_retrieval_chain | ||
import wikipedia as wiki | ||
import os | ||
|
||
# NOTE: The following function is a RAG template written by me and wasn't copied from anywhere | ||
def create_RAG_model(url, llm): | ||
# Create the LLM (Large Language Model) | ||
llm = Ollama(model=str(llm)) | ||
# Define model used to embed the info | ||
embeddings = OllamaEmbeddings(model="nomic-embed-text") | ||
# Load the webpage | ||
loader = WebBaseLoader(str(url)) | ||
webpage = loader.load() | ||
# Split the text and embed it into the vector DB | ||
text_splitter = RecursiveCharacterTextSplitter() | ||
split = text_splitter.split_documents(webpage) | ||
if (os.path.exists("wiki_index")): | ||
vector_store = FAISS.load_local("wiki_index", allow_dangerous_deserialization=True, embeddings=embeddings) | ||
vector_store = vector_store.from_documents(split, embeddings) | ||
else: | ||
vector_store = FAISS.from_documents(split, embeddings) | ||
print("[+] Finished embedding!") | ||
vector_store.save_local("wiki_index") | ||
|
||
# Prompt generation: Giving the LLM character and purpose | ||
prompt = ChatPromptTemplate.from_template( | ||
""" | ||
Answer the following questions only based on the given context | ||
<context> | ||
{context} | ||
</context> | ||
Question: {input} | ||
""" | ||
) | ||
# Linking the LLM, vector DB and the prompt | ||
docs_chain = create_stuff_documents_chain(llm, prompt) | ||
retriever = vector_store.as_retriever() | ||
retrieval_chain = create_retrieval_chain(retriever, docs_chain) | ||
return retrieval_chain | ||
|
||
number = int(input("Do you want me to:\n 1) Learn from a single article \n 2) Learn from articles of a given topic\n :")) | ||
if (number == 2): | ||
topic = input("What topic to do you want me to learn?: ") | ||
results = wiki.search(topic) | ||
for result in results: | ||
wiki_url = str("https://en.wikipedia.org/wiki/"+str(result)).replace(' ','_') | ||
chain = create_RAG_model(wiki_url, "dolphin-phi") | ||
elif (number == 1): | ||
wiki_url = input("Give me the URL of the article: ") | ||
chain = create_RAG_model(wiki_url, "dolphin-phi") | ||
|
||
print("Type 'exit' to exit") | ||
|
||
while True: | ||
query = input("Ask me a question: ") | ||
if (query == "exit"): | ||
break | ||
else: | ||
output = chain.invoke({"input":query}) | ||
print(output["answer"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import wikipedia as wiki # Import the library | ||
|
||
topic = input("Please enter the topic: ") | ||
results = wiki.search(topic) # Search for related articles | ||
print("[+] Found", len(results), "entries!") | ||
print("Select the article: ") | ||
for index, value in enumerate(results): # Give the user an opportunity to choose between the articles | ||
print(str(index)+ ")"+" "+str(value)) | ||
|
||
print("\n") | ||
article = int(input()) | ||
try: # Try retrieving info from the Wiki page | ||
page = wiki.page(results[article]) | ||
print(str(page.title).center(1000)) | ||
print(page.url) | ||
print(wiki.summary(results[article], sentences=1)) | ||
except DisambiguationError as e: # Workaround for the disambiguation error | ||
print("[-] An error occured!") | ||
print("URL: "+"https://en.wikipedia.org/wiki/"+str(results[article]).replace(' ', '_')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
################################################################################# | ||
### Author: Pyerie # | ||
### Application: A not-so-accurate ML based anti-malware solution # | ||
################################################################################# | ||
|
||
print("[+] Loading.... ") | ||
import customtkinter | ||
from tkinter.filedialog import * | ||
from tkinter import * | ||
import pefile | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.tree import DecisionTreeClassifier | ||
from sklearn.model_selection import train_test_split | ||
from sklearn import metrics | ||
import os | ||
|
||
|
||
|
||
dataset = pd.read_csv('database3.csv') | ||
X = dataset.drop(['legitimate'],axis=1).values | ||
|
||
y = dataset['legitimate'].values | ||
|
||
|
||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) | ||
clf = DecisionTreeClassifier() | ||
|
||
|
||
y_test = y_test.reshape(-1,1) | ||
for i in range(0, 10): | ||
clf = clf.fit(X_train,y_train) | ||
res1 = clf.predict(X_test) | ||
accuracy = metrics.accuracy_score(y_test, res1) | ||
accuracy = str(accuracy)[2:4] + "%" | ||
print("Accuracy: "+accuracy) | ||
|
||
|
||
customtkinter.set_appearance_mode("dark") | ||
customtkinter.set_default_color_theme("dark-blue") | ||
|
||
|
||
window = Tk() | ||
screen_width = window.winfo_screenwidth() | ||
screen_height = window.winfo_screenheight() | ||
window.geometry(str(screen_width)+"x"+str(screen_height)) | ||
window.title("eSuraksha") | ||
window['bg'] = "#121212" | ||
def extract_features(file): | ||
features = [] | ||
|
||
|
||
|
||
try: | ||
|
||
pe_obj = pefile.PE(file, fast_load=True) | ||
except pefile.PEFormatError as error: | ||
print("Not PE file!") | ||
|
||
features.append(pe_obj.OPTIONAL_HEADER.DATA_DIRECTORY[6].Size) | ||
features.append(pe_obj.OPTIONAL_HEADER.DATA_DIRECTORY[6].VirtualAddress) | ||
features.append(pe_obj.OPTIONAL_HEADER.MajorImageVersion) | ||
features.append(pe_obj.OPTIONAL_HEADER.MajorOperatingSystemVersion) | ||
features.append(pe_obj.OPTIONAL_HEADER.DATA_DIRECTORY[0].VirtualAddress) | ||
features.append(pe_obj.OPTIONAL_HEADER.DATA_DIRECTORY[0].Size) | ||
try: | ||
features.append(pe_obj.OPTIONAL_HEADER.DATA_DIRECTORY[12].VirtualAddress) | ||
except: | ||
features.append(0) | ||
features.append(pe_obj.OPTIONAL_HEADER.DATA_DIRECTORY[2].Size) | ||
features.append(pe_obj.OPTIONAL_HEADER.MajorLinkerVersion) | ||
features.append(pe_obj.FILE_HEADER.NumberOfSections) | ||
features.append(pe_obj.OPTIONAL_HEADER.SizeOfStackReserve) | ||
features.append(pe_obj.OPTIONAL_HEADER.DllCharacteristics) | ||
features.append(pe_obj.OPTIONAL_HEADER.AddressOfEntryPoint) | ||
features.append(pe_obj.OPTIONAL_HEADER.ImageBase) | ||
|
||
|
||
|
||
|
||
|
||
|
||
return features | ||
|
||
toplevel_created = False | ||
|
||
toplevel2_created = False | ||
|
||
def single_file(): | ||
|
||
global toplevel_created | ||
global toplevel2_created | ||
global single_file_top | ||
if toplevel_created == "True": | ||
single_file_top.destroy() | ||
toplevel_created = "False" | ||
elif toplevel_created == "False": | ||
pass | ||
|
||
if toplevel2_created == "True": | ||
many_files.destroy() | ||
toplevel2_created = "False" | ||
elif toplevel2_created == "False": | ||
pass | ||
|
||
single_file_top = Toplevel(window) | ||
single_file_top.geometry("350x200") | ||
customtkinter.set_appearance_mode("dark") | ||
customtkinter.set_default_color_theme("dark-blue") | ||
single_file_top['bg'] = "#121212" | ||
single_file_top.title("Scan a single file") | ||
toplevel_created = "True" | ||
result = customtkinter.CTkLabel(single_file_top, text="Loading...") | ||
result.pack() | ||
|
||
file_path = askopenfilename() | ||
try: | ||
features_extracted = extract_features(str(file_path)) | ||
not_pe = False | ||
except UnboundLocalError as e: | ||
not_pe = True | ||
result.after(0, result.destroy) | ||
benign_l = customtkinter.CTkLabel(single_file_top, text="Not PE file!") | ||
benign_l.pack() | ||
toplevel2_created = False | ||
|
||
if not_pe != True: | ||
data_of_sample = np.array(features_extracted) | ||
data_of_sample = data_of_sample.reshape(1,-1) | ||
|
||
|
||
prediction = clf.predict(data_of_sample) | ||
|
||
|
||
if prediction == 1: | ||
result.after(0, result.destroy) | ||
|
||
malware_l = customtkinter.CTkLabel(single_file_top, fg_color="red", text="ML model detected malware!") | ||
malware_l.pack() | ||
|
||
|
||
elif prediction == 0: | ||
result.after(0, result.destroy) | ||
benign_l = customtkinter.CTkLabel(single_file_top, fg_color="green", text="No malware detected!") | ||
benign_l.pack() | ||
|
||
|
||
def scan_many(): | ||
|
||
|
||
global toplevel2_created | ||
global toplevel_created | ||
global many_files | ||
|
||
if toplevel2_created == "True": | ||
many_files.destroy() | ||
toplevel2_created = "False" | ||
elif toplevel2_created == "False": | ||
pass | ||
|
||
if toplevel_created == "True": | ||
single_file_top.destroy() | ||
toplevel_created = "False" | ||
elif toplevel_created == "False": | ||
pass | ||
|
||
many_files = Toplevel(window) | ||
many_files.geometry("350x200") | ||
customtkinter.set_appearance_mode("dark") | ||
customtkinter.set_default_color_theme("dark-blue") | ||
many_files['bg'] = "#121212" | ||
many_files.title("Scan a directory") | ||
toplevel2_created = "True" | ||
result2 = customtkinter.CTkLabel(many_files, text="Loading...") | ||
result2.pack() | ||
malware_many = [] | ||
directory = askdirectory() | ||
global extracted | ||
for root, directory, files in os.walk(str(directory)): | ||
for name_of_file in files: | ||
path = os.path.join(str(root),str(name_of_file)) | ||
|
||
formats_of_pe = [".acm" , ".ax" , ".cpl" , ".dll" , ".drv" , ".efi" , ".exe" , ".mui" , ".ocx" , ".scr" , ".sys" , ".tsp", ".bin"] | ||
for format_i in formats_of_pe: | ||
if name_of_file.endswith(format_i) == True: | ||
|
||
extracted = 1 | ||
try: | ||
|
||
features_of_many = extract_features(str(path)) | ||
except UnboundLocalError as e: | ||
pass | ||
break | ||
|
||
else: | ||
extracted = 0 | ||
|
||
|
||
|
||
if extracted == 1: | ||
data_for_many = np.array(features_of_many) | ||
data_for_many = data_for_many.reshape(1,-1) | ||
|
||
prediction_for_many = clf.predict(data_for_many) | ||
|
||
|
||
if prediction_for_many == 1: | ||
malware_many.append(str(path)) | ||
|
||
|
||
if len(malware_many) != 0: | ||
result2.after(0, result2.destroy) | ||
malware_label2 = customtkinter.CTkLabel(many_files,text="Malware found: ") | ||
malware_label2.pack() | ||
malware_text_box = customtkinter.CTkTextbox(many_files) | ||
for_text_box = '' | ||
|
||
for name_of_malware in malware_many: | ||
for_text_box += "".join([name_of_malware, '\n------------------------------------------']) | ||
|
||
|
||
|
||
malware_text_box.insert('0.0',for_text_box) | ||
malware_text_box.configure(state="disabled") | ||
malware_text_box.pack() | ||
|
||
|
||
|
||
|
||
elif len(malware_many) == 0: | ||
result2.after(0, result2.destroy) | ||
benign_label = customtkinter.CTkLabel(many_files,text="No malware found!") | ||
benign_label.pack() | ||
|
||
button1 = customtkinter.CTkButton(master=window, command=single_file,text="Scan a single file") | ||
button1.pack() | ||
button2 = customtkinter.CTkButton(master=window, command=scan_many, text="Scan a folder") | ||
button2.pack() | ||
|
||
window.mainloop() |
Oops, something went wrong.