Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch capable sampling functions + proto-type HMC/MCMC #25

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffhod/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
print("batch mode...")
37 changes: 21 additions & 16 deletions diffhod/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def Zheng07Cens(halo_mvir,
name='zheng07Cens', **kwargs):
halo_mvir = tf.math.log(halo_mvir) / tf.math.log(10.)
# Compute the mean number of centrals
p = tf.clip_by_value(0.5 * (1+tf.math.erf((halo_mvir - logMmin)/sigma_logM)), 1.e-4, 1-1.e-4)
p = tf.clip_by_value(0.5 * (1+tf.math.erf((halo_mvir - tf.reshape(logMmin,(-1,1)))/tf.reshape(sigma_logM,(-1,1)))), 1.e-4, 1-1.e-4)
return ed.RelaxedBernoulli(temperature, probs=p, name=name)

def Zheng07SatsPoisson(halo_mvir,
Expand All @@ -21,27 +21,32 @@ def Zheng07SatsPoisson(halo_mvir,
logM1=ed.Deterministic(12.4, name='logM1'),
alpha=ed.Deterministic(0.83, name='alpha'),
name='zheng07Sats', **kwargs):
M0 = 10.**logM0
M1 = 10.**logM1
rate = n_cen.distribution.probs * ((halo_mvir - M0)/M1)**alpha
rate = tf.where(halo_mvir < M0, 1e-4, rate)
M0 = tf.pow(10.,logM0)
M1 = tf.pow(10.,logM1)
rate = n_cen.distribution.probs * tf.math.pow((halo_mvir - tf.reshape(M0,(-1,1)))/(tf.reshape(M1,(-1,1))),tf.reshape(alpha,(-1,1)))
rate = tf.where(halo_mvir < tf.reshape(M0,(-1,1)), 1e-4, rate)
return ed.Poisson(rate=rate, name=name)

def Zheng07SatsRelaxedBernoulli(halo_mvir,
n_cen,
sample_shape,
logM0=ed.Deterministic(11.2, name='logM0'),
logM1=ed.Deterministic(12.4, name='logM1'),
alpha=ed.Deterministic(0.83, name='alpha'),
temperature=0.2,
name='zheng07Sats', **kwargs):
M0 = 10.**logM0
M1 = 10.**logM1
rate = n_cen.distribution.probs * (tf.nn.relu(halo_mvir - M0)/M1)**alpha
return ed.RelaxedBernoulli(temperature=temperature,
n_cen,
sample_shape,
logM0=ed.Deterministic(11.2, name='logM0'),
logM1=ed.Deterministic(12.4, name='logM1'),
alpha=ed.Deterministic(0.83, name='alpha'),
temperature=0.2,
name='zheng07Sats', **kwargs):
M0 = tf.pow(10.,logM0)
M1 = tf.pow(10.,logM1)
print(M0)

num = halo_mvir - tf.reshape(M0,(-1,1))

rate = n_cen.distribution.probs * (tf.nn.relu(num)/tf.reshape(M1,(-1,1)))**tf.reshape(alpha,(-1,1))
return ed.RelaxedBernoulli(temperature=temperature,
probs=tf.clip_by_value(rate/sample_shape[0],1.e-5,1-1e-4),
sample_shape=sample_shape)


def NFWProfile(pos,
concentration,
Rvir,
Expand Down
1 change: 1 addition & 0 deletions diffhod/mock_observables/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from diffhod.mock_observables import *
85 changes: 85 additions & 0 deletions diffhod/mock_observables/pk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import numpy

import tensorflow as tf


def _initialize_pk(shape,boxsize,kmin,dk):
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
"""
I = np.eye(len(shape), dtype='int') * -2 + 1

W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0

kmax = np.pi * np.min(shape.as_list())/np.max(boxsize) + dk/2
kedges = np.arange(kmin, kmax, dk)

k = [np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)]
kmag = sum(ki ** 2 for ki in k) ** 0.5

xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)

dig = np.digitize(kmag.flat, kedges)

xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
dig = tf.Variable(dig,dtype=tf.int32)
Nsum = tf.Variable(Nsum,dtype=tf.complex64)
return dig, Nsum, xsum, W, k, kedges


def pk(field,kmin=5,dk=0.5,shape = False,boxsize= False):
"""
Calculate the powerspectra given real space field

Args:

field: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
shape: shape of field to calculate field (can be strangely shaped?)
boxsize: length of each boxlength (can be strangly shaped?)

Returns:

kbins: the central value of the bins for plotting
power: real valued array of power in each bin

"""


#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape,boxsize,kmin,dk)


#convert field to complex for fft
field_complex = tf.dtypes.cast(field,dtype=tf.complex64)

#fast fourier transform
fft_image = tf.signal.fft3d(field_complex)

#absolute value of fast fourier transform
pk = tf.math.real(fft_image*tf.math.conj(fft_image))

#calculating powerspectra
Psum = tf.zeros(len(kedges) + 1, dtype=tf.complex64)
real = tf.reshape(tf.math.real(pk),[-1,])
imag = tf.reshape(tf.math.imag(pk),[-1,])

Psum += tf.dtypes.cast(tf.math.bincount(dig, weights=(W.flatten() * imag), minlength=xsum.size),dtype=tf.complex64)*1j
Psum += tf.dtypes.cast(tf.math.bincount(dig, weights=(W.flatten() * real), minlength=xsum.size),dtype=tf.complex64)

power = (Psum / Nsum)[1:-1] * boxsize.prod()

#normalization for powerspectra
norm = tf.dtypes.cast(tf.reduce_prod(shape),dtype=tf.float32)**2

#find central values of each bin
kbins = kedges[:-1]+ (kedges[1:] - kedges[:-1])/2

return kbins,tf.dtypes.cast(power,dtype=tf.float32)/norm
Loading