From 512da83a429808f48e480d3e09a429154b895d83 Mon Sep 17 00:00:00 2001 From: dehnert Date: Mon, 27 Jun 2016 19:10:42 +0200 Subject: [PATCH] added proper mult_add to gmm++ Former-commit-id: 03a4f13a47be604cf56b44b4f3e7c334d44af371 --- .../3rdparty/gmm-5.0/include/gmm/gmm_blas.h | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/resources/3rdparty/gmm-5.0/include/gmm/gmm_blas.h b/resources/3rdparty/gmm-5.0/include/gmm/gmm_blas.h index 2c9069c5b..4c7d82866 100644 --- a/resources/3rdparty/gmm-5.0/include/gmm/gmm_blas.h +++ b/resources/3rdparty/gmm-5.0/include/gmm/gmm_blas.h @@ -1858,6 +1858,25 @@ namespace gmm { linalg_traits::sub_orientation>::potype()); } } + + /** Multiply-accumulate. l4 = l1*l2 + l3; */ + template inline + void mult_add(const L1& l1, const L2& l2, const L3& l3, L4& l4) { + size_type m = mat_nrows(l1), n = mat_ncols(l1); + if (!m || !n) return; + GMM_ASSERT2(n==vect_size(l2) && m==vect_size(l3) && vect_size(l3) == vect_size(l4), "dimensions mismatch"); + if (!same_origin(l2, l3)) { + mult_add_spec(l1, l2, l3, l4, typename principal_orientation_type::sub_orientation>::potype()); + } + else { + GMM_WARNING2("Warning, A temporary is used for mult\n"); + typename temporary_vector::vector_type temp(vect_size(l2)); + copy(l2, temp); + mult_add_spec(l1, temp, l3, l4, typename principal_orientation_type::sub_orientation>::potype()); + } + } ///@cond DOXY_SHOW_ALL_FUNCTIONS template inline @@ -1893,6 +1912,16 @@ namespace gmm { *it += vect_sp(linalg_traits::row(itr), l2); } + template + void mult_add_by_row(const L1& l1, const L2& l2, const L3& l3, L4& l4, abstract_dense) { + typename linalg_traits::const_iterator add_it=vect_begin(l3), add_ite=vect_end(l3); + typename linalg_traits::iterator target_it=vect_begin(l4), target_ite=vect_end(l4); + typename linalg_traits::const_row_iterator + itr = mat_row_const_begin(l1); + for (; add_it != add_ite; ++add_it, ++target_it, ++itr) + *target_it = vect_sp(linalg_traits::row(itr), l2) + *add_it; + } + template void mult_add_by_col(const L1& l1, const L2& l2, L3& l3, abstract_dense) { size_type nc = mat_ncols(l1); @@ -1922,6 +1951,10 @@ namespace gmm { void mult_add_spec(const L1& l1, const L2& l2, L3& l3, row_major) { mult_add_by_row(l1, l2, l3, typename linalg_traits::storage_type()); } + template inline + void mult_add_spec(const L1& l1, const L2& l2, const L3& l3, L4& l4, row_major) + { mult_add_by_row(l1, l2, l3, l4, typename linalg_traits::storage_type()); } + template inline void mult_add_spec(const L1& l1, const L2& l2, L3& l3, col_major) { mult_add_by_col(l1, l2, l3, typename linalg_traits::storage_type()); }