~azzar1/unity/add-show-desktop-key

« back to all changes in this revision

Viewing changes to www/apps/tutorialservice/test/TestFramework.py

  • Committer: mattgiuca
  • Date: 2008-01-24 23:57:26 UTC
  • Revision ID: svn-v3-trunk0:2b9c9e99-6f39-0410-b283-7f802c844ae2:trunk:294
Added application: tutorialservice. Will be used as the Ajax backend for
    tutorial (currently empty).
Moved tutorial/test to tutorialservice/test.
Reason: The testing framework will not be used by the tutorial HTML-side app.
    Only by the ajax backend.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# IVLE - Informatics Virtual Learning Environment
 
2
# Copyright (C) 2007-2008 The University of Melbourne
 
3
#
 
4
# This program is free software; you can redistribute it and/or modify
 
5
# it under the terms of the GNU General Public License as published by
 
6
# the Free Software Foundation; either version 2 of the License, or
 
7
# (at your option) any later version.
 
8
#
 
9
# This program is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# GNU General Public License for more details.
 
13
#
 
14
# You should have received a copy of the GNU General Public License
 
15
# along with this program; if not, write to the Free Software
 
16
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 
17
 
 
18
# Module: TestFramework
 
19
# Author: Dilshan Angampitiya
 
20
# Date:   24/1/2008
 
21
 
 
22
# Brief description of the Module# define custom exceptions
 
23
# use exceptions for all errors found in testing
 
24
 
 
25
import sys, StringIO, copy
 
26
 
 
27
# student error
 
28
class FunctionNotFoundError(Exception):
 
29
    """This error is returned when a function was expected in student
 
30
    code but was not found"""
 
31
    def __init__(self, function_name):
 
32
        self.function_name = function_name
 
33
 
 
34
    def __str__(self):
 
35
        return "Function " + self.function_name + " not found"
 
36
 
 
37
# author error
 
38
class TestCreationError(Exception):
 
39
    """An error occured while creating the test suite or one of its components"""
 
40
    def __init__(self, reason):
 
41
        self._reason = reason
 
42
        
 
43
    def __str__(self):
 
44
        return self._reason
 
45
 
 
46
# author error
 
47
class SolutionError(Exception):
 
48
    """Error in the provided solution"""
 
49
    def __init__(self, exc_info):
 
50
        cla, exc, trbk = exc_info
 
51
        self.name = cla.__name__
 
52
        self._detail = str(exc)
 
53
 
 
54
    def __str__(self):
 
55
        return "Error running solution: %s" %str(self._detail)
 
56
 
 
57
# author error
 
58
class TestError(Exception):
 
59
    """Runtime error in the testing framework outside of the provided or student code"""
 
60
    def __init__(self, exc_info):
 
61
        cla, exc, trbk = exc_info
 
62
        self.name = cla.__name__
 
63
        self._detail = str(exc)
 
64
 
 
65
    def __str__(self):
 
66
        return "Error testing solution against attempt: %s" %str(self._detail)
 
67
 
 
68
# student error
 
69
class AttemptError(Exception):
 
70
    """Runtime error in the student code"""
 
71
    def __init__(self, exc_info):
 
72
        cla, exc, trbk = exc_info
 
73
        self._name = cla.__name__
 
74
        self._detail = str(exc)
 
75
 
 
76
    def is_critical(self):
 
77
        if (    self._name == 'FunctionNotFoundError'
 
78
            or  self._name == 'SyntaxError'
 
79
            or  self._name == 'IndentationError'):
 
80
            return True
 
81
        else:
 
82
            return False
 
83
 
 
84
    def to_dict(self):
 
85
        return {'name': self._name,
 
86
                'detail': self._detail,
 
87
                'critical': self.is_critical()
 
88
                }
 
89
 
 
90
    def __str__(self):
 
91
        return self._name + " - " + str(self._detail)
 
92
 
 
93
class TestCasePart:
 
94
    """
 
95
    A part of a test case which compares a subset of the input files or file streams.
 
96
    This can be done either with a comparision function, or by comparing directly, after
 
97
    applying normalisations.
 
98
    """
 
