Skip to content

Commit

Permalink
Merge pull request #228 from Jake-Moss/main
Browse files Browse the repository at this point in the history
Add some mpoly context util functions
  • Loading branch information
oscarbenjamin authored Sep 22, 2024
2 parents 9fb7f2a + d1a1438 commit d76d40b
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 194 deletions.
2 changes: 2 additions & 0 deletions src/flint/flint_base/flint_base.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flint.flintlib.types.mpoly cimport ordering_t
from flint.flintlib.types.flint cimport slong

cdef class flint_ctx:
pass
Expand Down Expand Up @@ -53,6 +54,7 @@ cdef class flint_mpoly(flint_elem):
cdef _isub_mpoly_(self, other)
cdef _imul_mpoly_(self, other)

cdef _compose_gens_(self, ctx, slong *mapping)

cdef class flint_mat(flint_elem):
pass
Expand Down
165 changes: 156 additions & 9 deletions src/flint/flint_base/flint_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from flint.flintlib.types.flint cimport (
FLINT_BITS as _FLINT_BITS,
FLINT_VERSION as _FLINT_VERSION,
__FLINT_RELEASE as _FLINT_RELEASE,
slong,
)
from flint.utils.flint_exceptions import DomainError
from flint.flintlib.types.mpoly cimport ordering_t
Expand Down Expand Up @@ -344,13 +345,20 @@ cdef class flint_mpoly_context(flint_elem):
return tuple(self.gen(i) for i in range(self.nvars()))

def variable_to_index(self, var: Union[int, str]) -> int:
"""Convert a variable name string or possible index to its index in the context."""
"""
Convert a variable name string or possible index to its index in the context.

If ``var`` is negative, return the index of the ``self.nvars() + var``
"""
if isinstance(var, str):
try:
i = self.names().index(var)
except ValueError:
raise ValueError("variable not in context")
elif isinstance(var, int):
if var < 0:
var = self.nvars() + var

if not 0 <= var < self.nvars():
raise IndexError("generator index out of range")
i = var
Expand Down Expand Up @@ -379,7 +387,7 @@ cdef class flint_mpoly_context(flint_elem):
names = (names,)

for name in names:
if isinstance(name, str):
if isinstance(name, (str, bytes)):
res.append(name)
else:
base, num = name
Expand Down Expand Up @@ -415,10 +423,14 @@ cdef class flint_mpoly_context(flint_elem):
return ctx

@classmethod
def from_context(cls, ctx: flint_mpoly_context):
def from_context(cls, ctx: flint_mpoly_context, names=None, ordering=None):
"""
Get a new context from an existing one. Optionally override ``names`` or
``ordering``.
"""
return cls.get(
ordering=ctx.ordering(),
names=ctx.names(),
names=ctx.names() if names is None else names,
ordering=ctx.ordering() if ordering is None else ordering,
)

def _any_as_scalar(self, other):
Expand Down Expand Up @@ -451,6 +463,62 @@ cdef class flint_mpoly_context(flint_elem):
exp_vec = (0,) * self.nvars()
return self.from_dict({tuple(exp_vec): coeff})

def drop_gens(self, gens: Iterable[str | int]):
"""
Get a context with the specified generators removed.
>>> from flint import fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get(('x', 'y', 'z', 'a', 'b'))
>>> ctx.drop_gens(('x', -2))
fmpz_mpoly_ctx(3, '<Ordering.lex: 'lex'>', ('y', 'z', 'b'))
"""
nvars = self.nvars()
gen_idxs = set(self.variable_to_index(i) for i in gens)

if len(gens) > nvars:
raise ValueError(f"expected at most {nvars} unique generators, got {len(gens)}")

names = self.names()
remaining_gens = []
for i in range(nvars):
if i not in gen_idxs:
remaining_gens.append(names[i])

return self.from_context(self, names=remaining_gens)

