Coverage for CIResults/sandbox/io.py: 100%

134 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 13:11 +0000

1import os 

2import sys 

3import json 

4import array 

5import struct 

6import subprocess 

7from threading import RLock 

8from subprocess import PIPE 

9from types import ModuleType 

10 

11# Make sure the sandbox module is available 

12sys.path.append(os.path.dirname(os.path.normpath(__file__))) 

13from lockdown import LockDown # noqa 

14 

15 

16class IOWrapper: 

17 # Commands 

18 EXEC_USER_SCRIPT = "exec_user_script" 

19 CALL_USER_FUNCTION = "call_user_function" 

20 

21 def __init__(self, stream_in=None, stream_out=None): 

22 self.stream_in = stream_in if stream_in else sys.stdin.buffer 

23 self.stream_out = stream_out if stream_out else sys.stdout.buffer 

24 

25 def send(self, data): 

26 if not isinstance(data, bytes): 

27 data = str(data).encode() 

28 length = len(data) 

29 header = array.array('b', b'\0\0\0\0') # 32 bit length 

30 struct.pack_into(">L", header, 0, length) 

31 self.stream_out.write(header.tobytes() + data) 

32 self.stream_out.flush() 

33 

34 def read(self): 

35 try: 

36 header = self.stream_in.read(4) 

37 length = struct.unpack(">L", header)[0] 

38 content = self.stream_in.read(length) 

39 

40 if len(content) != length: 

41 raise IOError("The message read is shorter than expected") 

42 

43 return content.decode() 

44 except struct.error: 

45 raise IOError("Invalid message format") 

46 

47 

48class Server: 

49 def __init__(self, stream_in=None, stream_out=None, lockdown=True): 

50 self.iowrapper = IOWrapper(stream_in=stream_in, stream_out=stream_out) 

51 self.usr_module = None 

52 

53 # NOTE: The code is covered, but since we have to run it in a separate 

54 # process, the coverage does not pick it up 

55 if lockdown: 

56 from seccomplite import ALLOW, EQ, Arg # pragma: no cover 

57 

58 ld = LockDown() # pragma: no cover 

59 ld.add_rule(ALLOW, "read", Arg(0, EQ, self.iowrapper.stream_in.fileno())) # pragma: no cover 

60 ld.add_rule(ALLOW, "write", Arg(0, EQ, self.iowrapper.stream_out.fileno())) # pragma: no cover 

61 ld.add_rule(ALLOW, "write", Arg(0, EQ, sys.stderr.fileno())) # pragma: no cover 

62 ld.start() # pragma: no cover 

63 

64 def serve_request(self): 

65 rpc_call = json.loads(self.iowrapper.read()) 

66 try: 

67 # NOTE: All the callabable methods are prefixed with rpc__ 

68 func = getattr(self, "rpc__" + rpc_call['method']) 

69 ret = func(rpc_call['kwargs'], rpc_call['version']) 

70 self.iowrapper.send(json.dumps(ret)) 

71 except AttributeError: 

72 self.iowrapper.send(json.dumps({'return_code': 404})) 

73 

74 def serve_forever(self): 

75 while True: # pragma: no cover 

76 self.serve_request() # pragma: no cover 

77 

78 # ----- RPC methods (need to be prefixed with 'rpc__') ----- 

79 

80 def rpc__exec_user_script(self, kwargs, version): 

81 # remove the current user module to allow the users to send new scripts 

82 # without restarting the service. 

83 del self.usr_module 

84 

85 user_script = kwargs.get('script') 

86 if user_script is None: 

87 return {'return_code': 400, 'reason': "The 'script' parameter is missing"} 

88 

89 # Create a dynamic module 

90 self.usr_module = ModuleType('usr_script') 

91 try: 

92 code = compile(user_script, "<user script>", 'exec') 

93 exec(code, self.usr_module.__dict__) 

94 except Exception as e: 

95 return {'return_code': 400, 'reason': 'Compilation error: {}'.format(e)} 

96 

97 return {'return_code': 200} 

98 

99 def rpc__call_user_function(self, kwargs, version): 

100 if 'user_function' not in kwargs or 'user_kwargs' not in kwargs: 

101 return {'return_code': 400, 

102 'reason': "Both the 'user_function' and 'user_kwargs' parameters are needed"} 

103 

104 try: 

105 func = getattr(self.usr_module, kwargs['user_function']) 

106 except AttributeError: 

107 return {'return_code': 404, 'reason': 'The user function does not exist'} 

108 

109 try: 

110 ret = func(**kwargs['user_kwargs']) 

111 return {'return_code': 200, 'return_value': ret} 

112 except Exception as e: 

113 return {'return_code': 400, 'reason': str(e)} 