99
    # how to make this work? atm they seem to get passed the class as a first arg
 
100
    ident =lambda x: x
 
101
    ignore = lambda x: None
 
102
    match = lambda x,y: x==y
 
103
    always_match = lambda x,y: True
 
104
    true = lambda *x: True
 
105
    false = lambda *x: False
 
106
 
 
107
    def __init__(self, desc, default='match'):
 
108
        """Initialise with a description and a default behavior for output
 
109
        If default is match, unspecified files are matched exactly
 
110
        If default is ignore, unspecified files are ignored
 
111
        The default default is match.
 
112
        """
 
113
        self._desc = desc
 
114
        self._default = default
 
115
        if default == 'ignore':
 
116
            self._default_func = lambda *x: True
 
117
        else:
 
118
            self._default_func = lambda x,y: x==y
 
119
 
 
120
        self._file_tests = {}
 
121
        self._stdout_test = ('check', self._default_func)
 
122
        self._stderr_test = ('check', self._default_func)
 
123
        self._result_test = ('check', self._default_func)
 
124
 
 
125
    def get_description(self):
 
126
        "Getter for description"
 
127
        return self._desc
 
128
 
 
129
    def _set_default_function(self, function, test_type):
 
130
        """"Ensure test type is valid and set function to a default
 
131
        if not specified"""
 
132
        
 
133
        if test_type not in ['norm', 'check']:
 
134
            raise TestCreationError("Invalid test type in %s" %self._desc)
 
135
        
 
136
        if function == '':
 
137
            if test_type == 'norm': function = lambda x: x
 
138
            else: function = lambda x,y: x==y
 
139
 
 
140
        return function
 
141
 
 
142
    def _validate_function(self, function, included_code):
 
143
        """Create a function object from the given string.
 
144
        If a valid function object cannot be created, raise and error.
 
145
        """
 
146
        if not callable(function):
 
147
            try:
 
148
                exec "__f__ = %s" %function in included_code
 
149
            except:
 
150
                raise TestCreationError("Invalid function %s" %function)
 
151
 
 
152
            f = included_code['__f__']
 
153
 
 
154
            if not callable(f):
 
155
                raise TestCreationError("Invalid function %s" %function)    
 
156
        else:
 
157
            f = function
 
158
 
 
159
        return f
 
160
 
 
161
    def validate_functions(self, included_code):
 
162
        """Ensure all functions used by the test cases exist and are callable.
 
163
        Also covert their string representations to function objects.
 
164
        This can only be done once all the include code has been specified.
 
165
        """
 
166
        (test_type, function) = self._stdout_test
 
167
        self._stdout_test = (test_type, self._validate_function(function, included_code))
 
168
        
 
169
        (test_type, function) = self._stderr_test
 
170
        self._stderr_test = (test_type, self._validate_function(function, included_code))
 
171
 
 
172
        for filename, (test_type, function) in self._file_tests.items():
 
173
            self._file_tests[filename] = (test_type, self._validate_function(function, included_code))
 
174
            
 
175
    def add_result_test(self, function, test_type='norm'):
 
176
        "Test part that compares function return values"
 
177
        function = self._set_default_function(function, test_type)
 
178
        self._result_test = (test_type, function)
 
179
 
 
180
            
 
181
    def add_stdout_test(self, function, test_type='norm'):
 
182
        "Test part that compares stdout"
 
183
        function = self._set_default_function(function, test_type)
 
184
        self._stdout_test = (test_type, function)
 
185
        
 
186
 
 
187
    def add_stderr_test(self, function, test_type='norm'):
 
188
        "Test part that compares stderr"
 
189
        function = self._set_default_function(function, test_type)
 
190
        self._stderr_test = (test_type, function)
 
191
 
 
192
    def add_file_test(self, filename, function, test_type='norm'):
 
193
        "Test part that compares the contents of a specified file"
 
194
        function = self._set_default_function(function, test_type)
 
195
        self._file_tests[filename] = (test_type, function)
 
196
 
 
197
    def _check_output(self, solution_output, attempt_output, test_type, f):
 
198
        """Compare solution output and attempt output using the
 
199
        specified comparision function.
 
200
        """
 
201
        # converts unicode to string
 
202
        if type(solution_output) == unicode:    
 
203
            solution_output = str(solution_output)
 
204
        if type(attempt_output) == unicode:
 
205
            attempt_output = str(attempt_output)
 
206
            
 
207
        if test_type == 'norm':
 
208
            return f(solution_output) == f(attempt_output)
 
209
        else:
 
210
            return f(solution_output, attempt_output)
 
211
 
 
212
    def run(self, solution_data, attempt_data):
 
213
        """Run the tests to compare the solution and attempt data
 
214
        Returns the empty string is the test passes, or else an error message.
 
215
        """
 
216
 
 
217
        # check function return value (None for scripts)
 
218
        (test_type, f) = self._result_test
 
219
        if not self._check_output(solution_data['result'], attempt_data['result'], test_type, f):       
 
220
            return 'function return value does not match'
 
221
 
 
222
        # check stdout
 
223
        (test_type, f) = self._stdout_test
 
224
        if not self._check_output(solution_data['stdout'], attempt_data['stdout'], test_type, f):       
 
225
            return 'stdout does not match'
 
226
 
 
227
        #check stderr
 
228
        (test_type, f) = self._stderr_test
 
229
        if not self._check_output(solution_data['stderr'], attempt_data['stderr'], test_type, f):        
 
230
            return 'stderr does not match'
 
231
 
 
232
 
 
233
        solution_files = solution_data['modified_files']
 
234
        attempt_files = attempt_data['modified_files']
 
235
 
 
236
        # check files indicated by test
 
237
        for (filename, (test_type, f)) in self._file_tests.items():
 
238
            if filename not in solution_files:
 
239
                raise SolutionError('File %s not found' %filename)
 
240
            elif filename not in attempt_files:
 
241
                return filename + ' not found'
 
242
            elif not self._check_output(solution_files[filename], attempt_files[filename], test_type, f):
 
243
                return filename + ' does not match'
 
244
 
 
245
        if self._default == 'ignore':
 
246
            return ''
 
247
 
 
248
        # check files found in solution, but not indicated by test
 
249
        for filename in [f for f in solution_files if f not in self._file_tests]:
 
250
            if filename not in attempt_files:
 
251
                return filename + ' not found'
 
252
            elif not self._check_output(solution_files[filename], attempt_files[filename], 'match', lambda x,y: x==y):
 
253
                return filename + ' does not match'
 
254
 
 
255
        # check if attempt has any extra files
 
256
        for filename in [f for f in attempt_files if f not in solution_files]:
 
257
            return "Unexpected file found: " + filename
 
258
 
 
259
        # Everything passed with no problems
 
260
        return ''
 
261
        
 
262
class TestCase:
 
263
    """
 
264
    A set of tests with a common inputs
 
265
    """
 
266
    def __init__(self, name='', function=None, stdin='', filespace=None, global_space=None):
 
267
        """Initialise with name and optionally, a function to test (instead of the entire script)
 
268
        The inputs stdin, the filespace and global variables can also be specified at
 
269
        initialisation, but may also be set later.
 
270
        """
 
271
        if global_space == None:
 
272
            global_space = {}
 
273
        if filespace == None:
 
274
            filespace = {}
 
275
        
 
276
        self._name = name
 
277
        
 
278
        if function == '': function = None
 
279
        self._function = function
 
280
        self._list_args = []
 
281
        self._keyword_args = {}
 
282
        
 
283
        # stdin must have a newline at the end for raw_input to work properly
 
284
        if stdin[-1:] != '\n': stdin += '\n'
 
285
        
 
286
        self._stdin = stdin
 
287
        self._filespace = TestFilespace(filespace)
 
288
        self._global_space = global_space
 
289
        self._parts = []
 
290
 
 
291
    def set_stdin(self, stdin):
 
292
        """ Set the given string as the stdin for this test case"""
 
293
        self._stdin = stdin
 
294
 
 
295
    def add_file(self, filename, data):
 
296
        """ Insert the given filename-data pair into the filespace for this test case"""
 
297
        self._filespace.add_file(filename, data)
 
