Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShiftVector{S,D,T} type as a generic type, which KPoint{D,T} is an alias of #206

Merged
merged 11 commits into from
Nov 29, 2023
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- `eachvertex` iterator for the vertices of the parallelepiped representation of a unit cell.
- `StateDensity{T}` type which combines `EnergyOccupancy{T}` with a density of states value at the
energy provided.
- `ByCoordinate{D}` traits: `ByCartesianCoordinate{D}` and `ByFractionalCoordinate{D}`.
- `ShiftVector{S,D,T} <: StaticVector{D,T}` type describing the shift of a lattice or data defined
on it with respect to the origin, along with an optional weight parameter.

### Changed
- The Types section of the documentation has been split up into separate sections for lattice
basis vectors, atoms and crystal representations, and data grids.
- `KPoint{D}` is now `KPoint{D,T}`, which is an alias for `ShiftVector{ByReciprocalSpace,D,T}`.

### Fixed
- The default definition of `Electrum.DataSpace(x)` is now `DataSpace(typeof(x))`, not
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Electrum.FFTBins

## k-points
```@docs
Electrum.ShiftVector
Electrum.KPoint
Electrum.KPointMesh
Electrum.nkpt
Expand Down
7 changes: 5 additions & 2 deletions src/Electrum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ include("lattices.jl")
export RealBasis, ReciprocalBasis, AbstractBasis
export eachvertex, basis, lengths, volume, angles_cos, angles_rad, angles_deg, gram, isdiag, qr,
triangularize, maxHKLindex
include("vectors.jl")
export ShiftVector, KPoint
export weight
# Methods and structs for working with atomic positions
include("atoms.jl")
export NamedAtom, AbstractAtomPosition, FractionalAtomPosition, CartesianAtomPosition,
Expand All @@ -127,8 +130,8 @@ export AbstractCrystal, Crystal, CrystalWithDatasets
export data, generators, set_transform!
# Weighed k-points and k-point meshes
include("data/kpoints.jl")
export KPoint, KPointMesh
export weight, nkpt
export KPointMesh
export nkpt
# Energy/occupancy pairs
include("data/energies.jl")
export AbstractEnergyData, EnergyOccupancy, StateDensity, EnergiesOccupancies
Expand Down
53 changes: 0 additions & 53 deletions src/data/kpoints.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,3 @@
"""
KPoint{D,T<:Real} <: StaticVector{D,T}

Stores a k-point as reduced reciprocal space coordiantes with an associated weight that corresponds
to the number of symmetry-equivalent k-points, stored as an integer.
"""
struct KPoint{D,T<:Real} <: StaticVector{D,T}
point::SVector{D,T}
weight::Int
KPoint{D,T}(pt::StaticVector, wt::Integer = 1) where {D,T} = new(pt, wt)
end

# Needed to resolve method ambiguities
KPoint{D,T}(::StaticArray, ::Integer = 1) where {D,T} = error("Argument must be a vector.")
KPoint{D}(::StaticArray, ::Integer = 1) where D = error("Argument must be a vector.")
KPoint(::StaticArray, ::Integer = 1) = error("Argument must be a vector.")

KPoint{D}(pt::StaticVector, wt::Integer = 1) where D = KPoint{D,eltype(pt)}(pt, wt)
KPoint(pt::StaticVector, wt::Integer = 1) = KPoint{length(pt),eltype(pt)}(pt, wt)

KPoint{D,T}(pt::AbstractVector, wt::Integer = 1) where {D,T} = KPoint(SVector{D,T}(pt), wt)
KPoint{D}(pt::AbstractVector, wt::Integer = 1) where D = KPoint(SVector{D}(pt), wt)

KPoint(pt::Real...; weight::Integer = 1) = KPoint(SVector(pt), weight)

Base.hash(k::KPoint, h::UInt) = hash(k.point, hash(k.weight, h))
Base.:(==)(k1::KPoint, k2::KPoint) = k1.point == k2.point && k1.weight == k2.weight

Base.IndexStyle(::Type{<:KPoint}) = IndexLinear()
Base.getindex(k::KPoint, i::Int) = k.point[i]

Tuple(k::KPoint) = Tuple(k.point)
# Base.convert(T::Type{<:AbstractVector}, k::KPoint) = convert(T, k.point)

Base.zero(::Type{KPoint{D}}) where D = KPoint(zero(SVector{D,Bool}))
Base.zero(::Type{KPoint{D,T}}) where {D,T} = KPoint(zero(SVector{D,T}))

"""
weight(k::KPoint) -> Int

Returns the weight associated with a k-point.
"""
weight(k::KPoint) = k.weight

# TODO: can we implement a remainder that excludes -0.5?
"""
truncate(k::KPoint) -> KPoint

Moves a k-point so that its values lie within the range [-1/2, 1/2]. The weight is preserved.
"""
Base.truncate(k::KPoint) = KPoint(rem.(k.point, 1, RoundNearest), k.weight)

#---Generated lists of k-points--------------------------------------------------------------------#
"""
KPointMesh{D,T} <: AbstractVector{KPoint{D,T}}

Expand Down
10 changes: 1 addition & 9 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,7 @@ function Base.show(io::IO, ::MIME"text/plain", l::AbstractAtomList; kwargs...)
end
end

#---Types from data/kpoints.jl========-------------------------------------------------------------#

Base.summary(io::IO, k::KPoint) = print(io, typeof(k), " with weight ", k.weight)

function Base.show(io::IO, k::KPoint)
print(io, KPoint, '(')
join(io, k.point, ", ")
print(io, ", weight = ", weight(k), ')')
end
#---Types from data/kpoints.jl---------------------------------------------------------------------#

function Base.summary(io::IO, k::KPointMesh)
print(io, length(k), "-element ", typeof(k), " (total weight ", sum(weight.(k)), ')')
Expand Down
96 changes: 96 additions & 0 deletions src/vectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#---Shift vectors----------------------------------------------------------------------------------#
"""
ShiftVector{S<:BySpace,D,T} <: StaticVector{D,T}

A vector in fractional coordinates representing a shift of a lattice or lattice dataset from the
origin. This wraps a `SVector{D,T}` with an optional weight parameter of type `T` that may be useful
when working with symmetrical structures or k-points in the irreducible Brillouin zone. If it is not
explicitly set, it defaults to 1.

If constructors do not explicitly reference an element type, the element type is automatically
inferred by promoting the types of the arguments.

# Type aliases

`ShiftVector` is a general way of working with vectors which shift a lattice or data within it, and
for this reason we define an alias for representing k-points:

const KPoint = ShiftVector{ByReciprocalSpace}
"""
struct ShiftVector{S<:BySpace,D,T<:Real} <: StaticVector{D,T}
vector::SVector{D,T}
weight::T
function ShiftVector{S,D,T}(vector::StaticVector, weight::Real = oneunit(T)) where {S,D,T}
return new(vector, weight)
end
end

const KPoint = ShiftVector{ByReciprocalSpace}
@doc (@doc ShiftVector) KPoint

# Needed to resolve method ambiguities
ShiftVector{S,D,T}(::StaticArray, ::Real = 1) where {S,D,T} = error("Argument must be a vector.")
ShiftVector{S,D}(::StaticArray, ::Real = 1) where {S,D} = error("Argument must be a vector.")
ShiftVector{S}(::StaticArray, ::Real = 1) where S = error("Argument must be a vector.")

function ShiftVector{S,D}(vector::StaticVector, weight::Real = true) where {S,D}
T = promote_type(eltype(vector), typeof(weight))
return ShiftVector{S,D,T}(vector, weight)
end

function ShiftVector{S}(vector::StaticVector{D}, weight::Real = true) where {S,D}
T = promote_type(eltype(vector), typeof(weight))
return ShiftVector{S,D,T}(vector, weight)
end

function ShiftVector{S,D,T}(vector::AbstractVector, weight::Real = true) where {S,D,T}
return ShiftVector{S,D,T}(SVector{D}(vector), weight)
end

function ShiftVector{S,D}(vector::AbstractVector, weight::Real = true) where {S,D}
T = promote_type(eltype(vector), typeof(weight))
return ShiftVector{S,D,T}(SVector{D}(vector), weight)
end

ShiftVector{S}(coord::Real...; weight::Real = 1) where S = ShiftVector{S}(SVector(coord), weight)

Base.hash(s::ShiftVector, h::UInt) = hash(s.vector, hash(s.weight, h))

function Base.:(==)(u::ShiftVector{S1}, v::ShiftVector{S2}) where {S1,S2}
return (S1 === S2 && u.vector == v.vector && u.weight == v.weight)
end

Base.IndexStyle(::Type{<:ShiftVector}) = IndexLinear()
Base.getindex(s::ShiftVector, i::Int) = s.vector[i]

Base.Tuple(s::ShiftVector) = Tuple(s.vector)

Base.zero(::Type{ShiftVector{S,D}}) where {S,D} = ShiftVector{S}(zero(SVector{D,Bool}))
Base.zero(::Type{ShiftVector{S,D,T}}) where {S,D,T} = ShiftVector{S}(zero(SVector{D,T}))

"""
weight(k::ShiftVector{S,D,T}) -> T

Returns the weight associated with a `ShiftVector`.
"""
weight(s::ShiftVector) = s.weight

# TODO: can we implement a remainder that excludes -0.5?
"""
truncate(s::ShiftVector) -> ShiftVector

Moves a `ShiftVector` so that its values lie within the range [-1/2, 1/2]. The weight is preserved.
"""
Base.truncate(s::ShiftVector) = (typeof(s))(rem.(s.vector, 1, RoundNearest), s.weight)

DataSpace(::Type{<:ShiftVector{S,D}}) where {S,D} = S{D}()
ByCoordinate(::Type{<:ShiftVector{S,D}}) where {S,D} = ByFractionalCoordinate{D}()

Base.summary(io::IO, s::ShiftVector) = print(io, typeof(s), " with weight ", s.weight)
Base.show(io::IO, s::ShiftVector) = print(io, typeof(s), '(', s.vector, ", ", s.weight, ')')

Check warning on line 90 in src/vectors.jl

View check run for this annotation

Codecov / codecov/patch

src/vectors.jl#L89-L90

Added lines #L89 - L90 were not covered by tests

function Base.show(io::IO, k::KPoint)
print(io, KPoint, '(')
join(io, k.vector, ", ")
print(io, ", weight = ", weight(k), ')')

Check warning on line 95 in src/vectors.jl

View check run for this annotation

Codecov / codecov/patch

src/vectors.jl#L92-L95

Added lines #L92 - L95 were not covered by tests
end
26 changes: 24 additions & 2 deletions test/kpoints.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
@testset "k-points" begin
@testset "Shift vectors" begin
# Equality and hashing
@test KPoint(0, 0, 0, weight = 1) == KPoint(0, 0, 0, weight = 1)
@test KPoint(0, 0, 0, weight = 1) != KPoint(0, 0, 0, weight = 2)
@test hash(KPoint(0, 0, 0, weight = 1)) == hash(KPoint(0, 0, 0, weight = 1))
@test hash(KPoint(0, 0, 0, weight = 1)) != hash(KPoint(0, 0, 0, weight = 2))
@test zero(ShiftVector{Electrum.ByRealSpace,3}) != zero(KPoint{3})
@test zero(KPoint{3,Float32}) == zero(KPoint{3,Int})
# Constructors and zero k-point
@test zero(KPoint{3}) === zero(KPoint{3,Bool})
@test zero(KPoint{3}) === KPoint(false, false, false; weight = true)
@test zero(KPoint{3,Int}) === KPoint(0, 0, 0, weight = 1)
@test zero(KPoint{3,Int}) === KPoint(SVector{3}(0, 0, 0), 1)
@test zero(KPoint{3,Float64}) === KPoint{3}(SVector{3}(0, 0, 0), 1.0)
@test zero(KPoint{3,Float32}) === KPoint{3,Float32}(SVector{3}(0, 0, 0), 1.0)
@test zero(KPoint{3,Float64}) === KPoint{3}([0, 0, 0], 1.0)
@test zero(KPoint{3,Float32}) === KPoint{3,Float32}([0, 0, 0], 1.0)
@test_throws Exception KPoint(SMatrix{3,1}([0, 0, 0]))
@test_throws Exception KPoint{3}(SMatrix{3,1}([0, 0, 0]))
@test_throws Exception KPoint{3,Int}(SMatrix{3,1}([0, 0, 0]))
# Traits
@test Electrum.DataSpace(zero(KPoint{3})) === Electrum.ByReciprocalSpace{3}()
@test Electrum.ByCoordinate(zero(KPoint{3})) === Electrum.ByFractionalCoordinate{3}()
# Truncation
# TODO: see note for trunc() in Electrum.jl/src/kpoints.jl
# TODO: see note for trunc() in src/vectors.jl
@test truncate(KPoint(1, 2, 3)) == KPoint(0, 0, 0)
@test truncate(KPoint(0.5, -0.5, 1.5)) == KPoint(0.5, -0.5, -0.5)
@test truncate(KPoint(0.5, -0.5 + eps(Float64), 1.5)) == KPoint(0.5, -0.5 + eps(Float64), -0.5)
@test truncate(KPoint(0.5, -0.5 - eps(Float64), 1.5)) == KPoint(0.5, 0.5 - eps(Float64), -0.5)
# Length measurement
@test length(KPoint(1, 2, 3, 4, 5)) == 5
@test length(KPoint{2}) == 2
@test size(KPoint(1, 2, 3, 4, 5)) == (5,)
@test size(KPoint{2}) == (2,)
@test convert(Vector, KPoint(0.1, 0.2, 0.3)) == [0.1, 0.2, 0.3]
@test convert(Vector, KPoint(0.1, 0.2, 0.3)) isa Vector{<:Real}
@test convert(SVector, KPoint(0.1, 0.2, 0.3)) === SVector{3}(0.1, 0.2, 0.3)
Expand Down
Loading