From c1ec3032fa6b57aa989a145ee2e7febcca1d489d Mon Sep 17 00:00:00 2001
From: Sebastian Junges <sebastian.junges@gmail.com>
Date: Thu, 18 Feb 2021 23:21:54 -0800
Subject: [PATCH] reset to state

---
 src/storm/generator/CompressedState.cpp       | 19 +++++++++++++++++++
 src/storm/generator/CompressedState.h         |  5 ++++-
 src/storm/simulator/PrismProgramSimulator.cpp | 12 ++++++++++++
 src/storm/simulator/PrismProgramSimulator.h   |  4 ++++
 4 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/src/storm/generator/CompressedState.cpp b/src/storm/generator/CompressedState.cpp
index 041a0fac1..1cbf7438e 100644
--- a/src/storm/generator/CompressedState.cpp
+++ b/src/storm/generator/CompressedState.cpp
@@ -48,6 +48,25 @@ namespace storm {
             return result;
         }
 
+        CompressedState packStateFromValuation(expressions::SimpleValuation const& valuation, VariableInformation const& variableInformation, bool checkOutOfBounds) {
+            CompressedState result(variableInformation.getTotalBitOffset(true));
+            STORM_LOG_THROW(variableInformation.locationVariables.size() == 0, storm::exceptions::NotImplementedException, "Support for JANI is not implemented");
+            for (auto const& booleanVariable : variableInformation.booleanVariables) {
+                result.set(booleanVariable.bitOffset, valuation.getBooleanValue(booleanVariable.variable));
+            }
+            for (auto const& integerVariable : variableInformation.integerVariables) {
+                int64_t assignedValue = valuation.getIntegerValue(integerVariable.variable);
+                if (checkOutOfBounds) {
+                    STORM_LOG_THROW(assignedValue >= integerVariable.lowerBound, storm::exceptions::InvalidArgumentException, "The assignment leads to an out-of-bounds value (" << assignedValue << ") for the variable '" << integerVariable.getName() << "'.");
+                    STORM_LOG_THROW(assignedValue <= integerVariable.upperBound, storm::exceptions::InvalidArgumentException, "The assignment leads to an out-of-bounds value (" << assignedValue << ") for the variable '" << integerVariable.getName() << "'.");
+                }
+                result.setFromInt(integerVariable.bitOffset, integerVariable.bitWidth, assignedValue - integerVariable.lowerBound);
+                STORM_LOG_ASSERT(static_cast<int_fast64_t>(result.getAsInt(integerVariable.bitOffset, integerVariable.bitWidth)) + integerVariable.lowerBound == assignedValue, "Writing to the bit vector bucket failed (read " << result.getAsInt(integerVariable.bitOffset, integerVariable.bitWidth) << " but wrote " << assignedValue << ").");
+            }
+
+            return result;
+        }
+
         void extractVariableValues(CompressedState const& state, VariableInformation const& variableInformation, std::vector<int64_t>& locationValues, std::vector<bool>& booleanValues, std::vector<int64_t>& integerValues) {
             for (auto const& locationVariable : variableInformation.locationVariables) {
                 if (locationVariable.bitWidth != 0) {
diff --git a/src/storm/generator/CompressedState.h b/src/storm/generator/CompressedState.h
index 801e00fe9..b8d7fe05d 100644
--- a/src/storm/generator/CompressedState.h
+++ b/src/storm/generator/CompressedState.h
@@ -92,7 +92,10 @@ namespace storm {
         CompressedState createOutOfBoundsState(VariableInformation const& varInfo, bool roundTo64Bit = true);
 
         CompressedState createCompressedState(VariableInformation const& varInfo, std::map<storm::expressions::Variable, storm::expressions::Expression> const& stateDescription, bool checkOutOfBounds);
-    }
+
+        CompressedState packStateFromValuation(expressions::SimpleValuation const& valuation, VariableInformation const& variableInformation, bool checkOutOfBounds = false);
+
+        }
 }
 
 #endif /* STORM_GENERATOR_COMPRESSEDSTATE_H_ */
diff --git a/src/storm/simulator/PrismProgramSimulator.cpp b/src/storm/simulator/PrismProgramSimulator.cpp
index 07911ef43..0b72f6f58 100644
--- a/src/storm/simulator/PrismProgramSimulator.cpp
+++ b/src/storm/simulator/PrismProgramSimulator.cpp
@@ -110,6 +110,18 @@ namespace storm {
             return explore();
         }
 
+        template<typename ValueType>
+        bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(generator::CompressedState const& newState) {
+            currentState = newState;
+            return explore();
+        }
+
+        template<typename ValueType>
+        bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(expressions::SimpleValuation const& valuation) {
+            currentState = generator::packStateFromValuation(valuation, stateGenerator->getVariableInformation(), true);
+            return explore();
+        }
+
         template<typename ValueType>
         uint32_t DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex(generator::CompressedState const& state) {
             uint32_t newIndex = static_cast<uint32_t>(stateToId.size());
diff --git a/src/storm/simulator/PrismProgramSimulator.h b/src/storm/simulator/PrismProgramSimulator.h
index 034d2791d..99e828bf9 100644
--- a/src/storm/simulator/PrismProgramSimulator.h
+++ b/src/storm/simulator/PrismProgramSimulator.h
@@ -70,6 +70,10 @@ namespace storm {
              * @return
              */
             bool resetToInitial();
+
+            bool resetToState(generator::CompressedState const& compressedState);
+
+            bool resetToState(expressions::SimpleValuation const& valuationState);
         protected:
             bool explore();
             void clearStateCaches();