def append_gens(self, *gens: str):
"""
Get a context with the specified generators appended.
>>> from flint import fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get(('x', 'y', 'z'))
>>> ctx.append_gens('a', 'b')
fmpz_mpoly_ctx(5, '<Ordering.lex: 'lex'>', ('x', 'y', 'z', 'a', 'b'))
"""
return self.from_context(self, names=self.names() + gens)

def infer_generator_mapping(self, ctx: flint_mpoly_context):
"""
Infer a mapping of generator indexes from this contexts generators, to the
provided contexts generators. Inference is done based upon generator names.
>>> from flint import fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get(('x', 'y', 'z', 'a', 'b'))
>>> ctx2 = fmpz_mpoly_ctx.get(('b', 'a'))
>>> mapping = ctx.infer_generator_mapping(ctx2)
>>> mapping # doctest: +SKIP
{3: 1, 4: 0}
>>> list(sorted(mapping.items())) # Set ordering is not stable
[(3, 1), (4, 0)]
"""
gens_to_idxs = {x: i for i, x in enumerate(self.names())}
other_gens_to_idxs = {x: i for i, x in enumerate(ctx.names())}
return {
gens_to_idxs[k]: other_gens_to_idxs[k]
for k in (gens_to_idxs.keys() & other_gens_to_idxs.keys())
}


cdef class flint_mod_mpoly_context(flint_mpoly_context):
@classmethod
def _new_(_, flint_mod_mpoly_context self, names, prime_modulus):
Expand All @@ -472,11 +540,15 @@ cdef class flint_mod_mpoly_context(flint_mpoly_context):
return *super().create_context_key(names, ordering), modulus

@classmethod
def from_context(cls, ctx: flint_mod_mpoly_context):
def from_context(cls, ctx: flint_mod_mpoly_context, names=None, ordering=None, modulus=None):
"""
Get a new context from an existing one. Optionally override ``names``,
``modulus``, or ``ordering``.
"""
return cls.get(
names=ctx.names(),
modulus=ctx.modulus(),
ordering=ctx.ordering(),
names=ctx.names() if names is None else names,
modulus=ctx.modulus() if modulus is None else modulus,
ordering=ctx.ordering() if ordering is None else ordering,
)

def is_prime(self):
Expand Down Expand Up @@ -869,6 +941,81 @@ cdef class flint_mpoly(flint_elem):
"""
return zip(self.monoms(), self.coeffs())

def unused_gens(self):
"""
Report the unused generators from this polynomial.
A generator is unused if it's maximum degree is 0.
>>> from flint import fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get(('x', 4))
>>> ctx2 = fmpz_mpoly_ctx.get(('x1', 'x2'))
>>> f = sum(ctx.gens()[1:3])
>>> f
x1 + x2
>>> f.unused_gens()
('x0', 'x3')
"""
names = self.context().names()
return tuple(names[i] for i, x in enumerate(self.degrees()) if not x)

def project_to_context(self, other_ctx, mapping: dict[str | int, str | int] = None):
"""
Project this polynomial to a different context.
This is equivalent to composing this polynomial with the generators of another
context. By default the mapping between contexts is inferred based on the name
of the generators. Generators with names that are not found within the other
context are mapped to 0. The mapping can be explicitly provided.
>>> from flint import fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get(('x', 'y', 'a', 'b'))
>>> ctx2 = fmpz_mpoly_ctx.get(('a', 'b'))
>>> x, y, a, b = ctx.gens()
>>> f = x + 2*y + 3*a + 4*b
>>> f.project_to_context(ctx2)
3*a + 4*b
>>> f.project_to_context(ctx2, mapping={"x": "a", "b": 0})
5*a
"""
cdef:
slong *c_mapping
slong i

ctx = self.context()
if not typecheck(other_ctx, type(ctx)):
raise ValueError(
f"provided context is not a {ctx.__class__.__name__}"
)
elif ctx is other_ctx:
return self

if mapping is None:
mapping = ctx.infer_generator_mapping(other_ctx)
else:
mapping = {
ctx.variable_to_index(k): other_ctx.variable_to_index(v)
for k, v in mapping.items()
}

try:
c_mapping = <slong *> libc.stdlib.malloc(ctx.nvars() * sizeof(slong *))
if c_mapping is NULL:
raise MemoryError("malloc returned a null pointer")

for i in range(ctx.nvars()):
c_mapping[i] = <slong>-1

for k, v in mapping.items():
c_mapping[k] = <slong>v

return self._compose_gens_(other_ctx, c_mapping)
finally:
libc.stdlib.free(c_mapping)

cdef _compose_gens_(self, other_ctx, slong *mapping):
raise NotImplementedError("abstract method")


cdef class flint_series(flint_elem):
"""
Expand Down
10 changes: 7 additions & 3 deletions src/flint/flintlib/types/flint.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,26 @@ cdef extern from "flint/fmpz.h":