298
        
 
299
    def add_variable(self, variable, value):
 
300
        """ Add the given varibale-value pair to the initial global environment
 
301
        for this test case.
 
302
        Throw and exception if thevalue cannot be paresed.
 
303
        """
 
304
        
 
305
        try:
 
306
            self._global_space[variable] = eval(value)
 
307
        except:
 
308
            raise TestCreationError("Invalid value for variable %s: %s" %(variable, value))
 
309
 
 
310
    def add_arg(self, value, name=None):
 
311
        """ Add a value to the argument list. This only applies when testing functions.
 
312
        By default arguments are not named, but if they are, they become keyword arguments.
 
313
        """
 
314
        try:
 
315
            if name == None or name == '':
 
316
                self._list_args.append(eval(value))
 
317
            else:
 
318
                self._keyword_args[name] = value
 
319
        except:
 
320
            raise TestCreationError("Invalid value for function argument: %s" %value)
 
321
        
 
322
    def add_part(self, test_part):
 
323
        """ Add a TestPart to this test case"""
 
324
        self._parts.append(test_part)
 
325
 
 
326
    def validate_functions(self, included_code):
 
327
        """ Validate all the functions in each part in this test case
 
328
        This can only be done once all the include code has been specified.
 
329
        """
 
330
        for part in self._parts:
 
331
            part.validate_functions(included_code)
 
332
 
 
333
    def get_name(self):
 
334
        """ Get the name of the test case """
 
335
        return self._name
 
336
 
 
337
    def run(self, solution, attempt_file):
 
338
        """ Run the solution and the attempt with the inputs specified for this test case.
 
339
        Then pass the outputs to each test part and collate the results.
 
340
        """
 
341
        case_dict = {}
 
342
        case_dict['name'] = self._name
 
343
        
 
344
        # Run solution
 
345
        try:
 
346
            global_space_copy = copy.deepcopy(self._global_space)
 
347
            solution_data = self._execstring(solution, global_space_copy)
 
348
            
 
349
            # if we are just testing a function
 
350
            if not self._function == None:
 
351
                if self._function not in global_space_copy:
 
352
                    raise FunctionNotFoundError(self._function)
 
353
                solution_data = self._run_function(lambda: global_space_copy[self._function](*self._list_args, **self._keyword_args))
 
354
                
 
355
        except:
 
356
            raise SolutionError(sys.exc_info())
 
357
 
 
358
        # Run student attempt
 
359
        try:
 
360
            global_space_copy = copy.deepcopy(self._global_space)
 
361
            attempt_data = self._execfile(attempt_file, global_space_copy)
 
362
            
 
363
            # if we are just testing a function
 
364
            if not self._function == None:
 
365
                if self._function not in global_space_copy:
 
366
                    raise FunctionNotFoundError(self._function)
 
367
                attempt_data = self._run_function(lambda: global_space_copy[self._function](*self._list_args, **self._keyword_args))
 
368
        except:
 
369
            case_dict['exception'] = AttemptError(sys.exc_info()).to_dict()
 
370
            return case_dict
 
371
        
 
372
        results = []
 
373
 
 
374
        # generate results
 
375
        for test_part in self._parts:
 
376
            result = test_part.run(solution_data, attempt_data)
 
377
            result_dict = {}
 
378
            result_dict['description'] = test_part.get_description()
 
379
            result_dict['passed']  = (result == '')
 
380
            if result_dict['passed'] == False:
 
381
                result_dict['error_message'] = result
 
382
                
 
383
            results.append(result_dict)
 
384
 
 
385
        case_dict['parts'] = results
 
386
 
 
387
        return case_dict
 
388
                
 
389
    def _execfile(self, filename, global_space):
 
390
        """ Execute the file given by 'filename' in global_space, and return the outputs. """
 
391
        self._initialise_global_space(global_space)
 
392
        data = self._run_function(lambda: execfile(filename, global_space))
 
393
        return data
 
394
 
 
395
    def _execstring(self, string, global_space):
 
396
        """ Execute the given string in global_space, and return the outputs. """
 
397
        self._initialise_global_space(global_space)
 
398
        # _run_function handles tuples in a special way
 
