diff --git a/resources/3rdparty/gmm-5.2/include/gmm/gmm_blas.h b/resources/3rdparty/gmm-5.2/include/gmm/gmm_blas.h index 4fbe070d3..9ee3f967a 100644 --- a/resources/3rdparty/gmm-5.2/include/gmm/gmm_blas.h +++ b/resources/3rdparty/gmm-5.2/include/gmm/gmm_blas.h @@ -1797,6 +1797,11 @@ namespace gmm { void mult_add_by_row_parallel(const L1& l1, const L2& l2, const L3& l3, L4& l4, abstract_dense) { tbb::parallel_for(tbb::blocked_range(0, vect_size(l4), 10), TbbMultAddFunctor(l1, l2, l3, l4)); } + + template + void mult_add_by_row_parallel(const L1& l1, const L2& l2, L3& l3, abstract_dense) { + tbb::parallel_for(tbb::blocked_range(0, vect_size(l3), 10), TbbMultAddFunctor(l1, l2, l3, l3)); + } #endif template @@ -1949,6 +1954,22 @@ namespace gmm { } } + /** Multiply-accumulate. l3 += l1*l2; */ + template inline + void mult_add_parallel(const L1& l1, const L2& l2, L3& l3) { + size_type m = mat_nrows(l1), n = mat_ncols(l1); + if (!m || !n) return; + GMM_ASSERT2(n==vect_size(l2) && m==vect_size(l3), "dimensions mismatch"); + if (!same_origin(l2, l3)) { + mult_add_parallel_spec(l1, l2, l3, typename principal_orientation_type::sub_orientation>::potype()); + } else { + GMM_WARNING2("Warning, A temporary is used for mult\n"); + typename temporary_vector::vector_type temp(vect_size(l2)); + copy(l2, temp); + mult_add_parallel_spec(l1, temp, l3, typename principal_orientation_type::sub_orientation>::potype()); + } + } + /** Multiply-accumulate. l4 = l1*l2 + l3; */ template inline void mult_add_parallel(const L1& l1, const L2& l2, const L3& l3, L4& l4) { @@ -2056,6 +2077,10 @@ namespace gmm { template inline void mult_add_parallel_spec(const L1& l1, const L2& l2, const L3& l3, L4& l4, row_major) { mult_add_by_row_parallel(l1, l2, l3, l4, typename linalg_traits::storage_type()); } + + template inline + void mult_add_parallel_spec(const L1& l1, const L2& l2, L3& l3, row_major) + { mult_add_by_row_parallel(l1, l2, l3, typename linalg_traits::storage_type()); } #endif template inline diff --git a/src/storm/solver/GmmxxMultiplier.cpp b/src/storm/solver/GmmxxMultiplier.cpp index d8c4c2468..6e0f727b0 100644 --- a/src/storm/solver/GmmxxMultiplier.cpp +++ b/src/storm/solver/GmmxxMultiplier.cpp @@ -120,7 +120,11 @@ namespace storm { template void GmmxxMultiplier::multAdd(std::vector const& x, std::vector const* b, std::vector& result) const { if (b) { - gmm::mult_add(gmmMatrix, x, *b, result); + if (b == &result) { + gmm::mult_add(gmmMatrix, x, result); + } else { + gmm::mult_add(gmmMatrix, x, *b, result); + } } else { gmm::mult(gmmMatrix, x, result); } @@ -257,7 +261,11 @@ namespace storm { void GmmxxMultiplier::multAddParallel(std::vector const& x, std::vector const* b, std::vector& result) const { #ifdef STORM_HAVE_INTELTBB if (b) { - gmm::mult_add_parallel(gmmMatrix, x, *b, result); + if (b == &result) { + gmm::mult_add_parallel(gmmMatrix, x, result); + } else { + gmm::mult_add_parallel(gmmMatrix, x, *b, result); + } } else { gmm::mult_parallel(gmmMatrix, x, result); }