Browse Source

fix in Matrix-vector multiplication

tempestpy_adaptions
TimQu 8 years ago
parent
commit
f16f18bbf6
  1. 25
      src/storm/storage/SparseMatrix.cpp

25
src/storm/storage/SparseMatrix.cpp

@ -1181,24 +1181,18 @@ namespace storm {
template<typename ValueType>
void SparseMatrix<ValueType>::multiplyWithVectorSequential(std::vector<ValueType> const& vector, std::vector<ValueType>& result) const {
if (&vector == &result) {
STORM_LOG_WARN("Matrix-vector-multiplication invoked but the target vector uses the same memory as the input vector. This requires to allocate auxiliary memory.");
std::vector<ValueType> tmpVector(this->getRowCount());
multiplyWithVectorSequential(vector, tmpVector);
result = std::move(tmpVector);
} else {
const_iterator it = this->begin();
const_iterator ite;
std::vector<index_type>::const_iterator rowIterator = rowIndications.begin();
typename std::vector<ValueType>::iterator resultIterator = result.begin();
typename std::vector<ValueType>::iterator resultIteratorEnd = result.end();
// If the vector to multiply with and the target vector are actually the same, we need an auxiliary variable
// to store the intermediate result.
if (&vector == &result) {
for (; resultIterator != resultIteratorEnd; ++rowIterator, ++resultIterator) {
ValueType tmpValue = storm::utility::zero<ValueType>();
for (ite = this->begin() + *(rowIterator + 1); it != ite; ++it) {
tmpValue += it->getValue() * vector[it->getColumn()];
}
*resultIterator = tmpValue;
}
} else {
for (; resultIterator != resultIteratorEnd; ++rowIterator, ++resultIterator) {
*resultIterator = storm::utility::zero<ValueType>();
@ -1212,6 +1206,12 @@ namespace storm {
#ifdef STORM_HAVE_INTELTBB
template<typename ValueType>
void SparseMatrix<ValueType>::multiplyWithVectorParallel(std::vector<ValueType> const& vector, std::vector<ValueType>& result) const {
if (&vector == &result) {
STORM_LOG_WARN("Matrix-vector-multiplication invoked but the target vector uses the same memory as the input vector. This requires to allocate auxiliary memory.");
std::vector<ValueType> tmpVector(this->getRowCount());
multiplyWithVectorParallel(vector, tmpVector);
result = std::move(tmpVector);
} else {
tbb::parallel_for(tbb::blocked_range<index_type>(0, result.size(), 10),
[&] (tbb::blocked_range<index_type> const& range) {
index_type startRow = range.begin();
@ -1232,6 +1232,7 @@ namespace storm {
}
});
}
}
#endif
template<typename ValueType>

Loading…
Cancel
Save