Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem loading a finetuned model. #135

Closed
rodgzilla opened this issue Dec 20, 2018 · 8 comments
Closed

Problem loading a finetuned model. #135

rodgzilla opened this issue Dec 20, 2018 · 8 comments

Comments

@rodgzilla
Copy link
Contributor

Hi!

There is a problem with the way model are saved and loaded. The following code should crash and doesn't:

import torch
from pytorch_pretrained_bert import BertForSequenceClassification

model_fn = 'model.bin'
bert_model = 'bert-base-multilingual-cased'
model = BertForSequenceClassification.from_pretrained(bert_model, num_labels = 16)

model_to_save = model.module if hasattr(model, 'module') else model
torch.save(model_to_save.state_dict(), model_fn)
print(model_to_save.num_labels)

model_state_dict = torch.load(model_fn)
loaded_model = BertForSequenceClassification.from_pretrained(bert_model, state_dict = model_state_dict)
print(loaded_model.num_labels)

This code prints:

16
2

The code should raise an exception when trying to load the weights of the task specific linear layer. I'm guessing that the problem comes from PreTrainedBertModel.from_pretrained.

I would be happy to submit a PR fixing this problem but I'm not used to work with the PyTorch loading mechanisms. @thomwolf could you give me some guidance?

Cheers!

@rodgzilla
Copy link
Contributor Author

rodgzilla commented Dec 20, 2018

Ok I managed to find the problem. It comes from:

https://github.com/huggingface/pytorch-pretrained-BERT/blob/7fb94ab934b2ad1041613fc93c61d13105faf98a/pytorch_pretrained_bert/modeling.py#L534-L540

When trying to load classifier.weight and classifier.bias, the following line gets added to error_msgs:

size mismatch for classifier.weight: copying a param with shape torch.Size([16, 768]) from checkpoint, the shape in current model is torch.Size([2, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([2]).

First, I think that we should add a check of error_msgs to from_pretrained. I don't really know if there is any other way than printing an error message and existing the program since the default behavior (keeping the classifier layer randomly initialized) can be frustrating for the user (I speak from experience ^^).

To fix this, we should probably fetch the number of labels of the saved model and use it to instanciate the model being created before loading the saved weights. Unfortunately I don't really know how to do that, any idea?

Another possible "fix" would be to force the user to give a num_labels argument when loading a pretrained classification model with the following code in BertForSequenceClassification:

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        if 'num_labels' not in kwargs:
            raise ValueError('num_labels should be given when loading a pre-trained classification model')
        return super().from_pretrained(*args, **kwargs)

And even with this code, we are not able to check that the num_labels value is the same as the saved model. I don't really like the idea of forcing the user to give an information that the checkpoint already contains.

@HamidMoghaddam
Copy link

HamidMoghaddam commented Jan 3, 2019

Just use the num_labels when you load your model

model_state_dict = torch.load(model_fn)
loaded_model = BertForSequenceClassification.from_pretrained(bert_model, state_dict = model_state_dict, num_labels = 16)
print(loaded_model.num_labels)```

@rodgzilla
Copy link
Contributor Author

As mentioned in my previous posts, I think that the library should either fetch the number of labels from the save file or force the user to provide a num_labels argument.

While what you are proposing fixes my problem I would like to prevent this problem for other users in the future by patching the library code.

@thomwolf
Copy link
Member

thomwolf commented Jan 7, 2019

I see thanks @rodgzilla. Indeed not using the error_msg is bad practice, let's raise these errors.

Regarding fetching the number of labels, I understand your point but it will probably add too much custom logic in the library for the moment so let's go for your simple solution of setting the number of labels as mandatory for now (should have done that since the beginning).

@thomwolf thomwolf closed this as completed Jan 7, 2019
@ugm2
Copy link

ugm2 commented Jan 12, 2021

Hi everyone!
I had to come here to know that I had to include num_labels when loading the model because the error was misleading.
Also, I didn't know how many labels there were so I had to guess.
The model I was trying to load:
biobert-base-cased-v1.1-mnli

@kaankork
Copy link

kaankork commented Jul 30, 2021

I'm also facing a similar problem using the same model as @ugm2 - biobert-base-cased-v1.1-mnli

In my example I know the exact num_labels and provide it as an argument while loading the model.
How can I solve this?

RuntimeError: Error(s) in loading state_dict for BertForSequenceClassification:
	size mismatch for classifier.weight: copying a param with shape torch.Size([3, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([10]).

@LysandreJik
Copy link
Member

With the latest transformers versions, you can use the recently introduced (#12664) ignore_mismatched_sizes=True parameter for from_pretrained method in order to specify that you'd rather drop the layers that have incompatible shapes rather than raise a RuntimeError.

@b-hritz2000
Copy link

b-hritz2000 commented Feb 23, 2025

Issue:

Exception in thread Thread-1 (process_files):
Traceback (most recent call last):
  File "C:\Program Files\Python312\Lib\threading.py", line 1073, in _bootstrap_inner
    self.run()
  File "C:\Program Files\Python312\Lib\threading.py", line 1010, in run
    self._target(*self._args, **self._kwargs)
  File "d:\HMB820419\OneDrive - TATA MOTORS LTD\Documents\UV\Digitization Initiatives\1. BOM Classifier\GUI based 400 Line item classifier_V18 trail using retrainer.py", line 725, in process_files
    model = BertForSequenceClassification.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\HMB820419\AppData\Roaming\Python\Python312\site-packages\transformers\modeling_utils.py", line 4014, in from_pretrained
    ) = cls._load_pretrained_model(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\HMB820419\AppData\Roaming\Python\Python312\site-packages\transformers\modeling_utils.py", line 4559, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for BertForSequenceClassification:
        size mismatch for classifier.weight: copying a param with shape torch.Size([370, 768]) from checkpoint, the shape in current model is torch.Size([371, 768]).
        size mismatch for classifier.bias: copying a param with shape torch.Size([370]) from checkpoint, the shape in current model is torch.Size([371]).    
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

I am frustrated by this issue I have a fine tune and fit function where i am updating the classifer bias and weights and then saving the model. but still i am getting this error. I want that when the model path already exists i must use the exisiting model path with my updated weights which apparently must b saved to be loaded and my fitting of model must happen. But, When i reopen the model it is throwing the same error above of classifier weight mismatch. I am not a full time coder and am doing something for my companys work but this issue does not seem to go away. I do not want to use ignore mismatch size=true

the problem is when i am saving my fine tune model i think it is somehow not getting saved correctly. My task is to save the fine tuned model and then use the updated model to fit it on a training dataset.
Probable problematic area:

def fine_tune_and_fit(training_set): 
 # Load BERT tokenizer from the saved model directory
 tokenizer = BertTokenizer.from_pretrained(model_path)

 # Combine the training data columns into one list of strings
 training_set_combined = (
     training_set['Design Group'].fillna('').astype(str) + ', ' +
     training_set['Sub Group'].fillna('').astype(str) + ', ' +
     training_set['Model'].fillna('').astype(str) + ', ' +
     training_set['Variant'].fillna('').astype(str) + ', ' +
     training_set['Fuel Type'].fillna('').astype(str) + ', ' +
     training_set['Description'].fillna('').astype(str)
 )
 training_set_combined = training_set_combined.astype(str).tolist()

 # Tokenize the training data
 training_encodings = tokenizer(training_set_combined, padding=True, truncation=True, return_tensors='pt')

 # ---------- Update or Initialize the Label Encoder ----------
 label_encoder_path = os.path.join(model_path, "label_encoder.pkl")
 new_labels = set(training_set['Target'].unique())
 if os.path.exists(label_encoder_path):
     # Load the previous label encoder
     with open(label_encoder_path, "rb") as f:
         old_label_encoder = pickle.load(f)
     # Create the union of the old and new classes
     updated_classes = sorted(set(old_label_encoder.classes_).union(new_labels))
     label_encoder = LabelEncoder()
     label_encoder.fit(updated_classes)
     logging.info(f"Label encoder updated. Total classes: {len(label_encoder.classes_)}")
 else:
     label_encoder = LabelEncoder()
     label_encoder.fit(list(new_labels))
     logging.info(f"Label encoder initialized with {len(label_encoder.classes_)} classes.")

 # Transform the targets using the updated label encoder
 training_labels = label_encoder.transform(training_set['Target'])

 # ---------- Prepare DataLoader ----------
 dataset = TensorDataset(
     training_encodings['input_ids'],
     training_encodings['attention_mask'],
     torch.tensor(training_labels, dtype=torch.long)
 )
 train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
 
 # ---------- Load or Initialize the Model ----------
 if os.path.exists(model_path):
     # Load existing fine-tuned model
     model = BertForSequenceClassification.from_pretrained(model_path)
     logging.info("Model loaded for fine-tuning.")
     new_num_labels = len(label_encoder.classes_)
     if model.config.num_labels != new_num_labels:
         # Update both the configuration and internal attributes
         old_classes = list(old_label_encoder.classes_) if os.path.exists(label_encoder_path) else []
         new_classes = list(label_encoder.classes_)
         model.config.num_labels = new_num_labels
         model.num_labels = new_num_labels  # Update internal attribute used during forward()
         model.classifier = update_classifier(model, old_classes, new_classes)
         logging.info(f"Model classifier head updated to {new_num_labels} output neurons.")
         # Immediately save the updated model so that the checkpoint reflects these changes.
         model.save_pretrained(model_path, safe_serialization=False)
         tokenizer.save_pretrained(model_path)
 else:
     new_num_labels = len(label_encoder.classes_)
     model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=new_num_labels)
     logging.info("New model initialized for training from scratch.")
 
 optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)
 
 # ---------- Training Loop ----------
 model.train()
 num_epochs = 15  # Change as needed
 for epoch in range(num_epochs):
     total_loss = 0
     for batch_idx, (input_ids, attention_mask, labels) in enumerate(train_loader):
         optimizer.zero_grad()
         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
         loss = outputs.loss
         loss.backward()
         optimizer.step()
         total_loss += loss.item()
         logging.info(f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Loss {loss.item():.4f}")
     avg_loss = total_loss / len(train_loader)
     logging.info(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")
 
 # ---------- Save the Fine-Tuned Model, Tokenizer, and Label Encoder ----------

 with open(label_encoder_path, "wb") as f:
     pickle.dump(label_encoder, f)
 logging.info("Model, tokenizer, and label encoder saved.")
 model.save_pretrained(model_path, safe_serialization=False)
 tokenizer.save_pretrained(model_path)
     # Print saved model information
 print("=== Reloaded Model Information ===")
 print("Number of labels (model config):", model.config.num_labels)
 print("Classifier weight shape:", model.classifier.weight.shape)
 print("Classifier bias shape:", model.classifier.bias.shape)

 
 return model, tokenizer, label_encoder
def process_files():
    global model, tokenizer, label_encoder  # Ensure these are recognized as global if set elsewhere

    '''try:'''
        # Load BOM file
    if bom_file_path:
        sanitized_bom = load_data(bom_file_path)
        logging.info(f"Loaded BOM file: {bom_file_path}")
        
        
        # Design group extraction
        design_group_extraction(sanitized_bom)

        # Add user-selected columns to the BOM (Model, Variant, Fuel Type)
        sanitized_bom = add_user_selected_columns(
            sanitized_bom,
            model_selected=model_dropdown.get(),  # Assuming these are OptionMenus
            variant_selected=variant_dropdown.get(),
            fuel_type_selected=fuel_type_dropdown.get()
        )
        '''print(sanitized_bom.head())'''

        if not model_dropdown.get() or not variant_dropdown.get() or not fuel_type_dropdown.get():
            # Stop the loading animation
            stop_loading_animation()

            # Show error message in the loading label
            loading_label.config(text="Error: Please select values for Model, Variant, and Fuel Type.")
            logging.error("User has not selected values for all dropdowns.")
            return  # Exit early, as we can't proceed without these selections

        # Check if 'Design Group' is created successfully
        if 'Design Group' not in sanitized_bom.columns:
            loading_label.config(text="Design Group column could not be created. Please check the BOM file.")
            return
        
        combined_dataset = pd.DataFrame()  # Initialize an empty DataFrame to combine datasets

        # Process learning dataset if provided
        if learning_dataset_path:
            # Load learning dataset and fine-tune the model
            training_set = load_data(learning_dataset_path)
            training_set['Target'] = training_set[['Aggregate', 'VMT', 'Sub Aggregate', 'Generic Part Name']].apply(lambda x: ';'.join(x.astype(str)), axis=1)

            # Load existing training dataset if it exists
            existing_training_data_path = os.path.join(model_path, "combined_training_dataset.xlsx")
            if os.path.exists(existing_training_data_path):
                existing_training_data = pd.read_excel(existing_training_data_path)
            else:
                existing_training_data = pd.DataFrame()  # If file doesn't exist, start fresh

            # Combine existing training data with new training set

            combined_dataset = pd.concat([existing_training_data, training_set], ignore_index=True)
            print(combined_dataset.head())

            # Fine-tune the model with the combined dataset
            model, tokenizer, label_encoder = fine_tune_and_fit(training_set)
                    # Print saved model information
            print("=== Reloaded Model Information ===")
            print("Number of labels (model config):", model.config.num_labels)
            print("Classifier weight shape:", model.classifier.weight.shape)
            print("Classifier bias shape:", model.classifier.bias.shape)


            # Save the combined dataset to Excel for future reference
            save_combined_dataset(combined_dataset)

        else:
            # Load the pre-trained model if no new learning dataset is provided
            if os.path.exists(model_path):
                # First, load the label encoder from disk
                with open(os.path.join(model_path, "label_encoder.pkl"), 'rb') as f:
                    label_encoder = pickle.load(f)
                
                # Dynamically set the number of labels based on the loaded label encoder
                num_labels = len(label_encoder.classes_)
                
                # Load the pre-trained model with the dynamic number of output labels
                model = BertForSequenceClassification.from_pretrained(
                    model_path,
                    num_labels=371,
                )
                tokenizer = BertTokenizer.from_pretrained(model_path)
                
                # Reinitialize the classifier layer using the dynamic length
                import torch.nn as nn
                model.classifier = nn.Linear(model.config.hidden_size, num_labels)
                
                logging.info("Pre-trained model, tokenizer, and label encoder loaded.")

        # Fit the model to the sanitized BOM data
        output_df = fit_model_to_bom(model, tokenizer, label_encoder, sanitized_bom)

        # Replace all NaN values with blanks in the final output
        output_df.fillna('', inplace=True)

        # Create a separate file for entries with confidence scores below 0.75
        low_confidence_df = output_df[output_df['Confidence Score'] < 0.75]
        if not low_confidence_df.empty:
            low_confidence_file_path = bom_file_path.replace('.xlsx', '_low_confidence_output.xlsx')
            low_confidence_df.to_excel(low_confidence_file_path, index=False)
            logging.info(f"Low confidence output saved to {low_confidence_file_path}")

        # Instead of saving directly, open the file in Excel
        with NamedTemporaryFile(delete=False, suffix=".xlsx") as temp_file:
            output_file_path = temp_file.name
            output_df.to_excel(output_file_path, index=False)
            logging.info(f"Temporary file saved at {output_file_path}")
            
            # Open the temporary file with Excel
            try:
                if os.name == 'nt':  # For Windows
                    subprocess.Popen(['start', 'excel', output_file_path], shell=True)
                elif os.name == 'posix':  # For macOS or Linux
                    subprocess.Popen(['open', output_file_path])
                loading_label.config(text="File opened in Excel. Please save it to your preferred location.")
            except Exception as e:
                logging.error(f"Failed to open Excel: {e}")
                loading_label.config(text="Error opening file in Excel.")             
    else:
        loading_label.config(text="Please upload a BOM file.")
    
    window.after(1000, reset_gui) #reset after 1 sec

Full Code:

from pathlib import Path
from tkinter import Tk, Canvas, Button, PhotoImage, filedialog, Label, messagebox
import pandas as pd
import os
import logging
import pickle
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import threading
import subprocess
from tempfile import NamedTemporaryFile
from tkinter import ttk
from tkinter import Toplevel
import gc
import torch

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define paths for the assets (change as needed)
OUTPUT_PATH = Path(__file__).parent
ASSETS_PATH = OUTPUT_PATH / Path(r"D:\HMB820419\OneDrive - TATA MOTORS LTD\Documents\UV\Digitization Initiatives\1. BOM Classifier\assets\frame0")

# Pre-trained model path
model_path = r"D:\HMB820419\OneDrive - TATA MOTORS LTD\Documents\UV\Digitization Initiatives\1. BOM Classifier\Dataset\Test Model_with BIW_Classification and Sub Group"

# Initialize the main window
window = Tk()
window.geometry("850x650")  # Increased height for additional labels
window.configure(bg="#FFFFFF")
window.title("BOM Classifier")

# Canvas for layout
canvas = Canvas(window, bg="#FFFFFF", height=650, width=850, bd=0, highlightthickness=0, relief="ridge")
canvas.place(x=0, y=0)

image_image_1 = PhotoImage(file=ASSETS_PATH / "image_1.png")
canvas.create_image(425.0, 44.0, image=image_image_1)

canvas.create_text(425.0, 44.0, anchor="center", text="BOM CLASSIFIER", fill="#FFFFFF", font=("Cantarell Bold", 32 * -1))

# Button coordinates
button_1_x = 484.0
button_1_y = 231.0
button_2_x = 46.0
button_2_y = 231.0
button_3_x = 284.0
button_3_y = 417.0
button_width = 278.0

# Button text labels
canvas.create_text(button_1_x + button_width / 2, button_1_y - 20, anchor="center", text="Enter the BOM file", fill="#14035E", font=("Cantarell Bold", 16 * -1))
canvas.create_text(button_2_x + button_width / 2, button_2_y - 20, anchor="center", text="Learning Dataset", fill="#14035E", font=("Cantarell Bold", 16 * -1))

# Initialize file paths
bom_file_path = ""
learning_dataset_path = ""

# Button images
button_image_1 = PhotoImage(file=ASSETS_PATH / "button_1.png")
button_1 = Button(image=button_image_1, borderwidth=0, highlightthickness=0, command=lambda: upload_bom_file(), relief="flat", bg="#FFFFFF")
button_1.place(x=button_1_x, y=button_1_y, width=button_width, height=44.0)

button_image_2 = PhotoImage(file=ASSETS_PATH / "button_2.png")
button_2 = Button(image=button_image_2, borderwidth=0, highlightthickness=0, command=lambda: upload_learning_dataset(), relief="flat", bg="#FFFFFF")
button_2.place(x=button_2_x, y=button_2_y, width=button_width, height=44.0)

button_image_3 = PhotoImage(file=ASSETS_PATH / "button_3.png")
button_3 = Button(image=button_image_3, borderwidth=0, highlightthickness=0, command=lambda: execute_files(), relief="flat", bg="#FFFFFF")
button_3.place(x=button_3_x, y=button_3_y, width=283.0, height=47.0)

# Button for comparing three BOM files
compare_button_x = 484.0
compare_button_y = 320.0  # Position below the existing buttons
compare_button_width = 278.0

canvas.create_text(compare_button_x + compare_button_width / 2, compare_button_y - 20, anchor="center", text="Compare BOMs", fill="#14035E", font=("Cantarell Bold", 16 * -1))



# Assuming you have a list of predefined options for the Model, Variant, and Fuel Type
model_options = ["Tiago", "Tigor", "Punch","Altroz","Nexon","Curvv","Harrier","Safari","Sierra"]
variant_options = ["Base", "Mid", "High", "High+"]
fuel_type_options = ["Petrol", "Diesel"]

# Dropdowns for selecting Model, Variant, and Fuel Type
model_dropdown = ttk.Combobox(window, values=model_options)
variant_dropdown = ttk.Combobox(window, values=variant_options)
fuel_type_dropdown = ttk.Combobox(window, values=fuel_type_options)

model_dropdown.set("Select Model")      # Default message for Model dropdown
variant_dropdown.set("Select Variant")  # Default message for Variant dropdown
fuel_type_dropdown.set("Select Fuel Type")  # Default message for Fuel Type dropdown

model_dropdown.place(x=50, y=350)  # Position as needed
variant_dropdown.place(x=50, y=400)  # Position as needed
fuel_type_dropdown.place(x=50, y=450)  # Position as needed

# Labels for displaying uploaded file names
bom_label = Label(window, text="", bg="#FFFFFF", font=("Cantarell", 12))
bom_label.place(x=button_1_x, y=button_1_y + 50)

learning_label = Label(window, text="", bg="#FFFFFF", font=("Cantarell", 12))
learning_label.place(x=button_2_x, y=button_2_y + 50)

# Label for loading indicator
loading_label = Label(window, text="", bg="#FFFFFF", font=("Cantarell", 12))
loading_label.place(x=button_3_x, y=button_3_y + 50)

# Label for completion message
completion_label = Label(window, text="", bg="#FFFFFF", font=("Cantarell Bold", 14))
completion_label.place(x=button_3_x, y=button_3_y + 80)

comparison_label = Label(window, text="", bg="#FFFFFF", font=("Cantarell", 12))
comparison_label.place(x=compare_button_x, y=compare_button_y + 50)  # Place it below the button

window.resizable(False, False)

# Initialize the paths for the three BOM files
bom_file_paths_for_comparison = []

def upload_bom_files_for_comparison():
    global bom_file_paths_for_comparison
    bom_file_paths_for_comparison = filedialog.askopenfilenames(title="Select the BOM files", filetypes=[("Excel files", "*.xlsx;*.xls")])
    
    # Ensure the user selects at least 2 and at most 3 files
    if 2 <= len(bom_file_paths_for_comparison) <= 3:
        # Update the label with the file names
        comparison_label.config(text=f"BOM files for comparison: {', '.join([os.path.basename(file) for file in bom_file_paths_for_comparison])}")
    else:
        messagebox.showerror("Error", "Please select 2 or 3 BOM files.")


def upload_bom_file():
    global bom_file_path
    bom_file_path = filedialog.askopenfilename(title="Select the BOM Excel file", filetypes=[("Excel files", "*.xlsx;*.xls")])
    if bom_file_path:
        bom_label.config(text=f"BOM file: {os.path.basename(bom_file_path)}")

def upload_learning_dataset():
    global learning_dataset_path
    learning_dataset_path = filedialog.askopenfilename(title="Select the Learning Dataset Excel file", filetypes=[("Excel files", "*.xlsx;*.xls")])
    if learning_dataset_path:
        learning_label.config(text=f"Learning dataset: {os.path.basename(learning_dataset_path)}")

def load_data(file_path):
    """Load data from an Excel file."""
    try:
        return pd.read_excel(file_path)
    except Exception as e:
        logging.error(f"Failed to load {file_path}: {e}")
        return pd.DataFrame()

import gc
import os
import torch
import torch.nn as nn

def update_classifier(model, old_classes, new_classes):
    """
    Create a new classifier head that preserves the weights for classes in old_classes.
    
    Args:
      model: The BertForSequenceClassification model.
      old_classes: List of class names from the previous label encoder.
      new_classes: List of updated class names from the new label encoder.
      
    Returns:
      new_classifier: A new nn.Linear layer with updated weights.
    """
    # Get the old classifier layer and its weight shape.
    old_classifier = model.classifier  # This is an instance of nn.Linear
    hidden_size = old_classifier.in_features
    new_num_labels = len(new_classes)
    
    # Create a new classifier layer with the new number of classes.
    new_classifier = nn.Linear(hidden_size, new_num_labels)
    
    # Get the old classifier's weights and bias.
    with torch.no_grad():
        old_weights = old_classifier.weight.data.clone()
        old_bias = old_classifier.bias.data.clone()
    
        # For each new class, if it existed before, copy its weights.
        for new_idx, label in enumerate(new_classes):
            if label in old_classes:
                old_idx = old_classes.index(label)
                new_classifier.weight.data[new_idx] = old_weights[old_idx]
                new_classifier.bias.data[new_idx] = old_bias[old_idx]
            # Otherwise, the weight remains with its default initialization.
    
    return new_classifier

def fine_tune_and_fit(training_set): 
   
    # Load BERT tokenizer from the saved model directory
    tokenizer = BertTokenizer.from_pretrained(model_path)

    # Combine the training data columns into one list of strings
    training_set_combined = (
        training_set['Design Group'].fillna('').astype(str) + ', ' +
        training_set['Sub Group'].fillna('').astype(str) + ', ' +
        training_set['Model'].fillna('').astype(str) + ', ' +
        training_set['Variant'].fillna('').astype(str) + ', ' +
        training_set['Fuel Type'].fillna('').astype(str) + ', ' +
        training_set['Description'].fillna('').astype(str)
    )
    training_set_combined = training_set_combined.astype(str).tolist()

    # Tokenize the training data
    training_encodings = tokenizer(training_set_combined, padding=True, truncation=True, return_tensors='pt')

    # ---------- Update or Initialize the Label Encoder ----------
    label_encoder_path = os.path.join(model_path, "label_encoder.pkl")
    new_labels = set(training_set['Target'].unique())
    if os.path.exists(label_encoder_path):
        # Load the previous label encoder
        with open(label_encoder_path, "rb") as f:
            old_label_encoder = pickle.load(f)
        # Create the union of the old and new classes
        updated_classes = sorted(set(old_label_encoder.classes_).union(new_labels))
        label_encoder = LabelEncoder()
        label_encoder.fit(updated_classes)
        logging.info(f"Label encoder updated. Total classes: {len(label_encoder.classes_)}")
    else:
        label_encoder = LabelEncoder()
        label_encoder.fit(list(new_labels))
        logging.info(f"Label encoder initialized with {len(label_encoder.classes_)} classes.")

    # Transform the targets using the updated label encoder
    training_labels = label_encoder.transform(training_set['Target'])

    # ---------- Prepare DataLoader ----------
    dataset = TensorDataset(
        training_encodings['input_ids'],
        training_encodings['attention_mask'],
        torch.tensor(training_labels, dtype=torch.long)
    )
    train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
    
    # ---------- Load or Initialize the Model ----------
    if os.path.exists(model_path):
        # Load existing fine-tuned model
        model = BertForSequenceClassification.from_pretrained(model_path)
        logging.info("Model loaded for fine-tuning.")
        new_num_labels = len(label_encoder.classes_)
        if model.config.num_labels != new_num_labels:
            # Update both the configuration and internal attributes
            old_classes = list(old_label_encoder.classes_) if os.path.exists(label_encoder_path) else []
            new_classes = list(label_encoder.classes_)
            model.config.num_labels = new_num_labels
            model.num_labels = new_num_labels  # Update internal attribute used during forward()
            model.classifier = update_classifier(model, old_classes, new_classes)
            logging.info(f"Model classifier head updated to {new_num_labels} output neurons.")
            # Immediately save the updated model so that the checkpoint reflects these changes.
            model.save_pretrained(model_path, safe_serialization=False)
            tokenizer.save_pretrained(model_path)
    else:
        new_num_labels = len(label_encoder.classes_)
        model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=new_num_labels)
        logging.info("New model initialized for training from scratch.")
    
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)
    
    # ---------- Training Loop ----------
    model.train()
    num_epochs = 15  # Change as needed
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, (input_ids, attention_mask, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            logging.info(f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Loss {loss.item():.4f}")
        avg_loss = total_loss / len(train_loader)
        logging.info(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")
    
    # ---------- Save the Fine-Tuned Model, Tokenizer, and Label Encoder ----------

    with open(label_encoder_path, "wb") as f:
        pickle.dump(label_encoder, f)
    logging.info("Model, tokenizer, and label encoder saved.")
    model.save_pretrained(model_path, safe_serialization=False)
    tokenizer.save_pretrained(model_path)

        # Print saved model information
    print("=== Reloaded Model Information ===")
    print("Number of labels (model config):", model.config.num_labels)
    print("Classifier weight shape:", model.classifier.weight.shape)
    print("Classifier bias shape:", model.classifier.bias.shape)

    
    return model, tokenizer, label_encoder


def fit_model_to_bom(model, tokenizer, label_encoder, sanitized_bom):
    # Tokenize the sanitized BOM for predictions
    sanitized_bom_combined = sanitized_bom['Design Group'].fillna('').astype(str)+', ' +sanitized_bom['Sub Group'].fillna('').astype(str) + ', ' + sanitized_bom['Model'].fillna('').astype(str)
    sanitized_bom_combined += ', ' + sanitized_bom['Variant'].fillna('').astype(str)
    sanitized_bom_combined += ', ' + sanitized_bom['Fuel Type'].fillna('').astype(str)
    sanitized_bom_combined += ', ' + sanitized_bom['Description'].fillna('').astype(str)
    sanitized_bom_encodings = tokenizer(sanitized_bom_combined.astype(str).tolist(), padding=True, truncation=True, return_tensors='pt')

    # Prediction
    model.eval()
    with torch.no_grad():
        predictions = model(**sanitized_bom_encodings)
        predicted_classes = torch.argmax(predictions.logits, dim=-1)
        confidence_scores = torch.softmax(predictions.logits, dim=-1).max(dim=-1).values

    # Create output DataFrame starting with the sanitized BOM columns
    output_df = sanitized_bom.copy()  # Start with the original sanitized BOM
    output_df['Predicted Target'] = label_encoder.inverse_transform(predicted_classes.numpy())
    output_df['Confidence Score'] = confidence_scores.numpy()
    output_df['Predicted Target'] = output_df['Predicted Target'].fillna('').astype(str)
    print(output_df['Predicted Target'])

    # Split predicted target into separate columns
    predicted_columns = output_df['Predicted Target'].astype(str).str.split(';', expand=True)
    predicted_columns.columns = ['Aggregate', 'VMT', 'Sub Aggregate', 'Generic Part Name']
    output_df=output_df.drop(['Predicted Target'], axis=1)

    # Combine the predicted columns with the output DataFrame
    output_df = pd.concat([output_df, predicted_columns], axis=1)

    # Fill BOM part numbers starting with 'G' with NA values
    output_df.loc[output_df['Part No'].astype(str).str.startswith('G', na=False), 'Description'] = np.nan

    #Drop NA values
    output_df.dropna(subset=['Description'], inplace=True)

    return output_df

def extract_sub_group(part_no):
    part_no = str(part_no)  # Ensure it's a string
    if len(part_no) >= 10 and part_no[8:10].isalpha():  # 9th & 10th digit are text
        return part_no[8:10]  # Use 9th & 10th digits
    elif len(part_no) >= 8:
        return part_no[6:8]  # Use 7th & 8th digits
    return np.nan  # Return NaN if the length is insufficient

def design_group_extraction(sanitized_bom):
    sanitized_bom['Design Group'] = sanitized_bom['Part No'].apply(lambda x: str(x)[4:6] if len(str(x)) >= 6 else np.nan)
    sanitized_bom['Sub Group'] = sanitized_bom['Part No'].apply(extract_sub_group)  # Extract Sub Group
    
def save_combined_dataset(combined_dataset):
    combined_dataset_path = os.path.join(model_path, "combined_training_dataset.xlsx")
    combined_dataset.to_excel(combined_dataset_path, index=False)
    logging.info("Combined dataset saved.")

import itertools

# Initialize the loading animation state
loading_text_cycle = itertools.cycle([".", "..", "..."])
is_loading = False  # Flag to control the animation

def stop_loading_animation():
    global is_loading
    is_loading = False  # Stop the loading animation
    loading_label.config(text="Processing complete!")  # Final message

def update_loading_animation():
    if is_loading:
        # Update the loading label with the next state
        loading_label.config(text="Processing" + next(loading_text_cycle))
        # Schedule the next update after 500 milliseconds (adjust for speed)
        window.after(500, update_loading_animation)

def start_loading_animation():
    global is_loading
    is_loading = True  # Set the loading flag to True
    update_loading_animation()  # Start the animation loop

def add_user_selected_columns(sanitized_bom, model_selected, variant_selected, fuel_type_selected):
    sanitized_bom['Model'] = model_selected
    sanitized_bom['Variant'] = variant_selected
    sanitized_bom['Fuel Type'] = fuel_type_selected
    return sanitized_bom

def execute_files():
    global model, tokenizer, label_encoder  # Ensure these are recognized as global if set elsewhere

    # Clear previous labels
    completion_label.config(text="")

    # Start the loading animation
    start_loading_animation()

    if bom_file_paths_for_comparison:
        # Compare the BOMs in a separate thread
        comparison_thread = threading.Thread(target=compare_boms)
        comparison_thread.start()
    else:
    # Run the main task in a separate thread to avoid freezing the UI
        task_thread = threading.Thread(target=process_files)
        task_thread.start()

def reset_gui():
    """Reset all the necessary components of the GUI to start fresh."""
    # Clear selected dropdowns
    model_dropdown.set("Select Model")
    variant_dropdown.set("Select Variant")
    fuel_type_dropdown.set("Select Fuel Type")

    # Reset labels
    bom_label.config(text="")
    learning_label.config(text="")
    loading_label.config(text="")
    completion_label.config(text="")

    # Reset paths
    global bom_file_path, learning_dataset_path
    bom_file_path = ""
    learning_dataset_path = ""

    # Re-enable the upload buttons
    button_1.config(state="normal")
    button_2.config(state="normal")
    button_3.config(state="normal")

    logging.info("GUI reset successfully!")



def create_comparator(bom_outputs, Aggregation_Type):
    """Aggregate BOMs by SUB Aggregate level and Generic Part Name level, filtered by Aggregation Type."""

    # Define base aggregation rules dynamically
    base_aggregation_rules = {
        "Rolling Chassis": {
            "BIW": ["BIW-Underbody"],
            "Ft Suspension": None,
            "Rr Suspension": None,
            "Wheels": None,
            "Exhaust Systems": None,
            "Seats": None,
            "Seat Belts": None,
            "Brakes": None,
            "Controls": None,
            "Engine Mounts": None,
            "Fuel Systems": None,
            "Steering": None,
            "Charging Sys": None
        },
        "Drive Away Chassis Extra": {
            "Cooling System": None,
            "Engine_Diesel": None,
            "Engine_Petrol": None,
            "TA": None
        }
    }

    # Dynamically derive Drive Away Chassis and Top Hat rules
    def combine_rules(base, extra):
        combined = base.copy() if base else {}
        for sub_aggregate, parts in (extra or {}).items():
            if sub_aggregate not in combined:
                combined[sub_aggregate] = parts
            elif combined[sub_aggregate] is None or parts is None:
                combined[sub_aggregate] = None
            else:
                combined[sub_aggregate] = list(set(combined[sub_aggregate]) | set(parts))
        return combined

    rolling_chassis_rules = base_aggregation_rules["Rolling Chassis"]
    drive_away_chassis_rules = combine_rules(
        rolling_chassis_rules, base_aggregation_rules["Drive Away Chassis Extra"]
    )

    # Final aggregation rules for filtering
    aggregation_rules = {
        "Rolling Chassis": rolling_chassis_rules,
        "Drive Away Chassis": drive_away_chassis_rules
    }

    # Initialize an empty DataFrame for the comparator
    final_df = pd.DataFrame()

    # Combine and group all BOM outputs
    combined_df = pd.concat(
        [
            output_df.groupby(['Sub Aggregate', 'Generic Part Name'])
            .agg({'Price': 'sum'})
            .reset_index()
            .assign(BOM_Source=f'BOM {idx + 1}')
            for idx, output_df in enumerate(bom_outputs)
        ],
        axis=0,
    )

    # Extract all unique sub-aggregates for Full Body aggregation (dynamically)
    full_body_sub_aggregates = combined_df['Sub Aggregate'].unique()

    # Pivot the combined data to have BOM Sources as columns
    pivoted_df = combined_df.pivot_table(
        index=['Sub Aggregate', 'Generic Part Name'],
        columns='BOM_Source',
        values='Price',
        aggfunc='sum',
    ).reset_index()

    # Filter data based on Aggregation Type
    if Aggregation_Type == "Full Vehicle":
        # For Full Body, we include everything
        pivoted_df = pivoted_df[pivoted_df['Sub Aggregate'].isin(full_body_sub_aggregates)]
    
    elif Aggregation_Type == "Drive Away Chassis":
        # For Drive Away Chassis, we include parts based on the defined rules
        rules = aggregation_rules["Drive Away Chassis"]
        filtered_rows = [
            pivoted_df[
                (pivoted_df['Sub Aggregate'] == sub_aggregate) & 
                (
                    rules[sub_aggregate] is None or 
                    pivoted_df['Generic Part Name'].isin(rules[sub_aggregate])
                )
            ]
            for sub_aggregate in rules if sub_aggregate in pivoted_df['Sub Aggregate'].unique()
        ]
        pivoted_df = pd.concat(filtered_rows, ignore_index=True)

    elif Aggregation_Type == "Top Hat":
        # For Top Hat, compare Full Body vs. Drive Away Chassis at Generic Part Name level
        full_body_df = pivoted_df[pivoted_df['Sub Aggregate'].isin(full_body_sub_aggregates)]
        drive_away_chassis_df = pivoted_df[pivoted_df['Sub Aggregate'].isin(drive_away_chassis_rules.keys())]

        # For each sub-aggregate, we remove Generic Part Names that are in both Full Body and Drive Away Chassis
        top_hat_df = pd.DataFrame()
        for sub_aggregate in full_body_sub_aggregates:
            # Get the rows for this sub-aggregate from both Full Body and Drive Away Chassis
            full_body_parts = full_body_df[full_body_df['Sub Aggregate'] == sub_aggregate]
            drive_away_chassis_parts = drive_away_chassis_df[drive_away_chassis_df['Sub Aggregate'] == sub_aggregate]
            
            # Identify the Generic Part Names that are in Full Body but not in Drive Away Chassis
            unique_parts = full_body_parts[~full_body_parts['Generic Part Name'].isin(drive_away_chassis_parts['Generic Part Name'])]
            
            # Append the filtered parts for this sub-aggregate
            top_hat_df = pd.concat([top_hat_df, unique_parts], ignore_index=True)

        # Now, we need to add the BIW parts from Full Body if they are missing in the Drive Away Chassis
        biw_parts_full_body = full_body_df[
            full_body_df['Sub Aggregate'] == 'BIW'
        ][full_body_df['Generic Part Name'].isin(['BIW-Upperbody', 'BIW-Closures', 'BIW-Other'])]

        # Append these BIW parts to the Top Hat result
        top_hat_df = pd.concat([top_hat_df, biw_parts_full_body], ignore_index=True)

        # Set the Top Hat output as the final DataFrame
        pivoted_df = top_hat_df
    
    elif Aggregation_Type == "Rolling Chassis":
        # For Rolling Chassis, include parts based on the defined rules
        rules = aggregation_rules["Rolling Chassis"]
        filtered_rows = [
            pivoted_df[
                (pivoted_df['Sub Aggregate'] == sub_aggregate) & 
                (
                    rules[sub_aggregate] is None or 
                    pivoted_df['Generic Part Name'].isin(rules[sub_aggregate])
                )
            ]
            for sub_aggregate in rules if sub_aggregate in pivoted_df['Sub Aggregate'].unique()
        ]
        pivoted_df = pd.concat(filtered_rows, ignore_index=True)


    # Add subtotals at the Sub Aggregate level
    subtotal_rows = []
    for sub_aggregate, group in pivoted_df.groupby('Sub Aggregate'):
        subtotal = group.iloc[:, 2:].sum(axis=0)  # Sum only price columns
        subtotal_row = {'Sub Aggregate': sub_aggregate, 'Generic Part Name': 'Subtotal'}
        subtotal_row.update(subtotal.to_dict())
        subtotal_rows.append(subtotal_row)

    subtotal_df = pd.DataFrame(subtotal_rows)
    final_df = pd.concat([pivoted_df, subtotal_df], ignore_index=True)

    # Sort and clean up
    final_df['Sort Order'] = final_df['Generic Part Name'].apply(lambda x: 1 if x == 'Subtotal' else 0)
    final_df = final_df.sort_values(by=['Sub Aggregate', 'Sort Order']).drop(columns=['Sort Order'])

    # Add the grand total row
    grand_total = final_df.loc[final_df['Generic Part Name'] != 'Subtotal', final_df.columns[2:]].sum(axis=0)
    grand_total_row = {'Sub Aggregate': 'Grand Total', 'Generic Part Name': ''}
    grand_total_row.update(grand_total.to_dict())
    final_df = pd.concat([final_df, pd.DataFrame([grand_total_row])], ignore_index=True)

    return final_df

def display_comparator(final_comparator_df):
    """Display the comparator DataFrame with Sub Aggregate and Generic Part Name aggregation."""
    # Save the comparator DataFrame to an Excel file
    with NamedTemporaryFile(delete=False, suffix=".xlsx") as temp_file:
        output_file_path = temp_file.name
        
        with pd.ExcelWriter(output_file_path, engine='xlsxwriter') as writer:
            # Write the final comparator DataFrame to a single sheet
            final_comparator_df.to_excel(writer, sheet_name='Comparator Summary', index=False)
            
            # Get the workbook and worksheet objects
            workbook = writer.book
            worksheet = writer.sheets['Comparator Summary']
            
            # Format the columns for better readability
            format_bold = workbook.add_format({'bold': True})
            format_subtotal = workbook.add_format({'italic': True, 'bg_color': '#F2F2F2', 'bold': True})
            format_grand_total = workbook.add_format({'bold': True, 'bg_color': '#D9EAD3'})
            
            # Apply formatting to rows based on their content
            for row_num, row_data in final_comparator_df.iterrows():
                if row_data['Generic Part Name'] == 'Subtotal':
                    worksheet.set_row(row_num + 1, None, format_subtotal)
                elif row_data['Sub Aggregate'] == 'Grand Total':
                    worksheet.set_row(row_num + 1, None, format_grand_total)
                else:
                    worksheet.set_row(row_num + 1)  # Default formatting
            
            # Auto-adjust column widths
            for col_num, col_data in enumerate(final_comparator_df.columns):
                max_width = max(final_comparator_df[col_data].astype(str).map(len).max(), len(str(col_data))) + 2
                worksheet.set_column(col_num, col_num, max_width)

        logging.info(f"Comparator file saved at {output_file_path}")
        
        # Try to open the Excel file automatically
        try:
            if os.name == 'nt':  # For Windows
                subprocess.Popen(['start', 'excel', output_file_path], shell=True)
            elif os.name == 'posix':  # For macOS or Linux
                subprocess.Popen(['open', output_file_path])
            loading_label.config(text="Comparator file opened in Excel.")
        except Exception as e:
            logging.error(f"Failed to open Excel: {e}")
            loading_label.config(text="Error opening comparator file in Excel.")


def process_files():
    global model, tokenizer, label_encoder  # Ensure these are recognized as global if set elsewhere

    '''try:'''
        # Load BOM file
    if bom_file_path:
        sanitized_bom = load_data(bom_file_path)
        logging.info(f"Loaded BOM file: {bom_file_path}")
        
        
        # Design group extraction
        design_group_extraction(sanitized_bom)

        # Add user-selected columns to the BOM (Model, Variant, Fuel Type)
        sanitized_bom = add_user_selected_columns(
            sanitized_bom,
            model_selected=model_dropdown.get(),  # Assuming these are OptionMenus
            variant_selected=variant_dropdown.get(),
            fuel_type_selected=fuel_type_dropdown.get()
        )
        '''print(sanitized_bom.head())'''

        if not model_dropdown.get() or not variant_dropdown.get() or not fuel_type_dropdown.get():
            # Stop the loading animation
            stop_loading_animation()

            # Show error message in the loading label
            loading_label.config(text="Error: Please select values for Model, Variant, and Fuel Type.")
            logging.error("User has not selected values for all dropdowns.")
            return  # Exit early, as we can't proceed without these selections

        # Check if 'Design Group' is created successfully
        if 'Design Group' not in sanitized_bom.columns:
            loading_label.config(text="Design Group column could not be created. Please check the BOM file.")
            return
        
        combined_dataset = pd.DataFrame()  # Initialize an empty DataFrame to combine datasets

        # Process learning dataset if provided
        if learning_dataset_path:
            # Load learning dataset and fine-tune the model
            training_set = load_data(learning_dataset_path)
            training_set['Target'] = training_set[['Aggregate', 'VMT', 'Sub Aggregate', 'Generic Part Name']].apply(lambda x: ';'.join(x.astype(str)), axis=1)

            # Load existing training dataset if it exists
            existing_training_data_path = os.path.join(model_path, "combined_training_dataset.xlsx")
            if os.path.exists(existing_training_data_path):
                existing_training_data = pd.read_excel(existing_training_data_path)
            else:
                existing_training_data = pd.DataFrame()  # If file doesn't exist, start fresh

            # Combine existing training data with new training set

            combined_dataset = pd.concat([existing_training_data, training_set], ignore_index=True)
            print(combined_dataset.head())

            # Fine-tune the model with the combined dataset
            model, tokenizer, label_encoder = fine_tune_and_fit(training_set)
                    # Print saved model information
            print("=== Reloaded Model Information ===")
            print("Number of labels (model config):", model.config.num_labels)
            print("Classifier weight shape:", model.classifier.weight.shape)
            print("Classifier bias shape:", model.classifier.bias.shape)


            # Save the combined dataset to Excel for future reference
            save_combined_dataset(combined_dataset)

        else:
            # Load the pre-trained model if no new learning dataset is provided
            if os.path.exists(model_path):
                # First, load the label encoder from disk
                with open(os.path.join(model_path, "label_encoder.pkl"), 'rb') as f:
                    label_encoder = pickle.load(f)
                
                # Dynamically set the number of labels based on the loaded label encoder
                num_labels = len(label_encoder.classes_)
                
                # Load the pre-trained model with the dynamic number of output labels
                model = BertForSequenceClassification.from_pretrained(
                    model_path,
                    num_labels=371,
                )
                tokenizer = BertTokenizer.from_pretrained(model_path)
                
                # Reinitialize the classifier layer using the dynamic length
                import torch.nn as nn
                model.classifier = nn.Linear(model.config.hidden_size, num_labels)
                
                logging.info("Pre-trained model, tokenizer, and label encoder loaded.")

        # Fit the model to the sanitized BOM data
        output_df = fit_model_to_bom(model, tokenizer, label_encoder, sanitized_bom)

        # Replace all NaN values with blanks in the final output
        output_df.fillna('', inplace=True)

        # Create a separate file for entries with confidence scores below 0.75
        low_confidence_df = output_df[output_df['Confidence Score'] < 0.75]
        if not low_confidence_df.empty:
            low_confidence_file_path = bom_file_path.replace('.xlsx', '_low_confidence_output.xlsx')
            low_confidence_df.to_excel(low_confidence_file_path, index=False)
            logging.info(f"Low confidence output saved to {low_confidence_file_path}")

        # Instead of saving directly, open the file in Excel
        with NamedTemporaryFile(delete=False, suffix=".xlsx") as temp_file:
            output_file_path = temp_file.name
            output_df.to_excel(output_file_path, index=False)
            logging.info(f"Temporary file saved at {output_file_path}")
            
            # Open the temporary file with Excel
            try:
                if os.name == 'nt':  # For Windows
                    subprocess.Popen(['start', 'excel', output_file_path], shell=True)
                elif os.name == 'posix':  # For macOS or Linux
                    subprocess.Popen(['open', output_file_path])
                loading_label.config(text="File opened in Excel. Please save it to your preferred location.")
            except Exception as e:
                logging.error(f"Failed to open Excel: {e}")
                loading_label.config(text="Error opening file in Excel.")             
    else:
        loading_label.config(text="Please upload a BOM file.")
    
    window.after(1000, reset_gui) #reset after 1 sec
    '''except Exception as e:
        logging.error(f"Error during file processing: {e}")
        loading_label.config(text=f"Error: {e}")
        stop_loading_animation()
        reset_gui()  # Reset GUI in case of an error as well'''

def open_comparison_window():
    # Create a new top-level window
    comparison_window = Toplevel(window)
    comparison_window.title("BOM Comparison")
    comparison_window.geometry("600x650")  # Adjust size to accommodate the new dropdown
    # Bring the window to the front
    comparison_window.lift()
    comparison_window.focus_force()

    # Add a label and dropdown for aggregation type selection
    Label(comparison_window, text="Select Aggregation Type").place(x=50, y=20)
    aggregation_dropdown = ttk.Combobox(comparison_window, values=["Full Vehicle", "Top Hat", "Drive Away Chassis", "Rolling Chassis"])
    aggregation_dropdown.set("Select Aggregation Type")  # Set default value
    aggregation_dropdown.place(x=200, y=20)

    # BOM file selection labels
    Label(comparison_window, text="Upload BOM 1").place(x=50, y=70)
    Label(comparison_window, text="Upload BOM 2").place(x=50, y=170)
    Label(comparison_window, text="Upload BOM 3").place(x=50, y=270)

    # Create upload buttons for each BOM
    upload_bom1_button = Button(comparison_window, text="Upload BOM 1", command=lambda: upload_bom(0, comparison_window))
    upload_bom2_button = Button(comparison_window, text="Upload BOM 2", command=lambda: upload_bom(1, comparison_window))
    upload_bom3_button = Button(comparison_window, text="Upload BOM 3", command=lambda: upload_bom(2, comparison_window))

    upload_bom1_button.place(x=200, y=70)
    upload_bom2_button.place(x=200, y=170)
    upload_bom3_button.place(x=200, y=270)

    # Dropdowns for Model, Variant, Fuel Type for each BOM
    dropdown_y_positions = [370, 420, 470]
    labels = ["Select Model", "Select Variant", "Select Fuel Type"]

    # Create dropdowns for each BOM and store in bom_dropdowns list
    bom_dropdowns = []
    for i in range(3):  # Loop for 3 BOMs
        bom_dropdowns.append({
            'model': ttk.Combobox(comparison_window, values=model_options),
            'variant': ttk.Combobox(comparison_window, values=variant_options),
            'fuel_type': ttk.Combobox(comparison_window, values=fuel_type_options)
        })
        # Set default value for dropdowns
        bom_dropdowns[i]['model'].set(labels[0])
        bom_dropdowns[i]['variant'].set(labels[1])
        bom_dropdowns[i]['fuel_type'].set(labels[2])

        # Place dropdowns on the window
        bom_dropdowns[i]['model'].place(x=50, y=dropdown_y_positions[i])
        bom_dropdowns[i]['variant'].place(x=200, y=dropdown_y_positions[i])
        bom_dropdowns[i]['fuel_type'].place(x=350, y=dropdown_y_positions[i])

    # Add a button to execute the comparison after all BOMs are uploaded
    execute_button = Button(
        comparison_window,
        text="Compare BOMs",
        command=lambda: compare_boms(bom_dropdowns, aggregation_dropdown.get())  # Pass aggregation type to comparison
    )
    execute_button.place(x=250, y=550)


# Function to upload BOM files
def upload_bom(bom_index, comparison_window):
# Ensure bom_file_paths_for_comparison is initialized
    global bom_file_paths_for_comparison
    if len(bom_file_paths_for_comparison) < 3:
        bom_file_paths_for_comparison = [None] * 3  # Initialize with a size of 3

    file_path = filedialog.askopenfilename(title=f"Select BOM {bom_index + 1}", filetypes=[("Excel files", "*.xlsx;*.xls")])
    if file_path:
        bom_file_paths_for_comparison[bom_index] = file_path
        print(f"BOM {bom_index + 1} uploaded: {file_path}")  # Debugging statement

                # After file selection, ensure the comparison window remains in focus and above
        comparison_window.lift()  # Bring the comparison window to the front again
        comparison_window.focus_force()  # Ensure it stays focused above the main window
        
def compare_boms(bom_dropdowns,Agg_Type):
    """Compare the 3 uploaded BOM files and create a comparator."""
    global model, tokenizer, label_encoder

    try:
        # Ensure 3 BOM files have been selected
        if len(bom_file_paths_for_comparison) not in [2, 3]:
            messagebox.showerror("Error", "Please upload 2 or 3 BOM files.")
            return

        selected_data = []
        for i in range(len(bom_file_paths_for_comparison)):
            # Get selected values for Model, Variant, and Fuel Type for each BOM
            model_selected = bom_dropdowns[i]['model'].get()
            variant_selected = bom_dropdowns[i]['variant'].get()
            fuel_type_selected = bom_dropdowns[i]['fuel_type'].get()

            # Add this selection data to a list for further processing
            selected_data.append({
                'model': model_selected,
                'variant': variant_selected,
                'fuel_type': fuel_type_selected
            })

        # List to store the outputs for all BOMs
        bom_outputs = []

        # Process each BOM file
        for idx, bom_file_path in enumerate([path for path in bom_file_paths_for_comparison if path is not None]):
            sanitized_bom = load_data(bom_file_path)  # Function to load BOM file (load_data is assumed)
            logging.info(f"Loaded BOM file: {bom_file_path}")

            #Adding Design group
            design_group_extraction(sanitized_bom)

            # Add the model, variant, and fuel type columns to the BOM data
            sanitized_bom['BOM Source'] = f'BOM {idx + 1}'  # Adds a new column like 'BOM 1', 'BOM 2', etc.
            sanitized_bom = add_user_selected_columns(
                sanitized_bom,
                model_selected=selected_data[idx]['model'],
                variant_selected=selected_data[idx]['variant'],
                fuel_type_selected=selected_data[idx]['fuel_type']
            )
            logging.info(f"Processing BOM {idx + 1} with model: {selected_data[idx]['model']}, variant: {selected_data[idx]['variant']}, fuel type: {selected_data[idx]['fuel_type']}")
            if os.path.exists(model_path):
                model = BertForSequenceClassification.from_pretrained(model_path)
                tokenizer = BertTokenizer.from_pretrained(model_path)

                    # Load label encoder
                with open(os.path.join(model_path, "label_encoder.pkl"), 'rb') as f:
                    label_encoder = pickle.load(f)
                logging.info("Pre-trained model and tokenizer loaded.")
            else:
                loading_label.config(text="No model found. Please upload a learning dataset first to train the model.")
                return
            #printing BOM
            print(sanitized_bom.head())

            # Fit the model to the sanitized BOM data
            output_df = fit_model_to_bom(model, tokenizer, label_encoder, sanitized_bom)
            bom_outputs.append(output_df)

        # Combine BOM outputs and aggregate prices on SUB Aggregate level
        final_comparator_df = create_comparator(bom_outputs,Agg_Type)

        # Display or save the comparator
        display_comparator(final_comparator_df)

    except Exception as e:
        logging.error(f"Error during file processing: {e}")
        messagebox.showerror("Error", f"An error occurred during BOM processing: {e}")
        reset_gui()  # Reset GUI in case of an error

compare_button_image = PhotoImage(file=ASSETS_PATH / "button_4.png")  # Use a new image for this button
compare_button = Button(image=compare_button_image, borderwidth=0, highlightthickness=0, command=open_comparison_window, relief="flat", bg="#FFFFFF")
compare_button.place(x=compare_button_x, y=compare_button_y, width=compare_button_width, height=44.0)

# Start the GUI
window.mainloop()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants