You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
5.2 KiB

  1. import pytest
  2. with pytest.suppress(ImportError):
  3. import numpy as np
  4. ref = np.array([[ 0, 3, 0, 0, 0, 11],
  5. [22, 0, 0, 0, 17, 11],
  6. [ 7, 5, 0, 1, 0, 11],
  7. [ 0, 0, 0, 0, 0, 11],
  8. [ 0, 0, 14, 0, 8, 11]])
  9. def assert_equal_ref(mat):
  10. np.testing.assert_array_equal(mat, ref)
  11. def assert_sparse_equal_ref(sparse_mat):
  12. assert_equal_ref(sparse_mat.todense())
  13. @pytest.requires_eigen_and_numpy
  14. def test_fixed():
  15. from pybind11_tests import fixed_r, fixed_c, fixed_passthrough_r, fixed_passthrough_c
  16. assert_equal_ref(fixed_c())
  17. assert_equal_ref(fixed_r())
  18. assert_equal_ref(fixed_passthrough_r(fixed_r()))
  19. assert_equal_ref(fixed_passthrough_c(fixed_c()))
  20. assert_equal_ref(fixed_passthrough_r(fixed_c()))
  21. assert_equal_ref(fixed_passthrough_c(fixed_r()))
  22. @pytest.requires_eigen_and_numpy
  23. def test_dense():
  24. from pybind11_tests import dense_r, dense_c, dense_passthrough_r, dense_passthrough_c
  25. assert_equal_ref(dense_r())
  26. assert_equal_ref(dense_c())
  27. assert_equal_ref(dense_passthrough_r(dense_r()))
  28. assert_equal_ref(dense_passthrough_c(dense_c()))
  29. assert_equal_ref(dense_passthrough_r(dense_c()))
  30. assert_equal_ref(dense_passthrough_c(dense_r()))
  31. @pytest.requires_eigen_and_numpy
  32. def test_nonunit_stride_from_python():
  33. from pybind11_tests import double_row, double_col, double_mat_cm, double_mat_rm
  34. counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3))
  35. first_row = counting_mat[0, :]
  36. first_col = counting_mat[:, 0]
  37. assert np.array_equal(double_row(first_row), 2.0 * first_row)
  38. assert np.array_equal(double_col(first_row), 2.0 * first_row)
  39. assert np.array_equal(double_row(first_col), 2.0 * first_col)
  40. assert np.array_equal(double_col(first_col), 2.0 * first_col)
  41. counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3))
  42. slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]]
  43. for slice_idx, ref_mat in enumerate(slices):
  44. assert np.array_equal(double_mat_cm(ref_mat), 2.0 * ref_mat)
  45. assert np.array_equal(double_mat_rm(ref_mat), 2.0 * ref_mat)
  46. @pytest.requires_eigen_and_numpy
  47. def test_nonunit_stride_to_python():
  48. from pybind11_tests import diagonal, diagonal_1, diagonal_n, block
  49. assert np.all(diagonal(ref) == ref.diagonal())
  50. assert np.all(diagonal_1(ref) == ref.diagonal(1))
  51. for i in range(-5, 7):
  52. assert np.all(diagonal_n(ref, i) == ref.diagonal(i)), "diagonal_n({})".format(i)
  53. assert np.all(block(ref, 2, 1, 3, 3) == ref[2:5, 1:4])
  54. assert np.all(block(ref, 1, 4, 4, 2) == ref[1:, 4:])
  55. assert np.all(block(ref, 1, 4, 3, 2) == ref[1:4, 4:])
  56. @pytest.requires_eigen_and_numpy
  57. def test_eigen_ref_to_python():
  58. from pybind11_tests import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
  59. chols = [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]
  60. for i, chol in enumerate(chols, start=1):
  61. mymat = chol(np.array([[1, 2, 4], [2, 13, 23], [4, 23, 77]]))
  62. assert np.all(mymat == np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])), "cholesky{}".format(i)
  63. @pytest.requires_eigen_and_numpy
  64. def test_special_matrix_objects():
  65. from pybind11_tests import incr_diag, symmetric_upper, symmetric_lower
  66. assert np.all(incr_diag(7) == np.diag([1, 2, 3, 4, 5, 6, 7]))
  67. asymm = np.array([[ 1, 2, 3, 4],
  68. [ 5, 6, 7, 8],
  69. [ 9, 10, 11, 12],
  70. [13, 14, 15, 16]])
  71. symm_lower = np.array(asymm)
  72. symm_upper = np.array(asymm)
  73. for i in range(4):
  74. for j in range(i + 1, 4):
  75. symm_lower[i, j] = symm_lower[j, i]
  76. symm_upper[j, i] = symm_upper[i, j]
  77. assert np.all(symmetric_lower(asymm) == symm_lower)
  78. assert np.all(symmetric_upper(asymm) == symm_upper)
  79. @pytest.requires_eigen_and_numpy
  80. def test_dense_signature(doc):
  81. from pybind11_tests import double_col, double_row, double_mat_rm
  82. assert doc(double_col) == "double_col(arg0: numpy.ndarray[float32[m, 1]]) -> numpy.ndarray[float32[m, 1]]"
  83. assert doc(double_row) == "double_row(arg0: numpy.ndarray[float32[1, n]]) -> numpy.ndarray[float32[1, n]]"
  84. assert doc(double_mat_rm) == "double_mat_rm(arg0: numpy.ndarray[float32[m, n]]) -> numpy.ndarray[float32[m, n]]"
  85. @pytest.requires_eigen_and_scipy
  86. def test_sparse():
  87. from pybind11_tests import sparse_r, sparse_c, sparse_passthrough_r, sparse_passthrough_c
  88. assert_sparse_equal_ref(sparse_r())
  89. assert_sparse_equal_ref(sparse_c())
  90. assert_sparse_equal_ref(sparse_passthrough_r(sparse_r()))
  91. assert_sparse_equal_ref(sparse_passthrough_c(sparse_c()))
  92. assert_sparse_equal_ref(sparse_passthrough_r(sparse_c()))
  93. assert_sparse_equal_ref(sparse_passthrough_c(sparse_r()))
  94. @pytest.requires_eigen_and_scipy
  95. def test_sparse_signature(doc):
  96. from pybind11_tests import sparse_passthrough_r, sparse_passthrough_c
  97. assert doc(sparse_passthrough_r) == "sparse_passthrough_r(arg0: scipy.sparse.csr_matrix[float32]) -> scipy.sparse.csr_matrix[float32]"
  98. assert doc(sparse_passthrough_c) == "sparse_passthrough_c(arg0: scipy.sparse.csc_matrix[float32]) -> scipy.sparse.csc_matrix[float32]"