Skip to content

Commit

Permalink
gh-36830: improved integer vectors efficiency -Enhancement
Browse files Browse the repository at this point in the history
    
Fixes #36816

1)  **Added** the **cardinality function** using stars and bars to the
**IntegerVector_nk** class
2) Fixed the errors in IntegerVectors_n and IntegerVectors_k now
**IntegerVectors_n(0)cardinality() and IntegerVectors_k(0).cardinality()
returns 1 and both of them are Finite EnumeratedSets**
    
URL: #36830
Reported by: Aman Moon
Reviewer(s): Aman Moon, Jukka Kohonen, Martin Rubey
  • Loading branch information
Release Manager committed Dec 17, 2023
2 parents 5d0093c + dedee8d commit 44a0d2d
Showing 1 changed file with 232 additions and 26 deletions.
258 changes: 232 additions & 26 deletions src/sage/combinat/integer_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from sage.rings.integer import Integer


def is_gale_ryser(r,s):
def is_gale_ryser(r, s):
r"""
Tests whether the given sequences satisfy the condition
of the Gale-Ryser theorem.
Expand Down Expand Up @@ -314,20 +314,20 @@ def gale_ryser_theorem(p1, p2, algorithm="gale",
"""
from sage.matrix.constructor import matrix

if not is_gale_ryser(p1,p2):
if not is_gale_ryser(p1, p2):
return False

if algorithm == "ryser": # ryser's algorithm
if algorithm == "ryser": # ryser's algorithm
from sage.combinat.permutation import Permutation

# Sorts the sequences if they are not, and remembers the permutation
# applied
tmp = sorted(enumerate(p1), reverse=True, key=lambda x:x[1])
tmp = sorted(enumerate(p1), reverse=True, key=lambda x: x[1])
r = [x[1] for x in tmp]
r_permutation = [x-1 for x in Permutation([x[0]+1 for x in tmp]).inverse()]
m = len(r)

tmp = sorted(enumerate(p2), reverse=True, key=lambda x:x[1])
tmp = sorted(enumerate(p2), reverse=True, key=lambda x: x[1])
s = [x[1] for x in tmp]
s_permutation = [x-1 for x in Permutation([x[0]+1 for x in tmp]).inverse()]

Expand All @@ -340,12 +340,12 @@ def gale_ryser_theorem(p1, p2, algorithm="gale",
k = i + 1
while k < m and r[i] == r[k]:
k += 1
if t >= k - i: # == number rows of the same length
if t >= k - i: # == number rows of the same length
for j in range(i, k):
r[j] -= 1
c[j] = 1
t -= k - i
else: # Remove the t last rows of that length
else: # Remove the t last rows of that length
for j in range(k-t, k):
r[j] -= 1
c[j] = 1
Expand All @@ -366,17 +366,17 @@ def gale_ryser_theorem(p1, p2, algorithm="gale",
k1, k2 = len(p1), len(p2)
p = MixedIntegerLinearProgram(solver=solver)
b = p.new_variable(binary=True)
for (i,c) in enumerate(p1):
p.add_constraint(p.sum([b[i,j] for j in range(k2)]) == c)
for (i,c) in enumerate(p2):
p.add_constraint(p.sum([b[j,i] for j in range(k1)]) == c)
for (i, c) in enumerate(p1):
p.add_constraint(p.sum([b[i, j] for j in range(k2)]) == c)
for (i, c) in enumerate(p2):
p.add_constraint(p.sum([b[j, i] for j in range(k1)]) == c)
p.set_objective(None)
p.solve()
b = p.get_values(b, convert=ZZ, tolerance=integrality_tolerance)
M = [[0]*k2 for i in range(k1)]
for i in range(k1):
for j in range(k2):
M[i][j] = b[i,j]
M[i][j] = b[i, j]
return matrix(M)

else:
Expand Down Expand Up @@ -780,6 +780,43 @@ def __contains__(self, x):
return False
return True

def _unrank_helper(self, x, rtn):
"""
Return the element at rank ``x`` by iterating through all integer vectors beginning with ``rtn``.
INPUT:
- ``x`` - a nonnegative integer
- ``rtn`` - a list of nonnegative integers
EXAMPLES::
sage: IV = IntegerVectors(k=5)
sage: IV._unrank_helper(10, [2,0,0,0,0])
[1, 0, 0, 0, 1]
sage: IV = IntegerVectors(n=7)
sage: IV._unrank_helper(100, [7,0,0,0])
[2, 0, 0, 5]
sage: IV = IntegerVectors(n=12, k=7)
sage: IV._unrank_helper(1000, [12,0,0,0,0,0,0])
[5, 3, 1, 1, 1, 1, 0]
"""
ptr = 0
while True:
current_rank = self.rank(rtn)
if current_rank < x:
rtn[ptr+1] = rtn[ptr]
rtn[ptr] = 0
ptr += 1
elif current_rank > x:
rtn[ptr] -= 1
rtn[ptr-1] += 1
else:
return self._element_constructor_(rtn)


class IntegerVectors_all(UniqueRepresentation, IntegerVectors):
"""
Expand Down Expand Up @@ -839,7 +876,10 @@ def __init__(self, n):
sage: TestSuite(IV).run()
"""
self.n = n
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())
if self.n == 0:
IntegerVectors.__init__(self, category=EnumeratedSets())
else:
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())

