Coverage for CIResults/tests/test_sandbox.py: 100%

180 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-19 09:20 +0000

1from unittest import skipIf 

2from unittest.mock import patch, MagicMock 

3from django.test import TestCase 

4 

5from CIResults.sandbox.io import IOWrapper, Server, Client 

6from CIResults.sandbox.lockdown import LockDown 

7 

8from io import BytesIO 

9import json 

10import sys 

11import os 

12 

13 

14def create_pipe(): 

15 p_in, p_out = os.pipe() 

16 return os.fdopen(p_in, 'rb'), os.fdopen(p_out, 'wb') 

17 

18 

19class IOWrapperTests(TestCase): 

20 MSG = "Some string with non-ascii characters - éè" 

21 

22 def test_streams__default_values(self): 

23 wrapper = IOWrapper() 

24 self.assertEqual(wrapper.stream_in, sys.stdin.buffer) 

25 self.assertEqual(wrapper.stream_out, sys.stdout.buffer) 

26 

27 def test_streams__overriden(self): 

28 p_in, p_out = create_pipe() 

29 wrapper = IOWrapper(p_in, p_out) 

30 

31 self.assertEqual(wrapper.stream_in, p_in) 

32 self.assertEqual(wrapper.stream_out, p_out) 

33 

34 def test_send_then_read__byte_array(self): 

35 p_in, p_out = create_pipe() 

36 wrapper = IOWrapper(p_in, p_out) 

37 

38 wrapper.send(self.MSG.encode()) 

39 self.assertEqual(wrapper.read(), self.MSG) 

40 

41 def test_send_then_read__string(self): 

42 p_in, p_out = create_pipe() 

43 wrapper = IOWrapper(p_in, p_out) 

44 

45 wrapper.send(self.MSG) 

46 self.assertEqual(wrapper.read(), self.MSG) 

47 

48 def test_read__header_too_short(self): 

49 wrapper = IOWrapper(BytesIO(b'hel')) 

50 self.assertRaisesMessage(IOError, "Invalid message format", wrapper.read) 

51 

52 def test_read__message_too_short(self): 

53 wrapper = IOWrapper(BytesIO(b'hello world')) 

54 self.assertRaisesMessage(IOError, "The message read is shorter than expected", wrapper.read) 

55 

56 

57class ServerTests(TestCase): 

58 def setUp(self): 

59 self.pc_r, self.pc_w = create_pipe() # For the client -> server communication 

60 self.ps_r, self.ps_w = create_pipe() # For server -> client communication 

61 

62 # Now create a server, and don't forget to swap TX and RX! 

63 self.client = IOWrapper(self.ps_r, self.pc_w) 

64 self.server = Server(self.pc_r, self. ps_w, lockdown=False) 

65 

66 def test_init(self): 

67 self.assertEqual(self.server.iowrapper.stream_in, self.pc_r) 

68 self.assertEqual(self.server.iowrapper.stream_out, self. ps_w) 

69 

70 def test_serve_request__not_a_json(self): 

71 self.client.send("hello world") 

72 self.assertRaises(json.decoder.JSONDecodeError, self.server.serve_request) 

73 

74 def _do_request(self, request): 

75 self.client.send(json.dumps(request)) 

76 self.server.serve_request() 

77 return json.loads(self.client.read()) 

78 

79 def test_serve_request__non_existing_method(self): 

80 ret = self._do_request({'method': 'missing', 'kwargs': {}, 'version': 1}) 

81 self.assertEqual(ret, {'return_code': 404}) 

82 

83 def test_rpc__set_user_script__invalid_syntax(self): 

84 ret = self._do_request({'method': IOWrapper.EXEC_USER_SCRIPT, 'version': 1, 

85 'kwargs': {'script': "gfdsgfdg"}}) 

86 self.assertEqual(ret, {'return_code': 400, 'reason': "Compilation error: name 'gfdsgfdg' is not defined"}) 

87 

88 def test_rpc__set_user_script__missing_script(self): 

89 ret = self._do_request({'method': IOWrapper.EXEC_USER_SCRIPT, 'version': 1, 

90 'kwargs': {}}) 

91 self.assertEqual(ret, {'return_code': 400, 'reason': "The 'script' parameter is missing"}) 

92 

93 def _exec_usr_function(self, script, func_name, usr_kwargs={}): 

94 ret = self._do_request({'method': IOWrapper.EXEC_USER_SCRIPT, 'version': 1, 

95 'kwargs': {'script': script}}) 

96 self.assertEqual(ret, {'return_code': 200}) 

97 

98 # Now try calling the function 

99 ret = self._do_request({'method': IOWrapper.CALL_USER_FUNCTION, 'version': 1, 

100 'kwargs': {'user_function': func_name, 'user_kwargs': usr_kwargs}}) 

