Skip to content

Commit

Permalink
fix homogenization for learn&apply and format
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha Demin committed Sep 17, 2024
1 parent 7a1a776 commit 426515b
Show file tree
Hide file tree
Showing 55 changed files with 436 additions and 1,677 deletions.
2 changes: 1 addition & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ align_pair_arrow = true
always_for_in = true
format_markdown = false
indent = 4
margin = 92
margin = 100
normalize_line_endings = "unix"
remove_extra_newlines = true
short_to_long_function_def = false
Expand Down
10 changes: 1 addition & 9 deletions src/Groebner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,15 @@ const _threaded = Ref(true)

function __init__()
_threaded[] = !(get(ENV, "GROEBNER_NO_THREADED", "") == "1")

# Setup the global logger
_groebner_log_lock[] = ReentrantLock()
logger_update(loglevel=Logging.Info)

nothing
end

###
# Includes

include("utils/logging.jl")
include("utils/invariants.jl")
include("utils/simd.jl")
include("utils/packed.jl")
include("utils/plots.jl")

# Test systems, such as katsura, cyclic, etc
include("utils/examples.jl")
Expand Down Expand Up @@ -176,8 +169,7 @@ export groebner_with_change_matrix
export isgroebner
export normalform

export Lex,
DegLex, DegRevLex, InputOrdering, WeightedOrdering, ProductOrdering, MatrixOrdering
export Lex, DegLex, DegRevLex, InputOrdering, WeightedOrdering, ProductOrdering, MatrixOrdering

@doc read(joinpath(dirname(@__DIR__), "README.md"), String) Groebner

Expand Down
3 changes: 1 addition & 2 deletions src/arithmetic/CompositeNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ Base.isone(ci::CompositeNumber) = all(isone, ci.data)

Base.zero(::Type{CompositeNumber{N, T}}) where {N, T} =
CompositeNumber(ntuple(_ -> zero(T), Val(N)))
Base.one(::Type{CompositeNumber{N, T}}) where {N, T} =
CompositeNumber(ntuple(_ -> one(T), N))
Base.one(::Type{CompositeNumber{N, T}}) where {N, T} = CompositeNumber(ntuple(_ -> one(T), N))

Base.inv(a::CompositeNumber{N, T}) where {N, T} = CompositeNumber(inv.(a.data))

Expand Down
3 changes: 1 addition & 2 deletions src/arithmetic/QQ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# Arithmetic in the rationals.

# All implementations of arithmetic in the rationals are a subtype of this
abstract type AbstractArithmeticQQ{AccumType, CoeffType} <:
AbstractArithmetic{AccumType, CoeffType} end
abstract type AbstractArithmeticQQ{AccumType, CoeffType} <: AbstractArithmetic{AccumType, CoeffType} end

# Standard arithmetic that uses Base.GMP.MPQ
struct ArithmeticQQ{AccumType, CoeffType} <: AbstractArithmeticQQ{AccumType, CoeffType}
Expand Down
38 changes: 9 additions & 29 deletions src/arithmetic/Zp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ abstract type AbstractArithmetic{AccumType, CoeffType} end
# - divisor(arithmetic) : returns the prime number p
# - mod_p(x, arithmetic) : returns x modulo p
# - inv_mod_p(x, arithmetic) : returns x^(-1) modulo p
abstract type AbstractArithmeticZp{AccumType, CoeffType} <:
AbstractArithmetic{AccumType, CoeffType} end
abstract type AbstractArithmeticZp{AccumType, CoeffType} <: AbstractArithmetic{AccumType, CoeffType} end

