Skip to content

Remove N from def of Cell to make it more type stable #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 6, 2022
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@ version = "0.5.3"
[deps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructHelpers = "4093c41a-2008-41fd-82b8-e3f9d02b504f"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
spglib_jll = "ac4a9f1e-bdb2-5204-990c-47c8b2f70d4e"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0"
StructHelpers = "0.1"
julia = "1.3"
UnPack = "1.0"
spglib_jll = "1.14"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
1 change: 0 additions & 1 deletion src/Spglib.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module Spglib

using spglib_jll: libsymspg
using UnPack: @unpack

# All public methods
export get_symmetry,
Expand Down
49 changes: 32 additions & 17 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using StaticArrays: MMatrix, MVector
using StructHelpers: @batteries

export Cell, Dataset, SpacegroupType, basis_vectors
export Cell, Dataset, SpacegroupType, basis_vectors, natoms

"""
Cell(lattice, positions, types, magmoms=zeros(length(types)))
Expand All @@ -11,31 +11,45 @@ The basic input data type of `Spglib`.
Lattice parameters `lattice` are given by a ``3×3`` matrix with floating point values,
where ``𝐚``, ``𝐛``, and ``𝐜`` are given as columns.
Fractional atomic positions `positions` are given
by a ``3×N`` matrix with floating point values, where ``N`` is the number of atoms.
by a vector of ``N`` vectors with floating point values, where ``N`` is the number of atoms.
Numbers to distinguish atomic species `types` are given by a list of ``N`` integers.
The collinear polarizations `magmoms` only work with `get_symmetry` and are given
as a list of ``N`` floating point values.
"""
struct Cell{N,L,P,T,M}
struct Cell{L,P,T,M}
lattice::MMatrix{3,3,L,9}
positions::MMatrix{3,N,P}
types::MVector{N,T}
magmoms::MVector{N,M}
positions::Vector{MVector{3,P}}
types::Vector{T}
magmoms::Vector{M}
end
function Cell(lattice, positions, types, magmoms = zeros(length(types)))
if lattice isa AbstractVector
lattice = hcat(lattice...)
if !(lattice isa AbstractMatrix)
lattice = reduce(hcat, lattice) # Use `reduce` can make it type stable
end
if positions isa AbstractVector
positions = hcat(positions...)
N = length(types)
if positions isa AbstractMatrix
P = eltype(positions)
if size(positions) == (3, 3)
error("ambiguous `positions` size 3×3! Use a vector of `Vector`s instead!")
elseif size(positions) == (3, N)
positions = collect(eachcol(positions))
elseif size(positions) == (N, 3)
positions = collect(eachrow(positions))
else
throw(DimensionMismatch( "the `positions` has a different number of atoms from the `types`!"))
end
else # positions isa AbstractVector or a Tuple
P = eltype(Base.promote_typeof(positions...))
positions = collect(map(MVector{3,P}, positions))
end
N, L, P, T, M =
length(types), eltype(lattice), eltype(positions), eltype(types), eltype(magmoms)
return Cell{N,L,P,T,M}(lattice, positions, types, magmoms)
L, T, M = eltype(lattice), eltype(types), eltype(magmoms)
return Cell{L,P,T,M}(lattice, positions, types, magmoms)
end

@batteries Cell eq = true hash = true

natoms(cell::Cell) = length(cell.types)

"""
basis_vectors(cell::Cell)

Expand All @@ -48,15 +62,16 @@ end

# This is an internal function, do not export!
function _expand_cell(cell::Cell)
@unpack lattice, positions, types, magmoms = cell
lattice, positions, types, magmoms =
cell.lattice, cell.positions, cell.types, cell.magmoms
# Reference: https://github.com/mdavezac/spglib.jl/blob/master/src/spglib.jl#L32-L35 and https://github.com/spglib/spglib/blob/444e061/python/spglib/spglib.py#L953-L975
clattice = Base.cconvert(Matrix{Cdouble}, lattice) |> transpose
cpositions = Base.cconvert(Matrix{Cdouble}, positions)
clattice = Base.cconvert(Matrix{Cdouble}, transpose(lattice))
cpositions = Base.cconvert(Matrix{Cdouble}, reduce(hcat, positions))
ctypes = Cint[findfirst(isequal(u), unique(types)) for u in types]
if magmoms !== nothing
magmoms = Base.cconvert(Vector{Cdouble}, magmoms)
end
return Cell(clattice, cpositions, ctypes, magmoms)
return clattice, cpositions, ctypes, magmoms
end

# This is an internal type, do not export!
Expand Down
2 changes: 1 addition & 1 deletion src/reciprocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function get_ir_reciprocal_mesh(
end
@assert all(isone(x) || iszero(x) for x in is_shift)
# Prepare for input
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
mesh = Base.cconvert(Vector{Cint}, mesh)
is_shift = Base.cconvert(Vector{Cint}, is_shift)
is_time_reversal = Base.cconvert(Cint, is_time_reversal)
Expand Down
2 changes: 1 addition & 1 deletion src/standardize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function standardize_cell(
no_idealize = false,
symprec = 1e-5,
)
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
to_primitive = Base.cconvert(Cint, to_primitive)
no_idealize = Base.cconvert(Cint, no_idealize)
num_atom = Base.cconvert(Cint, length(types))
Expand Down
14 changes: 7 additions & 7 deletions src/symmetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function get_symmetry!(
if !(size(rotation, 1) == size(rotation, 2) == size(translation, 1) == 3)
throw(DimensionMismatch("`rotation` & `translation` don't have the right size!"))
end
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
rotation = Base.cconvert(Array{Cint,3}, rotation)
translation = Base.cconvert(Matrix{Cdouble}, translation)
max_size = Base.cconvert(Cint, size(rotation, 3))
Expand Down Expand Up @@ -82,7 +82,7 @@ function get_symmetry_with_collinear_spin!(
cell::Cell,
symprec = 1e-5,
) where {T}
@unpack lattice, positions, types, magmoms = _expand_cell(cell)
lattice, positions, types, magmoms = _expand_cell(cell)
rotation = Base.cconvert(Array{Cint,3}, rotation)
translation = Base.cconvert(Matrix{Cdouble}, translation)
equivalent_atoms = Base.cconvert(Vector{Cint}, equivalent_atoms)
Expand Down Expand Up @@ -220,7 +220,7 @@ end
Return the exact number of symmetry operations. An error is thrown when it fails.
"""
function get_multiplicity(cell::Cell, symprec = 1e-5)
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
num_atom = Base.cconvert(Cint, length(types))
num_sym = ccall(
(:spg_get_multiplicity, libsymspg),
Expand All @@ -244,7 +244,7 @@ end
Search symmetry operations of an input unit cell structure.
"""
function get_dataset(cell::Cell, symprec = 1e-5)
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
num_atom = Base.cconvert(Cint, length(types))
ptr = ccall(
(:spg_get_dataset, libsymspg),
Expand All @@ -270,7 +270,7 @@ end
Search symmetry operations of an input unit cell structure, using a given Hall number.
"""
function get_dataset_with_hall_number(cell::Cell, hall_number::Integer, symprec = 1e-5)
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
num_atom = Base.cconvert(Cint, length(types))
hall_number = Base.cconvert(Cint, hall_number)
ptr = ccall(
Expand Down Expand Up @@ -332,7 +332,7 @@ end
Return the space group type in Hermann–Mauguin (international) notation.
"""
function get_international(cell::Cell, symprec = 1e-5)
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
symbol = Vector{Cchar}(undef, 11)
exitcode = ccall(
(:spg_get_international, libsymspg),
Expand All @@ -357,7 +357,7 @@ end
Return the space group type in Schoenflies notation.
"""
function get_schoenflies(cell::Cell, symprec = 1e-5)
@unpack lattice, positions, types = _expand_cell(cell)
lattice, positions, types = _expand_cell(cell)
symbol = Vector{Cchar}(undef, 7)
exitcode = ccall(
(:spg_get_schoenflies, libsymspg),
Expand Down
8 changes: 4 additions & 4 deletions test/standardize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
-1.999719735 1.999719735 0.0
0.0 0.0 8.57154746
]
@test primitive_cell.positions ≈ [ # Python results
@test reduce(hcat, primitive_cell.positions) ≈ [ # Python results
0.15311561 0.34688439 0.65311561 0.84688439
0.84688439 0.65311561 0.34688439 0.15311561
0.1203133 0.6203133 0.3796867 0.8796867
Expand Down Expand Up @@ -69,7 +69,7 @@ end
-1.99971973 1.99971973 0.0
0.0 0.0 8.57154746
]
@test primitive_cell.positions ≈ [ # Python results
@test reduce(hcat, primitive_cell.positions) ≈ [ # Python results
0.15311561 0.34688439 0.65311561 0.84688439
0.84688439 0.65311561 0.34688439 0.15311561
0.1203133 0.6203133 0.3796867 0.8796867
Expand Down Expand Up @@ -105,7 +105,7 @@ end
2.0 -2.0 2.0
2.0 2.0 -2.0
]
@test new_cell.positions ≈ [0.0 0.0 0.0]'
@test new_cell.positions ≈ [[0.0, 0.0, 0.0]]
@test new_cell.types == [1]
end
@testset "Test `refine_cell`" begin
Expand All @@ -123,7 +123,7 @@ end
0.0 4.0 0.0
0.0 0.0 4.0
]
@test new_cell.positions ≈ [
@test reduce(hcat, new_cell.positions) ≈ [
0.0 0.5
0.0 0.5
0.0 0.5
Expand Down