Skip to content

Commit 50183cc

Browse files
authored
Merge pull request #94 from singularitti:cell
Remove `N` from def of `Cell` to make it more type stable
2 parents ce8a5e6 + 035243c commit 50183cc

File tree

7 files changed

+48
-36
lines changed

7 files changed

+48
-36
lines changed

Project.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,16 @@ version = "0.5.3"
66
[deps]
77
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
88
StructHelpers = "4093c41a-2008-41fd-82b8-e3f9d02b504f"
9-
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
109
spglib_jll = "ac4a9f1e-bdb2-5204-990c-47c8b2f70d4e"
1110

12-
[extras]
13-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14-
1511
[compat]
1612
StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0"
1713
StructHelpers = "0.1"
1814
julia = "1.3"
19-
UnPack = "1.0"
2015
spglib_jll = "1.14"
2116

17+
[extras]
18+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
19+
2220
[targets]
2321
test = ["Test"]

src/Spglib.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module Spglib
22

33
using spglib_jll: libsymspg
4-
using UnPack: @unpack
54

65
# All public methods
76
export get_symmetry,

src/model.jl

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using StaticArrays: MMatrix, MVector
22
using StructHelpers: @batteries
33

4-
export Cell, Dataset, SpacegroupType, basis_vectors
4+
export Cell, Dataset, SpacegroupType, basis_vectors, natoms
55

66
"""
77
Cell(lattice, positions, types, magmoms=zeros(length(types)))
@@ -11,31 +11,45 @@ The basic input data type of `Spglib`.
1111
Lattice parameters `lattice` are given by a ``3×3`` matrix with floating point values,
1212
where ``𝐚``, ``𝐛``, and ``𝐜`` are given as columns.
1313
Fractional atomic positions `positions` are given
14-
by a ``3×N`` matrix with floating point values, where ``N`` is the number of atoms.
14+
by a vector of ``N`` vectors with floating point values, where ``N`` is the number of atoms.
1515
Numbers to distinguish atomic species `types` are given by a list of ``N`` integers.
1616
The collinear polarizations `magmoms` only work with `get_symmetry` and are given
1717
as a list of ``N`` floating point values.
1818
"""
19-
struct Cell{N,L,P,T,M}
19+
struct Cell{L,P,T,M}
2020
lattice::MMatrix{3,3,L,9}
21-
positions::MMatrix{3,N,P}
22-
types::MVector{N,T}
23-
magmoms::MVector{N,M}
21+
positions::Vector{MVector{3,P}}
22+
types::Vector{T}
23+
magmoms::Vector{M}
2424
end
2525
function Cell(lattice, positions, types, magmoms = zeros(length(types)))
26-
if lattice isa AbstractVector
27-
lattice = hcat(lattice...)
26+
if !(lattice isa AbstractMatrix)
27+
lattice = reduce(hcat, lattice) # Use `reduce` can make it type stable
2828
end
29-
if positions isa AbstractVector
30-
positions = hcat(positions...)
29+
N = length(types)
30+
if positions isa AbstractMatrix
31+
P = eltype(positions)
32+
if size(positions) == (3, 3)
33+
error("ambiguous `positions` size 3×3! Use a vector of `Vector`s instead!")
34+
elseif size(positions) == (3, N)
35+
positions = collect(eachcol(positions))
36+
elseif size(positions) == (N, 3)
37+
positions = collect(eachrow(positions))
38+
else
39+
throw(DimensionMismatch( "the `positions` has a different number of atoms from the `types`!"))
40+
end
41+
else # positions isa AbstractVector or a Tuple
42+
P = eltype(Base.promote_typeof(positions...))
43+
positions = collect(map(MVector{3,P}, positions))
3144
end
32-
N, L, P, T, M =
33-
length(types), eltype(lattice), eltype(positions), eltype(types), eltype(magmoms)
34-
return Cell{N,L,P,T,M}(lattice, positions, types, magmoms)
45+
L, T, M = eltype(lattice), eltype(types), eltype(magmoms)
46+
return Cell{L,P,T,M}(lattice, positions, types, magmoms)
3547
end
3648

