From c33a8df85ffd76a25a28cfe215f1cd0bad5a2528 Mon Sep 17 00:00:00 2001
From: Jip Spel <jip.spel@cs.rwth-aachen.de>
Date: Thu, 31 Jan 2019 18:34:26 +0100
Subject: [PATCH] Eliminate selfloop introduced by SCC elimination

---
 src/storm-pars-cli/storm-pars.cpp             | 10 ++++-
 .../stateelimination/EliminatorBase.cpp       | 42 ++++++++++++++++++-
 .../solver/stateelimination/EliminatorBase.h  |  4 +-
 3 files changed, 52 insertions(+), 4 deletions(-)

diff --git a/src/storm-pars-cli/storm-pars.cpp b/src/storm-pars-cli/storm-pars.cpp
index 60fda624b..bfc81bf72 100644
--- a/src/storm-pars-cli/storm-pars.cpp
+++ b/src/storm-pars-cli/storm-pars.cpp
@@ -564,7 +564,7 @@ namespace storm {
 
             if (parSettings.isSccEliminationSet()) {
                 // TODO: check for correct Model type
-                std::cout << "Applying scc elimination" << std::endl;
+                STORM_PRINT("Applying scc elimination" << std::endl);
                 auto sparseModel = model->as<storm::models::sparse::Model<ValueType>>();
                 auto matrix = sparseModel->getTransitionMatrix();
                 auto backwardsTransitionMatrix = matrix.transpose();
@@ -572,6 +572,7 @@ namespace storm {
                 auto decomposition = storm::storage::StronglyConnectedComponentDecomposition<ValueType>(matrix, false, false);
 
                 storm::storage::BitVector selectedStates(matrix.getRowCount());
+                storm::storage::BitVector selfLoopStates(matrix.getRowCount());
                 for (auto i = 0; i < decomposition.size(); ++i) {
                     auto scc = decomposition.getBlock(i);
                     if (scc.size() > 1) {
@@ -588,6 +589,7 @@ namespace storm {
                             }
                             if (found) {
                                 entryStates.push_back(state);
+                                selfLoopStates.set(state);
                             } else {
                                 selectedStates.set(state);
                             }
@@ -607,6 +609,10 @@ namespace storm {
                 for(auto state : selectedStates) {
                     stateEliminator.eliminateState(state, true);
                 }
+                for (auto state : selfLoopStates) {
+                    auto row = flexibleMatrix.getRow(state);
+                    stateEliminator.eliminateLoop(state);
+                }
                 selectedStates.complement();
                 auto keptRows = matrix.getRowFilter(selectedStates);
                 storm::storage::SparseMatrix<ValueType> newTransitionMatrix = flexibleMatrix.createSparseMatrix(keptRows, selectedStates);
@@ -620,7 +626,7 @@ namespace storm {
                 model = std::make_shared<storm::models::sparse::Dtmc<ValueType>>(std::move(newTransitionMatrix), sparseModel->getStateLabeling().getSubLabeling(selectedStates));
 
 
-                std::cout << "SCC Elimination applied" << std::endl;
+                STORM_PRINT("SCC Elimination applied" << std::endl);
             }
 
             if (parSettings.isMonotonicityAnalysisSet()) {
diff --git a/src/storm/solver/stateelimination/EliminatorBase.cpp b/src/storm/solver/stateelimination/EliminatorBase.cpp
index 041b610bf..bc523b073 100644
--- a/src/storm/solver/stateelimination/EliminatorBase.cpp
+++ b/src/storm/solver/stateelimination/EliminatorBase.cpp
@@ -249,7 +249,47 @@ namespace storm {
                     elementsWithEntryInColumnEqualRow.shrink_to_fit();
                 }
             }
-            
+
+            template<typename ValueType, ScalingMode Mode>
+            void EliminatorBase<ValueType, Mode>::eliminateLoop(uint64_t state) {
+                // Start by finding value of the selfloop.
+                bool hasEntryInColumn = false;
+                ValueType columnValue = storm::utility::zero<ValueType>();
+                FlexibleRowType& entriesInRow = matrix.getRow(state);
+                for (auto entryIt = entriesInRow.begin(), entryIte = entriesInRow.end(); entryIt != entryIte; ++entryIt) {
+                    if (entryIt->getColumn() == state) {
+                        columnValue = entryIt->getValue();
+                        hasEntryInColumn = true;
+                    }
+                }
+
+                // Scale all entries in this row.
+                // Depending on the scaling mode, we scale the other entries of the row.
+                STORM_LOG_TRACE((hasEntryInColumn ? "State has entry in column." : "State does not have entry in column."));
+                if (Mode == ScalingMode::Divide) {
+                    STORM_LOG_ASSERT(hasEntryInColumn, "The scaling mode 'divide' requires an element in the given column.");
+                    STORM_LOG_ASSERT(storm::utility::isZero(columnValue), "The scaling mode 'divide' requires a non-zero element in the given column.");
+                    columnValue = storm::utility::one<ValueType>() / columnValue;
+                } else if (Mode == ScalingMode::DivideOneMinus) {
+                    if (hasEntryInColumn) {
+                        STORM_LOG_ASSERT(columnValue != storm::utility::one<ValueType>(), "The scaling mode 'divide-one-minus' requires a non-one value in the given column.");
+                        columnValue = storm::utility::one<ValueType>() / (storm::utility::one<ValueType>() - columnValue);
+                        columnValue = storm::utility::simplify(columnValue);
+                    }
+                }
+
+                if (hasEntryInColumn) {
+                    for (auto entryIt = entriesInRow.begin(), entryIte = entriesInRow.end(); entryIt != entryIte; ++entryIt) {
+                        // Scale the entries in a different column, set state transition probability to 0.
+                        if (entryIt->getColumn() != state) {
+                            entryIt->setValue(storm::utility::simplify((ValueType) (entryIt->getValue() * columnValue)));
+                        } else {
+                            entryIt->setValue(storm::utility::zero<ValueType>());
+                        }
+                    }
+                }
+            }
+
             template<typename ValueType, ScalingMode Mode>
             void EliminatorBase<ValueType, Mode>::updateValue(storm::storage::sparse::state_type const&, ValueType const&) {
                 // Intentionally left empty.
diff --git a/src/storm/solver/stateelimination/EliminatorBase.h b/src/storm/solver/stateelimination/EliminatorBase.h
index 996da58e6..0d66ab18a 100644
--- a/src/storm/solver/stateelimination/EliminatorBase.h
+++ b/src/storm/solver/stateelimination/EliminatorBase.h
@@ -22,7 +22,9 @@ namespace storm {
                 virtual ~EliminatorBase() = default;
 
                 void eliminate(uint64_t row, uint64_t column, bool clearRow);
-                
+
+                void eliminateLoop(uint64_t row);
+
                 // Provide virtual methods that can be customized by subclasses to govern side-effect of the elimination.
                 virtual void updateValue(storm::storage::sparse::state_type const& state, ValueType const& loopProbability);
                 virtual void updatePredecessor(storm::storage::sparse::state_type const& predecessor, ValueType const& probability, storm::storage::sparse::state_type const& state);