diff --git a/src/dft/dft.cpp b/src/dft/dft.cpp index 2345012..459347d 100644 --- a/src/dft/dft.cpp +++ b/src/dft/dft.cpp @@ -43,6 +43,9 @@ void define_dft(py::module& m) { return dft.getElement(dft.getTopLevelIndex()); }, "Get top level element") .def("get_element", &DFT::getElement, "Get DFT element at index", py::arg("index")) + .def("get_element_by_name", [](DFT& dft, std::string const& name) { + return dft.getElement(dft.getIndex(name)); + }, "Get DFT element by name", py::arg("name")) .def("modularisation", &DFT::topModularisation, "Split DFT into independent modules") .def("symmetries", [](DFT& dft) { return dft.findSymmetries(dft.colourDFT()); @@ -59,6 +62,9 @@ void define_dft(py::module& m) { return dft.getElement(dft.getTopLevelIndex()); }, "Get top level element") .def("get_element", &DFT::getElement, "Get DFT element at index", py::arg("index")) + .def("get_element_by_name", [](DFT& dft, std::string const& name) { + return dft.getElement(dft.getIndex(name)); + }, "Get DFT element by name", py::arg("name")) .def("modularisation", &DFT::topModularisation, "Split DFT into independent modules") .def("symmetries", [](DFT& dft) { return dft.findSymmetries(dft.colourDFT()); diff --git a/tests/dft/test_dft.py b/tests/dft/test_dft.py index d68b029..3dff554 100644 --- a/tests/dft/test_dft.py +++ b/tests/dft/test_dft.py @@ -1,4 +1,5 @@ import os +import pytest import stormpy import stormpy.logic @@ -30,6 +31,16 @@ class TestDftElement: assert dft.nr_dynamic() == 0 assert tle.id == 2 assert tle.name == "A" + b = dft.get_element(0) + assert b.id == 0 + assert b.name == "B" + c = dft.get_element_by_name("C") + assert c.id == 1 + assert c.name == "C" + # Invalid name should raise exception + with pytest.raises(RuntimeError) as exception: + d = dft.get_element_by_name("D") + assert "InvalidArgumentException" in str(exception.value) @dft class TestDftSymmetries: