Skip to content

Commit d930538

Browse files
committed
benefits need to be seen with a pop of size 3
1 parent abdc7be commit d930538

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

vector_quantize_pytorch/evo_vq.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
from torch import cat
3+
4+
# helpers
5+
6+
def exists(v):
7+
return v is not None
8+
9+
def default(v, d):
10+
return v if exists(v) else d
11+
12+
# evolution - start with the most minimal, a population of 3
13+
# 1 is natural selected out, the other 2 performs crossover
14+
15+
def select_and_crossover(
16+
codes, # Float[3 ...]
17+
fitness, # Float[3]
18+
):
19+
assert codes.shape[0] == fitness.shape[0] == 3
20+
21+
# selection
22+
23+
top2 = fitness.topk(2, dim = -1).indices
24+
codes = codes[top2]
25+
26+
# crossover
27+
28+
child = codes.mean(dim = 0, keepdim = True)
29+
codes = cat((codes, child))
30+
31+
return codes
32+
33+
# class
34+
35+
class EvoVQ(Module):
36+
def __init__(self):
37+
super().__init__()
38+
raise NotImplementedError
39+
40+
def forward(self, x):
41+
raise NotImplementedError

0 commit comments

Comments
 (0)