101 return ret 

102 

103 def test_rpc__call_user_function__missing_rpc_parameters(self): 

104 ret = self._do_request({'method': IOWrapper.CALL_USER_FUNCTION, 'version': 1, 

105 'kwargs': {}}) 

106 self.assertEqual(ret, {'return_code': 400, 

107 'reason': "Both the 'user_function' and 'user_kwargs' parameters are needed"}) 

108 

109 def test_rpc__call_user_function__unknown_function(self): 

110 ret = self._exec_usr_function("def helloworld(toto): pass", "invalid") 

111 self.assertEqual(ret, {'return_code': 404, 'reason': 'The user function does not exist'}) 

112 

113 def test_rpc__call_user_function__missing_argument(self): 

114 ret = self._exec_usr_function("def helloworld(toto): pass", "helloworld") 

115 self.assertEqual(ret, {'return_code': 400, 

116 'reason': "helloworld() missing 1 required positional argument: 'toto'"}) 

117 

118 def test_rpc__call_user_function__success(self): 

119 ret = self._exec_usr_function("def func(): return 'called!'", "func") 

120 self.assertEqual(ret, {'return_code': 200, 'return_value': "called!"}) 

121 

122 # NOTE: Coverage with the Server in lockdown is provided by the integration test 

123 

124 

125@skipIf(not LockDown.is_supported(), "SECCOMP is missing") 

126class LockDownTests(TestCase): 

127 def test_make_coverage_happy(self): 

128 # This is already properly tested, but in a forked process which means 

129 # coverage does not get access to it 

130 from seccomplite import ALLOW, Filter 

131 

132 ld = LockDown() 

133 ld.add_rule(ALLOW, "read") 

134 ld.f = Filter(def_action=ALLOW) 

135 ld.start() 

136 

137 def _test_operation(self, method, with_lockdown=False): 

138 pid = os.fork() 

139 if pid == 0: 

140 if with_lockdown: # pragma: no cover 

141 LockDown().start() # pragma: no cover 

142 method() # pragma: no cover 

143 

144 # exit immediately without calling any cleanup functions 

145 os._exit(0) # pragma: no cover 

146 else: 

147 return os.waitpid(pid, 0)[1] 

148 

149 # ---------- Operations that should fail in lockdown mode but succeed otherwise ---------- 

150 def check_fail__read_file(self): 

151 with open("/etc/resolv.conf", "r") as f: # pragma: no cover 

152 f.readlines() # pragma: no cover 

153 

154 def check_fail__write_file(self): 

155 with open("/tmp/foo", "w") as f: # pragma: no cover 

156 f.write("Short message") # pragma: no cover 

157 

158 def check_fail__stat_file(self): 

159 os.stat('/etc/resolv.conf') # pragma: no cover 

160 

161 def check_fail__reset_sandbox(self): 

162 from seccomplite import Filter, ALLOW # pragma: no cover 

163 f = Filter(def_action=ALLOW) # pragma: no cover 

164 f.load() # pragma: no cover 

165 

166 # ---------- Operations that should always succeed ---------- 

167 def check_pass__big_alloc(self): 

168 x = "*" * 1000000 # pragma: no cover 

169 del x # pragma: no cover 

170 

171 def check_pass__import_standard_library(self): 

172 import re # noqa # pragma: no cover 

173 import sys # noqa # pragma: no cover 

174 import os # noqa # pragma: no cover 

175 

176 def test_operations(self): 

177 for op_name in [o for o in dir(self) if o.startswith('check_')]: 

178 operation = getattr(self, op_name) 

179 

180 if op_name.startswith('check_fail__'): 

181 with self.subTest(msg="Checking operation {}: free mode".format(op_name)): 

182 self.assertEqual(self._test_operation(operation, False), 0) 

183 with self.subTest(msg="Checking operation {}: lockdown mode".format(op_name)): 

184 self.assertNotEqual(self._test_operation(operation, True), 0) 

185 elif op_name.startswith('check_pass__'): 

186 with self.subTest(msg="Checking operation {}: lockdown mode".format(op_name)): 

187 self.assertEqual(self._test_operation(operation, True), 0) 

188 

189 

190class UserFunctionCallErrorTests(TestCase): 

191 def test_exception(self): 

192 exc = Client.UserFunctionCallError(400, 'reason') 

193 

194 self.assertEqual(exc.return_code, 400) 

195 self.assertEqual(exc.reason, 'reason') 

196 self.assertEqual(str(exc), "UserFunctionCallError 400: reason") 

197 

198 

199class ClientTests(TestCase): 

200 SCRIPT = "def helloworld(): return 'OK'" 

201 

202 @patch('subprocess.Popen') 