def _repr_(self):
"""
Expand Down Expand Up @@ -898,6 +938,68 @@ def __contains__(self, x):
return False
return sum(x) == self.n

def rank(self, x):
"""
Return the rank of a given element.
INPUT:
- ``x`` -- a list with ``sum(x) == n``
EXAMPLES::
sage: IntegerVectors(n=5).rank([5,0])
1
sage: IntegerVectors(n=5).rank([3,2])
3
"""
if sum(x) != self.n:
raise ValueError("argument is not a member of IntegerVectors({},{})".format(self.n, None))

n, k, s = self.n, len(x), 0
r = binomial(k + n - 1, n + 1)
for i in range(k - 1):
s += x[k - 1 - i]
r += binomial(s + i, i + 1)
return r

def unrank(self, x):
"""
Return the element at given rank x.
INPUT:
- ``x`` -- an integer.
EXAMPLES::
sage: IntegerVectors(n=5).unrank(2)
[4, 1]
sage: IntegerVectors(n=10).unrank(10)
[1, 9]
"""
rtn = [self.n]
while self.rank(rtn) <= x:
rtn.append(0)
rtn.pop()

return IntegerVectors._unrank_helper(self, x, rtn)

def cardinality(self):
"""
Return the cardinality of ``self``.
EXAMPLES::
sage: IntegerVectors(n=0).cardinality()
1
sage: IntegerVectors(n=10).cardinality()
+Infinity
"""
if self.n == 0:
return Integer(1)
return PlusInfinity()


class IntegerVectors_k(UniqueRepresentation, IntegerVectors):
"""
Expand All @@ -912,7 +1014,10 @@ def __init__(self, k):
sage: TestSuite(IV).run()
"""
self.k = k
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())
if self.k == 0:
IntegerVectors.__init__(self, category=EnumeratedSets())
else:
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())

def _repr_(self):
"""
Expand Down Expand Up @@ -968,6 +1073,75 @@ def __contains__(self, x):
return False
return len(x) == self.k

def rank(self, x):
"""
Return the rank of a given element.
INPUT:
- ``x`` -- a list with ``len(x) == k``
EXAMPLES::
sage: IntegerVectors(k=5).rank([0,0,0,0,0])
0
sage: IntegerVectors(k=5).rank([1,1,0,0,0])
7
"""
if len(x) != self.k:
raise ValueError("argument is not a member of IntegerVectors({},{})".format(None, self.k))

n, k, s = sum(x), self.k, 0
r = binomial(n + k - 1, k)
for i in range(k - 1):
s += x[k - 1 - i]
r += binomial(s + i, i + 1)
return r

def unrank(self, x):
"""
Return the element at given rank x.
INPUT:
- ``x`` -- an integer such that x < self.cardinality()``
EXAMPLES::
sage: IntegerVectors(k=5).unrank(10)
[1, 0, 0, 0, 1]
sage: IntegerVectors(k=5).unrank(15)
[0, 0, 2, 0, 0]
sage: IntegerVectors(k=0).unrank(0)
[]
"""
if self.k == 0 and x != 0:
raise IndexError(f"Index {x} is out of range for the IntegerVector.")
rtn = [0]*self.k
if self.k == 0 and x == 0:
return rtn

while self.rank(rtn) <= x:
rtn[0] += 1
rtn[0] -= 1

return IntegerVectors._unrank_helper(self, x, rtn)

def cardinality(self):
"""
Return the cardinality of ``self``.
EXAMPLES::
sage: IntegerVectors(k=0).cardinality()
1
sage: IntegerVectors(k=10).cardinality()
+Infinity
"""
if self.k == 0:
return Integer(1)
return PlusInfinity()


class IntegerVectors_nk(UniqueRepresentation, IntegerVectors):
"""
Expand Down Expand Up @@ -1010,11 +1184,11 @@ def _list_rec(self, n, k):
res = []

if k == 1:
return [ (n, ) ]
return [(n, )]

for nbar in range(n + 1):
n_diff = n - nbar
for rest in self._list_rec( nbar , k - 1):
for rest in self._list_rec(nbar, k - 1):
res.append((n_diff,) + rest)
return res

Expand Down Expand Up @@ -1153,17 +1327,49 @@ def rank(self, x):
if x not in self:
raise ValueError("argument is not a member of IntegerVectors({},{})".format(self.n, self.k))

n = self.n
k = self.k

r = 0
k, s, r = self.k, 0, 0
for i in range(k - 1):
k -= 1
n -= x[i]
r += binomial(k + n - 1, k)

s += x[k - 1 - i]
r += binomial(s + i, i + 1)
return r

def unrank(self, x):
"""
Return the element at given rank x.
INPUT:
- ``x`` -- an integer such that ``x < self.cardinality()``
EXAMPLES::
sage: IntegerVectors(4,5).unrank(30)
[1, 0, 1, 0, 2]
sage: IntegerVectors(2,3).unrank(5)
[0, 0, 2]
"""
if x >= self.cardinality():
raise IndexError(f"Index {x} is out of range for the IntegerVector.")
rtn = [0]*self.k
rtn[0] = self.n
return IntegerVectors._unrank_helper(self, x, rtn)

def cardinality(self):
"""
Return the cardinality of ``self``.
EXAMPLES::
sage: IntegerVectors(3,5).cardinality()
35
sage: IntegerVectors(99, 3).cardinality()
5050
sage: IntegerVectors(10^9 - 1, 3).cardinality()
500000000500000000
"""
n, k = self.n, self.k
return Integer(binomial(n + k - 1, n))


class IntegerVectors_nnondescents(UniqueRepresentation, IntegerVectors):
r"""
Expand Down Expand Up @@ -1320,11 +1526,11 @@ def __init__(self, n=None, k=None, **constraints):
category = FiniteEnumeratedSets()
else:
category = EnumeratedSets()
elif k is not None and 'max_part' in constraints: # n is None
elif k is not None and 'max_part' in constraints: # n is None
category = FiniteEnumeratedSets()
else:
category = EnumeratedSets()
IntegerVectors.__init__(self, category=category) # placeholder category
IntegerVectors.__init__(self, category=category) # placeholder category

def _repr_(self):
"""
Expand Down

0 comments on commit 44a0d2d

Please sign in to comment.