cdef extern from *:
"""
/*
* Functions renamed in Flint 3.2.0
*/
#if __FLINT_RELEASE < 30200 /* Flint < 3.2.0 */
/* Functions renamed in Flint 3.2.0 */
#define flint_rand_init flint_randinit
#define flint_rand_clear flint_randclear
#endif
/* FIXME: add version guard when https://github.com/flintlib/flint/pull/2068 */
/* is resolved */
#define fmpz_mod_mpoly_compose_fmpz_mod_mpoly_gen(...) (void)0
"""

cdef extern from "flint/flint.h":
"""
#define SIZEOF_ULONG sizeof(ulong)
#define SIZEOF_SLONG sizeof(slong)
"""
int SIZEOF_ULONG
int SIZEOF_SLONG

ctypedef struct __FLINT_FILE:
pass
Expand Down
42 changes: 36 additions & 6 deletions src/flint/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,6 +2861,12 @@ def test_mpolys():

ctx = get_context(("x", 2))

def mpoly(x):
return ctx.from_dict(x)

def quick_poly():
return mpoly({(0, 0): 1, (0, 1): 2, (1, 0): 3, (2, 2): 4})

assert raises(lambda : ctx.__class__("x", flint.Ordering.lex), RuntimeError)
assert raises(lambda: get_context(("x", 2), ordering="bad"), ValueError)
assert raises(lambda: get_context(("x", -1)), ValueError)
Expand All @@ -2877,17 +2883,41 @@ def test_mpolys():
assert raises(lambda: P(val={"bad": 1}, ctx=None), ValueError)
assert raises(lambda: P(val="1", ctx=None), ValueError)

ctx1 = get_context(("x", 4))
ctx2 = get_context(("x", 4), ordering="deglex")
assert ctx1.drop_gens(ctx1.names()).names() == tuple()
assert ctx1.drop_gens((ctx1.name(1), ctx1.name(2))).names() == (ctx1.name(0), ctx1.name(3))
assert ctx1.drop_gens(tuple()).names() == ctx1.names()
assert ctx1.drop_gens((-1,)).names() == ctx1.names()[:-1]

assert ctx.infer_generator_mapping(ctx) == {i: i for i in range(ctx.nvars())}
assert ctx1.infer_generator_mapping(ctx) == {0: 0, 1: 1}
assert ctx1.drop_gens(ctx.names()).infer_generator_mapping(ctx) == {}

assert quick_poly().project_to_context(ctx1) == \
ctx1.from_dict(
{(0, 0, 0, 0): 1, (0, 1, 0, 0): 2, (1, 0, 0, 0): 3, (2, 2, 0, 0): 4}
)
new_poly = quick_poly().project_to_context(ctx1)
assert ctx1.drop_gens(new_poly.unused_gens()) == ctx
assert new_poly.project_to_context(ctx) == quick_poly()