399
        data = self._run_function((string, global_space))
 
400
        return data
 
401
 
 
402
    def _initialise_global_space(self, global_space):
 
403
        """ Modify the provided global_space so that file, open and raw_input are redefined
 
404
        to use our methods instead.
 
405
        """
 
406
        self._current_filespace_copy = self._filespace.copy()
 
407
        global_space['file'] = lambda filename, mode='r', bufsize=-1: self._current_filespace_copy.openfile(filename, mode)
 
408
        global_space['open'] = global_space['file']
 
409
        global_space['raw_input'] = lambda x=None: raw_input()
 
410
        return global_space
 
411
 
 
412
    def _run_function(self, function):
 
413
        """ Run the provided function with the provided stdin, capturing stdout and stderr
 
414
        and the return value.
 
415
        Return all the output data.
 
416
        """
 
417
        import sys, StringIO
 
418
        sys_stdout, sys_stdin, sys_stderr = sys.stdout, sys.stdin, sys.stderr
 
419
 
 
420
        output_stream, input_stream, error_stream = StringIO.StringIO(), StringIO.StringIO(self._stdin), StringIO.StringIO()
 
421
        sys.stdout, sys.stdin, sys.stderr = output_stream, input_stream, error_stream
 
422
 
 
423
        try:
 
424
            if type(function) == tuple:
 
425
                # very hackish... exec can't be put into a lambda function!
 
426
                # or even with eval
 
427
                exec(function[0], function[1])
 
428
                result = None
 
429
            else:
 
430
                result = function()
 
431
        except:
 
432
            sys.stdout, sys.stdin, sys.stderr = sys_stdout, sys_stdin, sys_stderr
 
433
            raise
 
434
        
 
435
        sys.stdout, sys.stdin, sys.stderr = sys_stdout, sys_stdin, sys_stderr
 
436
 
 
437
        self._current_filespace_copy.flush_all()
 
438
            
 
439
        return {'result': result,
 
440
                'stdout': output_stream.getvalue(),
 
441
                'stderr': output_stream.getvalue(),
 
442
                'modified_files': self._current_filespace_copy.get_modified_files()}
 
443
 
 
444
class TestSuite:
 
445
    """
 
446
    The complete collection of test cases for a given problem
 
447
    """
 
448
    def __init__(self, name, solution=None):
 
449
        """Initialise with the name of the test suite (the problem name) and the solution.
 
450
        The solution may be specified later.
 
451
        """
 
452
        self._solution = solution
 
453
        self._name = name
 
454
        self._tests = []
 
455
        self.add_include_code("")
 
456
 
 
457
    def add_solution(self, solution):
 
458
        " Specifiy the solution script for this problem "
 
459
        self._solution = solution
 
460
 
 
461
    def has_solution(self):
 
462
        " Returns true if a soltion has been provided "
 
463
        return self._solution != None
 
464
 
 
465
    def add_include_code(self, include_code = ''):
 
466
        """ Add include code that may be used by the test cases during
 
467
        comparison of outputs.
 
468
        """
 
469
        
 
470
        # if empty, make sure it can still be executed
 
471
        if include_code == "":
 
472
            include_code = "pass"
 
473
        self._include_code = str(include_code)
 
474
        
 
475
        include_space = {}
 
476
        try:
 
477
            exec self._include_code in include_space
 
478
        except:
 
479
            raise TestCreationError("Bad include code")
 
480
 
 
481
        self._include_space = include_space
 
482
    
 
483
    def add_case(self, test_case):
 
484
        """ Add a TestCase, then validate all functions inside test case
 
485
        now that the include code is known
 
486
        """
 
487
        self._tests.append(test_case)
 
488
        test_case.validate_functions(self._include_space)
 
489
 
 
490
    def run_tests(self, attempt_file):
 
491
        " Run all test cases and collate the results "
 
492
        
 
493
        problem_dict = {}
 
494
        problem_dict['name'] = self._name
 
495
        
 
496
        test_case_results = []
 
497
        for test in self._tests:
 
498
            result_dict = test.run(self._solution, attempt_file)
 
499
            if 'exception' in result_dict and result_dict['exception']['critical']:
 
