Skip to content

Commit 5747d77

Browse files
JJ
J
authored and
J
committed
migrating library from previous page
1 parent c15b2a9 commit 5747d77

29 files changed

+4015
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# folded_optimization

cvxpylayers/.DS_Store

6 KB
Binary file not shown.

cvxpylayers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.1.4"
166 Bytes
Binary file not shown.
200 Bytes
Binary file not shown.

cvxpylayers/jax/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from cvxpylayers.jax.cvxpylayer import CvxpyLayer # noqa: F401

cvxpylayers/jax/cvxpylayer.py

+322
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
import diffcp
2+
import cvxpy as cp
3+
from cvxpy.reductions.solvers.conic_solvers.scs_conif import \
4+
dims_to_solver_dict
5+
import numpy as np
6+
import time
7+
from functools import partial
8+
9+
try:
10+
import jax
11+
except ImportError:
12+
raise ImportError("Unable to import jax. Please install from "
13+
"https://github.com/google/jax")
14+
from jax import core
15+
import jax.numpy as jnp
16+
17+
18+
def CvxpyLayer(problem, parameters, variables, gp=False):
19+
"""Construct a CvxpyLayer
20+
21+
Args:
22+
problem: The CVXPY problem; must be DPP.
23+
parameters: A list of CVXPY Parameters in the problem; the order
24+
of the Parameters determines the order in which parameter
25+
values must be supplied in the forward pass. Must include
26+
every parameter involved in problem.
27+
variables: A list of CVXPY Variables in the problem; the order of the
28+
Variables determines the order of the optimal variable
29+
values returned from the forward pass.
30+
gp: Whether to parse the problem using DGP (True or False).
31+
32+
Returns:
33+
A callable that solves the problem.
34+
"""
35+
36+
if gp:
37+
if not problem.is_dgp(dpp=True):
38+
raise ValueError('Problem must be DPP.')
39+
else:
40+
if not problem.is_dcp(dpp=True):
41+
raise ValueError('Problem must be DPP.')
42+
43+
if not set(problem.parameters()) == set(parameters):
44+
raise ValueError("The layer's parameters must exactly match "
45+
"problem.parameters")
46+
if not set(variables).issubset(set(problem.variables())):
47+
raise ValueError("Argument variables must be a subset of "
48+
"problem.variables")
49+
if not isinstance(parameters, list) and \
50+
not isinstance(parameters, tuple):
51+
raise ValueError("The layer's parameters must be provided as "
52+
"a list or tuple")
53+
if not isinstance(variables, list) and \
54+
not isinstance(variables, tuple):
55+
raise ValueError("The layer's variables must be provided as "
56+
"a list or tuple")
57+
58+
var_dict = {v.id for v in variables}
59+
60+
# Construct compiler
61+
param_order = parameters
62+
if gp:
63+
for param in parameters:
64+
if param.value is None:
65+
raise ValueError("An initial value for each parameter is "
66+
"required when gp=True.")
67+
data, solving_chain, _ = problem.get_problem_data(
68+
solver=cp.SCS, gp=True)
69+
compiler = data[cp.settings.PARAM_PROB]
70+
dgp2dcp = solving_chain.get(cp.reductions.Dgp2Dcp)
71+
param_ids = [p.id for p in compiler.parameters]
72+
old_params_to_new_params = (
73+
dgp2dcp.canon_methods._parameters
74+
)
75+
else:
76+
data, _, _ = problem.get_problem_data(solver=cp.SCS)
77+
compiler = data[cp.settings.PARAM_PROB]
78+
param_ids = [p.id for p in param_order]
79+
dgp2dcp = None
80+
cone_dims = dims_to_solver_dict(data["dims"])
81+
82+
info = {}
83+
CvxpyLayerFn_p = core.Primitive("CvxpyLayerFn_" + str(hash(problem)))
84+
85+
@partial(jax.custom_vjp, nondiff_argnums=(0,))
86+
def CvxpyLayerFn(solver_args, *params):
87+
return CvxpyLayerFn_p.bind(solver_args, *params)
88+
89+
def CvxpyLayerFn_impl(solver_args, *params):
90+
"""Solve problem (or a batch of problems) corresponding to `params`
91+
92+
Args:
93+
solver_args: a dict of optional arguments, to send to `diffcp`.
94+
Keys should be the names of keyword arguments.
95+
params: a sequence of JAX arrays; the n-th argument specifies
96+
the value for the n-th CVXPY Parameter. These arrays
97+
can be batched: if a array has 3 dimensions, then its
98+
first dimension is interpreted as the batch size. These
99+
arrays must all have the same dtype.
100+
101+
Returns:
102+
a list of optimal variable values, one for each CVXPY Variable
103+
supplied to the constructor.
104+
"""
105+
if len(params) != len(param_ids):
106+
raise ValueError('An array must be provided for each CVXPY '
107+
'parameter; received %d arrays, expected %d' % (
108+
len(params), len(param_ids)))
109+
110+
dtype, batch, batch_sizes, batch_size = batch_info(
111+
params, param_order)
112+
113+
if gp:
114+
param_map = {}
115+
# construct a list of params for the DCP problem
116+
for param, value in zip(param_order, params):
117+
if param in old_params_to_new_params:
118+
new_id = old_params_to_new_params[param].id
119+
param_map[new_id] = jnp.log(value)
120+
else:
121+
new_id = param.id
122+
param_map[new_id] = value
123+
params_numpy = [np.array(param_map[pid]) for pid in param_ids]
124+
else:
125+
params_numpy = [np.array(p) for p in params]
126+
127+
# canonicalize problem
128+
start = time.time()
129+
As, bs, cs, cone_dicts, shapes = [], [], [], [], []
130+
for i in range(batch_size):
131+
params_numpy_i = [
132+
p if sz == 0 else p[i]
133+
for p, sz in zip(params_numpy, batch_sizes)]
134+
c, _, neg_A, b = compiler.apply_parameters(
135+
dict(zip(param_ids, params_numpy_i)),
136+
keep_zeros=True)
137+
A = -neg_A # cvxpy canonicalizes -A
138+
As.append(A)
139+
bs.append(b)
140+
cs.append(c)
141+
cone_dicts.append(cone_dims)
142+
shapes.append(A.shape)
143+
info['canon_time'] = time.time() - start
144+
info['shapes'] = shapes
145+
146+
# compute solution and derivative function
147+
start = time.time()
148+
try:
149+
xs, _, _, _, DT_batch = diffcp.solve_and_derivative_batch(
150+
As, bs, cs, cone_dicts, **solver_args)
151+
info['DT_batch'] = DT_batch
152+
except diffcp.SolverError as e:
153+
print(
154+
"Please consider re-formulating your problem so that "
155+
"it is always solvable or increasing the number of "
156+
"solver iterations.")
157+
raise e
158+
info['solve_time'] = time.time() - start
159+
160+
# extract solutions and append along batch dimension
161+
start = time.time()
162+
sol = [[] for i in range(len(variables))]
163+
for i in range(batch_size):
164+
sltn_dict = compiler.split_solution(
165+
xs[i], active_vars=var_dict)
166+
for j, v in enumerate(variables):
167+
sol[j].append(jnp.expand_dims(jnp.array(
168+
sltn_dict[v.id], dtype=dtype), axis=0))
169+
sol = [jnp.concatenate(s, axis=0) for s in sol]
170+
171+
if not batch:
172+
sol = [jnp.squeeze(s, axis=0) for s in sol]
173+
174+
if gp:
175+
sol = [jnp.exp(s) for s in sol]
176+
177+
return tuple(sol)
178+
179+
CvxpyLayerFn_p.def_impl(CvxpyLayerFn_impl)
180+
181+
def CvxpyLayerFn_fwd_vjp(solver_args, *params):
182+
sol = CvxpyLayerFn(solver_args, *params)
183+
return sol, (params, sol)
184+
185+
def CvxpyLayerFn_bwd_vjp(solver_args, res, dvars):
186+
params, sol = res
187+
dtype, batch, batch_sizes, batch_size = batch_info(
188+
params, param_order)
189+
190+
# Use info here to retrieve this from the forward pass because
191+
# the residual in JAX's vjp doesn't allow non-JAX types to be
192+
# easily returned. This works when calling this serially,
193+
# but will break if this is called in parallel.
194+
shapes = info['shapes']
195+
DT_batch = info['DT_batch']
196+
197+
if gp:
198+
# derivative of exponential recovery transformation
199+
dvars = [dvar*s for dvar, s in zip(dvars, sol)]
200+
201+
dvars_numpy = [np.array(dvar) for dvar in dvars]
202+
203+
if not batch:
204+
dvars_numpy = [np.expand_dims(dvar, 0) for dvar in dvars_numpy]
205+
206+
# differentiate from cvxpy variables to cone problem data
207+
dxs, dys, dss = [], [], []
208+
for i in range(batch_size):
209+
del_vars = {}
210+
for v, dv in zip(variables, [dv[i] for dv in dvars_numpy]):
211+
del_vars[v.id] = dv
212+
dxs.append(compiler.split_adjoint(del_vars))
213+
dys.append(np.zeros(shapes[i][0]))
214+
dss.append(np.zeros(shapes[i][0]))
215+
216+
dAs, dbs, dcs = DT_batch(dxs, dys, dss)
217+
218+
# differentiate from cone problem data to cvxpy parameters
219+
start = time.time()
220+
grad = [[] for _ in range(len(param_ids))]
221+
for i in range(batch_size):
222+
del_param_dict = compiler.apply_param_jac(
223+
dcs[i], -dAs[i], dbs[i])
224+
for j, pid in enumerate(param_ids):
225+
grad[j] += [jnp.expand_dims(jnp.array(
226+
del_param_dict[pid], dtype=dtype), axis=0)]
227+
grad = [jnp.concatenate(g, axis=0) for g in grad]
228+
if gp:
229+
# differentiate through the log transformation of params
230+
dcp_grad = grad
231+
grad = []
232+
dparams = {pid: g for pid, g in zip(param_ids, dcp_grad)}
233+
for param, value in zip(param_order, params):
234+
v = 0.0 if param.id not in dparams else dparams[param.id]
235+
if param in old_params_to_new_params:
236+
dcp_param_id = old_params_to_new_params[param].id
237+
# new_param.value == log(param), apply chain rule
238+
v += (1.0 / value) * dparams[dcp_param_id]
239+
grad.append(v)
240+
info['dcanon_time'] = time.time() - start
241+
242+
if not batch:
243+
grad = [jnp.squeeze(g, axis=0) for g in grad]
244+
else:
245+
for i, sz in enumerate(batch_sizes):
246+
if sz == 0:
247+
grad[i] = jnp.sum(grad[i], axis=0)
248+
249+
return tuple(grad)
250+
251+
CvxpyLayerFn.defvjp(CvxpyLayerFn_fwd_vjp, CvxpyLayerFn_bwd_vjp)
252+
253+
# Default solver_args to an optional empty dict
254+
def f(*params, **kwargs):
255+
solver_args = kwargs.get('solver_args', {})
256+
return CvxpyLayerFn(solver_args, *params)
257+
258+
return f
259+
260+
261+
def batch_info(params, param_order):
262+
# infer dtype and whether or not params are batched
263+
dtype = params[0].dtype
264+
265+
batch_sizes = []
266+
for i, (p, q) in enumerate(zip(params, param_order)):
267+
# check dtype, device of params
268+
if p.dtype != dtype:
269+
raise ValueError(
270+
"Two or more parameters have different dtypes. "
271+
"Expected parameter %d to have dtype %s but "
272+
"got dtype %s." %
273+
(i, str(dtype), str(p.dtype))
274+
)
275+
276+
# check and extract the batch size for the parameter
277+
# 0 means there is no batch dimension for this parameter
278+
# and we assume the batch dimension is non-zero
279+
if p.ndim == q.ndim:
280+
batch_size = 0
281+
elif p.ndim == q.ndim + 1:
282+
batch_size = p.shape[0]
283+
if batch_size == 0:
284+
raise ValueError(
285+
"The batch dimension for parameter {} is zero "
286+
"but should be non-zero.".format(i))
287+
else:
288+
raise ValueError(
289+
"Invalid parameter size passed in. Expected "
290+
"parameter {} to have have {} or {} dimensions "
291+
"but got {} dimensions".format(
292+
i, q.ndim, q.ndim + 1, p.ndim))
293+
294+
batch_sizes.append(batch_size)
295+
296+
# validate the parameter shape
297+
p_shape = p.shape if batch_size == 0 else p.shape[1:]
298+
if not np.all(p_shape == param_order[i].shape):
299+
raise ValueError(
300+
"Inconsistent parameter shapes passed in. "
301+
"Expected parameter {} to have non-batched shape of "
302+
"{} but got {}.".format(
303+
i,
304+
q.shape,
305+
p.shape))
306+
307+
batch_sizes = np.array(batch_sizes)
308+
batch = np.any(batch_sizes > 0)
309+
310+
if batch:
311+
nonzero_batch_sizes = batch_sizes[batch_sizes > 0]
312+
batch_size = nonzero_batch_sizes[0]
313+
if np.any(nonzero_batch_sizes != batch_size):
314+
raise ValueError(
315+
"Inconsistent batch sizes passed in. Expected "
316+
"parameters to have no batch size or all the same "
317+
"batch size but got sizes: {}.".format(
318+
batch_sizes))
319+
else:
320+
batch_size = 1
321+
322+
return dtype, batch, batch_sizes, batch_size

0 commit comments

Comments
 (0)