new_poly = quick_poly().project_to_context(ctx2)
new_ctx = ctx2.drop_gens(new_poly.unused_gens())
assert new_ctx != ctx
assert new_poly != quick_poly()

new_ctx = new_ctx.from_context(new_ctx, ordering=ctx.ordering())
assert new_ctx == ctx
assert new_poly.project_to_context(new_ctx) == quick_poly()

assert ctx.append_gens(*ctx1.names()[-2:]) == ctx1

assert P(val={(0, 0): 1}, ctx=ctx) == ctx.from_dict({(0, 0): 1})
assert P(ctx=ctx).context() == ctx
assert P(1, ctx=ctx).is_one()
assert ctx.gen(1) == ctx.from_dict({(0, 1): 1})

def mpoly(x):
return ctx.from_dict(x)

def quick_poly():
return mpoly({(0, 0): 1, (0, 1): 2, (1, 0): 3, (2, 2): 4})

assert ctx.nvars() == 2
assert ctx.ordering() == flint.Ordering.lex

Expand Down
41 changes: 17 additions & 24 deletions src/flint/types/fmpq_mpoly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from flint.flintlib.functions.fmpq_mpoly cimport (
fmpq_mpoly_add_fmpq,
fmpq_mpoly_clear,
fmpq_mpoly_compose_fmpq_mpoly,
fmpq_mpoly_compose_fmpq_mpoly_gen,
fmpq_mpoly_ctx_init,
fmpq_mpoly_degrees_fmpz,
fmpq_mpoly_derivative,
Expand Down Expand Up @@ -547,28 +548,6 @@ cdef class fmpq_mpoly(flint_mpoly):

return res

# def terms(self):
# """
# Return the terms of this polynomial as a list of fmpq_mpolys.

# >>> ctx = fmpq_mpoly_ctx.get(('x', 2), 'lex')
# >>> f = ctx.from_dict({(0, 0): 1, (1, 0): 2, (0, 1): 3, (1, 1): 4})
# >>> f.terms()
# [4*x0*x1, 2*x0, 3*x1, 1]

# """
# cdef:
# fmpq_mpoly term
# slong i

# res = []
# for i in range(len(self)):
# term = create_fmpq_mpoly(self.ctx)
# fmpq_mpoly_get_term(term.val, self.val, i, self.ctx.val)
# res.append(term)

# return res

def subs(self, dict_args) -> fmpq_mpoly:
"""
Partial evaluate this polynomial with select constants. Keys must be generator names or generator indices,
Expand Down Expand Up @@ -699,9 +678,11 @@ cdef class fmpq_mpoly(flint_mpoly):
Return a dictionary of variable name to degree.
>>> ctx = fmpq_mpoly_ctx.get(('x', 4), 'lex')
>>> p = ctx.from_dict({(1, 0, 0, 0): 1, (0, 2, 0, 0): 2, (0, 0, 3, 0): 3})
>>> p = sum(x**i for i, x in enumerate(ctx.gens()))
>>> p
x1 + x2^2 + x3^3 + 1
>>> p.degrees()
(1, 2, 3, 0)
(0, 1, 2, 3)
"""
cdef:
slong nvars = self.ctx.nvars()
Expand Down Expand Up @@ -1119,6 +1100,18 @@ cdef class fmpq_mpoly(flint_mpoly):
fmpz_mpoly_deflation(shift.val, stride.val, self.val.zpoly, self.ctx.val.zctx)
return list(stride), list(shift)

cdef _compose_gens_(self, ctx, slong *mapping):
cdef fmpq_mpoly res = create_fmpq_mpoly(ctx)
fmpq_mpoly_compose_fmpq_mpoly_gen(
res.val,
self.val,
mapping,
self.ctx.val,
(<fmpq_mpoly_ctx>ctx).val
)

return res


cdef class fmpq_mpoly_vec:
"""
Expand Down
Loading

0 comments on commit d76d40b

Please sign in to comment.