203 @patch('subprocess.check_output') 

204 def setUp(self, chk_out_mock, popen_mocked): 

205 self.pc_r, self.pc_w = create_pipe() # For the client -> server communication 

206 self.ps_r, self.ps_w = create_pipe() # For server -> client communication 

207 

208 self.server_io = IOWrapper(stream_in=self.pc_r, stream_out=self.ps_w) 

209 popen_mocked.return_value = MagicMock(stdout=self.ps_r, stdin=self.pc_w, 

210 kill=MagicMock()) 

211 chk_out_mock.return_value = sys.executable.encode() 

212 

213 # Pre-send a success for the first rpc call 

214 self.server_io.send('{"return_code": 200}') 

215 self.client = Client(self.SCRIPT) 

216 self.first_request = json.loads(self.server_io.read()) 

217 

218 @patch('CIResults.sandbox.io.Client.__init__', return_value=None) 

219 def test_get_or_create_instance(self, client_mocked): 

220 SCRIPT1 = 'hello' 

221 SCRIPT2 = 'hello2' 

222 

223 script1 = Client.get_or_create_instance(SCRIPT1) 

224 self.assertEqual(Client.get_or_create_instance(SCRIPT1), script1) 

225 

226 script2 = Client.get_or_create_instance(SCRIPT2) 

227 self.assertEqual(Client.get_or_create_instance(SCRIPT2), script2) 

228 

229 self.assertNotEqual(script2, script1) 

230 

231 def test_init_sequence(self): 

232 self.assertEqual(self.client.usr_script, self.SCRIPT) 

233 self.assertEqual(self.first_request, 

234 {"method": "exec_user_script", 

235 "kwargs": {"script": "def helloworld(): return 'OK'"}, 

236 "version": 1}) 

237 

238 def test_init_sequence_with_bad_script(self): 

239 RETURN_CODE = 101 

240 REASON = "Generic obscure reason" 

241 MSG = 'Error {} - Cannot set the user script: {}'.format(RETURN_CODE, REASON) 

242 

243 self.client.rpc_call = MagicMock(return_value={'return_code': RETURN_CODE, 'reason': REASON}) 

244 self.assertRaisesMessage(ValueError, MSG, self.client._restart_server) 

245 

246 def test_rpc_call__retries(self): 

247 self.client._restart_server = MagicMock() 

248 

249 for i in range(10): 

250 self.server_io.send('invalid {}'.format(i)) 

251 self.assertRaisesMessage(IOError, 

252 "Failed to make an RPC call: method='remote_func', kwargs={}. version=1", 

253 self.client.rpc_call, 'remote_func') 

254 self.assertEqual(self.client._restart_server.call_count, 3) 

255 

256 def test_rpc_call__retries_disabled(self): 

257 self.client._restart_server = MagicMock() 

258 

259 for i in range(10): 

260 self.server_io.send('invalid {}'.format(i)) 

261 self.assertRaisesMessage(IOError, 

262 "Failed to make an RPC call: method='remote_func', kwargs={}. version=1", 

263 self.client.rpc_call, 'remote_func', retry=False) 

264 self.assertEqual(self.client._restart_server.call_count, 0) 

265 

266 def test_call_user_function__success(self): 

267 FUNC_NAME = 'custom' 

268 FUNC_ARGS = {'arg1': 'val1', 'arg2': 'val2'} 

269 FUNC_RETURN = 1234 

270 

271 self.client.rpc_call = MagicMock(return_value={'return_code': 200, 'return_value': FUNC_RETURN}) 

272 self.assertEqual(self.client.call_user_function(FUNC_NAME, FUNC_ARGS), FUNC_RETURN) 

273 self.client.rpc_call.assert_called_with(IOWrapper.CALL_USER_FUNCTION, {'user_function': FUNC_NAME, 

274 'user_kwargs': FUNC_ARGS}) 

275 

276 def test_call_user_function__failure(self): 

277 RETURN_CODE = 101 

278 REASON = "Generic obscure reason" 

279 

280 self.client.rpc_call = MagicMock(return_value={'return_code': RETURN_CODE, 'reason': REASON}) 

281 

282 # make sure that the call raises, then make sure the exception is filled correctly 

283 self.assertRaises(Client.UserFunctionCallError, self.client.call_user_function, 'custom', {}) 

284 try: 

285 self.client.call_user_function('custom', {}) 

286 except Client.UserFunctionCallError as e: 

287 self.assertEqual(e.return_code, RETURN_CODE) 

288 self.assertEqual(e.reason, REASON) 

289 

290 

291class IntegrationTests(TestCase): 

292 def test_call_user_function(self): 

293 client = Client(ClientTests.SCRIPT) 

294 client.call_user_function('helloworld', {})