Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set fork as the starting method of a sub-process #215

Merged
merged 12 commits into from
Jan 7, 2024
37 changes: 18 additions & 19 deletions RLTest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
import shlex
import json
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, set_start_method

from RLTest.env import Env, TestAssertionFailure, Defaults
from RLTest.utils import Colors, fix_modules, fix_modulesArgs
Expand Down Expand Up @@ -399,7 +399,7 @@ def __exit__(self, exc_type, exc_value, traceback):
self.is_done = True
self.condition.notify(1)
self.condition.release()


class RLTest:
def __init__(self):
Expand Down Expand Up @@ -671,7 +671,7 @@ def handleFailure(self, testFullName=None, exception=None, prefix='', testname=N
else:
self.addFailure(testname, '<No exception or environment>')

def _runTest(self, test, numberOfAssertionFailed=0, prefix='', before=None, after=None):
def _runTest(self, test, numberOfAssertionFailed=0, prefix='', before=lambda x=None: None, after=lambda x=None: None):
test.initialize()

msgPrefix = test.name
Expand All @@ -696,17 +696,16 @@ def _runTest(self, test, numberOfAssertionFailed=0, prefix='', before=None, afte
return 0

fn = lambda: test.target(env)
before_func = (lambda: before(env)) if before is not None else None
after_func = (lambda: after(env)) if after is not None else None
before_func = lambda: before(env)
after_func = lambda: after(env)
else:
fn = test.target
before_func = before
after_func = after

hasException = False
try:
if before_func:
before_func()
before_func()
fn()
passed = True
except unittest.SkipTest:
Expand All @@ -721,16 +720,15 @@ def _runTest(self, test, numberOfAssertionFailed=0, prefix='', before=None, afte
except Exception as err:
if self.args.exit_on_failure:
self.takeEnvDown(fullShutDown=True)
after = None
after_func = lambda x=None: None
raise

self.handleFailure(testFullName=testFullName, exception=err, prefix=msgPrefix,
testname=test.name, env=self.currEnv)
hasException = True
passed = False
finally:
if after_func:
after_func()
after_func()

numFailed = 0
if self.currEnv:
Expand Down Expand Up @@ -771,7 +769,7 @@ def printPass(self, name):

def envScopeGuard(self):
return EnvScopeGuard(self)

def killEnvWithSegFault(self):
if self.currEnv and Defaults.print_verbose_information_on_failure:
try:
Expand All @@ -786,7 +784,7 @@ def killEnvWithSegFault(self):
print('Failed %s' % str(e))
else:
self.stopEnvWithSegFault()

def run_single_test(self, test, on_timeout_func):
done = 0
with self.envScopeGuard():
Expand All @@ -807,8 +805,8 @@ def run_single_test(self, test, on_timeout_func):
return 0

failures = 0
before = getattr(obj, 'setUp', None)
after = getattr(obj, 'tearDown', None)
before = getattr(obj, 'setUp', lambda x=None: None)
after = getattr(obj, 'tearDown', lambda x=None: None)
for subtest in test.get_functions(obj):
with TestTimeLimit(self.args.test_timeout, on_timeout_func):
failures += self._runTest(subtest, prefix='\t',
Expand All @@ -831,7 +829,7 @@ def run_single_test(self, test, on_timeout_func):
verboseInfo['after_dispose'] = lastEnv.getInformationAfterDispose()
lastEnv.debugPrint(json.dumps(verboseInfo, indent=2).replace('\\n', '\n'), force=True)
return done

def print_failures(self):
for group, failures in self.testsFailed.items():
print('\t' + Colors.Bold(group))
Expand All @@ -842,7 +840,7 @@ def print_failures(self):

def disable_progress_bar(self):
return self.args.no_output_catch or self.args.no_progress or not sys.stdout.isatty()

def progressbar(self, num_elements):
bar = None
if not self.disable_progress_bar():
Expand Down Expand Up @@ -914,12 +912,11 @@ def run_jobs(jobs, results, summary, port):
except Exception as e:
break


output = io.StringIO()
with redirect_stdout(output):
def on_timeout():
nonlocal done
try:
try:
done += 1
self.killEnvWithSegFault()
self.handleFailure(testFullName=test.name, testname=test.name, error_msg=Colors.Bred('Test timeout'))
Expand Down Expand Up @@ -1003,9 +1000,11 @@ def on_timeout():


def main():
# Aviod "UnicodeEncodeError: 'ascii' codec can't encode character" errors
# Avoid "UnicodeEncodeError: 'ascii' codec can't encode character" errors
sys.stdout = io.open(sys.stdout.fileno(), 'w', encoding='utf8')
sys.stderr = io.open(sys.stderr.fileno(), 'w', encoding='utf8')
# Set multiprocessing start method to fork, we have unserializable objects in the env
set_start_method('fork')
RLTest().execute()


Expand Down