File tree 1 file changed +41
-0
lines changed
1 file changed +41
-0
lines changed Original file line number Diff line number Diff line change
1
+ import torch
2
+ from torch import cat
3
+
4
+ # helpers
5
+
6
+ def exists (v ):
7
+ return v is not None
8
+
9
+ def default (v , d ):
10
+ return v if exists (v ) else d
11
+
12
+ # evolution - start with the most minimal, a population of 3
13
+ # 1 is natural selected out, the other 2 performs crossover
14
+
15
+ def select_and_crossover (
16
+ codes , # Float[3 ...]
17
+ fitness , # Float[3]
18
+ ):
19
+ assert codes .shape [0 ] == fitness .shape [0 ] == 3
20
+
21
+ # selection
22
+
23
+ top2 = fitness .topk (2 , dim = - 1 ).indices
24
+ codes = codes [top2 ]
25
+
26
+ # crossover
27
+
28
+ child = codes .mean (dim = 0 , keepdim = True )
29
+ codes = cat ((codes , child ))
30
+
31
+ return codes
32
+
33
+ # class
34
+
35
+ class EvoVQ (Module ):
36
+ def __init__ (self ):
37
+ super ().__init__ ()
38
+ raise NotImplementedError
39
+
40
+ def forward (self , x ):
41
+ raise NotImplementedError
You can’t perform that action at this time.
0 commit comments