You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
5.2 KiB

  1. #include "gtest/gtest.h"
  2. #include "storm-config.h"
  3. #include "src/storage/dd/CuddDdManager.h"
  4. #include "src/storage/dd/CuddAdd.h"
  5. #include "src/storage/dd/CuddBdd.h"
  6. #include "src/utility/solver.h"
  7. #include "src/settings/SettingsManager.h"
  8. #include "src/solver/SymbolicGameSolver.h"
  9. #include "src/settings/modules/NativeEquationSolverSettings.h"
  10. TEST(FullySymbolicGameSolverTest, Solve) {
  11. // Create some variables.
  12. std::shared_ptr<storm::dd::DdManager<storm::dd::DdType::CUDD>> manager(new storm::dd::DdManager<storm::dd::DdType::CUDD>());
  13. std::pair<storm::expressions::Variable, storm::expressions::Variable> state = manager->addMetaVariable("x", 1, 4);
  14. std::pair<storm::expressions::Variable, storm::expressions::Variable> pl1 = manager->addMetaVariable("a", 0, 1);
  15. std::pair<storm::expressions::Variable, storm::expressions::Variable> pl2 = manager->addMetaVariable("b", 0, 1);
  16. storm::dd::Bdd<storm::dd::DdType::CUDD> allRows = manager->getBddZero();
  17. std::set<storm::expressions::Variable> rowMetaVariables({state.first});
  18. std::set<storm::expressions::Variable> columnMetaVariables({state.second});
  19. std::vector<std::pair<storm::expressions::Variable, storm::expressions::Variable>> rowColumnMetaVariablePairs = {state};
  20. std::set<storm::expressions::Variable> player1Variables({pl1.first});
  21. std::set<storm::expressions::Variable> player2Variables({pl2.first});
  22. // Construct simple game.
  23. storm::dd::Add<storm::dd::DdType::CUDD> matrix = manager->getEncoding(state.first, 1).toAdd() * manager->getEncoding(state.second, 2).toAdd() * manager->getEncoding(pl1.first, 0).toAdd() * manager->getEncoding(pl2.first, 0).toAdd() * manager->getConstant(0.6);
  24. matrix += manager->getEncoding(state.first, 1).toAdd() * manager->getEncoding(state.second, 1).toAdd() * manager->getEncoding(pl1.first, 0).toAdd() * manager->getEncoding(pl2.first, 0).toAdd() * manager->getConstant(0.4);
  25. matrix += manager->getEncoding(state.first, 1).toAdd() * manager->getEncoding(state.second, 2).toAdd() * manager->getEncoding(pl1.first, 0).toAdd() * manager->getEncoding(pl2.first, 1).toAdd() * manager->getConstant(0.2);
  26. matrix += manager->getEncoding(state.first, 1).toAdd() * manager->getEncoding(state.second, 3).toAdd() * manager->getEncoding(pl1.first, 0).toAdd() * manager->getEncoding(pl2.first, 1).toAdd() * manager->getConstant(0.8);
  27. matrix += manager->getEncoding(state.first, 1).toAdd() * manager->getEncoding(state.second, 3).toAdd() * manager->getEncoding(pl1.first, 1).toAdd() * manager->getEncoding(pl2.first, 0).toAdd() * manager->getConstant(0.5);
  28. matrix += manager->getEncoding(state.first, 1).toAdd() * manager->getEncoding(state.second, 4).toAdd() * manager->getEncoding(pl1.first, 1).toAdd() * manager->getEncoding(pl2.first, 0).toAdd() * manager->getConstant(0.5);
  29. matrix += manager->getEncoding(state.first, 1).toAdd() * manager->getEncoding(state.second, 1).toAdd() * manager->getEncoding(pl1.first, 1).toAdd() * manager->getEncoding(pl2.first, 1).toAdd() * manager->getConstant(1);
  30. std::unique_ptr<storm::utility::solver::SymbolicGameSolverFactory<storm::dd::DdType::CUDD>> solverFactory(new storm::utility::solver::SymbolicGameSolverFactory<storm::dd::DdType::CUDD>());
  31. std::unique_ptr<storm::solver::SymbolicGameSolver<storm::dd::DdType::CUDD>> solver = solverFactory->create(matrix, allRows, rowMetaVariables, columnMetaVariables, rowColumnMetaVariablePairs, player1Variables,player2Variables);
  32. // Create solution and target state vector.
  33. storm::dd::Add<storm::dd::DdType::CUDD> x = manager->getAddZero();
  34. storm::dd::Add<storm::dd::DdType::CUDD> b = manager->getEncoding(state.first, 2).toAdd() + manager->getEncoding(state.first, 4).toAdd();
  35. // Now solve the game with different strategies for the players.
  36. storm::dd::Add<storm::dd::DdType::CUDD> result = solver->solveGame(storm::OptimizationDirection::Minimize, storm::OptimizationDirection::Minimize, x, b);
  37. result *= manager->getEncoding(state.first, 1).toAdd();
  38. result = result.sumAbstract({state.first});
  39. EXPECT_NEAR(0, result.getValue(), storm::settings::nativeEquationSolverSettings().getPrecision());
  40. x = manager->getAddZero();
  41. result = solver->solveGame(storm::OptimizationDirection::Minimize, storm::OptimizationDirection::Maximize, x, b);
  42. result *= manager->getEncoding(state.first, 1).toAdd();
  43. result = result.sumAbstract({state.first});
  44. EXPECT_NEAR(0.5, result.getValue(), storm::settings::nativeEquationSolverSettings().getPrecision());
  45. x = manager->getAddZero();
  46. result = solver->solveGame(storm::OptimizationDirection::Maximize, storm::OptimizationDirection::Minimize, x, b);
  47. result *= manager->getEncoding(state.first, 1).toAdd();
  48. result = result.sumAbstract({state.first});
  49. EXPECT_NEAR(0.2, result.getValue(), storm::settings::nativeEquationSolverSettings().getPrecision());
  50. x = manager->getAddZero();
  51. result = solver->solveGame(storm::OptimizationDirection::Maximize, storm::OptimizationDirection::Maximize, x, b);
  52. result *= manager->getEncoding(state.first, 1).toAdd();
  53. result = result.sumAbstract({state.first});
  54. EXPECT_NEAR(0.99999892625817599, result.getValue(), storm::settings::nativeEquationSolverSettings().getPrecision());
  55. }