You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

259 lines
8.2 KiB

8 years ago
  1. import pytest
  2. import pybind11_tests
  3. from pybind11_tests import ConstructorStats
  4. def test_override(capture, msg):
  5. from pybind11_tests import (ExampleVirt, runExampleVirt, runExampleVirtVirtual,
  6. runExampleVirtBool)
  7. class ExtendedExampleVirt(ExampleVirt):
  8. def __init__(self, state):
  9. super(ExtendedExampleVirt, self).__init__(state + 1)
  10. self.data = "Hello world"
  11. def run(self, value):
  12. print('ExtendedExampleVirt::run(%i), calling parent..' % value)
  13. return super(ExtendedExampleVirt, self).run(value + 1)
  14. def run_bool(self):
  15. print('ExtendedExampleVirt::run_bool()')
  16. return False
  17. def get_string1(self):
  18. return "override1"
  19. def pure_virtual(self):
  20. print('ExtendedExampleVirt::pure_virtual(): %s' % self.data)
  21. class ExtendedExampleVirt2(ExtendedExampleVirt):
  22. def __init__(self, state):
  23. super(ExtendedExampleVirt2, self).__init__(state + 1)
  24. def get_string2(self):
  25. return "override2"
  26. ex12 = ExampleVirt(10)
  27. with capture:
  28. assert runExampleVirt(ex12, 20) == 30
  29. assert capture == """
  30. Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
  31. """ # noqa: E501 line too long
  32. with pytest.raises(RuntimeError) as excinfo:
  33. runExampleVirtVirtual(ex12)
  34. assert msg(excinfo.value) == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
  35. ex12p = ExtendedExampleVirt(10)
  36. with capture:
  37. assert runExampleVirt(ex12p, 20) == 32
  38. assert capture == """
  39. ExtendedExampleVirt::run(20), calling parent..
  40. Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
  41. """ # noqa: E501 line too long
  42. with capture:
  43. assert runExampleVirtBool(ex12p) is False
  44. assert capture == "ExtendedExampleVirt::run_bool()"
  45. with capture:
  46. runExampleVirtVirtual(ex12p)
  47. assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
  48. ex12p2 = ExtendedExampleVirt2(15)
  49. with capture:
  50. assert runExampleVirt(ex12p2, 50) == 68
  51. assert capture == """
  52. ExtendedExampleVirt::run(50), calling parent..
  53. Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
  54. """ # noqa: E501 line too long
  55. cstats = ConstructorStats.get(ExampleVirt)
  56. assert cstats.alive() == 3
  57. del ex12, ex12p, ex12p2
  58. assert cstats.alive() == 0
  59. assert cstats.values() == ['10', '11', '17']
  60. assert cstats.copy_constructions == 0
  61. assert cstats.move_constructions >= 0
  62. def test_inheriting_repeat():
  63. from pybind11_tests import A_Repeat, B_Repeat, C_Repeat, D_Repeat, A_Tpl, B_Tpl, C_Tpl, D_Tpl
  64. class AR(A_Repeat):
  65. def unlucky_number(self):
  66. return 99
  67. class AT(A_Tpl):
  68. def unlucky_number(self):
  69. return 999
  70. obj = AR()
  71. assert obj.say_something(3) == "hihihi"
  72. assert obj.unlucky_number() == 99
  73. assert obj.say_everything() == "hi 99"
  74. obj = AT()
  75. assert obj.say_something(3) == "hihihi"
  76. assert obj.unlucky_number() == 999
  77. assert obj.say_everything() == "hi 999"
  78. for obj in [B_Repeat(), B_Tpl()]:
  79. assert obj.say_something(3) == "B says hi 3 times"
  80. assert obj.unlucky_number() == 13
  81. assert obj.lucky_number() == 7.0
  82. assert obj.say_everything() == "B says hi 1 times 13"
  83. for obj in [C_Repeat(), C_Tpl()]:
  84. assert obj.say_something(3) == "B says hi 3 times"
  85. assert obj.unlucky_number() == 4444
  86. assert obj.lucky_number() == 888.0
  87. assert obj.say_everything() == "B says hi 1 times 4444"
  88. class CR(C_Repeat):
  89. def lucky_number(self):
  90. return C_Repeat.lucky_number(self) + 1.25
  91. obj = CR()
  92. assert obj.say_something(3) == "B says hi 3 times"
  93. assert obj.unlucky_number() == 4444
  94. assert obj.lucky_number() == 889.25
  95. assert obj.say_everything() == "B says hi 1 times 4444"
  96. class CT(C_Tpl):
  97. pass
  98. obj = CT()
  99. assert obj.say_something(3) == "B says hi 3 times"
  100. assert obj.unlucky_number() == 4444
  101. assert obj.lucky_number() == 888.0
  102. assert obj.say_everything() == "B says hi 1 times 4444"
  103. class CCR(CR):
  104. def lucky_number(self):
  105. return CR.lucky_number(self) * 10
  106. obj = CCR()
  107. assert obj.say_something(3) == "B says hi 3 times"
  108. assert obj.unlucky_number() == 4444
  109. assert obj.lucky_number() == 8892.5
  110. assert obj.say_everything() == "B says hi 1 times 4444"
  111. class CCT(CT):
  112. def lucky_number(self):
  113. return CT.lucky_number(self) * 1000
  114. obj = CCT()
  115. assert obj.say_something(3) == "B says hi 3 times"
  116. assert obj.unlucky_number() == 4444
  117. assert obj.lucky_number() == 888000.0
  118. assert obj.say_everything() == "B says hi 1 times 4444"
  119. class DR(D_Repeat):
  120. def unlucky_number(self):
  121. return 123
  122. def lucky_number(self):
  123. return 42.0
  124. for obj in [D_Repeat(), D_Tpl()]:
  125. assert obj.say_something(3) == "B says hi 3 times"
  126. assert obj.unlucky_number() == 4444
  127. assert obj.lucky_number() == 888.0
  128. assert obj.say_everything() == "B says hi 1 times 4444"
  129. obj = DR()
  130. assert obj.say_something(3) == "B says hi 3 times"
  131. assert obj.unlucky_number() == 123
  132. assert obj.lucky_number() == 42.0
  133. assert obj.say_everything() == "B says hi 1 times 123"
  134. class DT(D_Tpl):
  135. def say_something(self, times):
  136. return "DT says:" + (' quack' * times)
  137. def unlucky_number(self):
  138. return 1234
  139. def lucky_number(self):
  140. return -4.25
  141. obj = DT()
  142. assert obj.say_something(3) == "DT says: quack quack quack"
  143. assert obj.unlucky_number() == 1234
  144. assert obj.lucky_number() == -4.25
  145. assert obj.say_everything() == "DT says: quack 1234"
  146. class DT2(DT):
  147. def say_something(self, times):
  148. return "DT2: " + ('QUACK' * times)
  149. def unlucky_number(self):
  150. return -3
  151. class BT(B_Tpl):
  152. def say_something(self, times):
  153. return "BT" * times
  154. def unlucky_number(self):
  155. return -7
  156. def lucky_number(self):
  157. return -1.375
  158. obj = BT()
  159. assert obj.say_something(3) == "BTBTBT"
  160. assert obj.unlucky_number() == -7
  161. assert obj.lucky_number() == -1.375
  162. assert obj.say_everything() == "BT -7"
  163. # PyPy: Reference count > 1 causes call with noncopyable instance
  164. # to fail in ncv1.print_nc()
  165. @pytest.unsupported_on_pypy
  166. @pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'),
  167. reason="NCVirt test broken on ICPC")
  168. def test_move_support():
  169. from pybind11_tests import NCVirt, NonCopyable, Movable
  170. class NCVirtExt(NCVirt):
  171. def get_noncopyable(self, a, b):
  172. # Constructs and returns a new instance:
  173. nc = NonCopyable(a * a, b * b)
  174. return nc
  175. def get_movable(self, a, b):
  176. # Return a referenced copy
  177. self.movable = Movable(a, b)
  178. return self.movable
  179. class NCVirtExt2(NCVirt):
  180. def get_noncopyable(self, a, b):
  181. # Keep a reference: this is going to throw an exception
  182. self.nc = NonCopyable(a, b)
  183. return self.nc
  184. def get_movable(self, a, b):
  185. # Return a new instance without storing it
  186. return Movable(a, b)
  187. ncv1 = NCVirtExt()
  188. assert ncv1.print_nc(2, 3) == "36"
  189. assert ncv1.print_movable(4, 5) == "9"
  190. ncv2 = NCVirtExt2()
  191. assert ncv2.print_movable(7, 7) == "14"
  192. # Don't check the exception message here because it differs under debug/non-debug mode
  193. with pytest.raises(RuntimeError):
  194. ncv2.print_nc(9, 9)
  195. nc_stats = ConstructorStats.get(NonCopyable)
  196. mv_stats = ConstructorStats.get(Movable)
  197. assert nc_stats.alive() == 1
  198. assert mv_stats.alive() == 1
  199. del ncv1, ncv2
  200. assert nc_stats.alive() == 0
  201. assert mv_stats.alive() == 0
  202. assert nc_stats.values() == ['4', '9', '9', '9']
  203. assert mv_stats.values() == ['4', '5', '7', '7']
  204. assert nc_stats.copy_constructions == 0
  205. assert mv_stats.copy_constructions == 1
  206. assert nc_stats.move_constructions >= 0
  207. assert mv_stats.move_constructions >= 0