###
# ArithmeticZp
Expand Down Expand Up @@ -132,8 +131,7 @@ inv_mod_p(a::T, arithm::SpecializedArithmeticZp{T}) where {T} = invmod(a, diviso

# Exploits the case when the representation of the prime has spare leading bits
# and thus delays reduction modulo a prime
struct DelayedArithmeticZp{AccumType, CoeffType, Add} <:
AbstractArithmeticZp{AccumType, CoeffType}
struct DelayedArithmeticZp{AccumType, CoeffType, Add} <: AbstractArithmeticZp{AccumType, CoeffType}
multiplier::AccumType
shift::UInt8
divisor::AccumType
Expand Down Expand Up @@ -192,8 +190,7 @@ inv_mod_p(a::T, arithm::DelayedArithmeticZp{T}) where {T} = invmod(a, divisor(ar
# CompositeArithmeticZp

# Can operate on several integers at once
struct CompositeArithmeticZp{AccumType, CoeffType, TT} <:
AbstractArithmeticZp{AccumType, CoeffType}
struct CompositeArithmeticZp{AccumType, CoeffType, TT} <: AbstractArithmeticZp{AccumType, CoeffType}
arithmetics::TT

function CompositeArithmeticZp(
Expand All @@ -202,11 +199,7 @@ struct CompositeArithmeticZp{AccumType, CoeffType, TT} <:
p::CompositeNumber{N, CoeffType}
) where {N, AccumType <: CoeffZp, CoeffType <: CoeffZp}
ai = ntuple(i -> SpecializedArithmeticZp(AccumType, CoeffType, p.data[i]), N)
new{
CompositeNumber{N, AccumType},
CompositeNumber{N, CoeffType},
Tuple{map(typeof, ai)...}
}(
new{CompositeNumber{N, AccumType}, CompositeNumber{N, CoeffType}, Tuple{map(typeof, ai)...}}(
ai
)
end
Expand All @@ -227,8 +220,7 @@ end
# SignedArithmeticZp

# Uses signed types and exploits the trick with adding p^2 to negative values
struct SignedArithmeticZp{AccumType, CoeffType} <:
AbstractArithmeticZp{AccumType, CoeffType}
struct SignedArithmeticZp{AccumType, CoeffType} <: AbstractArithmeticZp{AccumType, CoeffType}
p::AccumType
p2::AccumType # = p*p

Expand Down Expand Up @@ -320,8 +312,7 @@ end
###
# FloatingPointArithmeticZp

struct FloatingPointArithmeticZp{AccumType, CoeffType} <:
AbstractArithmeticZp{AccumType, CoeffType}
struct FloatingPointArithmeticZp{AccumType, CoeffType} <: AbstractArithmeticZp{AccumType, CoeffType}
multiplier::AccumType
divisor::AccumType

Expand All @@ -346,12 +337,7 @@ divisor(arithm::FloatingPointArithmeticZp) = arithm.divisor
a - mod.divisor * c # may be fused
end

@inline function fma_mod_p(
a1::T,
a2::C,
a3::T,
mod::FloatingPointArithmeticZp{T, C}
) where {T, C}
@inline function fma_mod_p(a1::T, a2::C, a3::T, mod::FloatingPointArithmeticZp{T, C}) where {T, C}
b = muladd(a1, a2, a3)
mod_p(b, mod)
end
Expand Down Expand Up @@ -431,11 +417,7 @@ function select_arithmetic(

if CoeffType <: CompositeCoeffZp
if hint === :signed || CoeffType <: CompositeNumber{N, T} where {N, T <: Signed}
return SignedCompositeArithmeticZp(
AccumType,
CoeffType,
CoeffType(characteristic)
)
return SignedCompositeArithmeticZp(AccumType, CoeffType, CoeffType(characteristic))
elseif hint === :floating
return FloatingPointCompositeArithmeticZp(
AccumType,
Expand All @@ -457,7 +439,6 @@ function select_arithmetic(

if hint === :delayed
if iszero(leading_zeros(characteristic) - (8 >> 1) * sizeof(AccumType))
@log :warn "Cannot use $hint arithmetic with characteristic $characteristic"
@assert false
end
return DelayedArithmeticZp(AccumType, CoeffType, CoeffType(characteristic))
Expand All @@ -478,8 +459,7 @@ function select_arithmetic(
# looks like it rarely pays off to use delayed modular arithmetic in
# such cases
if !(AccumType === UInt128)
if !using_wide_type_for_coeffs &&
leading_zeros(CoeffType(characteristic)) > 4 ||
if !using_wide_type_for_coeffs && leading_zeros(CoeffType(characteristic)) > 4 ||
using_wide_type_for_coeffs &&
((8 * sizeof(CoeffType)) >> 1) -
(8 * sizeof(CoeffType) - leading_zeros(CoeffType(characteristic))) > 4
Expand Down
31 changes: 7 additions & 24 deletions src/f4/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ end
function basis_well_formed(ring::PolyRing, basis::Basis, hashtable::MonomialHashtable)
(isempty(basis.monoms) || isempty(basis.coeffs)) && error("Basis cannot be empty")
!(basis.n_filled >= basis.n_processed) && error("Basis cannot be empty")
!is_sorted_by_lead_increasing(basis, hashtable) &&
error("Basis elements must be sorted")
!is_sorted_by_lead_increasing(basis, hashtable) && error("Basis elements must be sorted")
for i in basis.n_filled
isempty(basis.monoms[i]) && error("Zero polynomials are not allowed.")
!(length(basis.monoms[i]) == length(basis.coeffs[i])) && error("Beda!")
Expand Down Expand Up @@ -239,8 +238,7 @@ function basis_changematrix_addmul!(
if !haskey(poly, new_monom_id)
poly[new_monom_id] = zero(CoeffType)
end
poly[new_monom_id] =
mod_p(poly[new_monom_id] + ref_cf * AccumType(cf), arithmetic)
poly[new_monom_id] = mod_p(poly[new_monom_id] + ref_cf * AccumType(cf), arithmetic)
end
end
nothing
Expand Down Expand Up @@ -312,10 +310,7 @@ function basis_shallow_copy_with_new_coeffs(
basis.is_redundant,
basis.nonredundant_indices,
basis.divmasks,
basis_changematrix_shallow_copy_with_new_type(
basis.changematrix,
new_sparse_row_coeffs
)
basis_changematrix_shallow_copy_with_new_type(basis.changematrix, new_sparse_row_coeffs)
)
end

Expand All @@ -341,10 +336,7 @@ function basis_deep_copy_with_new_coeffs(
copy(basis.is_redundant),
copy(basis.nonredundant_indices),
copy(basis.divmasks),
basis_changematrix_deep_copy_with_new_type(
basis.changematrix,
new_sparse_row_coeffs
)
basis_changematrix_deep_copy_with_new_type(basis.changematrix, new_sparse_row_coeffs)
)
end

Expand Down Expand Up @@ -412,10 +404,7 @@ end

function basis_make_monic!(
basis::Basis{C},
arithmetic::Union{
FloatingPointCompositeArithmeticZp{A, C},
FloatingPointArithmeticZp{A, C}
},
arithmetic::Union{FloatingPointCompositeArithmeticZp{A, C}, FloatingPointArithmeticZp{A, C}},
changematrix::Bool
) where {A <: Union{CoeffZp, CompositeCoeffZp}, C <: Union{CoeffZp, CompositeCoeffZp}}
cfs = basis.coeffs
Expand Down Expand Up @@ -700,10 +689,7 @@ function basis_standardize!(
perm
end

function basis_get_monoms_by_identifiers(
basis::Basis,
ht::MonomialHashtable{M}
) where {M <: Monom}
function basis_get_monoms_by_identifiers(basis::Basis, ht::MonomialHashtable{M}) where {M <: Monom}
monoms = Vector{Vector{M}}(undef, basis.n_nonredundant)
@inbounds for i in 1:(basis.n_nonredundant)
idx = basis.nonredundant_indices[i]
Expand All @@ -725,10 +711,7 @@ function basis_export_coeffs(basis::Basis{C}) where {C <: Coeff}
coeffs
end

function basis_export_data(
basis::Basis{C},
ht::MonomialHashtable{M}
) where {M <: Monom, C <: Coeff}
function basis_export_data(basis::Basis{C}, ht::MonomialHashtable{M}) where {M <: Monom, C <: Coeff}
exps = basis_get_monoms_by_identifiers(basis, ht)
coeffs = basis_export_coeffs(basis)
exps, coeffs
Expand Down
9 changes: 1 addition & 8 deletions src/f4/f4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,7 @@ function f4_find_multiplied_reducer!(
@invariant quotient_hash == monom_hash(quotient, ht.hash_vector)

hashtable_resize_if_needed!(ht, length(poly))
row = matrix_polynomial_multiple_to_row!(
matrix,
symbol_ht,
ht,
quotient_hash,
quotient,
poly
)
row = matrix_polynomial_multiple_to_row!(matrix, symbol_ht, ht, quotient_hash, quotient, poly)
matrix.nrows_filled_upper += 1
row_id = matrix.nrows_filled_upper
@inbounds matrix.upper_rows[row_id] = row
Expand Down
3 changes: 1 addition & 2 deletions src/f4/hashtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ end
# Initialization

hashtable_resize_threshold() = 0.4
hashtable_needs_resize(size, load, added) =
(load + added) / size > hashtable_resize_threshold()
hashtable_needs_resize(size, load, added) = (load + added) / size > hashtable_resize_threshold()

function hashtable_initialize(
ring::PolyRing{Ord},
Expand Down
40 changes: 10 additions & 30 deletions src/f4/learn_apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ function f4_initialize_structs_with_trace(
sort_input=sort_input
)

trace =
trace_initialize(ring, basis_deepcopy(basis), basis, hashtable, permutation, params)
trace = trace_initialize(ring, basis_deepcopy(basis), basis, hashtable, permutation, params)

trace, basis, pairset, hashtable, permutation
end
Expand Down Expand Up @@ -65,8 +64,7 @@ function f4_reduction_learn!(
params;
batched_ht_insert=true
)
pivot_indices =
map(i -> Int32(basis.monoms[basis.n_processed + i][1]), 1:(matrix.npivots))
pivot_indices = map(i -> Int32(basis.monoms[basis.n_processed + i][1]), 1:(matrix.npivots))
push!(trace.matrix_pivot_indices, pivot_indices)
matrix_pivot_signature =
matrix_compute_pivot_signature(basis.monoms, basis.n_processed + 1, matrix.npivots)
Expand Down Expand Up @@ -105,8 +103,7 @@ function f4_reducegb_learn!(
matrix.upper_to_mult[row_idx] = hashtable_insert!(ht, etmp)
symbol_ht.labels[uprows[row_idx][1]] = UNKNOWN_PIVOT_COLUMN
end
trace.nonredundant_indices_before_reduce =
basis.nonredundant_indices[1:(basis.n_nonredundant)]
trace.nonredundant_indices_before_reduce = basis.nonredundant_indices[1:(basis.n_nonredundant)]

# needed for correct column count in symbol hashtable
matrix.ncols_left = matrix.nrows_filled_upper
Expand Down Expand Up @@ -285,35 +282,21 @@ function f4_reduction_apply!(
end

flag = linalg_main!(matrix, basis, params, trace, linalg=LinearAlgebra(:apply, :sparse))
if !flag
@log :info "In apply, some of the matrix rows unexpectedly reduced to zero."
return false, false
end
!flag && return (false, false)

matrix_convert_rows_to_basis_elements!(matrix, basis, ht, symbol_ht, params)

# Check that the leading terms were not reduced to zero accidentally
pivot_indices = trace.matrix_pivot_indices[f4_iteration]
@inbounds for i in 1:(matrix.npivots)
sgn = basis.monoms[basis.n_processed + i][1]
if sgn != pivot_indices[i]
@log :info "In apply, some leading terms cancelled out!"
return false, false
end
sgn != pivot_indices[i] && return (false, false)
end

if cache_column_order
matrix_pivot_signature = matrix_compute_pivot_signature(
basis.monoms,
basis.n_processed + 1,
matrix.npivots
)
matrix_pivot_signature =
matrix_compute_pivot_signature(basis.monoms, basis.n_processed + 1, matrix.npivots)
if matrix_pivot_signature != trace.matrix_pivot_signatures[f4_iteration]
@log :info """
In apply, on iteration $(f4_iteration) of F4, some terms cancelled out.
hash (expected): $(trace.matrix_pivot_signatures[f4_iteration])
hash (got): $(matrix_pivot_signature)"""
# return false, false
return false, false
end
end
Expand Down Expand Up @@ -461,10 +444,8 @@ function f4_autoreduce_apply!(
end

flag = linalg_autoreduce!(matrix, basis, params, linalg=LinearAlgebra(:apply, :sparse))
if !flag
@log :info "In apply, the final autoreduction of the basis failed"
return false
end
!flag && return false

matrix_convert_rows_to_basis_elements!(matrix, basis, hashtable, symbol_ht, params)

basis.n_filled = matrix.npivots + basis.n_processed
Expand All @@ -473,8 +454,7 @@ function f4_autoreduce_apply!(
output_nonredundant = trace.output_nonredundant_indices
for i in 1:length(output_nonredundant)
basis.nonredundant_indices[i] = output_nonredundant[i]
basis.divmasks[i] =
hashtable.divmasks[basis.monoms[basis.nonredundant_indices[i]][1]]
basis.divmasks[i] = hashtable.divmasks[basis.monoms[basis.nonredundant_indices[i]][1]]
end
basis.n_nonredundant = length(output_nonredundant)
true
Expand Down
Loading

0 comments on commit 426515b

Please sign in to comment.