diff --git a/cypress/platform/knsv2.html b/cypress/platform/knsv2.html index 831a5d33ae..c01729c84d 100644 --- a/cypress/platform/knsv2.html +++ b/cypress/platform/knsv2.html @@ -59,8 +59,16 @@
 stateDiagram-v2
-     Second --> Third
-     Second --> Fourth
+    state fork_state <>
+      [*] --> fork_state
+      fork_state --> State2
+      fork_state --> State3
+
+      state join_state <>
+      State2 --> join_state
+      State3 --> join_state
+      join_state --> State4
+      State4 --> [*]
   
diff --git a/packages/mermaid/src/diagrams/state/stateCommon.ts b/packages/mermaid/src/diagrams/state/stateCommon.ts
index 7d80f41e0e..e847d1514e 100644
--- a/packages/mermaid/src/diagrams/state/stateCommon.ts
+++ b/packages/mermaid/src/diagrams/state/stateCommon.ts
@@ -28,6 +28,7 @@ export const G_EDGE_LABELTYPE = 'text';
 export const G_EDGE_THICKNESS = 'normal';
 
 export const CSS_EDGE = 'transition';
+export const CSS_DIAGRAM = 'statediagram';
 
 export default {
   DEFAULT_DIAGRAM_DIRECTION,
@@ -44,4 +45,5 @@ export default {
   G_EDGE_LABELTYPE,
   G_EDGE_THICKNESS,
   CSS_EDGE,
+  CSS_DIAGRAM,
 };
diff --git a/packages/mermaid/src/diagrams/state/stateDb.js b/packages/mermaid/src/diagrams/state/stateDb.js
index e9a4148b96..89bfb55d2e 100644
--- a/packages/mermaid/src/diagrams/state/stateDb.js
+++ b/packages/mermaid/src/diagrams/state/stateDb.js
@@ -27,6 +27,7 @@ import {
   G_EDGE_THICKNESS,
   CSS_EDGE,
 } from './stateCommon.js';
+import { rect } from 'dagre-d3-es/src/dagre-js/intersect/index.js';
 
 const START_NODE = '[*]';
 const START_TYPE = 'start';
@@ -555,9 +556,36 @@ const dataFetcher = (parentId, doc, nodes, edges) => {
 
   stateKeys.forEach((key) => {
     const item = currentDocument.states[key];
+    console.log('Item:', item);
+
+    let itemShape = 'rect';
+    if (item.type === 'default' && item.id === 'root_start') {
+      itemShape = 'stateStart';
+    }
+    if (item.type === 'default' && item.id === 'root_end') {
+      itemShape = 'stateEnd';
+    }
+
+    if (item.type === 'fork' || item.type === 'join') {
+      itemShape = 'forkJoin';
+    }
+
+    if (item.type === 'choice') {
+      itemShape = 'choice';
+    }
+
+    if (item.id === '' && item.type === 'default') {
+      //ignore this item
+      return;
+    }
+
+    if (item.id === '' && item.type === 'default') {
+      //ignore this item
+      return;
+    }
 
     if (parentId) {
-      nodes.push({ ...item, labelText: item.id, labelType: 'text', parentId, shape: 'rect' });
+      nodes.push({ ...item, labelText: item.id, labelType: 'text', parentId, shape: itemShape });
     } else {
       nodes.push({
         ...item,
@@ -565,7 +593,7 @@ const dataFetcher = (parentId, doc, nodes, edges) => {
         // description: item.id,
         labelType: 'text',
         labelStyle: '',
-        shape: 'rect',
+        shape: itemShape,
         padding: 15,
         classes: ' statediagram-state',
       });
diff --git a/packages/mermaid/src/diagrams/state/stateRenderer-v3-unified.ts b/packages/mermaid/src/diagrams/state/stateRenderer-v3-unified.ts
index bf72d3cbbd..4a44d83454 100644
--- a/packages/mermaid/src/diagrams/state/stateRenderer-v3-unified.ts
+++ b/packages/mermaid/src/diagrams/state/stateRenderer-v3-unified.ts
@@ -8,6 +8,7 @@ import { render } from '../../rendering-util/render.js';
 import insertElementsForSize, {
   getDiagramElements,
 } from '../../rendering-util/inserElementsForSize.js';
+import { setupViewPortForSVG } from '../../rendering-util/setupViewPortForSVG.js';
 import {
   DEFAULT_DIAGRAM_DIRECTION,
   DEFAULT_NESTED_DOC_DIR,
@@ -15,6 +16,7 @@ import {
   STMT_RELATION,
   DEFAULT_STATE_TYPE,
   DIVIDER_TYPE,
+  CSS_DIAGRAM,
 } from './stateCommon.js';
 
 // Configuration
@@ -93,6 +95,8 @@ export const draw = async function (text: string, id: string, _version: string,
   data4Layout.diagramId = id;
   console.log('REF1:', data4Layout);
   await render(data4Layout, svg, element);
+  const padding = 8;
+  setupViewPortForSVG(svg, padding, CSS_DIAGRAM, conf.useMaxWidth);
 };
 
 export default {
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/nodes.js b/packages/mermaid/src/rendering-util/rendering-elements/nodes.js
index 88db04393d..bf25716de0 100644
--- a/packages/mermaid/src/rendering-util/rendering-elements/nodes.js
+++ b/packages/mermaid/src/rendering-util/rendering-elements/nodes.js
@@ -1,5 +1,9 @@
 import { log } from '$root/logger.js';
-import { rect } from './shapes/rect.js';
+import { rect } from './shapes/rect.ts';
+import { stateStart } from './shapes/stateStart.ts';
+import { stateEnd } from './shapes/stateEnd.ts';
+import { forkJoin } from './shapes/forkJoin.ts';
+import { choice } from './shapes/choice.ts';
 import { getConfig } from '$root/diagram-api/diagramAPI.js';
 
 const formatClass = (str) => {
@@ -11,6 +15,10 @@ const formatClass = (str) => {
 
 const shapes = {
   rect,
+  stateStart,
+  stateEnd,
+  forkJoin,
+  choice,
 };
 
 let nodeElems = {};
@@ -19,9 +27,9 @@ export const insertNode = async (elem, node, dir) => {
   let newEl;
   let el;
 
-  console.log('insertNode element', elem, elem.node(), rect);
   // debugger;
   // Add link when appropriate
+  console.log('node.link', node.link);
   if (node.link) {
     let target;
     if (getConfig().securityLevel === 'sandbox') {
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts
new file mode 100644
index 0000000000..04d4466add
--- /dev/null
+++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts
@@ -0,0 +1,37 @@
+import intersect from '../intersect/index.js';
+import type { Node } from '$root/rendering-util/types.d.ts';
+import type { SVG } from '$root/diagram-api/types.js';
+
+export const choice = (parent: SVG, node: Node) => {
+  const shapeSvg = parent
+    .insert('g')
+    .attr('class', 'node default')
+    .attr('id', node.domId || node.id);
+
+  const s = 28;
+  const points = [
+    { x: 0, y: s / 2 },
+    { x: s / 2, y: 0 },
+    { x: 0, y: -s / 2 },
+    { x: -s / 2, y: 0 },
+  ];
+
+  const choice = shapeSvg.insert('polygon', ':first-child').attr(
+    'points',
+    points
+      .map(function (d) {
+        return d.x + ',' + d.y;
+      })
+      .join(' ')
+  );
+  // center the circle around its coordinate
+  choice.attr('class', 'state-start').attr('r', 7).attr('width', 28).attr('height', 28);
+  node.width = 28;
+  node.height = 28;
+
+  node.intersect = function (point) {
+    return intersect.circle(node, 14, point);
+  };
+
+  return shapeSvg;
+};
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/forkJoin.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/forkJoin.ts
new file mode 100644
index 0000000000..657749051e
--- /dev/null
+++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/forkJoin.ts
@@ -0,0 +1,51 @@
+import { log } from '$root/logger.js';
+import { updateNodeBounds } from './util.js';
+import intersect from '../intersect/index.js';
+import type { Node } from '$root/rendering-util/types.d.ts';
+import type { SVG } from '$root/diagram-api/types.js';
+
+export const forkJoin = (parent: SVG, node: Node, dir: string) => {
+  const shapeSvg = parent
+    .insert('g')
+    .attr('class', 'node default')
+    .attr('id', node.domId || node.id);
+
+  let width = 70;
+  let height = 10;
+
+  if (dir === 'LR') {
+    width = 10;
+    height = 70;
+  }
+
+  const shape = shapeSvg
+    .append('rect')
+    .attr('x', (-1 * width) / 2)
+    .attr('y', (-1 * height) / 2)
+    .attr('width', width)
+    .attr('height', height)
+    .attr('class', 'fork-join');
+
+  updateNodeBounds(node, shape);
+  let nodeHeight = 0;
+  let nodeWidth = 0;
+  let nodePadding = 10;
+  if (node.height) {
+    nodeHeight = node.height;
+  }
+  if (node.width) {
+    nodeWidth = node.width;
+  }
+
+  if (node.padding) {
+    nodePadding = node.padding;
+  }
+
+  node.height = nodeHeight + nodePadding / 2;
+  node.width = nodeWidth + nodePadding / 2;
+  node.intersect = function (point) {
+    return intersect.rect(node, point);
+  };
+
+  return shapeSvg;
+};
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/rect.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/rect.ts
index 03abb1f937..30469e8d99 100644
--- a/packages/mermaid/src/rendering-util/rendering-elements/shapes/rect.ts
+++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/rect.ts
@@ -4,6 +4,7 @@ import intersect from '../intersect/index.js';
 import type { Node } from '$root/rendering-util/types.d.ts';
 import rough from 'roughjs';
 import { select } from 'd3';
+
 /**
  *
  * @param rect
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/stateEnd.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/stateEnd.ts
new file mode 100644
index 0000000000..3f968fe861
--- /dev/null
+++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/stateEnd.ts
@@ -0,0 +1,26 @@
+import { log } from '$root/logger.js';
+import { updateNodeBounds } from './util.js';
+import intersect from '../intersect/index.js';
+import type { Node } from '$root/rendering-util/types.d.ts';
+import type { SVG } from '$root/diagram-api/types.js';
+
+export const stateEnd = (parent: SVG, node: Node) => {
+  const shapeSvg = parent
+    .insert('g')
+    .attr('class', 'node default')
+    .attr('id', node.domId || node.id);
+  const innerCircle = shapeSvg.insert('circle', ':first-child');
+  const circle = shapeSvg.insert('circle', ':first-child');
+
+  circle.attr('class', 'state-start').attr('r', 7).attr('width', 14).attr('height', 14);
+
+  innerCircle.attr('class', 'state-end').attr('r', 5).attr('width', 10).attr('height', 10);
+
+  updateNodeBounds(node, circle);
+
+  node.intersect = function (point) {
+    return intersect.circle(node, 7, point);
+  };
+
+  return shapeSvg;
+};
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/stateStart.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/stateStart.ts
new file mode 100644
index 0000000000..20dc861e92
--- /dev/null
+++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/stateStart.ts
@@ -0,0 +1,24 @@
+import { log } from '$root/logger.js';
+import { updateNodeBounds } from './util.js';
+import intersect from '../intersect/index.js';
+import type { Node } from '$root/rendering-util/types.d.ts';
+import type { SVG } from '$root/diagram-api/types.js';
+
+export const stateStart = (parent: SVG, node: Node) => {
+  const shapeSvg = parent
+    .insert('g')
+    .attr('class', 'node default')
+    .attr('id', node.domId || node.id);
+  const circle = shapeSvg.insert('circle', ':first-child');
+
+  // center the circle around its coordinate
+  circle.attr('class', 'state-start').attr('r', 7).attr('width', 14).attr('height', 14);
+
+  updateNodeBounds(node, circle);
+
+  node.intersect = function (point) {
+    return intersect.circle(node, 7, point);
+  };
+
+  return shapeSvg;
+};
diff --git a/packages/mermaid/src/rendering-util/setupViewPortForSVG.ts b/packages/mermaid/src/rendering-util/setupViewPortForSVG.ts
new file mode 100644
index 0000000000..1fa2de1fd4
--- /dev/null
+++ b/packages/mermaid/src/rendering-util/setupViewPortForSVG.ts
@@ -0,0 +1,40 @@
+import { configureSvgSize } from '$root/setupGraphViewbox.js';
+import type { SVG } from '$root/diagram-api/types.js';
+import { log } from '$root/logger.js';
+
+export const setupViewPortForSVG = (
+  svg: SVG,
+  padding: number,
+  cssDiagram: string,
+  useMaxWidth: boolean
+) => {
+  // Initialize the SVG element and set the diagram class
+  svg.attr('class', cssDiagram);
+
+  // Calculate the dimensions and position with padding
+  const { width, height, x, y } = calculateDimensionsWithPadding(svg, padding);
+
+  // Configure the size and aspect ratio of the SVG
+  configureSvgSize(svg, height, width, useMaxWidth);
+
+  // Update the viewBox to ensure all content is visible with padding
+  const viewBox = createViewBox(x, y, width, height, padding);
+  svg.attr('viewBox', viewBox);
+
+  // Log the viewBox configuration for debugging
+  log.debug(`viewBox configured: ${viewBox}`);
+};
+
+const calculateDimensionsWithPadding = (svg: SVG, padding: number) => {
+  const bounds = svg.node()?.getBBox() || { width: 0, height: 0, x: 0, y: 0 };
+  return {
+    width: bounds.width + padding * 2,
+    height: bounds.height + padding * 2,
+    x: bounds.x,
+    y: bounds.y,
+  };
+};
+
+const createViewBox = (x: number, y: number, width: number, height: number, padding: number) => {
+  return `${x - padding} ${y - padding} ${width} ${height}`;
+};
diff --git a/packages/mermaid/src/rendering-util/types.d.ts b/packages/mermaid/src/rendering-util/types.d.ts
index 38a48b38d3..e69537e814 100644
--- a/packages/mermaid/src/rendering-util/types.d.ts
+++ b/packages/mermaid/src/rendering-util/types.d.ts
@@ -33,8 +33,10 @@ interface Node {
   tooltip?: string;
   type: string;
   width?: number;
-  intersect?: (point: any) => any;
+  height?: number;
+
   // Specific properties for State Diagram nodes TODO remove and use generic properties
+  intersect?: (point: any) => any;
   style?: string;
   class?: string;
   borders?: string;