Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add TGLFInputs and TGLFNN #477

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torax/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import chex
import jax
from jax import numpy as jnp

from torax import array_typing
from torax import constants
from torax import geometry
Expand Down Expand Up @@ -406,6 +407,23 @@ def _calculate_lambda_ei(
"""
return 15.2 - 0.5 * jnp.log(ne / 1e20) + jnp.log(temp_el)

def _calculate_lambda_ee(
temp_el: jax.Array,
ne: jax.Array,
) -> jax.Array:
"""Calculates Coulomb logarithm for electron-ion collisions.

See Wesson 3rd edition p727.

Args:
temp_el: Electron temperature in keV.
ne: Electron density in m^-3.

Returns:
Coulomb logarithm.
"""
return 14.9 - 0.5 * jnp.log(ne / 1e20) + jnp.log(temp_el)


def fast_ion_fractional_heating_formula(
birth_energy: float | array_typing.ArrayFloat,
Expand Down
173 changes: 173 additions & 0 deletions torax/transport_model/tglf_based_transport_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class and utils for TGLF-based models."""

import chex
from jax import numpy as jnp

from torax import geometry
from torax import physics
from torax import state
from torax.constants import CONSTANTS
from torax.transport_model import quasilinear_transport_model
from torax.transport_model import runtime_params as runtime_params_lib


@chex.dataclass
class RuntimeParams(quasilinear_transport_model.RuntimeParams):
pass


@chex.dataclass(frozen=True)
class DynamicRuntimeParams(quasilinear_transport_model.DynamicRuntimeParams):
pass


@chex.dataclass
class RuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider):
pass


@chex.dataclass(frozen=True)
class TGLFInputs(quasilinear_transport_model.QuasilinearInputs):
r"""Dimensionless inputs to the TGLF model.

See https://gafusion.github.io/doc/tglf/tglf_table.html for definitions.
"""

# Ti/Te
Ti_over_Te: chex.Array
# dRmaj/dr
dRmaj: chex.Array
# q
q: chex.Array
# r/q dq/dr
s_hat: chex.Array
# nu_ei (see note in prepare_tglf_inputs)
ei_collision_freq: chex.Array
# Elongation kappa
kappa: chex.Array
# r/kappa dkappa/dr
kappa_shear: chex.Array
# Triangularity delta
delta: chex.Array
# r ddelta/dr
delta_shear: chex.Array
# Electron pressure defined w.r.t B_unit
beta_e: chex.Array
# Effective charge
Zeff: chex.Array


class TGLFBasedTransportModel(quasilinear_transport_model.QuasilinearTransportModel):
"""Base class for TGLF-based transport models."""

