-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathtask3.py
33 lines (29 loc) · 927 Bytes
/
task3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
def get_indices(N, n_batches, split_ratio):
"""Generates splits of indices from 0 to N-1 into uniformly distributed\
batches. Each batch is defined by 3 indices [i, j, k] where\
(j-i) = split_ratio*(k-j). The first batch starts with i = 0,\
the last one ends with k = N - 1.
Args:
N (int): total counts
n_batches (int): number of splits
split_ratio (float): split ratio, defines position of j in [i, j, k].
Returns:
generator for batch indices [i, j, k]
"""
inds = np.array([0, 0, 0])
for i in range(n_batches):
# todo: move forward batch
# calculate new indices
yield inds
def main():
for inds in get_indices(100, 5, 0.25):
print(inds)
# expected result:
# [0, 44, 55]
# [11, 55, 66]
# [22, 66, 77]
# [33, 77, 88]
# [44, 88, 99]
if __name__ == "__main__":
main()