diff --git a/lib/stormpy/__init__.py b/lib/stormpy/__init__.py
index 93a83ad..8ffd318 100644
--- a/lib/stormpy/__init__.py
+++ b/lib/stormpy/__init__.py
@@ -137,11 +137,13 @@ def perform_bisimulation(model, properties, bisimulation_type):
         return core._perform_bisimulation(model, formulae, bisimulation_type)
 
 
-def model_checking(model, property):
+def model_checking(model, property, only_initial_states=False):
     """
     Perform model checking on model for property.
     :param model: Model.
     :param property: Property to check for.
+    :param only_initial_states: If True, only results for initial states are computed.
+                                If False, results for all states are computed.
     :return: Model checking result.
     :rtype: CheckResult
     """
@@ -151,10 +153,10 @@ def model_checking(model, property):
         formula = property
 
     if model.supports_parameters:
-        task = core.ParametricCheckTask(formula, False)
+        task = core.ParametricCheckTask(formula, only_initial_states)
         return core._parametric_model_checking_sparse_engine(model, task)
     else:
-        task = core.CheckTask(formula, False)
+        task = core.CheckTask(formula, only_initial_states)
         return core._model_checking_sparse_engine(model, task)
 
 
diff --git a/tests/core/test_modelchecking.py b/tests/core/test_modelchecking.py
index b0c1c1b..70b6ae5 100644
--- a/tests/core/test_modelchecking.py
+++ b/tests/core/test_modelchecking.py
@@ -69,6 +69,17 @@ class TestModelChecking:
         reference = [0.16666666666666663, 0.3333333333333333, 0, 0.6666666666666666, 0, 0, 0, 1, 0, 0, 0, 0, 0]
         assert all(map(math.isclose, result.get_values(), reference))
 
+    def test_model_checking_only_initial(self):
+        program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm"))
+        formulas = stormpy.parse_properties_for_prism_program("Pmax=? [F{\"coin_flips\"}<=3 \"one\"]", program)
+        model = stormpy.build_model(program, formulas)
+        assert len(model.initial_states) == 1
+        initial_state = model.initial_states[0]
+        assert initial_state == 0
+        result = stormpy.model_checking(model, formulas[0], only_initial_states=True)
+        assert not result.result_for_all_states
+        assert math.isclose(result.at(initial_state), 0.125)
+
     def test_model_checking_prob01(self):
         program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm"))
         formulaPhi = stormpy.parse_properties("true")[0]