114 

115 

116class Client: 

117 _instance_cache = dict() 

118 

119 class UserFunctionCallError(Exception): 

120 def __init__(self, return_code, reason): 

121 self.return_code = return_code 

122 self.reason = reason 

123 

124 def __str__(self): 

125 return "UserFunctionCallError {}: {}".format(self.return_code, self.reason) 

126 

127 @classmethod 

128 def get_or_create_instance(cls, script): 

129 instance = cls._instance_cache.get(script) 

130 if instance is None: 

131 instance = cls._instance_cache[script] = cls(script) 

132 return instance 

133 

134 @property 

135 def interpreter(self): 

136 """ 

137 Intuitively, 'sys.executable' would be used for this purpose, but when using 

138 uwsgi, it clobbers the actual interpreter pointed to by sys.executable, and 

139 replaces it with the path to the uwsgi executable...because that makes sense. 

140 

141 One workaround would be to use a shebang, but that won't work, as the env is wiped 

142 when starting the server process. So the PATH variable won't reflect the updated 

143 PATH when using a venv (either locally or in the Docker container). 

144 

145 `which python3` works, as it will return the path to the interpreter based on the 

146 PATH variable. This works whether using a venv locally or in the docker container, 

147 as it is executed separately in the current env. 

148 

149 NOTE: This is not 100% foolproof. If the user is not using the same interpreter 

150 for execution as is pointed to in their PATH var, then this could call a different 

151 interpreter. That would be bad practice, but it could happen. 

152 

153 If following the standard setup guide, or using Docker containers, this 

154 works perfectly fine though. There could be room for improvement here, as uwsgi 

155 maintainers are well aware of the 'sys.executable' issue and don't seem to care. 

156 """ 

157 if hasattr(self, '_interpreter'): 

158 return self._interpreter 

159 

160 interp = subprocess.check_output(["which", "python3"]) 

161 self._interpreter = interp.decode().strip() 

162 return self._interpreter 

163 

164 def shutdown(self): 

165 if hasattr(self, "sesh"): 

166 self.sesh.kill() 

167 

168 try: 

169 del self._instance_cache[self.usr_script] 

170 except KeyError: 

171 pass # Already garbage collected 

172 

173 def _restart_server(self): 

174 self.shutdown() 

175 

176 self.sesh = subprocess.Popen([self.interpreter, __file__], stdin=PIPE, stdout=PIPE, 

177 universal_newlines=False, env={}) 

178 self.iowrapper = IOWrapper(stream_in=self.sesh.stdout, stream_out=self.sesh.stdin) 

179 

180 ret = self.rpc_call(self.iowrapper.EXEC_USER_SCRIPT, {'script': self.usr_script}, retry=False) 

181 if ret.get('return_code') != 200: 

182 raise ValueError('Error {} - Cannot set the user script: {}'.format(ret.get('return_code'), 

183 ret.get('reason'))) 

184 

185 def __init__(self, usr_script): 

186 self.usr_script = usr_script 

187 

188 self.lock = RLock() 

189 self._restart_server() 

190 

191 def __del__(self): 

192 self.shutdown() 

193 

194 def rpc_call(self, method, kwargs=None, version=1, retry=True): 

195 if kwargs is None: 

196 kwargs = {} 

197 

198 # Make sure we do not have multiple RPC calls at the same time 

199 with self.lock: 

200 # Try up to 3 times to make the call, and restart the server after each fail 

201 for i in range(3): 

202 try: 

203 self.iowrapper.send(json.dumps({'method': method, 'kwargs': kwargs, 'version': version})) 

204 return json.loads(self.iowrapper.read()) 

205 except Exception: 

206 pass 

207 

208 # restart the server, since we got an unexpected output 

209 if retry: 

210 self._restart_server() 

211 else: 

212 break 

213 

214 raise IOError("Failed to make an RPC call: method='{}', kwargs={}. version={}".format(method, 

215 kwargs, 

216 version)) 

217 

218 def call_user_function(self, func_name, kwargs): 

219 ret = self.rpc_call(self.iowrapper.CALL_USER_FUNCTION, {"user_function": func_name, "user_kwargs": kwargs}) 

220 if ret.get('return_code') == 200: 

221 return ret.get('return_value') 

222 else: 

223 raise Client.UserFunctionCallError(ret.get('return_code'), ret.get('reason')) 

224 

225 

226# If the script is run directly, just start a server 

227if __name__ == "__main__": 

228 # This code is tested by test_sandbox.IntegrationTests, but coverage cannot access line coverage there 

229 __io_s = Server(lockdown=True) # pragma: no cover 

230 __io_s.serve_forever() # pragma: no cover