From aa6a3d214283c449b37aa75dc1dec60064ddefce Mon Sep 17 00:00:00 2001
From: Sebastian Junges <sebastian.junges@gmail.com>
Date: Wed, 23 Dec 2020 22:37:16 -0800
Subject: [PATCH] sampling from a distribution and from a choice

---
 src/storm/generator/Choice.cpp     |  5 +++++
 src/storm/generator/Choice.h       | 11 +++++++++++
 src/storm/storage/Distribution.cpp | 14 ++++++++++++++
 src/storm/storage/Distribution.h   | 11 +++++++++++
 4 files changed, 41 insertions(+)

diff --git a/src/storm/generator/Choice.cpp b/src/storm/generator/Choice.cpp
index 5d69c98ab..c021d59e5 100644
--- a/src/storm/generator/Choice.cpp
+++ b/src/storm/generator/Choice.cpp
@@ -44,6 +44,11 @@ namespace storm {
                 this->addOriginData(other.originData.get());
             }
         }
+
+        template<typename ValueType, typename StateType>
+        StateType Choice<ValueType, StateType>::sampleFromDistribution(ValueType const& quantile) const {
+            return distribution.sampleFromDistribution(quantile);
+        }
         
         template<typename ValueType, typename StateType>
         typename storm::storage::Distribution<ValueType, StateType>::iterator Choice<ValueType, StateType>::begin() {
diff --git a/src/storm/generator/Choice.h b/src/storm/generator/Choice.h
index 1dbc41760..f30e3c83b 100644
--- a/src/storm/generator/Choice.h
+++ b/src/storm/generator/Choice.h
@@ -28,6 +28,17 @@ namespace storm {
              * Adds the given choice to the current one.
              */
             void add(Choice const& other);
+
+            /**
+             * Given a value q, find the event in the ordered distribution that corresponds to this prob.
+             * Example: Given a (sub)distribution { x -> 0.4, y -> 0.3, z -> 0.2 },
+             * A value q in [0,0.4] yields x, q in [0.4, 0.7] yields y, and q in [0.7, 0.9] yields z.
+             * Any other value for q yields undefined behavior.
+             *
+             * @param quantile q, a value in the CDF.
+             * @return A state
+             */
+            StateType sampleFromDistribution(ValueType const& quantile) const;
             
             /*!
              * Returns an iterator to the distribution associated with this choice.
diff --git a/src/storm/storage/Distribution.cpp b/src/storm/storage/Distribution.cpp
index 2290c611c..0d4ceb8eb 100644
--- a/src/storm/storage/Distribution.cpp
+++ b/src/storm/storage/Distribution.cpp
@@ -176,6 +176,20 @@ namespace storm {
                 entry.second /= sum;
             }
         }
+
+        template<typename ValueType, typename StateType>
+        StateType Distribution<ValueType, StateType>::sampleFromDistribution(const ValueType &quantile) const {
+            ValueType sum = storm::utility::zero<ValueType>();
+            storm::utility::ConstantsComparator<ValueType> comp;
+            for (auto const& entry: distribution) {
+                sum += entry.second;
+                if (comp.isLess(quantile,sum)) {
+                    return entry.first;
+                }
+            }
+            STORM_LOG_ASSERT(false,"This point should not be reached.");
+            return 0;
+        }
     
         
         template class Distribution<double>;
diff --git a/src/storm/storage/Distribution.h b/src/storm/storage/Distribution.h
index c3ac58dcc..3c3af5db2 100644
--- a/src/storm/storage/Distribution.h
+++ b/src/storm/storage/Distribution.h
@@ -148,6 +148,17 @@ namespace storm {
              * Normalizes the distribution such that the values sum up to one.
              */
             void normalize();
+
+            /**
+            * Given a value q, find the event in the ordered distribution that corresponds to this prob.
+            * Example: Given a (sub)distribution { x -> 0.4, y -> 0.3, z -> 0.2 },
+            * A value q in [0,0.4] yields x, q in [0.4, 0.7] yields y, and q in [0.7, 0.9] yields z.
+            * Any other value for q yields undefined behavior.
+            *
+            * @param quantile q, a value in the CDF.
+            * @return A state
+            */
+            StateType sampleFromDistribution(ValueType const& quantile) const;
             
         private:
             // A list of states and the probabilities that are assigned to them.