def _prepare_tglf_inputs(
Zeff_face: chex.Array,
nref: chex.Numeric,
q_correction_factor: chex.Numeric,
transport: DynamicRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> TGLFInputs:
# Shorthand for the appropriate variables
Te = core_profiles.temp_el
Ti = core_profiles.temp_ion
ne = core_profiles.ne

# Reference velocity and length, used for normalisation
vref = (Te.face_value() / (core_profiles.Ai * CONSTANTS.mp)) ** 0.5
lref = geo.Rmin[-1] # Minor radius at LCFS

# Temperature gradients
Ti_over_Te = Ti.face_value() / Te.face_value()
Ate = -lref / Te.face_value() * Te.face_grad()
Ati = -lref / Ti.face_value() * Ti.face_grad()

# Density gradient
# Note: nref cancels, as 1/(ne*nref) * (ne_grad * nref) = 1/ne * ne_grad
Ane = -lref / ne.face_value() * core_profiles.ne.face_grad()

# Electron-electron collision frequency
# Note: In the TGLF docs, XNUE is mislabelled.
# It is actually the electron-electron collision frequency
# See https://pyrokinetics.readthedocs.io/en/latest/user_guide/collisions.html
Lambda_ee = physics._calculate_lambda_ee(Te, ne)
normalised_nu_ee = (4 * jnp.pi * ne * CONSTANTS.qe**4 * Lambda_ee) / (
CONSTANTS.me**0.5 * (2 * Te) ** 1.5
)
nu_ee = normalised_nu_ee / (vref / lref)

# Safety factor
# Need to recalculate since in the nonlinear solver psi has intermediate
# states in the iterative solve
q, _ = physics.calc_q_from_psi(
geo=geo,
psi=core_profiles.psi,
q_correction_factor=q_correction_factor,
)
# Shear uses rho_face_norm
# TODO: check whether this should be midplane R
s_hat = physics.calc_s_from_psi(geo, core_profiles.psi) # = r/q dq/dr
theo-brown marked this conversation as resolved.
Show resolved Hide resolved

# Electron beta
p_e = ne * (Te * 1e3) # ne in m^-3, Te in eV
# B_unit = q/r dpsi/dr
B_unit = (
q / geo.rho_face_norm * jnp.gradient(core_profiles.psi, geo.rho_face_norm)
)
beta_e = 8 * jnp.pi * p_e / B_unit**2

# Geometry
Rmaj = geo.Rmaj
Rmin = geo.Rmin
dRmaj = jnp.gradient(geo.Rmaj, geo.rho_face_norm)
kappa = geo.elongation_face
# Elongation
kappa_shear = geo.rho_face_norm / kappa * jnp.gradient(kappa, geo.rho_face_norm)
# Triangularity
delta = geo.delta_face
delta_shear = geo.delta_face * jnp.gradient(geo.delta_face, geo.rho_face_norm)

# Gyrobohm diffusivity
# Used to unnormalise the outputs
# TODO: check this definition with Lorenzo/TGLF and ensure correct normalisation
chiGB = (
(core_profiles.Ai * CONSTANTS.mp) ** 0.5
/ (CONSTANTS.qe * geo.B0) ** 2
* (Ti.face_value() * CONSTANTS.keV2J) ** 1.5
/ lref
)
Copy link
Collaborator Author

@theo-brown theo-brown Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this definition (chiGB) correct / consistent with TGLF?

Copy link
Collaborator

@jcitrin jcitrin Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use quasilinear_transport_model.calculate_chiGB , where you can set a b_unit and reference length as input.

For TGLF probably b_unit=Bunit which you need to define (it's not the same as geo.B0), and aminor for the reference length. Should double check

As always, be careful with sqrt(2) in rho_s_unit and compare to TORAX calculate_chiGB to make sure we didn't miss another input argument (to include sqrt(2) or not)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I think this might've been a function added in that I hadn't noticed. I'll recheck the tglf docs for the reference field etc.


return TGLFInputs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably need alphaMHD as well? See quasilinear_transport_model.calculate_alpha

Copy link
Collaborator

@lorenzozanisi lorenzozanisi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now included this as P_PRIME_LOC here. This is not treated as an independent dimension in the NN, as it can be computed from gradients, safety factor and BETAE and the NN should be able to figure this out. It does need to be used for running TGLF of course, which is where the code I linked plays a role.

# From QuasilinearInputs
chiGB=chiGB,
Rmin=Rmin,
Rmaj=Rmaj,
Ati=Ati,
Ate=Ate,
Ane=Ane,
# From TGLFInputs
Ti_over_Te=Ti_over_Te,
dRmaj=dRmaj,
q=q,
s_hat=s_hat,
nu_ee=nu_ee,
kappa=kappa,
kappa_shear=kappa_shear,
delta=delta,
delta_shear=delta_shear,
beta_e=beta_e,
Zeff=Zeff_face,
)
105 changes: 105 additions & 0 deletions torax/transport_model/tglfnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import chex
import jax.numpy as jnp
from flax import linen as nn


class TGLFNN(nn.Module):
"""A simple MLP with dropout layers, ReLU activation, and outputting a mean and variance."""

hidden_dimension: int
n_hidden_layers: int
dropout: float
input_means: chex.Array
input_stds: chex.Array
output_mean: float
output_std: float

@nn.compact
def __call__(
self,
x,
deterministic: bool = False,
standardise_inputs: bool = True,
standardise_outputs: bool = False,
):
if standardise_inputs:
# Transform to 0 mean and unit variance
x = (x - self.input_means) / self.input_stds

x = nn.Dense(self.hidden_dimension)(x)
x = nn.Dropout(rate=self.dropout, deterministic=deterministic)(x)
x = nn.relu(x)
for _ in range(self.n_hidden_layers):
x = nn.Dense(self.hidden_dimension)(x)
x = nn.Dropout(rate=self.dropout, deterministic=deterministic)(x)
x = nn.relu(x)
mean_and_var = nn.Dense(2)(x)
mean = mean_and_var[..., 0]
var = mean_and_var[..., 1]
var = nn.softplus(var)

if not standardise_outputs:
# Transform back from 0 mean and unit variance
mean = mean * self.output_std + self.output_mean
var = var * self.output_std**2

return jnp.stack([mean, var], axis=-1)


class EnsembleTGLFNN(nn.Module):
"""An ensemble of TGLFNN models."""

input_means: chex.Array
input_stds: chex.Array
output_mean: chex.Array
output_std: chex.Array
n_models: int = 5
hidden_dimension: int = 512
n_hidden_layers: int = 4
dropout: float = 0.05

def setup(
self,
):
self.models = [
TGLFNN(
hidden_dimension=self.hidden_dimension,
n_hidden_layers=self.n_hidden_layers,
dropout=self.dropout,
input_means=self.input_means,
input_stds=self.input_stds,
output_mean=self.output_mean,
output_std=self.output_std,
)
for i in range(self.n_models)
]

def __call__(self, x, *args, **kwargs):
# Shape is batch size x 2 x n_models
outputs = jnp.stack(
[model(x, *args, **kwargs) for model in self.models], axis=-1
)
# Shape is batch_size
mean = jnp.mean(outputs[:, 0, :], axis=-1)
aleatoric_uncertainty = jnp.mean(outputs[:, 1, :], axis=-1)
epistemic_uncertainty = jnp.var(outputs[:, 0, :], axis=-1)
return jnp.stack([mean, aleatoric_uncertainty + epistemic_uncertainty], axis=-1)

def get_params_from_pytorch_state_dict(self, pytorch_state_dict: dict):
params = {}
for i in range(self.n_models):
model_dict = {}
for j in range(self.n_hidden_layers + 2): # +2 for input and output layers
# j*3 to skip dropout and activation
layer_dict = {
"kernel": jnp.array(
pytorch_state_dict[f"models.{i}.model.{j*3}.weight"]
).T,
"bias": jnp.array(
pytorch_state_dict[f"models.{i}.model.{j*3}.bias"]
).T,
}
model_dict[f"Dense_{j}"] = layer_dict
params[f"models_{i}"] = model_dict

return params
Loading