-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
34 lines (31 loc) · 1.13 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torch.nn.functional as f
class Net(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Net, self).__init__()
'''
-Método:
- __init__: Este es el constructor de la clase Net
-Argumentos:
- input_size : Número de neuronas en la entrada.
- hidden_size: Número de neuronas en la capa oculta.
- output_size: Número de neuronas en la capa de salida.
'''
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.linear3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
'''
- Método
- forward: Esta función realizara el paso hacie adelante en la red.
- Argumentos:
- x: Le pasamos de parámetro un tensor de torch.
- Retorna:
- out: Nos retorna la inferencia del modelo.
'''
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
out = self.linear3(x)
return out