Skip to content

falcalde21/Quick-Stat-TFM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Quick, Stat!: Un análisis completo de la base de datos Quick, Draw! usando redes recurrentes

Este repositorio contiene el modelo y los pesos del clasificador RNN utilizado en el TFM. Es un modelo implementado en Keras utilizando TensorFlow 2. El modelo clasifica las 345 categorías del dataset de QuickDraw en formato TFRecord con una precisión del 76.8%.

Contenidos

  • model_weights-0.768.h5: Pesos del modelo entrenado.
  • create_TFRecord_datasets.py: Código utilizado para convertir el dataset de QuickDraw a formato TFRecord, creando conjuntos de datos de entrenamiento y validación.
  • train_RNN.py: Código utilizado para entrenar el modelo RNN.

Definición del Modelo

El siguiente código muestra la función build_model que define la arquitectura del modelo RNN:

from tensorflow.keras import layers, models

def build_model():
    model = models.Sequential()
    # Capa de entrada y filtro
    model.add(layers.Input(shape=(3711, 3)))
    model.add(layers.Masking(mask_value=0.0))
    # Capas convolucionales
    model.add(layers.Conv1D(48, 5, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv1D(64, 5, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv1D(96, 3, activation='relu'))
    model.add(layers.BatchNormalization())
    # LSTM
    model.add(layers.Bidirectional(layers.LSTM(128, return_sequences=True)))
    model.add(layers.Bidirectional(layers.LSTM(128, return_sequences=True)))
    # Pooling y Softmax
    model.add(layers.GlobalAveragePooling1D())
    model.add(layers.Dense(345, activation='softmax'))
    return model

Cargar el Modelo Preentrenado

Para utilizar el modelo ya entrenado se pueden cargar los pesos utilizando el siguiente código:

# Cargar el modelo y los pesos
MODEL_WEIGHTS_PATH = 'model_weights-0.768.h5'
trained_model = build_model()
trained_model.load_weights(MODEL_WEIGHTS_PATH)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages