-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimulate_6B.py
132 lines (120 loc) · 5.82 KB
/
simulate_6B.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# TiN/Hf(Al)O/Hf/TiN devices from Figure 2 (C)
import enum
from enum import Enum, auto
from memtorch.mn.Module import supported_module_parameters
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import memtorch
from memtorch.mn.Module import patch_model
from memtorch.map.Module import naive_tune
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities
from memtorch.bh.crossbar.Crossbar import init_crossbar
import copy
from pprint import pprint
from mobilenetv2 import MobileNetV2
from scipy.interpolate import interp1d
import torchvision
def test(model, test_loader):
correct = 0
for batch_idx, (data, target) in enumerate(test_loader):
output = model(data.to(device))
pred = output.data.max(1)[1]
correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()
return 100. * float(correct) / float(len(test_loader.dataset))
def update_patched_model(patched_model, model):
for i, (name, m) in enumerate(list(patched_model.named_modules())):
if isinstance(m, memtorch.mn.Conv2d) or isinstance(m, memtorch.mn.Linear):
pos_conductance_matrix, neg_conductance_matrix = naive_map(getattr(model, name).weight.data, r_on, r_off,scheme=memtorch.bh.Scheme.DoubleColumn)
m.crossbars[0].write_conductance_matrix(pos_conductance_matrix, transistor=True, programming_routine=None)
m.crossbars[1].write_conductance_matrix(neg_conductance_matrix, transistor=True, programming_routine=None)
m.weight.data = getattr(model, name).weight.data
return patched_model
scale_input = interp1d([1.3, 1.9], [0, 1])
def scale_p_0(p_0, p_1, v_stop, cell_size=10):
scaled_input = scale_input(v_stop)
x = 1.50
y = p_0 * np.exp(p_1 * cell_size)
k = np.log10(y) / (1 - (2 * scale_input(x) - 1) ** (2))
return (10 ** (k * (1 - (2 * scaled_input - 1) ** (2)))) / (np.exp(p_1 * cell_size))
def model_sudden(layer, cycle_count, v_stop):
cell_size = 20
convergence_point_lrs = 1e5
initial_resistance_lrs = 14000
stable_resistance_lrs = 5e7
p_0_lrs = 10500.207573382977
p_1_lrs = 0.2519450238812669
convergence_point_hrs = 1e5
initial_resistance_hrs = 300000
stable_resistance_hrs = 5e7
p_0_hrs = 10629.493769115974
p_1_hrs = 0.24199726610964553
p_0_lrs = scale_p_0(p_0_lrs, p_1_lrs, v_stop, cell_size)
p_0_hrs = scale_p_0(p_0_hrs, p_1_hrs, v_stop, cell_size)
threshold_lrs = p_0_lrs * np.exp(p_1_lrs * cell_size)
threshold_hrs = p_0_hrs * np.exp(p_1_hrs * cell_size)
for i in range(len(layer.crossbars)):
initial_resistance = 1 / layer.crossbars[i].conductance_matrix
if initial_resistance[initial_resistance < convergence_point_lrs].nelement() > 0:
if cycle_count > threshold_lrs:
initial_resistance[initial_resistance < convergence_point_lrs] = stable_resistance_lrs
if initial_resistance[initial_resistance > convergence_point_hrs].nelement() > 0:
if cycle_count > threshold_hrs:
initial_resistance[initial_resistance < convergence_point_hrs] = stable_resistance_hrs
layer.crossbars[i].conductance_matrix = 1 / initial_resistance
return layer
def model_degradation(model, cycle_count, v_stop):
for i, (name, m) in enumerate(list(model.named_modules())):
if type(m) in supported_module_parameters.values():
setattr(model, name, model_sudden(m, cycle_count, v_stop))
return model
device = torch.device('cuda')
batch_size = 256
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=1)
reference_memristor = memtorch.bh.memristor.VTEAM
r_on = 14000
r_off = 300000
reference_memristor_params = {'time_series_resolution': 1e-10, 'r_off': r_off, 'r_on': r_on}
model = MobileNetV2().to(device)
model.load_state_dict(torch.load('trained_model.pt'), strict=False)
model.eval()
patched_model = patch_model(model,
memristor_model=reference_memristor,
memristor_model_params=reference_memristor_params,
module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
mapping_routine=naive_map,
transistor=True,
programming_routine=None,
scheme=memtorch.bh.Scheme.DoubleColumn,
tile_shape=(128, 128),
max_input_voltage=0.3,
ADC_resolution=8,
ADC_overflow_rate=0.,
quant_method='linear')
del model
patched_model.tune_()
times_to_reprogram = 10 ** np.arange(1, 10, dtype=np.float64)
v_stop_values = np.linspace(1.3, 1.9, 10, endpoint=True)
df = pd.DataFrame(columns=['times_reprogramed', 'v_stop', 'test_set_accuracy'])
for time_to_reprogram in times_to_reprogram:
cycle_count = time_to_reprogram
for v_stop in v_stop_values:
print('time_to_reprogram: %f, v_stop: %f' % (time_to_reprogram, v_stop))
patched_model_copy = copy.deepcopy(patched_model)
patched_model_copy = model_degradation(patched_model_copy, cycle_count, v_stop)
accuracy = test(patched_model_copy, test_loader)
del patched_model_copy
df = df.append({'times_reprogramed': time_to_reprogram, 'v_stop': v_stop, 'test_set_accuracy': accuracy}, ignore_index=True)
df.to_csv('6B.csv', index=False)