98962ec5b27305c7934ec12e7553a7e223cf9089
java/com.sap.sailing.windestimation.lab/python/mst_graph_visualizer_graphviz.py
| ... | ... | @@ -165,21 +165,33 @@ def create_node_label_with_ports(node, best_type=None): |
| 165 | 165 | footer_content = time_part |
| 166 | 166 | html += f'<TR><TD COLSPAN="4" BGCOLOR="white"><FONT POINT-SIZE="9">{footer_content}</FONT></TD></TR>' |
| 167 | 167 | |
| 168 | - # Bottom row with ports for OUTGOING edges (minimal height) |
|
| 168 | + # Bottom row with path vote diagnostics AND ports for OUTGOING edges |
|
| 169 | 169 | # Each cell corresponds to one compartment position for proper horizontal alignment |
| 170 | + path_votes = node.get('pathVotes', {}) |
|
| 170 | 171 | html += '<TR>' |
| 171 | 172 | for type_name in TYPE_ORDER: |
| 172 | 173 | # Port for outgoing edges (named type_out) |
| 173 | 174 | out_port_name = f'{type_name}_out' |
| 174 | - # Minimal height cells just for port positioning |
|
| 175 | - html += f'<TD PORT="{out_port_name}" BGCOLOR="white" HEIGHT="1"></TD>' |
|
| 175 | + |
|
| 176 | + vote_info = path_votes.get(type_name, {}) |
|
| 177 | + path_count = vote_info.get('pathCount', 0) |
|
| 178 | + quality_sum = vote_info.get('qualitySum', 0) |
|
| 179 | + if path_count > 0: |
|
| 180 | + # Format quality sum in scientific notation if very small |
|
| 181 | + if quality_sum < 0.001: |
|
| 182 | + qs_str = f'{quality_sum:.1e}' |
|
| 183 | + else: |
|
| 184 | + qs_str = f'{quality_sum:.3f}' |
|
| 185 | + vote_content = f'<FONT POINT-SIZE="6" COLOR="gray40">{path_count}p/{qs_str}</FONT>' |
|
| 186 | + else: |
|
| 187 | + vote_content = '<FONT POINT-SIZE="6" COLOR="gray70">-</FONT>' |
|
| 188 | + html += f'<TD PORT="{out_port_name}" BGCOLOR="white">{vote_content}</TD>' |
|
| 176 | 189 | html += '</TR>' |
| 177 | 190 | |
| 178 | 191 | html += '</TABLE>>' |
| 179 | 192 | |
| 180 | 193 | return html |
| 181 | 194 | |
| 182 | - |
|
| 183 | 195 | def create_legend(): |
| 184 | 196 | """Create a legend explaining the colors.""" |
| 185 | 197 | html = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">' |
java/com.sap.sailing.windestimation/src/com/sap/sailing/windestimation/aggregator/msthmm/MstGraphExporter.java
| ... | ... | @@ -3,11 +3,19 @@ package com.sap.sailing.windestimation.aggregator.msthmm; |
| 3 | 3 | import java.io.IOException; |
| 4 | 4 | import java.io.Writer; |
| 5 | 5 | import java.text.SimpleDateFormat; |
| 6 | +import java.util.ArrayList; |
|
| 6 | 7 | import java.util.HashMap; |
| 7 | 8 | import java.util.HashSet; |
| 9 | +import java.util.List; |
|
| 8 | 10 | import java.util.Map; |
| 9 | 11 | import java.util.Set; |
| 12 | +import java.util.function.Supplier; |
|
| 13 | +import java.util.stream.Collectors; |
|
| 10 | 14 | |
| 15 | +import com.sap.sailing.windestimation.aggregator.graph.DijkstraShortestPathFinderImpl; |
|
| 16 | +import com.sap.sailing.windestimation.aggregator.graph.DijsktraShortestPathFinder; |
|
| 17 | +import com.sap.sailing.windestimation.aggregator.graph.ElementAdjacencyQualityMetric; |
|
| 18 | +import com.sap.sailing.windestimation.aggregator.graph.InnerGraphSuccessorSupplier; |
|
| 11 | 19 | import com.sap.sailing.windestimation.aggregator.hmm.GraphLevelInference; |
| 12 | 20 | import com.sap.sailing.windestimation.aggregator.hmm.GraphNode; |
| 13 | 21 | import com.sap.sailing.windestimation.aggregator.hmm.WindCourseRange; |
| ... | ... | @@ -44,8 +52,12 @@ public class MstGraphExporter { |
| 44 | 52 | final int[] nodeIdCounter = {0}; |
| 45 | 53 | assignNodeIds(root, levelToId, nodeIdCounter); |
| 46 | 54 | |
| 55 | + // Collect path vote diagnostics for debugging |
|
| 56 | + final Map<MstGraphLevel, Map<ManeuverTypeForClassification, List<Double>>> pathVotes = |
|
| 57 | + collectPathVoteDiagnostics(graphComponents); |
|
| 58 | + |
|
| 47 | 59 | // Collect best (disambiguated) classification per node |
| 48 | - final Map<String, String> bestNodePerLevel = collectBestNodePerLevel(graphComponents, levelToId); |
|
| 60 | + final Map<String, String> bestNodePerLevel = collectBestNodePerLevel(graphComponents, levelToId, pathVotes); |
|
| 49 | 61 | |
| 50 | 62 | // Derive best path edges from the best node selections |
| 51 | 63 | // This ensures edges connect nodes with red frames (best classifications) |
| ... | ... | @@ -57,7 +69,7 @@ public class MstGraphExporter { |
| 57 | 69 | final boolean[] firstNode = {true}; |
| 58 | 70 | final int[] exportNodeIdCounter = {0}; |
| 59 | 71 | final Map<MstGraphLevel, Integer> exportLevelToId = new HashMap<>(); |
| 60 | - exportNode(root, writer, firstNode, exportLevelToId, exportNodeIdCounter, 0); |
|
| 72 | + exportNode(root, writer, firstNode, exportLevelToId, exportNodeIdCounter, 0, pathVotes); |
|
| 61 | 73 | writer.write("\n ],\n"); |
| 62 | 74 | writer.write(" \"edges\": [\n"); |
| 63 | 75 | // Export edges between nodes |
| ... | ... | @@ -79,7 +91,8 @@ public class MstGraphExporter { |
| 79 | 91 | } |
| 80 | 92 | |
| 81 | 93 | private void exportNode(MstGraphLevel level, Writer writer, boolean[] firstNode, |
| 82 | - Map<MstGraphLevel, Integer> levelToId, int[] nodeIdCounter, int depth) throws IOException { |
|
| 94 | + Map<MstGraphLevel, Integer> levelToId, int[] nodeIdCounter, int depth, |
|
| 95 | + Map<MstGraphLevel, Map<ManeuverTypeForClassification, List<Double>>> pathVotes) throws IOException { |
|
| 83 | 96 | final int nodeId = nodeIdCounter[0]++; |
| 84 | 97 | levelToId.put(level, nodeId); |
| 85 | 98 | if (!firstNode[0]) { |
| ... | ... | @@ -139,11 +152,13 @@ public class MstGraphExporter { |
| 139 | 152 | writer.write(" \"tackAfter\": \"" + node.getTackAfter() + "\"\n"); |
| 140 | 153 | writer.write(" }"); |
| 141 | 154 | } |
| 142 | - writer.write("\n ]\n"); |
|
| 155 | + writer.write("\n ],\n"); |
|
| 156 | + // Add diagnostic info about path votes for this node |
|
| 157 | + writer.write(" \"pathVotes\": " + formatPathVoteDiagnostics(level, pathVotes) + "\n"); |
|
| 143 | 158 | writer.write(" }"); |
| 144 | 159 | // Recursively export children |
| 145 | 160 | for (MstGraphLevel child : level.getChildren()) { |
| 146 | - exportNode(child, writer, firstNode, levelToId, nodeIdCounter, depth + 1); |
|
| 161 | + exportNode(child, writer, firstNode, levelToId, nodeIdCounter, depth + 1, pathVotes); |
|
| 147 | 162 | } |
| 148 | 163 | } |
| 149 | 164 | |
| ... | ... | @@ -215,22 +230,112 @@ public class MstGraphExporter { |
| 215 | 230 | } |
| 216 | 231 | } |
| 217 | 232 | |
| 233 | + /** |
|
| 234 | + * Collects diagnostic information about which paths voted for which classification at each node. |
|
| 235 | + * This runs Dijkstra from each leaf and records which classification was selected and with what path quality. |
|
| 236 | + */ |
|
| 237 | + private Map<MstGraphLevel, Map<ManeuverTypeForClassification, List<Double>>> collectPathVoteDiagnostics( |
|
| 238 | + MstManeuverGraphComponents graphComponents) { |
|
| 239 | + final Map<MstGraphLevel, Map<ManeuverTypeForClassification, List<Double>>> result = new HashMap<>(); |
|
| 240 | + |
|
| 241 | + final ElementAdjacencyQualityMetric<GraphNode<MstGraphLevel>> edgeQualityMetric = (previousNode, currentNode) -> { |
|
| 242 | + return transitionProbabilitiesCalculator.getTransitionProbability(currentNode, previousNode, |
|
| 243 | + previousNode.getGraphLevel() == null ? 0.0 : previousNode.getGraphLevel().getDistanceToParent()); |
|
| 244 | + }; |
|
| 245 | + |
|
| 246 | + for (MstGraphLevel leaf : graphComponents.getLeaves()) { |
|
| 247 | + final InnerGraphSuccessorSupplier<GraphNode<MstGraphLevel>, MstGraphLevel> innerGraphSuccessorSupplier = |
|
| 248 | + new InnerGraphSuccessorSupplier<>(graphComponents, |
|
| 249 | + (final Supplier<String> nameSupplier) -> new GraphNode<MstGraphLevel>( |
|
| 250 | + null, null, new WindCourseRange(0, 360), 1.0, 0, null) { |
|
| 251 | + @Override |
|
| 252 | + public String toString() { |
|
| 253 | + return nameSupplier.get(); |
|
| 254 | + } |
|
| 255 | + }); |
|
| 256 | + |
|
| 257 | + final DijsktraShortestPathFinder<GraphNode<MstGraphLevel>> dijkstra = |
|
| 258 | + new DijkstraShortestPathFinderImpl<>( |
|
| 259 | + innerGraphSuccessorSupplier.getArtificialLeaf(leaf), |
|
| 260 | + innerGraphSuccessorSupplier.getArtificialRoot(), |
|
| 261 | + innerGraphSuccessorSupplier, edgeQualityMetric); |
|
| 262 | + |
|
| 263 | + for (GraphNode<MstGraphLevel> node : dijkstra.getShortestPath()) { |
|
| 264 | + if (node.getGraphLevel() != null) { |
|
| 265 | + Map<ManeuverTypeForClassification, List<Double>> votesForNode = |
|
| 266 | + result.computeIfAbsent(node.getGraphLevel(), k -> new HashMap<>()); |
|
| 267 | + votesForNode.computeIfAbsent(node.getManeuverType(), k -> new ArrayList<>()) |
|
| 268 | + .add(dijkstra.getPathQuality()); |
|
| 269 | + } |
|
| 270 | + } |
|
| 271 | + } |
|
| 272 | + |
|
| 273 | + return result; |
|
| 274 | + } |
|
| 275 | + |
|
| 218 | 276 | private Map<String, String> collectBestNodePerLevel(MstManeuverGraphComponents graphComponents, |
| 219 | - Map<MstGraphLevel, Integer> levelToId) { |
|
| 277 | + Map<MstGraphLevel, Integer> levelToId, |
|
| 278 | + Map<MstGraphLevel, Map<ManeuverTypeForClassification, List<Double>>> pathVotes) { |
|
| 220 | 279 | final Map<String, String> bestNodePerLevel = new HashMap<>(); |
| 221 | - // Use the MstBestPathsCalculatorImpl to get the best nodes |
|
| 222 | - final MstBestPathsCalculatorImpl calculator = new MstBestPathsCalculatorImpl(transitionProbabilitiesCalculator); |
|
| 223 | - for (final GraphLevelInference<MstGraphLevel> inference : calculator.getBestNodes(graphComponents)) { |
|
| 224 | - final MstGraphLevel level = inference.getGraphNode().getGraphLevel(); |
|
| 225 | - if (level != null) { |
|
| 280 | + |
|
| 281 | + // Determine best classification per node based on sum of path qualities |
|
| 282 | + for (Map.Entry<MstGraphLevel, Map<ManeuverTypeForClassification, List<Double>>> entry : pathVotes.entrySet()) { |
|
| 283 | + MstGraphLevel level = entry.getKey(); |
|
| 284 | + Map<ManeuverTypeForClassification, List<Double>> votes = entry.getValue(); |
|
| 285 | + |
|
| 286 | + // Find classification with highest sum of path qualities |
|
| 287 | + double maxSum = -1; |
|
| 288 | + ManeuverTypeForClassification bestType = null; |
|
| 289 | + |
|
| 290 | + for (Map.Entry<ManeuverTypeForClassification, List<Double>> typeVotes : votes.entrySet()) { |
|
| 291 | + double sum = typeVotes.getValue().stream().mapToDouble(Double::doubleValue).sum(); |
|
| 292 | + if (sum > maxSum) { |
|
| 293 | + maxSum = sum; |
|
| 294 | + bestType = typeVotes.getKey(); |
|
| 295 | + } |
|
| 296 | + } |
|
| 297 | + |
|
| 298 | + if (bestType != null) { |
|
| 226 | 299 | Integer nodeId = levelToId.get(level); |
| 227 | 300 | if (nodeId != null) { |
| 228 | - bestNodePerLevel.put(String.valueOf(nodeId), inference.getGraphNode().getManeuverType().name()); |
|
| 301 | + bestNodePerLevel.put(String.valueOf(nodeId), bestType.name()); |
|
| 229 | 302 | } |
| 230 | 303 | } |
| 231 | 304 | } |
| 305 | + |
|
| 232 | 306 | return bestNodePerLevel; |
| 233 | 307 | } |
| 308 | + |
|
| 309 | + /** |
|
| 310 | + * Formats path vote diagnostics for a node as a JSON string for inclusion in the export. |
|
| 311 | + */ |
|
| 312 | + private String formatPathVoteDiagnostics(MstGraphLevel level, |
|
| 313 | + Map<MstGraphLevel, Map<ManeuverTypeForClassification, List<Double>>> pathVotes) { |
|
| 314 | + Map<ManeuverTypeForClassification, List<Double>> votes = pathVotes.get(level); |
|
| 315 | + if (votes == null || votes.isEmpty()) { |
|
| 316 | + return "{}"; |
|
| 317 | + } |
|
| 318 | + |
|
| 319 | + StringBuilder sb = new StringBuilder(); |
|
| 320 | + sb.append("{"); |
|
| 321 | + boolean first = true; |
|
| 322 | + for (ManeuverTypeForClassification type : ManeuverTypeForClassification.values()) { |
|
| 323 | + List<Double> typeVotes = votes.get(type); |
|
| 324 | + if (typeVotes != null && !typeVotes.isEmpty()) { |
|
| 325 | + if (!first) sb.append(", "); |
|
| 326 | + first = false; |
|
| 327 | + double sum = typeVotes.stream().mapToDouble(Double::doubleValue).sum(); |
|
| 328 | + sb.append("\"").append(type.name()).append("\": {"); |
|
| 329 | + sb.append("\"pathCount\": ").append(typeVotes.size()); |
|
| 330 | + sb.append(", \"qualitySum\": ").append(sum); |
|
| 331 | + sb.append(", \"qualities\": ["); |
|
| 332 | + sb.append(typeVotes.stream().map(d -> String.format("%.6e", d)).collect(Collectors.joining(", "))); |
|
| 333 | + sb.append("]}"); |
|
| 334 | + } |
|
| 335 | + } |
|
| 336 | + sb.append("}"); |
|
| 337 | + return sb.toString(); |
|
| 338 | + } |
|
| 234 | 339 | |
| 235 | 340 | private void assignNodeIds(MstGraphLevel level, Map<MstGraphLevel, Integer> levelToId, int[] nodeIdCounter) { |
| 236 | 341 | levelToId.put(level, nodeIdCounter[0]++); |