Skip to content

Commit

Permalink
Support for gradients (#94)
Browse files Browse the repository at this point in the history
* Support for gradients

* typing fixes

* Update graphviz2drawio/models/SvgParser.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update graphviz2drawio/mx/Node.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 15, 2024
1 parent 3986517 commit d72aa3c
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 19 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ python -m graphviz2drawio test/directed/hello.gv.txt
## Roadmap

* Migrate to uv/hatch for packaging and dep mgmt
* Support for fill gradient
* Support compatible [arrows](https://graphviz.org/docs/attr-types/arrowType/)
* Support [multiple edges](https://graphviz.org/Gallery/directed/switch.html)
* Support [edges with links](https://graphviz.org/Gallery/directed/pprof.html)
Expand Down
4 changes: 2 additions & 2 deletions graphviz2drawio/models/SVG.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def get_first(g: Element, tag: str) -> Element | None:
return g.find(f"./{NS_SVG}{tag}")


def count_tags(g: Element, tag: str) -> int:
return len(g.findall(f"./{NS_SVG}{tag}"))
def findall(g: Element, tag: str) -> list[Element]:
return g.findall(f"./{NS_SVG}{tag}")


def get_title(g: Element) -> str | None:
Expand Down
80 changes: 77 additions & 3 deletions graphviz2drawio/models/SvgParser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import re
from collections import OrderedDict
from collections.abc import Iterable
from math import isclose
from xml.etree import ElementTree

from graphviz2drawio.mx.Edge import Edge
from graphviz2drawio.mx.EdgeFactory import EdgeFactory
from graphviz2drawio.mx.Node import Node
from graphviz2drawio.mx.Node import Gradient, Node
from graphviz2drawio.mx.NodeFactory import NodeFactory

from ..mx.Curve import LINE_TOLERANCE
from ..mx.utils import adjust_color_opacity
from . import SVG
from .commented_tree_builder import COMMENT, CommentedTreeBuilder
from .CoordsTranslate import CoordsTranslate
Expand All @@ -29,17 +34,28 @@ def parse_nodes_edges_clusters(
nodes: OrderedDict[str, Node] = OrderedDict()
edges: OrderedDict[str, Edge] = OrderedDict()
clusters: OrderedDict[str, Node] = OrderedDict()
gradients = dict[str, Gradient]()

prev_comment = None
for g in root:
if g.tag == COMMENT:
prev_comment = g.text
elif SVG.is_tag(g, "defs"):
for gradient in _extract_gradients(g):
gradients[gradient[0]] = gradient[1:]
elif SVG.is_tag(g, "g"):
title = prev_comment or SVG.get_title(g)
if title is None:
raise MissingTitleError(g)
if (defs := SVG.get_first(g, "defs")) is not None:
for gradient in _extract_gradients(defs):
gradients[gradient[0]] = gradient[1:]
if g.attrib["class"] == "node":
nodes[title] = node_factory.from_svg(g, labelloc="c")
nodes[title] = node_factory.from_svg(
g,
labelloc="c",
gradients=gradients,
)
elif g.attrib["class"] == "edge":
# We need to merge edges with the same source and target
# GV represents multiple labels with multiple edges
Expand All @@ -52,6 +68,64 @@ def parse_nodes_edges_clusters(
else:
edges[edge.key_for_label] = edge
elif g.attrib["class"] == "cluster":
clusters[title] = node_factory.from_svg(g, labelloc="t")
clusters[title] = node_factory.from_svg(
g,
labelloc="t",
gradients=gradients,
)

return nodes, list(edges.values()), clusters


_stop_color_re = re.compile(r"stop-color:([^;]+);")
_stop_opacity_re = re.compile(r"stop-opacity:([^;]+);")


def _extract_stop_color(stop: ElementTree.Element) -> str | None:
style = stop.attrib.get("style", "")
if (color := _stop_color_re.search(style)) is not None:
if (opacity := _stop_opacity_re.search(style)) is not None:
return adjust_color_opacity(color.group(1), float(opacity.group(1)))
return None


def _extract_gradients(
defs: ElementTree.Element,
) -> Iterable[tuple[str, str, str, str]]:
for radial_gradient in SVG.findall(defs, "radialGradient"):
stops = SVG.findall(radial_gradient, "stop")
start_color = _extract_stop_color(stops[0])
end_color = _extract_stop_color(stops[-1])
if start_color is None or end_color is None:
continue
yield (
radial_gradient.attrib["id"],
start_color,
end_color,
"radial",
)
for linear_gradient in SVG.findall(defs, "linearGradient"):
stops = SVG.findall(linear_gradient, "stop")

start_color = _extract_stop_color(stops[0])
end_color = _extract_stop_color(stops[-1])
if start_color is None or end_color is None:
continue

y1 = float(linear_gradient.attrib["y1"])
y2 = float(linear_gradient.attrib["y2"])

gradient_direction = "north"
if isclose(y1, y2, rel_tol=LINE_TOLERANCE):
x1 = float(linear_gradient.attrib["y1"])
x2 = float(linear_gradient.attrib["y2"])
gradient_direction = "east" if x1 < x2 else "west"
elif y1 < y2:
gradient_direction = "south"

yield (
linear_gradient.attrib["id"],
start_color,
end_color,
gradient_direction,
)
18 changes: 15 additions & 3 deletions graphviz2drawio/mx/Node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import TypeAlias

from ..models.Rect import Rect
from .GraphObj import GraphObj
from .MxConst import VERTICAL_ALIGN
from .Styles import Styles
from .Text import Text

Gradient: TypeAlias = tuple[str, str | None, str]


class Node(GraphObj):
def __init__(
Expand All @@ -12,7 +16,7 @@ def __init__(
gid: str,
rect: Rect | None,
texts: list[Text],
fill: str,
fill: str | Gradient,
stroke: str,
shape: str,
labelloc: str,
Expand Down Expand Up @@ -46,21 +50,29 @@ def texts_to_mx_value(self) -> str:
def get_node_style(self) -> str:
style_for_shape = Styles.get_for_shape(self.shape)
dashed = 1 if self.dashed else 0
additional_styling = ""

attributes = {
"fill": self.fill,
"stroke": self.stroke,
"stroke_width": self.stroke_width,
"dashed": dashed,
}
if isinstance(self.fill, str):
attributes["fill"] = self.fill
elif type(self.fill) is tuple:
attributes["fill"] = self.fill[0]
additional_styling += (
f"gradientColor={self.fill[1]};gradientDirection={self.fill[2]};"
)

if (rect := self.rect) is not None and (image_path := rect.image) is not None:
from graphviz2drawio.mx.image import image_data_for_path

attributes["image"] = image_data_for_path(image_path)

attributes["vertical_align"] = VERTICAL_ALIGN.get(self.labelloc, "middle")

return style_for_shape.format(**attributes)
return style_for_shape.format(**attributes) + additional_styling

def __repr__(self) -> str:
return (
Expand Down
36 changes: 27 additions & 9 deletions graphviz2drawio/mx/NodeFactory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from xml.etree.ElementTree import Element

from graphviz2drawio.models import SVG
Expand All @@ -6,7 +7,7 @@
from ..models.Errors import MissingIdentifiersError
from . import MxConst, Shape
from .MxConst import DEFAULT_STROKE_WIDTH
from .Node import Node
from .Node import Gradient, Node
from .RectFactory import rect_from_ellipse_svg, rect_from_image, rect_from_svg_points
from .Text import Text
from .utils import adjust_color_opacity
Expand All @@ -17,11 +18,16 @@ def __init__(self, coords: CoordsTranslate) -> None:
super().__init__()
self.coords = coords

def from_svg(self, g: Element, labelloc: str) -> Node:
def from_svg(
self,
g: Element,
labelloc: str,
gradients: dict[str, Gradient],
) -> Node:
sid = g.attrib["id"]
gid = SVG.get_title(g)
rect = None
fill = MxConst.NONE
fill: str | Gradient = MxConst.NONE
stroke = MxConst.NONE
stroke_width = DEFAULT_STROKE_WIDTH
dashed = False
Expand All @@ -36,7 +42,8 @@ def from_svg(self, g: Element, labelloc: str) -> Node:
if (polygon := SVG.get_first(g, "polygon")) is not None:
rect = rect_from_svg_points(self.coords, polygon.attrib["points"])
shape = Shape.RECT
fill, stroke = self._extract_fill_and_stroke(polygon)
fill = self._extract_fill(polygon, gradients)
stroke = self._extract_stroke(polygon)
stroke_width = polygon.attrib.get("stroke-width", DEFAULT_STROKE_WIDTH)
if "stroke-dasharray" in polygon.attrib:
dashed = True
Expand All @@ -50,10 +57,11 @@ def from_svg(self, g: Element, labelloc: str) -> Node:
)
shape = (
Shape.ELLIPSE
if SVG.count_tags(g, "ellipse") == 1
if len(SVG.findall(g, "ellipse")) == 1
else Shape.DOUBLE_CIRCLE
)
fill, stroke = self._extract_fill_and_stroke(ellipse)
fill = self._extract_fill(ellipse, gradients)
stroke = self._extract_stroke(ellipse)
stroke_width = ellipse.attrib.get("stroke-width", DEFAULT_STROKE_WIDTH)
if "stroke-dasharray" in ellipse.attrib:
dashed = True
Expand All @@ -76,15 +84,25 @@ def from_svg(self, g: Element, labelloc: str) -> Node:
dashed=dashed,
)

_fill_url_re = re.compile(r"url\(#([^)]+)\)")

@staticmethod
def _extract_fill_and_stroke(g: Element) -> tuple[str, str]:
def _extract_fill(g: Element, gradients: dict[str, Gradient]) -> str | Gradient:
fill = g.attrib.get("fill", MxConst.NONE)
stroke = g.attrib.get("stroke", MxConst.NONE)
if fill.startswith("url"):
match = NodeFactory._fill_url_re.search(fill)
if match is not None:
return gradients[match.group(1)]
if "fill-opacity" in g.attrib and fill != MxConst.NONE:
fill = adjust_color_opacity(fill, float(g.attrib["fill-opacity"]))
return fill

@staticmethod
def _extract_stroke(g: Element) -> str:
stroke = g.attrib.get("stroke", MxConst.NONE)
if "stroke-opacity" in g.attrib and stroke != MxConst.NONE:
stroke = adjust_color_opacity(stroke, float(g.attrib["stroke-opacity"]))
return fill, stroke
return stroke

def _extract_texts(self, g: Element) -> tuple[list[Text], complex | None]:
texts = []
Expand Down
5 changes: 4 additions & 1 deletion graphviz2drawio/mx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ def adjust_color_opacity(hex_color: str, opacity: float) -> str:
hex_color = hex_color.lstrip("#")

# Convert hex to RGB
r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
try:
r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
except ValueError:
return hex_color

# Apply opacity over white background
r = int(r * opacity + 255 * (1 - opacity))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ line-ending = "auto"
"graphviz2drawio/version.py" = ["T201"]
"doc/source/conf.py" = ["A001", "ERA001", "INP001"]
"graphviz2drawio/models/commented_tree_builder.py" = ["ANN001", "ANN201", "ANN204"]
"graphviz2drawio/models/SvgParser.py" = ["C901", "PLR0912"]

[tool.pytest.ini_options]
pythonpath = ". venv/lib/python3.12/site-packages"
Expand Down

0 comments on commit d72aa3c

Please sign in to comment.