generated from minitorch/Module-0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
95 lines (75 loc) · 1.94 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import math
import random
from dataclasses import dataclass
from typing import List, Tuple
def make_pts(N: int) -> List[Tuple[float, float]]:
X = []
for i in range(N):
x_1 = random.random()
x_2 = random.random()
X.append((x_1, x_2))
return X
@dataclass
class Graph:
N: int
X: List[Tuple[float, float]]
y: List[int]
def simple(N: int) -> Graph:
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if x_1 < 0.5 else 0
y.append(y1)
return Graph(N, X, y)
def diag(N: int) -> Graph:
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if x_1 + x_2 < 0.5 else 0
y.append(y1)
return Graph(N, X, y)
def split(N: int) -> Graph:
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if x_1 < 0.2 or x_1 > 0.8 else 0
y.append(y1)
return Graph(N, X, y)
def xor(N: int) -> Graph:
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if ((x_1 < 0.5 and x_2 > 0.5) or (x_1 > 0.5 and x_2 < 0.5)) else 0
y.append(y1)
return Graph(N, X, y)
def circle(N: int) -> Graph:
X = make_pts(N)
y = []
for x_1, x_2 in X:
x1, x2 = (x_1 - 0.5, x_2 - 0.5)
y1 = 1 if x1 * x1 + x2 * x2 > 0.1 else 0
y.append(y1)
return Graph(N, X, y)
def spiral(N: int) -> Graph:
def x(t: float) -> float:
return t * math.cos(t) / 20.0
def y(t: float) -> float:
return t * math.sin(t) / 20.0
X = [
(x(10.0 * (float(i) / (N // 2))) + 0.5, y(10.0 * (float(i) / (N // 2))) + 0.5)
for i in range(5 + 0, 5 + N // 2)
]
X = X + [
(y(-10.0 * (float(i) / (N // 2))) + 0.5, x(-10.0 * (float(i) / (N // 2))) + 0.5)
for i in range(5 + 0, 5 + N // 2)
]
y2 = [0] * (N // 2) + [1] * (N // 2)
return Graph(N, X, y2)
datasets = {
"Simple": simple,
"Diag": diag,
"Split": split,
"Xor": xor,
"Circle": circle,
"Spiral": spiral,
}