Browse Source

choice labeling

tempestpy_adaptions
radioGiorgio 6 years ago
parent
commit
b2d7b1e096
  1. 86
      src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp
  2. 1
      src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h

86
src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp

@ -18,11 +18,19 @@ namespace storm {
storm::storage::sparse::ModelComponents<ValueType> components; storm::storage::sparse::ModelComponents<ValueType> components;
components.transitionMatrix = buildTransitions(); components.transitionMatrix = buildTransitions();
components.stateLabeling = buildStateLabeling(); components.stateLabeling = buildStateLabeling();
components.choiceLabeling = buildChoiceLabeling(components.transitionMatrix);
// Now delete unreachable states. // Now delete unreachable states.
storm::storage::BitVector allStates(components.transitionMatrix.getRowGroupCount(), true); storm::storage::BitVector allStates(components.transitionMatrix.getRowGroupCount(), true);
auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates);
storm::storage::BitVector enabledActions(components.transitionMatrix.getRowCount());
for (uint64_t state : reachableStates) {
for (uint64_t row = components.transitionMatrix.getRowGroupIndices()[state]; row < components.transitionMatrix.getRowGroupIndices()[state + 1]; ++ row) {
enabledActions.set(row);
}
}
components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates); components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates);
components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates); components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates);
components.choiceLabeling = components.choiceLabeling->getSubLabeling(enabledActions);
// build the remaining components // build the remaining components
for (auto const& rewModel : model.getRewardModels()) { for (auto const& rewModel : model.getRewardModels()) {
@ -130,12 +138,20 @@ namespace storm {
for (uint64_t modelState = 0; modelState < model.getNumberOfStates(); ++ modelState) { for (uint64_t modelState = 0; modelState < model.getNumberOfStates(); ++ modelState) {
if (forceLabeling) { if (forceLabeling) {
for (uint64_t memoryState = 0; memoryState < memory.getNumberOfStates(); ++ memoryState) { for (uint64_t memoryState = 0; memoryState < memory.getNumberOfStates(); ++ memoryState) {
if (labeling.getLabelsOfState(getProductState(modelState, memoryState)).empty()) {
std::ostringstream stream;
stream << "s" << modelState;
std::string labelName = stream.str();
addLabel(labelName, getProductState(modelState, memoryState));
}
{
std::ostringstream stream; std::ostringstream stream;
stream << "m" << memoryState; stream << "m" << memoryState;
std::string labelName = stream.str(); std::string labelName = stream.str();
addLabel(labelName, getProductState(modelState, memoryState)); addLabel(labelName, getProductState(modelState, memoryState));
} }
} }
}
uint64_t entryCount = 0; uint64_t entryCount = 0;
for (uint64_t row = origTransitions.getRowGroupIndices()[modelState]; row < origTransitions.getRowGroupIndices()[modelState + 1]; ++ row) { for (uint64_t row = origTransitions.getRowGroupIndices()[modelState]; row < origTransitions.getRowGroupIndices()[modelState + 1]; ++ row) {
for (auto const& entry : origTransitions.getRow(row)) { for (auto const& entry : origTransitions.getRow(row)) {
@ -143,7 +159,9 @@ namespace storm {
for (uint64_t memoryState = 0; memoryState < memory.getNumberOfStates(); ++ memoryState) { for (uint64_t memoryState = 0; memoryState < memory.getNumberOfStates(); ++ memoryState) {
uint64_t productState = getProductState(modelState, memoryState) + 1 + entryCount; uint64_t productState = getProductState(modelState, memoryState) + 1 + entryCount;
// origin state // origin state
if (model.getStateLabeling().getLabelsOfState(modelState).empty()) {
if ( model.getStateLabeling().getLabelsOfState(modelState).empty() or
(model.getStateLabeling().getLabelsOfState(modelState).size() == 1 and model.getStateLabeling().getStateHasLabel("init", modelState)
and not labeling.getStateHasLabel("init", productState)) ){
if (forceLabeling) { if (forceLabeling) {
std::ostringstream stream; std::ostringstream stream;
stream << "s" << modelState; stream << "s" << modelState;
@ -152,9 +170,11 @@ namespace storm {
} }
} else { } else {
for (auto const& labelName : model.getStateLabeling().getLabelsOfState(modelState)) { for (auto const& labelName : model.getStateLabeling().getLabelsOfState(modelState)) {
if (labelName != "init") {
addLabel(labelName, productState); addLabel(labelName, productState);
} }
} }
}
// memory labeling // memory labeling
if (forceLabeling) { if (forceLabeling) {
std::ostringstream stream; std::ostringstream stream;
@ -175,7 +195,9 @@ namespace storm {
} }
} }
// arrival state // arrival state
if (model.getStateLabeling().getLabelsOfState(successor).empty()) {
if ( model.getStateLabeling().getLabelsOfState(successor).empty() or
(model.getStateLabeling().getLabelsOfState(successor).size() == 1 and model.getStateLabeling().getStateHasLabel("init", successor)
and not labeling.getStateHasLabel("init", productState)) ){
if (forceLabeling) { if (forceLabeling) {
std::ostringstream stream; std::ostringstream stream;
stream << "s" << successor; stream << "s" << successor;
@ -184,10 +206,12 @@ namespace storm {
} }
} else { } else {
for (auto const& labelName : model.getStateLabeling().getLabelsOfState(successor)) { for (auto const& labelName : model.getStateLabeling().getLabelsOfState(successor)) {
if (labelName != "init") {
addLabel(labelName, productState); addLabel(labelName, productState);
} }
} }
} }
}
++ entryCount; ++ entryCount;
} }
} }
@ -195,6 +219,54 @@ namespace storm {
return labeling; return labeling;
} }
template<typename SparseModelType>
storm::models::sparse::ChoiceLabeling SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::buildChoiceLabeling(storm::storage::SparseMatrix<ValueType> const& transitions) const {
storm::storage::SparseMatrix<ValueType> const& origTransitions = model.getTransitionMatrix();
storm::models::sparse::ChoiceLabeling labeling(transitions.getRowCount());
auto addLabel = [&] (std::string const& labelName, uint64_t row) -> void {
if (not labeling.containsLabel(labelName)) {
labeling.addLabel(labelName);
}
labeling.addLabelToChoice(labelName, row);
};
uint64_t row = 0;
for (uint64_t modelState = 0; modelState < model.getNumberOfStates(); ++ modelState) {
for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++ memState) {
for (uint64_t origRow = origTransitions.getRowGroupIndices()[modelState]; origRow < origTransitions.getRowGroupIndices()[modelState + 1]; ++origRow) {
if (forceLabeling and (not model.getOptionalChoiceLabeling()
or model.getChoiceLabeling().getLabelsOfChoice(origRow).empty())) {
std::ostringstream stream;
stream << "a" << origRow;
std::string labelName = stream.str();
addLabel(labelName, row);
} else if (model.getOptionalChoiceLabeling()) {
for (auto const &labelName : model.getChoiceLabeling().getLabelsOfChoice(origRow)) {
addLabel(labelName, row);
}
}
++row;
}
// transition states
for (uint64_t origRow = origTransitions.getRowGroupIndices()[modelState]; origRow < origTransitions.getRowGroupIndices()[modelState + 1]; ++origRow) {
for (auto const& entry : origTransitions.getRow(origRow)) {
for (auto const& memStatePrime : memory.getTransitions(memState)) {
if (forceLabeling) {
std::ostringstream stream;
stream << "m" << memStatePrime;
std::string labelName = stream.str();
addLabel(labelName, row);
}
++row;
}
}
}
}
}
return labeling;
}
template<typename SparseModelType> template<typename SparseModelType>
storm::models::sparse::StandardRewardModel<typename SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::ValueType> SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::buildRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel, storm::storage::BitVector const& reachableStates, storm::storage::SparseMatrix<ValueType> const& resultTransitionMatrix) const { storm::models::sparse::StandardRewardModel<typename SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::ValueType> SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::buildRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel, storm::storage::BitVector const& reachableStates, storm::storage::SparseMatrix<ValueType> const& resultTransitionMatrix) const {
boost::optional<std::vector<ValueType>> stateRewards, actionRewards; boost::optional<std::vector<ValueType>> stateRewards, actionRewards;
@ -235,6 +307,7 @@ namespace storm {
template<typename SparseModelType> template<typename SparseModelType>
std::vector<uint64_t> SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::generateOffsetVector(storm::storage::BitVector const& reachableStates) { std::vector<uint64_t> SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::generateOffsetVector(storm::storage::BitVector const& reachableStates) {
uint64_t numberOfStates = model.getNumberOfStates() * memory.getNumberOfStates() * (1 + model.getNumberOfTransitions()); uint64_t numberOfStates = model.getNumberOfStates() * memory.getNumberOfStates() * (1 + model.getNumberOfTransitions());
STORM_LOG_ASSERT(reachableStates.size() == numberOfStates, "wrong size for the vector reachableStates");
uint64_t offset = 0; uint64_t offset = 0;
std::vector<uint64_t> offsetVector(numberOfStates); std::vector<uint64_t> offsetVector(numberOfStates);
for (uint64_t state = 0; state < numberOfStates; ++ state) { for (uint64_t state = 0; state < numberOfStates; ++ state) {
@ -262,18 +335,17 @@ namespace storm {
template<typename SparseModelType> template<typename SparseModelType>
uint64_t SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::getModelState(uint64_t productState) const { uint64_t SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::getModelState(uint64_t productState) const {
uint64_t productStateWithOffset = productState + (fullProductStatesOffset.empty() ? 0 : fullProductStatesOffset[productState]);
// binary search in the vector containing the product states indices // binary search in the vector containing the product states indices
auto search = std::upper_bound(productStates.begin(), productStates.end(), productState);
uint64_t index = search - productStates.begin() - 1;
return index - (fullProductStatesOffset.empty() ? 0 : fullProductStatesOffset[index]);
auto search = std::upper_bound(productStates.begin(), productStates.end(), productStateWithOffset);
return search - productStates.begin() - 1;
} }
template<typename SparseModelType> template<typename SparseModelType>
uint64_t SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::getMemoryState(uint64_t productState) const { uint64_t SparseModelNondeterministicTransitionsBasedMemoryProduct<SparseModelType>::getMemoryState(uint64_t productState) const {
uint64_t modelState = getModelState(productState); uint64_t modelState = getModelState(productState);
uint64_t offset = productState - productStates[modelState]; uint64_t offset = productState - productStates[modelState];
uint64_t index = offset / (1 + model.getTransitionMatrix().getRowGroupEntryCount(modelState));
return index - (fullProductStatesOffset.empty() ? 0 : fullProductStatesOffset[index]);
return offset / (1 + model.getTransitionMatrix().getRowGroupEntryCount(modelState));
} }
template class SparseModelNondeterministicTransitionsBasedMemoryProduct<storm::models::sparse::Mdp<double>>; template class SparseModelNondeterministicTransitionsBasedMemoryProduct<storm::models::sparse::Mdp<double>>;

1
src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h

@ -36,6 +36,7 @@ namespace storm {
private: private:
storm::storage::SparseMatrix<ValueType> buildTransitions(); storm::storage::SparseMatrix<ValueType> buildTransitions();
storm::models::sparse::StateLabeling buildStateLabeling() const; storm::models::sparse::StateLabeling buildStateLabeling() const;
storm::models::sparse::ChoiceLabeling buildChoiceLabeling(storm::storage::SparseMatrix<ValueType> const& transitions) const;
storm::models::sparse::StandardRewardModel<ValueType> buildRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel, storm::storage::BitVector const& reachableStates, storm::storage::SparseMatrix<ValueType> const& resultTransitionMatrix) const; storm::models::sparse::StandardRewardModel<ValueType> buildRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel, storm::storage::BitVector const& reachableStates, storm::storage::SparseMatrix<ValueType> const& resultTransitionMatrix) const;
std::vector<uint64_t> generateOffsetVector(storm::storage::BitVector const& reachableStates); std::vector<uint64_t> generateOffsetVector(storm::storage::BitVector const& reachableStates);

Loading…
Cancel
Save