500
                # critical error occured, running more cases is useless
 
501
                # FunctionNotFound, Syntax, Indentation
 
502
                problem_dict['critical_error'] = result_dict['exception']
 
503
                return problem_dict
 
504
            
 
505
            test_case_results.append(result_dict)
 
506
 
 
507
        problem_dict['cases'] = test_case_results
 
508
        return problem_dict
 
509
 
 
510
    def get_name(self):
 
511
        return self._name
 
512
 
 
513
class TestFilespace:
 
514
    """
 
515
    Our dummy file system which is accessed by code being tested.
 
516
    Implemented as a dictionary which maps filenames to strings
 
517
    """
 
518
    def __init__(self, files=None):
 
519
        "Initialise, optionally with filename-filedata pairs"
 
520
 
 
521
        if files == None:
 
522
            files = {}
 
523
 
 
524
        # dict mapping files to strings
 
525
        self._files = {}
 
526
        self._files.update(files)
 
527
        # set of file names
 
528
        self._modified_files = set([])
 
529
        # dict mapping files to stringIO objects
 
530
        self._open_files = {}
 
531
 
 
532
    def add_file(self, filename, data):
 
533
        " Add a file to the filespace "
 
534
        self._files[filename] = data
 
535
 
 
536
    def openfile(self, filename, mode='r'):
 
537
        """ Open a file from the filespace with the given mode.
 
538
        Return a StringIO subclass object with the file contents.
 
539
        """
 
540
        import re
 
541
 
 
542
        if filename in self._open_files:
 
543
            raise IOError("File already open: %s" %filename)
 
544
 
 
545
        if not re.compile("[rwa][+b]{0,2}").match(mode):
 
546
            raise IOError("invalid mode %s" %mode)
 
547
        
 
548
        ## TODO: validate filename?
 
549
        
 
550
        mode.replace("b",'')
 
551
        
 
552
        # initialise the file properly (truncate/create if required)
 
553
        if mode[0] == 'w':
 
554
            self._files[filename] = ''
 
555
            self._modified_files.add(filename)
 
556
        elif filename not in self._files:
 
557
            if mode[0] == 'a':
 
558
                self._files[filename] = ''
 
559
                self._modified_files.add(filename)
 
560
            else:
 
561
                raise IOError(2, "Access to file denied: %s" %filename)
 
562
 
 
563
        # for append mode, remember the existing data
 
564
        if mode[0] == 'a':
 
565
            existing_data = self._files[filename]
 
566
        else:
 
567
            existing_data = ""
 
568
 
 
569
        # determine what operations are allowed
 
570
        reading_ok = (len(mode) == 2 or mode[0] == 'r')
 
571
        writing_ok = (len(mode) == 2 or mode[0] in 'wa')
 
572
 
 
573
        # for all writing modes, start off with blank file
 
574
        if mode[0] == 'w':
 
575
            initial_data = ''
 
576
        else:
 
577
            initial_data = self._files[filename]
 
578
 
 
579
        file_object = TestStringIO(initial_data, filename, self, reading_ok, writing_ok, existing_data)
 
580
        self._open_files[filename] = file_object
 
581
        
 
582
        return file_object
 
583
 
 
584
    def flush_all(self):
 
585
        """ Flush all open files
 
586
        """
 
587
        for file_object in self._open_files.values():
 
588
            file_object.flush()
 
589
 
 
590
    def updatefile(self,filename, data):
 
591
        """ Callback function used by an open file to inform when it has been updated.
 
592
        """
 
593
        if filename in self._open_files:
 
594
            self._files[filename] = data
 
595
            if self._open_files[filename].is_modified():
 
596
                self._modified_files.add(filename)
 
597
        else:
 
598
            raise IOError(2, "Access to file denied: %s" %filename)
 
599
 
 
600
    def closefile(self, filename):
 
601
        """ Callback function used by an open file to inform when it has been closed.
 
602
        """
 
603
        if filename in self._open_files:
 
604
            del self._open_files[filename]
 
605
 
 
606
    def get_modified_files(self):
 
