You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
5.1 KiB
64 lines
5.1 KiB
#include "tracker.h"
|
|
#include "src/helpers.h"
|
|
#include <storm-pomdp/generator/BeliefSupportTracker.h>
|
|
#include <storm-pomdp/generator/NondeterministicBeliefTracker.h>
|
|
|
|
|
|
template<typename ValueType> using SparsePomdp = storm::models::sparse::Pomdp<ValueType>;
|
|
template<typename ValueType> using SparsePomdpTracker = storm::generator::BeliefSupportTracker<ValueType>;
|
|
|
|
template<typename ValueType> using NDPomdpTrackerSparse = storm::generator::NondeterministicBeliefTracker<ValueType, storm::generator::SparseBeliefState<ValueType>>;
|
|
template<typename ValueType> using NDPomdpTrackerDense = storm::generator::NondeterministicBeliefTracker<ValueType, storm::generator::ObservationDenseBeliefState<ValueType>>;
|
|
|
|
|
|
template<typename ValueType>
|
|
void define_tracker(py::module& m, std::string const& vtSuffix) {
|
|
py::class_<storm::generator::BeliefSupportTracker<ValueType>> tracker(m, ("BeliefSupportTracker" + vtSuffix).c_str(), "Tracker for BeliefSupports");
|
|
tracker.def(py::init<SparsePomdp<ValueType> const&>(), py::arg("pomdp"));
|
|
tracker.def("get_current_belief_support", &SparsePomdpTracker<ValueType>::getCurrentBeliefSupport, "What is the support given the trace so far");
|
|
tracker.def("track", &SparsePomdpTracker<ValueType>::track, py::arg("action"), py::arg("observation"));
|
|
|
|
py::class_<storm::generator::SparseBeliefState<ValueType>> sbel(m, ("SparseBeliefState" + vtSuffix).c_str(), "Belief state in sparse format");
|
|
sbel.def("get", &storm::generator::SparseBeliefState<ValueType>::get, py::arg("state"));
|
|
sbel.def_property_readonly("risk", &storm::generator::SparseBeliefState<ValueType>::getRisk);
|
|
sbel.def("__str__", &storm::generator::SparseBeliefState<ValueType>::toString);
|
|
sbel.def_property_readonly("is_valid", &storm::generator::SparseBeliefState<ValueType>::isValid);
|
|
//
|
|
// py::class_<storm::generator::ObservationDenseBeliefState<double>> dbel(m, "DenseBeliefStateDouble", "Belief state in dense format");
|
|
// dbel.def("get", &storm::generator::ObservationDenseBeliefState<double>::get, py::arg("state"));
|
|
// dbel.def_property_readonly("risk", &storm::generator::ObservationDenseBeliefState<double>::getRisk);
|
|
// dbel.def("__str__", &storm::generator::ObservationDenseBeliefState<double>::toString);
|
|
|
|
py::class_<typename NDPomdpTrackerSparse<ValueType>::Options> opts(m, ("NondeterministicBeliefTracker" + vtSuffix + "SparseOptions").c_str(), "Options for the corresponding tracker");
|
|
opts.def(py::init<>());
|
|
opts.def_readwrite("track_timeout", &NDPomdpTrackerSparse<ValueType>::Options::trackTimeOut);
|
|
opts.def_readwrite("reduction_timeout", &NDPomdpTrackerSparse<ValueType>::Options::timeOut);
|
|
opts.def_readwrite("reduction_wiggle", &NDPomdpTrackerSparse<ValueType>::Options::wiggle);
|
|
|
|
py::class_<NDPomdpTrackerSparse<ValueType>> ndetbelieftracker(m, ("NondeterministicBeliefTracker" + vtSuffix + "Sparse").c_str(), "Tracker for belief states and uncontrollable actions");
|
|
ndetbelieftracker.def(py::init<SparsePomdp<ValueType> const&, typename NDPomdpTrackerSparse<ValueType>::Options>(), py::arg("pomdp"), py::arg("options"));
|
|
ndetbelieftracker.def("reset", &NDPomdpTrackerSparse<ValueType>::reset);
|
|
ndetbelieftracker.def("set_risk", &NDPomdpTrackerSparse<ValueType>::setRisk, py::arg("risk"));
|
|
ndetbelieftracker.def("obtain_current_risk",&NDPomdpTrackerSparse<ValueType>::getCurrentRisk, py::arg("max")=true);
|
|
ndetbelieftracker.def("track", &NDPomdpTrackerSparse<ValueType>::track, py::arg("observation"));
|
|
ndetbelieftracker.def("obtain_beliefs", &NDPomdpTrackerSparse<ValueType>::getCurrentBeliefs);
|
|
ndetbelieftracker.def("size", &NDPomdpTrackerSparse<ValueType>::getNumberOfBeliefs);
|
|
ndetbelieftracker.def("dimension", &NDPomdpTrackerSparse<ValueType>::getCurrentDimension);
|
|
ndetbelieftracker.def("obtain_last_observation", &NDPomdpTrackerSparse<ValueType>::getCurrentObservation);
|
|
ndetbelieftracker.def("reduce",&NDPomdpTrackerSparse<ValueType>::reduce);
|
|
ndetbelieftracker.def("reduction_timed_out", &NDPomdpTrackerSparse<ValueType>::hasTimedOut);
|
|
|
|
// py::class_<NDPomdpTrackerDense<double>> ndetbelieftrackerd(m, "NondeterministicBeliefTrackerDoubleDense", "Tracker for belief states and uncontrollable actions");
|
|
// ndetbelieftrackerd.def(py::init<SparsePomdp<double> const&>(), py::arg("pomdp"));
|
|
// ndetbelieftrackerd.def("reset", &NDPomdpTrackerDense<double>::reset);
|
|
// ndetbelieftrackerd.def("set_risk", &NDPomdpTrackerDense<double>::setRisk, py::arg("risk"));
|
|
// ndetbelieftrackerd.def("obtain_current_risk",&NDPomdpTrackerDense<double>::getCurrentRisk, py::arg("max")=true);
|
|
// ndetbelieftrackerd.def("track", &NDPomdpTrackerDense<double>::track, py::arg("observation"));
|
|
// ndetbelieftrackerd.def("obtain_beliefs", &NDPomdpTrackerDense<double>::getCurrentBeliefs);
|
|
// ndetbelieftrackerd.def("obtain_last_observation", &NDPomdpTrackerDense<double>::getCurrentObservation);
|
|
// ndetbelieftrackerd.def("reduce",&NDPomdpTrackerDense<double>::reduce);
|
|
|
|
}
|
|
|
|
template void define_tracker<double>(py::module& m, std::string const& vtSuffix);
|
|
template void define_tracker<storm::RationalNumber>(py::module& m, std::string const& vtSuffix);
|