3749
@batteries Cell eq = true hash = true
3850

51+
natoms(cell::Cell) = length(cell.types)
52+
3953
"""
4054
basis_vectors(cell::Cell)
4155
@@ -48,15 +62,16 @@ end
4862

4963
# This is an internal function, do not export!
5064
function _expand_cell(cell::Cell)
51-
@unpack lattice, positions, types, magmoms = cell
65+
lattice, positions, types, magmoms =
66+
cell.lattice, cell.positions, cell.types, cell.magmoms
5267
# Reference: https://github.com/mdavezac/spglib.jl/blob/master/src/spglib.jl#L32-L35 and https://github.com/spglib/spglib/blob/444e061/python/spglib/spglib.py#L953-L975
53-
clattice = Base.cconvert(Matrix{Cdouble}, lattice) |> transpose
54-
cpositions = Base.cconvert(Matrix{Cdouble}, positions)
68+
clattice = Base.cconvert(Matrix{Cdouble}, transpose(lattice))
69+
cpositions = Base.cconvert(Matrix{Cdouble}, reduce(hcat, positions))
5570
ctypes = Cint[findfirst(isequal(u), unique(types)) for u in types]
5671
if magmoms !== nothing
5772
magmoms = Base.cconvert(Vector{Cdouble}, magmoms)
5873
end
59-
return Cell(clattice, cpositions, ctypes, magmoms)
74+
return clattice, cpositions, ctypes, magmoms
6075
end
6176

6277
# This is an internal type, do not export!

src/reciprocal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function get_ir_reciprocal_mesh(
3737
end
3838
@assert all(isone(x) || iszero(x) for x in is_shift)
3939
# Prepare for input
40-
@unpack lattice, positions, types = _expand_cell(cell)
40+
lattice, positions, types = _expand_cell(cell)
4141
mesh = Base.cconvert(Vector{Cint}, mesh)
4242
is_shift = Base.cconvert(Vector{Cint}, is_shift)
4343
is_time_reversal = Base.cconvert(Cint, is_time_reversal)

src/standardize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function standardize_cell(
1414
no_idealize = false,
1515
symprec = 1e-5,
1616
)
17-
@unpack lattice, positions, types = _expand_cell(cell)
17+
lattice, positions, types = _expand_cell(cell)
1818
to_primitive = Base.cconvert(Cint, to_primitive)
1919
no_idealize = Base.cconvert(Cint, no_idealize)
2020
num_atom = Base.cconvert(Cint, length(types))

src/symmetry.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function get_symmetry!(
4141
if !(size(rotation, 1) == size(rotation, 2) == size(translation, 1) == 3)
4242
throw(DimensionMismatch("`rotation` & `translation` don't have the right size!"))
4343
end
44-
@unpack lattice, positions, types = _expand_cell(cell)
44+
lattice, positions, types = _expand_cell(cell)
4545
rotation = Base.cconvert(Array{Cint,3}, rotation)
4646
translation = Base.cconvert(Matrix{Cdouble}, translation)
4747
max_size = Base.cconvert(Cint, size(rotation, 3))
@@ -82,7 +82,7 @@ function get_symmetry_with_collinear_spin!(
8282
cell::Cell,
8383
symprec = 1e-5,
8484
) where {T}
85-
@unpack lattice, positions, types, magmoms = _expand_cell(cell)
85+
lattice, positions, types, magmoms = _expand_cell(cell)
8686
rotation = Base.cconvert(Array{Cint,3}, rotation)
8787
translation = Base.cconvert(Matrix{Cdouble}, translation)
8888
equivalent_atoms = Base.cconvert(Vector{Cint}, equivalent_atoms)
@@ -220,7 +220,7 @@ end
220220
Return the exact number of symmetry operations. An error is thrown when it fails.
221221
"""
222222
function get_multiplicity(cell::Cell, symprec = 1e-5)
223-
@unpack lattice, positions, types = _expand_cell(cell)
223+
lattice, positions, types = _expand_cell(cell)
224224
num_atom = Base.cconvert(Cint, length(types))
225225
num_sym = ccall(
226226
(:spg_get_multiplicity, libsymspg),
@@ -244,7 +244,7 @@ end
244244
Search symmetry operations of an input unit cell structure.
245245
"""
246246
function get_dataset(cell::Cell, symprec = 1e-5)
247-
@unpack lattice, positions, types = _expand_cell(cell)
247+
lattice, positions, types = _expand_cell(cell)
248248
num_atom = Base.cconvert(Cint, length(types))
249249
ptr = ccall(
250250
(:spg_get_dataset, libsymspg),
@@ -270,7 +270,7 @@ end
270270
Search symmetry operations of an input unit cell structure, using a given Hall number.
271271
"""
272272
function get_dataset_with_hall_number(cell::Cell, hall_number::Integer, symprec = 1e-5)
273-
@unpack lattice, positions, types = _expand_cell(cell)
273+
lattice, positions, types = _expand_cell(cell)
274274
num_atom = Base.cconvert(Cint, length(types))
275275
hall_number = Base.cconvert(Cint, hall_number)
276276
ptr = ccall(
@@ -332,7 +332,7 @@ end
332332
Return the space group type in Hermann–Mauguin (international) notation.
333333
"""
334334
function get_international(cell::Cell, symprec = 1e-5)
335-
@unpack lattice, positions, types = _expand_cell(cell)
335+
lattice, positions, types = _expand_cell(cell)
336336
symbol = Vector{Cchar}(undef, 11)
337337
exitcode = ccall(
338338
(:spg_get_international, libsymspg),
@@ -357,7 +357,7 @@ end
357357
Return the space group type in Schoenflies notation.
358358
"""
359359
function get_schoenflies(cell::Cell, symprec = 1e-5)
360-
@unpack lattice, positions, types = _expand_cell(cell)
360+
lattice, positions, types = _expand_cell(cell)
361361
symbol = Vector{Cchar}(undef, 7)
362362
exitcode = ccall(
363363
(:spg_get_schoenflies, libsymspg),

test/standardize.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
-1.999719735 1.999719735 0.0
2727
0.0 0.0 8.57154746
2828
]
29-
@test primitive_cell.positions [ # Python results
29+
@test reduce(hcat, primitive_cell.positions) [ # Python results
3030
0.15311561 0.34688439 0.65311561 0.84688439
3131
0.84688439 0.65311561 0.34688439 0.15311561
3232
0.1203133 0.6203133 0.3796867 0.8796867
@@ -69,7 +69,7 @@ end
6969
-1.99971973 1.99971973 0.0
7070
0.0 0.0 8.57154746
7171
]
72-
@test primitive_cell.positions [ # Python results
72+
@test reduce(hcat, primitive_cell.positions) [ # Python results
7373
0.15311561 0.34688439 0.65311561 0.84688439
7474
0.84688439 0.65311561 0.34688439 0.15311561
7575
0.1203133 0.6203133 0.3796867 0.8796867
@@ -105,7 +105,7 @@ end
105105
2.0 -2.0 2.0
106106
2.0 2.0 -2.0
107107
]
108-
@test new_cell.positions [0.0 0.0 0.0]'
108+
@test new_cell.positions [[0.0, 0.0, 0.0]]
109109
@test new_cell.types == [1]
110110
end
111111
@testset "Test `refine_cell`" begin
@@ -123,7 +123,7 @@ end
123123
0.0 4.0 0.0
124124
0.0 0.0 4.0
125125
]
126-
@test new_cell.positions [
126+
@test reduce(hcat, new_cell.positions) [
127127
0.0 0.5
128128
0.0 0.5
129129
0.0 0.5

0 commit comments

Comments
 (0)