607
        """" A subset of the filespace containing only those files which have been
 
608
        modified
 
609
        """
 
610
        modified_files = {}
 
611
        for filename in self._modified_files:
 
612
            modified_files[filename] = self._files[filename]
 
613
 
 
614
        return modified_files
 
615
 
 
616
    def get_open_files(self):
 
617
        " Return the names of all open files "
 
618
        return self._open_files.keys()
 
619
            
 
620
    def copy(self):
 
621
        """ Return a copy of the current filespace.
 
622
        Only the files are copied, not the modified or open file lists.
 
623
        """
 
624
        self.flush_all()
 
625
        return TestFilespace(self._files)
 
626
 
 
627
class TestStringIO(StringIO.StringIO):
 
628
    """
 
629
    A subclass of StringIO which acts as a file in our dummy file system
 
630
    """
 
631
    def __init__(self, string, filename, filespace, reading_ok, writing_ok, existing_data):
 
632
        """ Initialise with the filedata, file name and infomation on what ops are
 
633
        acceptable """
 
634
        StringIO.StringIO.__init__(self, string)
 
635
        self._filename = filename
 
636
        self._filespace = filespace
 
637
        self._reading_ok = reading_ok
 
638
        self._writing_ok = writing_ok
 
639
        self._existing_data = existing_data
 
640
        self._modified = False
 
641
        self._open = True
 
642
 
 
643
    # Override all standard file ops. Make sure that they are valid with the given
 
644
    # permissions and if so then call the corresponding method in StringIO
 
645
    
 
646
    def read(self, *args):
 
647
        if not self._reading_ok:
 
648
            raise IOError(9, "Bad file descriptor")
 
649
        else:
 
650
            return StringIO.StringIO.read(self, *args)
 
651
 
 
652
    def readline(self, *args):
 
653
        if not self._reading_ok:
 
654
            raise IOError(9, "Bad file descriptor")
 
655
        else:
 
656
            return StringIO.StringIO.readline(self, *args)
 
657
 
 
658
    def readlines(self, *args):
 
659
        if not self._reading_ok:
 
660
            raise IOError(9, "Bad file descriptor")
 
661
        else:
 
662
            return StringIO.StringIO.readlines(self, *args)
 
663
 
 
664
    def seek(self, *args):
 
665
        if not self._reading_ok:
 
666
            raise IOError(9, "Bad file descriptor")
 
667
        else:
 
668
            return StringIO.StringIO.seek(self, *args)
 
669
 
 
670
    def truncate(self, *args):
 
671
        self._modified = True
 
672
        if not self._writing_ok:
 
673
            raise IOError(9, "Bad file descriptor")
 
674
        else:
 
675
            return StringIO.StringIO.truncate(self, *args)
 
676
        
 
677
    def write(self, *args):
 
678
        self._modified = True
 
679
        if not self._writing_ok:
 
680
            raise IOError(9, "Bad file descriptor")
 
681
        else:
 
682
            return StringIO.StringIO.write(self, *args)
 
683
 
 
684
    def writelines(self, *args):
 
685
        self._modified = True
 
686
        if not self._writing_ok:
 
687
            raise IOError(9, "Bad file descriptor")
 
688
        else:
 
689
            return StringIO.StringIO.writelines(self, *args)
 
690
 
 
691
    def is_modified(self):
 
692
        " Return true if the file has been written to, or truncated"
 
693
        return self._modified
 
694
        
 
695
    def flush(self):
 
696
        " Update the contents of the filespace with the new data "
 
697
        self._filespace.updatefile(self._filename, self._existing_data+self.getvalue())
 
698
        return StringIO.StringIO.flush(self)
 
699
 
 
700
    def close(self):
 
701
        " Flush the file and close it "
 
702
        self.flush()
 
703
        self._filespace.closefile(self._filename)
 
704
        return StringIO.StringIO.close(self)
 
705
 
 
706
##def get_function(filename, function_name):
 
707
##      import compiler
 
708
##      mod = compiler.parseFile(filename)
 
709
##      for node in mod.node.nodes:
 
710
##              if isinstance(node, compiler.ast.Function) and node.name == function_name:
 
711
##                      return node
 
712
##